Tensorflow 2.1 迁移学习 基于VGG

环境

Tensorflow 2.1

准备工作

下载VGG 的权重可以自动下载也可以离线下载。
下载要训练的图片。这个里图片包含五种类型的花(‘daisy’,‘dandelion’,‘roses’,‘sunflowers’,‘tulips’)

https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
然后解压放在你的项目地下这个目录里 flower_photos

简要说明

基于VGG的迁移学习, VGG 的权重不训练了,因为已经训练好了。
但是要去掉全连接层,加上我们的全连接层就好。我们只有简单训练一下我全连接层就可以了。

训练集与验证集的结果

验证集上准确率 80%左右

87/290 [============================>.] - ETA: 0s - loss: 0.4844 - categorical_accuracy: 0.8358
288/290 [============================>.] - ETA: 0s - loss: 0.4836 - categorical_accuracy: 0.8364
289/290 [============================>.] - ETA: 0s - loss: 0.4831 - categorical_accuracy: 0.8366save_weight 36 0.5442695867802415

290/290 [==============================] - 35s 119ms/step - loss: 0.4835 - categorical_accuracy: 0.8365 - val_loss: 0.5443 - val_categorical_accuracy: 0.8071

训练的完整代码

from tensorflow.keras.applications import VGG16
import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import tensorflow.keras.preprocessing.image as image
import os as os

vgg16=VGG16(input_shape = (224,224,3),  include_top=False)

best_model =vgg16

l_layer=len(best_model.layers)

new_model=keras.Sequential(best_model)
for i in range(l_layer-1):
    best_model.layers[i].trainable = False

new_output=keras.layers.Dense(5,activation=tf.nn.softmax,kernel_initializer=tf.initializers.Constant(0.001))
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
new_model.add(global_average_layer)
new_model.add(new_output)

new_model.compile(optimizer=keras.optimizers.Adam(),
                   loss=keras.losses.categorical_crossentropy,
                   # metrics=['accuracy'])
                   metrics=[keras.metrics.categorical_accuracy])

new_model.summary()


#雏菊,蒲公英, 郁金香
label_names={'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
label_key=['daisy','dandelion','roses','sunflowers','tulips']

train_datagen = image.ImageDataGenerator(
    rescale=1 / 255,
    rotation_range=40,  # 角度值,0-180.表示图像随机旋转的角度范围
    width_shift_range=0.2,  # 平移比例,下同
    height_shift_range=0.2,
    shear_range=0.2,  # 随机错切变换角度
    zoom_range=0.2,  # 随即缩放比例
    horizontal_flip=True,  # 随机将一半图像水平翻转
    validation_split=0.2,
    fill_mode='nearest'  # 填充新创建像素的方法
)

IMG_SIZE = 224
BATCH_SIZE = 32
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
pic_folder = './flower_photos'

train_generator = train_datagen.flow_from_directory(
    directory=pic_folder,
    target_size=IMG_SHAPE[:-1],
    color_mode='rgb',
    classes=None,
    class_mode='categorical',
    batch_size=10,
    subset='training',
    shuffle=True)

validation_generator = train_datagen.flow_from_directory(
    directory=pic_folder,
    target_size=IMG_SHAPE[:-1],
    color_mode='rgb',
    classes=None,
    class_mode='categorical',
    batch_size=10,
    subset='validation',
    shuffle=True)

current_max_loss = 9999
weight_file='./weightsf/model.h5'

if os.path.isfile(weight_file):
    print('load weight')
    new_model.load_weights(weight_file)

def save_weight(epoch, logs):
    global current_max_loss
    if (logs['val_loss'] is not None and logs['val_loss'] < current_max_loss):
        current_max_loss = logs['val_loss']
        print('save_weight', epoch, current_max_loss)
        new_model.save_weights(weight_file)

batch_print_callback = keras.callbacks.LambdaCallback(
    on_epoch_end=save_weight
)

callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=4, monitor='val_loss'),
    batch_print_callback,
    # keras.callbacks.ModelCheckpoint('./weights/model.h5', save_best_only=True),
    tf.keras.callbacks.TensorBoard(log_dir='logsf')
]

history = new_model.fit_generator(train_generator, steps_per_epoch=290, epochs=40, callbacks=callbacks,
                                   validation_data=validation_generator, validation_steps=70)
print(history)

def show_result(history):
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.plot(history.history['categorical_accuracy'])
    plt.plot(history.history['val_categorical_accuracy'])
    plt.legend(['loss', 'val_loss', 'categorical_accuracy', 'val_categorical_accuracy'],
               loc='upper left')
    plt.show()
    print(history)

show_result(history)

你可能感兴趣的:(tensorflow)