[分类] 从零训练一个图像分类模型--TF2 Keras

从零训练一个图像分类模型--TF2 Keras

    • 前言
    • 数据准备
      • 过滤损坏的图像
    • 生成数据集
    • 可视化数据
    • 使用数据增强
    • 标准化数据
    • 预处理数据的两种方式
    • 为性能配置数据集
    • 构建模型
    • 训练模型
    • h5模型推断
      • 使用新数据推断
    • 保存为SavedModel并推断
    • pb模型推断
    • keras model 保存为 TFLite
    • savedModel 保存为 TFLite

前言

这个例子展示了如何从头开始进行图像分类,即从磁盘上的JPEG图像文件开始,而不需要利用预先训练过的权重或预先制作的Keras应用程序模型。演示了Kaggle猫与狗二元分类数据集上的工作流程。
使用image_dataset_from_directory实用程序生成数据集,并使用Keras图像预处理层进行图像标准化和数据增强。

数据准备

# 下载数据
curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip
# 解压数据
unzip -q kagglecatsanddogs_3367a.zip

在这里插入图片描述
或者直接点击这里下载数据集

import os
from tqdm import tqdm

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

过滤损坏的图像

在处理大量真实图像数据时,经常会出现损坏的图像。我们需要处理掉一些不合适的图像,这里我们需要过滤掉编码糟糕的图像,这些图像的头部不包含字符串“JFIF”。

num_skipped = 0
for folder_name in ("Cat", "Dog"):
    folder_path = os.path.join("PetImages", folder_name)
    for fname in tqdm(os.listdir(folder_path)):
        fpath = os.path.join(folder_path, fname)
        try:
            fobj = open(fpath, "rb")
            is_jfif = tf.compat.as_bytes("JFIF") in fobj.peek(10)
        finally:
            fobj.close()

        if not is_jfif:
            num_skipped += 1
            # Delete corrupted image
            os.remove(fpath)

print("Deleted %d images" % num_skipped)

生成数据集

image_size = (180, 180)
batch_size = 64

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "PetImages",
    validation_split=0.2,
    subset="training",
    seed=1337,
    image_size=image_size,
    batch_size=batch_size,
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "PetImages",
    validation_split=0.2,
    subset="validation",
    seed=1337,
    image_size=image_size,
    batch_size=batch_size,
)

在这里插入图片描述

可视化数据

这里展示训练集的前9张图像,dog的标签为1,cat的标签为0。

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(int(labels[i]))
        plt.axis("off")

[分类] 从零训练一个图像分类模型--TF2 Keras_第1张图片

使用数据增强

当您没有一个大的图像数据集时,通过对训练图像应用随机但真实的转换人工引入样本多样性是一个很好的做法,例如随机水平翻转或小的随机旋转。
这有助于将模型暴露在训练数据的不同方面,同时减缓过拟合。

data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(0.1),
    ]
)

可视化数据集一张图像通过data_augmentation之后的样子。

plt.figure(figsize=(10, 10))
for images, _ in train_ds.take(1):
    for i in range(9):
        augmented_images = data_augmentation(images)
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(augmented_images[0].numpy().astype("uint8"))
        plt.axis("off")

[分类] 从零训练一个图像分类模型--TF2 Keras_第2张图片

标准化数据

我们的图像已经是标准大小(180x180),然而,他们的RGB通道值是在[0, 255]范围。这对于神经网络来说并不理想;
一般来说,我们应该设法使我们的输入值较小。在这里,我们将通过在模型开始之前将输入预处理标准化到[0, 1]之间

预处理数据的两种方式

有两种方法可以使用data_augmentation预处理器:
方法1:让它成为模型的一部分,像这样:

inputs = keras.Input(shape=input_shape)
x = data_augmentation(inputs)
x = layers.experimental.preprocessing.Rescaling(1./255)(x)
...  # Rest of the model

有了这种方式,你的数据增强将在设备上发生,与模型执行的其余部分同步,这意味着它将受益于GPU加速。
请注意,在测试时数据增强是不活跃(inactive)的,所以输入样本只会在fit()期间增强,而不会在调用evaluate()或predict()时增强。
如果你在GPU上训练,这是更好的选择。

方法2:将其应用到数据集,从而获得一个生成批量扩增图像的数据集,如下所示:

augmented_train_ds = train_ds.map(
  lambda x, y: (data_augmentation(x, training=True), y))

使用这种方式,您的数据扩增将在CPU上异步进行,并在进入模型之前进行缓冲。
如果你在CPU上训练,这是更好的选择,因为它使数据增强异步和非阻塞。
在这个例子中,采用的是第一种方式。

