- 导入所需的包
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import tensorflow as tf
from tensorflow import keras
print("TensorFlow version is ", tf.__version__)
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
- 下载数据
zip_file = tf.keras.utils.get_file(origin="https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip",
fname="cats_and_dogs_filtered.zip", extract=True)
base_dir, _ = os.path.splitext(zip_file)
- 构建batch数据生成器
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')
# Directory with our training cat pictures
train_cats_dir = os.path.join(train_dir, 'cats')
print ('Total training cat images:', len(os.listdir(train_cats_dir)))
# Directory with our training dog pictures
train_dogs_dir = os.path.join(train_dir, 'dogs')
print ('Total training dog images:', len(os.listdir(train_dogs_dir)))
# Directory with our validation cat pictures
validation_cats_dir = os.path.join(validation_dir, 'cats')
print ('Total validation cat images:', len(os.listdir(validation_cats_dir)))
# Directory with our validation dog pictures
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
print ('Total validation dog images:', len(os.listdir(validation_dogs_dir)))
image_size = 160 # All images will be resized to 160x160
batch_size = 32
# Rescale all images by 1./255 and apply image augmentation
train_datagen = keras.preprocessing.image.ImageDataGenerator(
rescale=1./255)
validation_datagen = keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
# Flow training images in batches of 20 using train_datagen generator
train_generator = train_datagen.flow_from_directory(
train_dir, # Source directory for the training images
target_size=(image_size, image_size),
batch_size=batch_size,
# Since we use binary_crossentropy loss, we need binary labels
class_mode='binary')
# Flow validation images in batches of 20 using test_datagen generator
validation_generator = validation_datagen.flow_from_directory(
validation_dir, # Source directory for the validation images
target_size=(image_size, image_size),
batch_size=batch_size,
class_mode='binary')
- 构建预训练的base模型
IMG_SHAPE = (image_size, image_size, 3)
# Create the base model from the pre-trained model MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')
- 冻结base模型的参数
base_model.trainable = False
- 查看模型结构
# Let's take a look at the base model architecture
base_model.summary()
- 在base模型上增加模型分类器
model = tf.keras.Sequential([
base_model,
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(1, activation='sigmoid')
])
- 编译模型
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.0001),
loss='binary_crossentropy',
metrics=['accuracy'])
- 查看模型结构和参数(可训练参数的数量)
base_model.summary()
- 训练模型
epochs = 10
steps_per_epoch = train_generator.n // batch_size
validation_steps = validation_generator.n // batch_size
history = model.fit_generator(train_generator,
steps_per_epoch = steps_per_epoch,
epochs=epochs,
workers=4,
validation_data=validation_generator,
validation_steps=validation_steps)
- 学习曲线
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,max(plt.ylim())])
plt.title('Training and Validation Loss')
plt.show()
- Fine tuning: 将base模型较高层次layer部分解冻进行训练
base_model.trainable = True
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))
# Fine tune from this layer onwards
fine_tune_at = 100
# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False
- 重新编译
model.compile(optimizer = tf.keras.optimizers.RMSprop(lr=2e-5),
loss='binary_crossentropy',
metrics=['accuracy'])
- 再次查看模型
model.summary()
len(model.trainable_variables)
- 继续训练模型
history_fine = model.fit_generator(train_generator,
steps_per_epoch = steps_per_epoch,
epochs=epochs,
workers=4,
validation_data=validation_generator,
validation_steps=validation_steps)
- 将两次训练的学习曲线连起来作图
acc += history_fine.history['acc']
val_acc += history_fine.history['val_acc']
loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.9, 1])
plt.plot([epochs-1,epochs-1], plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 0.2])
plt.plot([epochs-1,epochs-1], plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
【参考资料】
Transfer Learning Using Pretrained ConvNets