这个例子展示了如何从头开始进行图像分类,即从磁盘上的JPEG图像文件开始,而不需要利用预先训练过的权重或预先制作的Keras应用程序模型。演示了Kaggle猫与狗二元分类数据集上的工作流程。
使用image_dataset_from_directory
实用程序生成数据集,并使用Keras图像预处理层进行图像标准化和数据增强。
# 下载数据
curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip
# 解压数据
unzip -q kagglecatsanddogs_3367a.zip
import os
from tqdm import tqdm
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
在处理大量真实图像数据时,经常会出现损坏的图像。我们需要处理掉一些不合适的图像,这里我们需要过滤掉编码糟糕的图像,这些图像的头部不包含字符串“JFIF”。
num_skipped = 0
for folder_name in ("Cat", "Dog"):
folder_path = os.path.join("PetImages", folder_name)
for fname in tqdm(os.listdir(folder_path)):
fpath = os.path.join(folder_path, fname)
try:
fobj = open(fpath, "rb")
is_jfif = tf.compat.as_bytes("JFIF") in fobj.peek(10)
finally:
fobj.close()
if not is_jfif:
num_skipped += 1
# Delete corrupted image
os.remove(fpath)
print("Deleted %d images" % num_skipped)
image_size = (180, 180)
batch_size = 64
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
"PetImages",
validation_split=0.2,
subset="training",
seed=1337,
image_size=image_size,
batch_size=batch_size,
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
"PetImages",
validation_split=0.2,
subset="validation",
seed=1337,
image_size=image_size,
batch_size=batch_size,
)
这里展示训练集的前9张图像,dog的标签为1,cat的标签为0。
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(int(labels[i]))
plt.axis("off")
当您没有一个大的图像数据集时,通过对训练图像应用随机但真实的转换人工引入样本多样性是一个很好的做法,例如随机水平翻转或小的随机旋转。
这有助于将模型暴露在训练数据的不同方面,同时减缓过拟合。
data_augmentation = keras.Sequential(
[
layers.experimental.preprocessing.RandomFlip("horizontal"),
layers.experimental.preprocessing.RandomRotation(0.1),
]
)
可视化数据集一张图像通过data_augmentation
之后的样子。
plt.figure(figsize=(10, 10))
for images, _ in train_ds.take(1):
for i in range(9):
augmented_images = data_augmentation(images)
ax = plt.subplot(3, 3, i + 1)
plt.imshow(augmented_images[0].numpy().astype("uint8"))
plt.axis("off")
我们的图像已经是标准大小(180x180),然而,他们的RGB通道值是在[0, 255]范围。这对于神经网络来说并不理想;
一般来说,我们应该设法使我们的输入值较小。在这里,我们将通过在模型开始之前将输入预处理标准化到[0, 1]之间
有两种方法可以使用data_augmentation预处理器:
方法1:让它成为模型的一部分,像这样:
inputs = keras.Input(shape=input_shape)
x = data_augmentation(inputs)
x = layers.experimental.preprocessing.Rescaling(1./255)(x)
... # Rest of the model
有了这种方式,你的数据增强将在设备上发生,与模型执行的其余部分同步,这意味着它将受益于GPU加速。
请注意,在测试时数据增强是不活跃(inactive)的,所以输入样本只会在fit()期间增强,而不会在调用evaluate()或predict()时增强。
如果你在GPU上训练,这是更好的选择。
方法2:将其应用到数据集,从而获得一个生成批量扩增图像的数据集,如下所示:
augmented_train_ds = train_ds.map(
lambda x, y: (data_augmentation(x, training=True), y))
使用这种方式,您的数据扩增将在CPU上异步进行,并在进入模型之前进行缓冲。
如果你在CPU上训练,这是更好的选择,因为它使数据增强异步和非阻塞。
在这个例子中,采用的是第一种方式。
确保使用了缓冲预取(buffered prefetching),这样我们就可以在不阻塞I/O的情况下从磁盘生成数据:
train_ds = train_ds.prefetch(buffer_size=32)
val_ds = val_ds.prefetch(buffer_size=32)
我们将构建一个小型版本的Xception网络。我们还没有特别尝试优化架构;
如果你想系统地搜索最佳的模型配置,参考Keras Tuner
注意:
dropout
层。def make_model(input_shape, num_classes):
inputs = keras.Input(shape=input_shape)
# Image augmentation block
x = data_augmentation(inputs)
# Entry block
x = layers.experimental.preprocessing.Rescaling(1.0 / 255)(x)
x = layers.Conv2D(32, 3, strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Conv2D(64, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
previous_block_activation = x # Set aside residual
for size in [128, 256, 512, 728]:
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(size, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(size, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D(3, strides=2, padding="same")(x)
# Project residual
residual = layers.Conv2D(size, 1, strides=2, padding="same")(
previous_block_activation
)
x = layers.add([x, residual]) # Add back residual
previous_block_activation = x # Set aside next residual
x = layers.SeparableConv2D(1024, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.GlobalAveragePooling2D()(x)
if num_classes == 2:
activation = "sigmoid"
units = 1
else:
activation = "softmax"
units = num_classes
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(units, activation=activation)(x)
return keras.Model(inputs, outputs)
model = make_model(input_shape=image_size + (3,), num_classes=2)
keras.utils.plot_model(model, show_shapes=True)
model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 180, 180, 3) 0
__________________________________________________________________________________________________
sequential (Sequential) (None, 180, 180, 3) 0 input_1[0][0]
__________________________________________________________________________________________________
rescaling (Rescaling) (None, 180, 180, 3) 0 sequential[0][0]
__________________________________________________________________________________________________
conv2d (Conv2D) (None, 90, 90, 32) 896 rescaling[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 90, 90, 32) 128 conv2d[0][0]
__________________________________________________________________________________________________
activation (Activation) (None, 90, 90, 32) 0 batch_normalization[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 90, 90, 64) 18496 activation[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 90, 90, 64) 256 conv2d_1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 90, 90, 64) 0 batch_normalization_1[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 90, 90, 64) 0 activation_1[0][0]
__________________________________________________________________________________________________
separable_conv2d (SeparableConv (None, 90, 90, 128) 8896 activation_2[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 90, 90, 128) 512 separable_conv2d[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 90, 90, 128) 0 batch_normalization_2[0][0]
__________________________________________________________________________________________________
separable_conv2d_1 (SeparableCo (None, 90, 90, 128) 17664 activation_3[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 90, 90, 128) 512 separable_conv2d_1[0][0]
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 45, 45, 128) 0 batch_normalization_3[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 45, 45, 128) 8320 activation_1[0][0]
__________________________________________________________________________________________________
add (Add) (None, 45, 45, 128) 0 max_pooling2d[0][0]
conv2d_2[0][0]
__________________________________________________________________________________________________
activation_4 (Activation) (None, 45, 45, 128) 0 add[0][0]
__________________________________________________________________________________________________
separable_conv2d_2 (SeparableCo (None, 45, 45, 256) 34176 activation_4[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 45, 45, 256) 1024 separable_conv2d_2[0][0]
__________________________________________________________________________________________________
activation_5 (Activation) (None, 45, 45, 256) 0 batch_normalization_4[0][0]
__________________________________________________________________________________________________
separable_conv2d_3 (SeparableCo (None, 45, 45, 256) 68096 activation_5[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 45, 45, 256) 1024 separable_conv2d_3[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 23, 23, 256) 0 batch_normalization_5[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 23, 23, 256) 33024 add[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 23, 23, 256) 0 max_pooling2d_1[0][0]
conv2d_3[0][0]
__________________________________________________________________________________________________
activation_6 (Activation) (None, 23, 23, 256) 0 add_1[0][0]
__________________________________________________________________________________________________
separable_conv2d_4 (SeparableCo (None, 23, 23, 512) 133888 activation_6[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 23, 23, 512) 2048 separable_conv2d_4[0][0]
__________________________________________________________________________________________________
activation_7 (Activation) (None, 23, 23, 512) 0 batch_normalization_6[0][0]
__________________________________________________________________________________________________
separable_conv2d_5 (SeparableCo (None, 23, 23, 512) 267264 activation_7[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 23, 23, 512) 2048 separable_conv2d_5[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 12, 12, 512) 0 batch_normalization_7[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 12, 12, 512) 131584 add_1[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 12, 12, 512) 0 max_pooling2d_2[0][0]
conv2d_4[0][0]
__________________________________________________________________________________________________
activation_8 (Activation) (None, 12, 12, 512) 0 add_2[0][0]
__________________________________________________________________________________________________
separable_conv2d_6 (SeparableCo (None, 12, 12, 728) 378072 activation_8[0][0]
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 12, 12, 728) 2912 separable_conv2d_6[0][0]
__________________________________________________________________________________________________
activation_9 (Activation) (None, 12, 12, 728) 0 batch_normalization_8[0][0]
__________________________________________________________________________________________________
separable_conv2d_7 (SeparableCo (None, 12, 12, 728) 537264 activation_9[0][0]
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 12, 12, 728) 2912 separable_conv2d_7[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D) (None, 6, 6, 728) 0 batch_normalization_9[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 6, 6, 728) 373464 add_2[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 6, 6, 728) 0 max_pooling2d_3[0][0]
conv2d_5[0][0]
__________________________________________________________________________________________________
separable_conv2d_8 (SeparableCo (None, 6, 6, 1024) 753048 add_3[0][0]
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 6, 6, 1024) 4096 separable_conv2d_8[0][0]
__________________________________________________________________________________________________
activation_10 (Activation) (None, 6, 6, 1024) 0 batch_normalization_10[0][0]
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 1024) 0 activation_10[0][0]
__________________________________________________________________________________________________
dropout (Dropout) (None, 1024) 0 global_average_pooling2d[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 1) 1025 dropout[0][0]
==================================================================================================
Total params: 2,782,649
Trainable params: 2,773,913
Non-trainable params: 8,736
__________________________________________________________________________________________________
epochs = 50
callbacks = [
keras.callbacks.ModelCheckpoint("save_at_{epoch}.h5"),
]
model.compile(
optimizer=keras.optimizers.Adam(1e-3),
loss="binary_crossentropy",
metrics=["accuracy"],
)
model.fit(
train_ds, epochs=epochs, callbacks=callbacks, validation_data=val_ds,
)
Epoch 1/50
586/586 [==============================] - 130s 206ms/step - loss: 0.7009 - accuracy: 0.5990 - val_loss: 0.7257 - val_accuracy: 0.5974
Epoch 2/50
586/586 [==============================] - 121s 206ms/step - loss: 0.5602 - accuracy: 0.7131 - val_loss: 0.4627 - val_accuracy: 0.7783
Epoch 3/50
586/586 [==============================] - 121s 206ms/step - loss: 0.4647 - accuracy: 0.7813 - val_loss: 0.4323 - val_accuracy: 0.7969
Epoch 4/50
586/586 [==============================] - 121s 206ms/step - loss: 0.3861 - accuracy: 0.8257 - val_loss: 0.5142 - val_accuracy: 0.7679
Epoch 5/50
586/586 [==============================] - 121s 206ms/step - loss: 0.3148 - accuracy: 0.8674 - val_loss: 0.4155 - val_accuracy: 0.8157
Epoch 6/50
586/586 [==============================] - 121s 205ms/step - loss: 0.2691 - accuracy: 0.8867 - val_loss: 0.2976 - val_accuracy: 0.8786
Epoch 7/50
586/586 [==============================] - 126s 213ms/step - loss: 0.2379 - accuracy: 0.9028 - val_loss: 0.6268 - val_accuracy: 0.7954
Epoch 8/50
586/586 [==============================] - 122s 206ms/step - loss: 0.2097 - accuracy: 0.9105 - val_loss: 0.2147 - val_accuracy: 0.9100
Epoch 9/50
586/586 [==============================] - 122s 206ms/step - loss: 0.1931 - accuracy: 0.9230 - val_loss: 0.2223 - val_accuracy: 0.9072
Epoch 10/50
586/586 [==============================] - 121s 206ms/step - loss: 0.1859 - accuracy: 0.9239 - val_loss: 0.2247 - val_accuracy: 0.9087
Epoch 11/50
586/586 [==============================] - 121s 206ms/step - loss: 0.1706 - accuracy: 0.9322 - val_loss: 0.1547 - val_accuracy: 0.9424
Epoch 12/50
586/586 [==============================] - 121s 206ms/step - loss: 0.1586 - accuracy: 0.9339 - val_loss: 0.1441 - val_accuracy: 0.9409
Epoch 13/50
586/586 [==============================] - 122s 206ms/step - loss: 0.1533 - accuracy: 0.9361 - val_loss: 0.1377 - val_accuracy: 0.9441
Epoch 14/50
586/586 [==============================] - 121s 206ms/step - loss: 0.1415 - accuracy: 0.9418 - val_loss: 0.1682 - val_accuracy: 0.9260
Epoch 15/50
586/586 [==============================] - 121s 206ms/step - loss: 0.1399 - accuracy: 0.9418 - val_loss: 0.1225 - val_accuracy: 0.9509
Epoch 16/50
586/586 [==============================] - 121s 205ms/step - loss: 0.1346 - accuracy: 0.9492 - val_loss: 0.1149 - val_accuracy: 0.9524
Epoch 17/50
586/586 [==============================] - 121s 206ms/step - loss: 0.1326 - accuracy: 0.9473 - val_loss: 0.2272 - val_accuracy: 0.9087
Epoch 18/50
586/586 [==============================] - 122s 206ms/step - loss: 0.1227 - accuracy: 0.9503 - val_loss: 0.1505 - val_accuracy: 0.9394
Epoch 19/50
586/586 [==============================] - 121s 206ms/step - loss: 0.1234 - accuracy: 0.9509 - val_loss: 0.4335 - val_accuracy: 0.8605
Epoch 20/50
586/586 [==============================] - 122s 206ms/step - loss: 0.1115 - accuracy: 0.9549 - val_loss: 0.1253 - val_accuracy: 0.9499
Epoch 21/50
586/586 [==============================] - 121s 205ms/step - loss: 0.1079 - accuracy: 0.9537 - val_loss: 0.1266 - val_accuracy: 0.9548
Epoch 22/50
586/586 [==============================] - 121s 205ms/step - loss: 0.1066 - accuracy: 0.9564 - val_loss: 0.1529 - val_accuracy: 0.9388
Epoch 23/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0999 - accuracy: 0.9611 - val_loss: 0.2258 - val_accuracy: 0.9076
Epoch 24/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0998 - accuracy: 0.9597 - val_loss: 0.1218 - val_accuracy: 0.9509
Epoch 25/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0959 - accuracy: 0.9641 - val_loss: 0.1314 - val_accuracy: 0.9541
Epoch 26/50
586/586 [==============================] - 121s 205ms/step - loss: 0.1017 - accuracy: 0.9569 - val_loss: 0.1304 - val_accuracy: 0.9496
Epoch 27/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0900 - accuracy: 0.9641 - val_loss: 0.1225 - val_accuracy: 0.9560
Epoch 28/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0833 - accuracy: 0.9661 - val_loss: 0.1020 - val_accuracy: 0.9610
Epoch 29/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0893 - accuracy: 0.9652 - val_loss: 0.1026 - val_accuracy: 0.9592
Epoch 30/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0834 - accuracy: 0.9669 - val_loss: 0.1700 - val_accuracy: 0.9328
Epoch 31/50
586/586 [==============================] - 120s 204ms/step - loss: 0.0856 - accuracy: 0.9656 - val_loss: 0.1025 - val_accuracy: 0.9586
Epoch 32/50
586/586 [==============================] - 120s 204ms/step - loss: 0.0829 - accuracy: 0.9652 - val_loss: 0.1293 - val_accuracy: 0.9482
Epoch 33/50
586/586 [==============================] - 120s 204ms/step - loss: 0.0767 - accuracy: 0.9694 - val_loss: 0.1117 - val_accuracy: 0.9560
Epoch 34/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0714 - accuracy: 0.9716 - val_loss: 0.1461 - val_accuracy: 0.9450
Epoch 35/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0790 - accuracy: 0.9684 - val_loss: 0.0989 - val_accuracy: 0.9620
Epoch 36/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0740 - accuracy: 0.9699 - val_loss: 0.1001 - val_accuracy: 0.9635
Epoch 37/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0704 - accuracy: 0.9731 - val_loss: 0.0964 - val_accuracy: 0.9659
Epoch 38/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0671 - accuracy: 0.9740 - val_loss: 0.0875 - val_accuracy: 0.9650
Epoch 39/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0659 - accuracy: 0.9757 - val_loss: 0.1233 - val_accuracy: 0.9528
Epoch 40/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0637 - accuracy: 0.9742 - val_loss: 0.1158 - val_accuracy: 0.9554
Epoch 41/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0638 - accuracy: 0.9734 - val_loss: 0.2093 - val_accuracy: 0.9149
Epoch 42/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0752 - accuracy: 0.9702 - val_loss: 0.1277 - val_accuracy: 0.9584
Epoch 43/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0668 - accuracy: 0.9748 - val_loss: 0.1082 - val_accuracy: 0.9603
Epoch 44/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0546 - accuracy: 0.9795 - val_loss: 0.1102 - val_accuracy: 0.9524
Epoch 45/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0643 - accuracy: 0.9753 - val_loss: 0.1130 - val_accuracy: 0.9648
Epoch 46/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0576 - accuracy: 0.9769 - val_loss: 0.1002 - val_accuracy: 0.9663
Epoch 47/50
586/586 [==============================] - 120s 204ms/step - loss: 0.0587 - accuracy: 0.9769 - val_loss: 0.0936 - val_accuracy: 0.9633
Epoch 48/50
586/586 [==============================] - 120s 204ms/step - loss: 0.0570 - accuracy: 0.9772 - val_loss: 0.1222 - val_accuracy: 0.9563
Epoch 49/50
586/586 [==============================] - 121s 205ms/step - loss: 0.0620 - accuracy: 0.9758 - val_loss: 0.0851 - val_accuracy: 0.9682
Epoch 50/50
586/586 [==============================] - 121s 206ms/step - loss: 0.0538 - accuracy: 0.9784 - val_loss: 0.1015 - val_accuracy: 0.9599
请注意,数据增强和dropout在推断时是不活跃(inactive)的。
#model.load_weights("save_at_49.h5")
model = keras.models.load_model("save_at_49.h5")
img = keras.preprocessing.image.load_img(
"PetImages/Cat/6779.jpg", target_size=image_size
)
img_array = keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0) # Create batch axis
predictions = model.predict(img_array)
score = predictions[0]
print(
"This image is %.2f percent cat and %.2f percent dog."
% (100 * (1 - score), 100 * score)
)
This image is 99.94 percent cat and 0.06 percent dog.
# keras训练的模型可以保存为h5也可以保存为pb模型 https://www.tensorflow.org/guide/keras/save_and_serialize?hl=zh-cn
# 如果已经保存为h5也可以加载之后重新保存为SavedModel 模型架构和训练配置(包括优化器、损失和指标)存储在 saved_model.pb 中。权重保存在 variables/ 目录下
model.save('resave')
#tf.keras.models.save_model(network,"resave")
reconstructed_model = keras.models.load_model("resave")
img = keras.preprocessing.image.load_img(
"PetImages/Cat/6779.jpg", target_size=image_size
)
img_array = keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0) # Create batch axis
predictions = reconstructed_model.predict(img_array)
score = predictions[0]
print(
"This image is %.2f percent cat and %.2f percent dog."
% (100 * (1 - score), 100 * score)
)
This image is 99.94 percent cat and 0.06 percent dog.
https://www.tensorflow.org/lite/convert/python_api?hl=zh-cn
https://www.pythonf.cn/read/105151
loaded_keras_model = keras.models.load_model("save_at_49.h5")
#keras model 保存为 TFLite
keras_to_tflite_converter = tf.lite.TFLiteConverter.from_keras_model(loaded_keras_model)
#量化tflite 模型,可以在损失较小精度或不影响精度的情况下减小模型大小
keras_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
keras_tflite=keras_to_tflite_converter.convert()
#with open('./keras_tflite', 'wb') as f:
# f.write(keras_tflite)
#保存量化的 tflite 模型
with open('./quantized_keras_tflite', 'wb') as f:
f.write(keras_tflite)
#savedModel 保存为 TFLite
saved_model_to_tflite_converter = tf.lite.TFLiteConverter.from_saved_model('resave')
#量化tflite 模型,可以在损失较小精度或不影响精度的情况下减小模型大小
saved_model_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
saved_model_tflite = saved_model_to_tflite_converter.convert()
with open('./saved_model_tflite', 'wb') as f:
f.write(saved_model_tflite)
#以前面的concrete_func_tflite为例,也可以直接加载模型路径
interpreter = tf.lite.Interpreter(model_path='saved_model_tflite')
interpreter.allocate_tensors()#给所有的tensor分配内存
#获取 input 和 output tensor
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
#输出字典,我们需要利用里面的输出来设置输入输出的相关参数
print(input_details)
print(output_details)
import numpy as np
image_size = (180, 180)
input_shape = input_details[0]['shape']#设置输入维度
input_data = tf.constant(np.ones(input_shape, dtype=np.float32))
img = keras.preprocessing.image.load_img(
"PetImages/Cat/6779.jpg", target_size=image_size
)
img_array = keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0) # Create batch axis
interpreter.set_tensor(input_details[0]['index'],img_array)
#执行预测
interpreter.invoke()
output_results = interpreter.get_tensor(output_details[0]['index'])
score = output_results[0]
print(
"This image is %.2f percent cat and %.2f percent dog."
% (100 * (1 - score), 100 * score)
)
This image is 99.94 percent cat and 0.06 percent dog.
h5模型32M转为tflite之后3M,准确率几乎没有下降。