为性能配置数据集

确保使用了缓冲预取(buffered prefetching),这样我们就可以在不阻塞I/O的情况下从磁盘生成数据:

train_ds = train_ds.prefetch(buffer_size=32)
val_ds = val_ds.prefetch(buffer_size=32)

构建模型

我们将构建一个小型版本的Xception网络。我们还没有特别尝试优化架构;
如果你想系统地搜索最佳的模型配置,参考Keras Tuner
注意:

  • 我们以data_augmentation预处理器开始模型,然后是一个缩放层。
  • 我们在最终的分类层之前包含一个dropout层。
def make_model(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)
    # Image augmentation block
    x = data_augmentation(inputs)

    # Entry block
    x = layers.experimental.preprocessing.Rescaling(1.0 / 255)(x)
    x = layers.Conv2D(32, 3, strides=2, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    x = layers.Conv2D(64, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    for size in [128, 256, 512, 728]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(size, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(size, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(size, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    x = layers.SeparableConv2D(1024, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    x = layers.GlobalAveragePooling2D()(x)
    if num_classes == 2:
        activation = "sigmoid"
        units = 1
    else:
        activation = "softmax"
        units = num_classes

    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(units, activation=activation)(x)
    return keras.Model(inputs, outputs)


model = make_model(input_shape=image_size + (3,), num_classes=2)
keras.utils.plot_model(model, show_shapes=True)

[分类] 从零训练一个图像分类模型--TF2 Keras_第3张图片

model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 180, 180, 3) 0                                            
__________________________________________________________________________________________________
sequential (Sequential)         (None, 180, 180, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
rescaling (Rescaling)           (None, 180, 180, 3)  0           sequential[0][0]                 
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 90, 90, 32)   896         rescaling[0][0]                  
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 90, 90, 32)   128         conv2d[0][0]                     
__________________________________________________________________________________________________
activation (Activation)         (None, 90, 90, 32)   0           batch_normalization[0][0]        
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 90, 90, 64)   18496       activation[0][0]                 
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 90, 90, 64)   256         conv2d_1[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 90, 90, 64)   0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 90, 90, 64)   0           activation_1[0][0]               
__________________________________________________________________________________________________
separable_conv2d (SeparableConv (None, 90, 90, 128)  8896        activation_2[0][0]               
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 90, 90, 128)  512         separable_conv2d[0][0]           
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 90, 90, 128)  0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
separable_conv2d_1 (SeparableCo (None, 90, 90, 128)  17664       activation_3[0][0]               
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 90, 90, 128)  512         separable_conv2d_1[0][0]         
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 45, 45, 128)  0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 45, 45, 128)  8320        activation_1[0][0]               
__________________________________________________________________________________________________
add (Add)                       (None, 45, 45, 128)  0           max_pooling2d[0][0]              
                                                                 conv2d_2[0][0]                   
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 45, 45, 128)  0           add[0][0]                        
__________________________________________________________________________________________________
separable_conv2d_2 (SeparableCo (None, 45, 45, 256)  34176       activation_4[0][0]               
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 45, 45, 256)  1024        separable_conv2d_2[0][0]         
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 45, 45, 256)  0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
separable_conv2d_3 (SeparableCo (None, 45, 45, 256)  68096       activation_5[0][0]               
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 45, 45, 256)  1024        separable_conv2d_3[0][0]         
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 23, 23, 256)  0           batch_normalization_5[0][0]      
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 23, 23, 256)  33024       add[0][0]                        
__________________________________________________________________________________________________
add_1 (Add)                     (None, 23, 23, 256)  0           max_pooling2d_1[0][0]            
                                                                 conv2d_3[0][0]                   
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 23, 23, 256)  0           add_1[0][0]                      
__________________________________________________________________________________________________
separable_conv2d_4 (SeparableCo (None, 23, 23, 512)  133888      activation_6[0][0]               
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 23, 23, 512)  2048        separable_conv2d_4[0][0]         
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 23, 23, 512)  0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
separable_conv2d_5 (SeparableCo (None, 23, 23, 512)  267264      activation_7[0][0]               
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 23, 23, 512)  2048        separable_conv2d_5[0][0]         
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 12, 12, 512)  0           batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 12, 12, 512)  131584      add_1[0][0]                      
__________________________________________________________________________________________________
add_2 (Add)                     (None, 12, 12, 512)  0           max_pooling2d_2[0][0]            
                                                                 conv2d_4[0][0]                   
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 12, 12, 512)  0           add_2[0][0]                      
__________________________________________________________________________________________________
separable_conv2d_6 (SeparableCo (None, 12, 12, 728)  378072      activation_8[0][0]               
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 12, 12, 728)  2912        separable_conv2d_6[0][0]         
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 12, 12, 728)  0           batch_normalization_8[0][0]      
__________________________________________________________________________________________________
separable_conv2d_7 (SeparableCo (None, 12, 12, 728)  537264      activation_9[0][0]               
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 12, 12, 728)  2912        separable_conv2d_7[0][0]         
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 6, 6, 728)    0           batch_normalization_9[0][0]      
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 6, 6, 728)    373464      add_2[0][0]                      
__________________________________________________________________________________________________
add_3 (Add)                     (None, 6, 6, 728)    0           max_pooling2d_3[0][0]            
                                                                 conv2d_5[0][0]                   
__________________________________________________________________________________________________
separable_conv2d_8 (SeparableCo (None, 6, 6, 1024)   753048      add_3[0][0]                      
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 6, 6, 1024)   4096        separable_conv2d_8[0][0]         
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 6, 6, 1024)   0           batch_normalization_10[0][0]     
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 1024)         0           activation_10[0][0]              
__________________________________________________________________________________________________
dropout (Dropout)               (None, 1024)         0           global_average_pooling2d[0][0]   
__________________________________________________________________________________________________
dense (Dense)                   (None, 1)            1025        dropout[0][0]                    
==================================================================================================
Total params: 2,782,649
Trainable params: 2,773,913
Non-trainable params: 8,736
__________________________________________________________________________________________________

训练模型

epochs = 50

callbacks = [
    keras.callbacks.ModelCheckpoint("save_at_{epoch}.h5"),
]
model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss="binary_crossentropy",
    metrics=["accuracy"],
)
model.fit(
    train_ds, epochs=epochs, callbacks=callbacks, validation_data=val_ds,
)
Epoch 1/50
586/586 [==============================] - 130s 206ms/step - loss: 0.7009 - accuracy: 0.5990 - val_loss: 0.7257 - val_accuracy: 0.5974
Epoch 2/50
586/586 [==============================] - 121s 206ms/step - loss: 0.5602 - accuracy: 0.7131 - val_loss: 0.4627 - val_accuracy: 0.7783
Epoch 3/50
586/586 [==============================] - 121s 206ms/step - loss: 0.4647 - accuracy: 0.7813 - val_loss: 0.4323 - val_accuracy: 0.7969
Epoch 4/50
586/586 [==============================] - 121s 206ms/step - loss: 0.3861 - accuracy: 0.8257 - val_loss: 0.5142 - val_accuracy: 0.7679
Epoch 5/50
586/586 [==============================] - 121s 206ms/step - loss: 0.3148 - accuracy: 0.8674 - val_loss: 0.4155 - val_accuracy: 0.8157
Epoch 6/50
586/586 [==============================] - 121s 205ms/step - loss: 0.2691 - accuracy: 0.8867 - val_loss: 0.2976 - val_accuracy: 0.8786
Epoch 7/50
586/586 [==============================] - 126s 213ms/step - loss: 0.2379 - accuracy: 0.9028 - val_loss: 0.6268 - val_accuracy: 0.7954
Epoch 8/50
586/586 [==============================] - 122s 206ms/step - loss: 0.2097 - accuracy: 0.9105 - val_loss: 0.2147 - val_accuracy: 0.9100
Epoch 9/50
586/586 [==============================] - 122s 206ms/step - loss: 0.1931 - accuracy: 0.9230 - val_loss: 0.2223 - val_accuracy: 0.9072
Epoch 10/50
586/586 [==============================] - 121s 206ms/step - loss: 0.1859 - accuracy: 0.9239 - val_loss: 0.2247 - val_accuracy: 0.9087
Epoch 11/50
586/586 [==============================] - 121s 206ms/step - loss: 0.1706 - accuracy: 0.9322 - val_loss: 0.1547 - val_accuracy: 0.9424
Epoch 12/50
586/586 [==============================] - 121s 206ms/step - loss: 0.1586 - accuracy: 0.9339 - val_loss: 0.1441 - val_accuracy: 0.9409
Epoch 13/50
586/586 [==============================] - 122s 206ms/step - loss: 0.1533 - accuracy: 0.9361 - val_loss: 0.1377 - val_accuracy: 0.9441
Epoch 14/50
586/586 [==============================] - 121s 206ms/step - loss: 0.1415 - accuracy: 0.9418 - val_loss: 0.1682 - val_accuracy: 0.9260
Epoch 15/50
586/586 [==============================] - 121s 206ms/step - loss: 0.1399 - accuracy: 0.9418 - val_loss: 0.1225 - val_accuracy: 0.9509
Epoch 16/50
586/586 [==============================] - 121s 205ms/step - loss: 0.1346 - accuracy: 0.9492 - val_loss: 0.1149 - val_accuracy: 0.9524
Epoch 17/50
586/586 [==============================] - 121s 206ms/step - loss: 0.1326 - accuracy: 0.9473 - val_loss: 0.2272 - val_accuracy: 0.9087
Epoch 18/50
586/586 [==============================] - 122s 206ms/step - loss: 0.1227 - accuracy: 0.9503 - val_loss: 0.1505 - val_accuracy: 0.9394
Epoch 19/50
586/586 [==============================] - 121s 206ms/step - loss: 0.1234 - accuracy: 0.9509 - val_loss: 0.4335 - val_accuracy: 0.8605
Epoch 20/50
586/586 [==============================] - 122s 206ms/step - loss: 0.1115 - accuracy: 0.9549 - val_loss: 0.1253 - val_accuracy: 0.9499
Epoch 21/50
586/586 [==============================] - 121s 205ms/step - loss: 0.1079 - accuracy: 0.9537 - val_loss: 0.1266 - val_accuracy: 0.9548
Epoch 22/50
586/586 [==============================] - 121s 205ms/step - loss: 0.1066 - accuracy: 0.9564 - val_loss: 0.1529 - val_accuracy: 0.9388
Epoch 23/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0999 - accuracy: 0.9611 - val_loss: 0.2258 - val_accuracy: 0.9076
Epoch 24/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0998 - accuracy: 0.9597 - val_loss: 0.1218 - val_accuracy: 0.9509
Epoch 25/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0959 - accuracy: 0.9641 - val_loss: 0.1314 - val_accuracy: 0.9541
Epoch 26/50
586/586 [==============================] - 121s 205ms/step - loss: 0.1017 - accuracy: 0.9569 - val_loss: 0.1304 - val_accuracy: 0.9496
Epoch 27/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0900 - accuracy: 0.9641 - val_loss: 0.1225 - val_accuracy: 0.9560
Epoch 28/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0833 - accuracy: 0.9661 - val_loss: 0.1020 - val_accuracy: 0.9610
Epoch 29/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0893 - accuracy: 0.9652 - val_loss: 0.1026 - val_accuracy: 0.9592
Epoch 30/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0834 - accuracy: 0.9669 - val_loss: 0.1700 - val_accuracy: 0.9328
Epoch 31/50
586/586 [==============================] - 120s 204ms/step - loss: 0.0856 - accuracy: 0.9656 - val_loss: 0.1025 - val_accuracy: 0.9586
Epoch 32/50
586/586 [==============================] - 120s 204ms/step - loss: 0.0829 - accuracy: 0.9652 - val_loss: 0.1293 - val_accuracy: 0.9482
Epoch 33/50
586/586 [==============================] - 120s 204ms/step - loss: 0.0767 - accuracy: 0.9694 - val_loss: 0.1117 - val_accuracy: 0.9560
Epoch 34/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0714 - accuracy: 0.9716 - val_loss: 0.1461 - val_accuracy: 0.9450
Epoch 35/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0790 - accuracy: 0.9684 - val_loss: 0.0989 - val_accuracy: 0.9620
Epoch 36/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0740 - accuracy: 0.9699 - val_loss: 0.1001 - val_accuracy: 0.9635
Epoch 37/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0704 - accuracy: 0.9731 - val_loss: 0.0964 - val_accuracy: 0.9659
Epoch 38/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0671 - accuracy: 0.9740 - val_loss: 0.0875 - val_accuracy: 0.9650
Epoch 39/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0659 - accuracy: 0.9757 - val_loss: 0.1233 - val_accuracy: 0.9528
Epoch 40/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0637 - accuracy: 0.9742 - val_loss: 0.1158 - val_accuracy: 0.9554
Epoch 41/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0638 - accuracy: 0.9734 - val_loss: 0.2093 - val_accuracy: 0.9149
Epoch 42/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0752 - accuracy: 0.9702 - val_loss: 0.1277 - val_accuracy: 0.9584
Epoch 43/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0668 - accuracy: 0.9748 - val_loss: 0.1082 - val_accuracy: 0.9603
Epoch 44/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0546 - accuracy: 0.9795 - val_loss: 0.1102 - val_accuracy: 0.9524
Epoch 45/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0643 - accuracy: 0.9753 - val_loss: 0.1130 - val_accuracy: 0.9648
Epoch 46/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0576 - accuracy: 0.9769 - val_loss: 0.1002 - val_accuracy: 0.9663
Epoch 47/50
586/586 [==============================] - 120s 204ms/step - loss: 0.0587 - accuracy: 0.9769 - val_loss: 0.0936 - val_accuracy: 0.9633
Epoch 48/50
586/586 [==============================] - 120s 204ms/step - loss: 0.0570 - accuracy: 0.9772 - val_loss: 0.1222 - val_accuracy: 0.9563
Epoch 49/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0620 - accuracy: 0.9758 - val_loss: 0.0851 - val_accuracy: 0.9682
Epoch 50/50
586/586 [==============================] - 121s 206ms/step - loss: 0.0538 - accuracy: 0.9784 - val_loss: 0.1015 - val_accuracy: 0.9599

h5模型推断

使用新数据推断

请注意,数据增强和dropout在推断时是不活跃(inactive)的。

#model.load_weights("save_at_49.h5")
model = keras.models.load_model("save_at_49.h5")
img = keras.preprocessing.image.load_img(
    "PetImages/Cat/6779.jpg", target_size=image_size
)
img_array = keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)  # Create batch axis

predictions = model.predict(img_array)
score = predictions[0]
print(
    "This image is %.2f percent cat and %.2f percent dog."
     % (100 * (1 - score), 100 * score)
)
This image is 99.94 percent cat and 0.06 percent dog.

保存为SavedModel并推断

# keras训练的模型可以保存为h5也可以保存为pb模型 https://www.tensorflow.org/guide/keras/save_and_serialize?hl=zh-cn
# 如果已经保存为h5也可以加载之后重新保存为SavedModel 模型架构和训练配置(包括优化器、损失和指标)存储在 saved_model.pb 中。权重保存在 variables/ 目录下
model.save('resave')
#tf.keras.models.save_model(network,"resave")

pb模型推断

reconstructed_model = keras.models.load_model("resave")
img = keras.preprocessing.image.load_img(
    "PetImages/Cat/6779.jpg", target_size=image_size
)
img_array = keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)  # Create batch axis

predictions = reconstructed_model.predict(img_array)
score = predictions[0]
print(
    "This image is %.2f percent cat and %.2f percent dog."
     % (100 * (1 - score), 100 * score)
)
This image is 99.94 percent cat and 0.06 percent dog.

keras model 保存为 TFLite

https://www.tensorflow.org/lite/convert/python_api?hl=zh-cn

https://www.pythonf.cn/read/105151

loaded_keras_model = keras.models.load_model("save_at_49.h5")

#keras model 保存为 TFLite
keras_to_tflite_converter = tf.lite.TFLiteConverter.from_keras_model(loaded_keras_model)
#量化tflite 模型,可以在损失较小精度或不影响精度的情况下减小模型大小
keras_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
keras_tflite=keras_to_tflite_converter.convert()


#with open('./keras_tflite', 'wb') as f:
#    f.write(keras_tflite)
#保存量化的 tflite 模型    
with open('./quantized_keras_tflite', 'wb') as f:
    f.write(keras_tflite)

savedModel 保存为 TFLite

#savedModel 保存为 TFLite
saved_model_to_tflite_converter = tf.lite.TFLiteConverter.from_saved_model('resave')
#量化tflite 模型,可以在损失较小精度或不影响精度的情况下减小模型大小
saved_model_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
saved_model_tflite = saved_model_to_tflite_converter.convert()

with open('./saved_model_tflite', 'wb') as f:
    f.write(saved_model_tflite)

#以前面的concrete_func_tflite为例,也可以直接加载模型路径
interpreter = tf.lite.Interpreter(model_path='saved_model_tflite')
interpreter.allocate_tensors()#给所有的tensor分配内存

#获取 input 和 output tensor
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

#输出字典,我们需要利用里面的输出来设置输入输出的相关参数
print(input_details)
print(output_details)
import numpy as np
image_size = (180, 180)
input_shape = input_details[0]['shape']#设置输入维度
input_data = tf.constant(np.ones(input_shape, dtype=np.float32))
img = keras.preprocessing.image.load_img(
    "PetImages/Cat/6779.jpg", target_size=image_size
)
img_array = keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)  # Create batch axis
interpreter.set_tensor(input_details[0]['index'],img_array)

#执行预测
interpreter.invoke()

output_results = interpreter.get_tensor(output_details[0]['index'])
score = output_results[0]
print(
    "This image is %.2f percent cat and %.2f percent dog."
     % (100 * (1 - score), 100 * score)
)
This image is 99.94 percent cat and 0.06 percent dog.

h5模型32M转为tflite之后3M,准确率几乎没有下降。

你可能感兴趣的:(Deep,Learning,深度学习,tensorflow,神经网络)