Keras Fine Tuning(微调)(3)

目录

Keras Fine Tuning(微调)(1)

Keras Fine Tuning(微调)(2)

Keras Fine Tuning(微调)(3)

数据集下载:https://download.csdn.net/download/github_39611196/10940372


接上一篇博客: Keras Fine Tuning(微调)(2) ,本文主要介绍Keras中的fine tuning(微调),通过对西瓜、南瓜、番茄数据集进行分类来进行实例说明。

与上一篇博客冻结不做数据增强不同,本篇博客冻结VGG 16网络最后四层,并对进行数据增强(Data Augmentation)的数据进行训练。

代码:

# 冻结网络除最后4层外的所有层,并对做了数据增强的数据进行训练
from keras.applications import VGG16

# 加载VGG 模型
vgg_conv = VGG16(weights='imagenet', include_top=False, input_shape=(image_size, image_size, 3))

# 冻结模型除最后4层外的所有层
for layer in vgg_conv.layers[:-4]:
    layer.trainable = False

# 检查所有层trainable属性的状态
for layer in vgg_conv.layers:
    print(layer, layer.trainable)

from keras import models
from keras import layers
from keras import optimizers

# 创建模型
model = models.Sequential()

# 添加vgg卷积基础模型
model.add(vgg_conv)

# 添加新的层
model.add(layers.Flatten())
model.add(layers.Dense(1024, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(3, activation='softmax'))  # softmax激活函数用于分类的时候使用

# summary
model.summary()

# 训练模型,使用imageDataGenerator实现数据增强(Data Augmentation)
train_datagen = ImageDataGenerator(rescale=1./255, rotation_range=20, width_shift_range=0.2, height_shift_range=0.2, horizontal_flip=True, fill_mode='nearest')

validation_datagen = ImageDataGenerator(rescale=1./ 255)

train_batchsize = 50
val_batchsize = 10

# 训练数据的数据生成器
train_generator = train_datagen.flow_from_directory(train_dir, target_size=(image_size, image_size), batch_size=train_batchsize, class_mode='categorical')

# 验证数据集的数据生成器
validation_generator = validation_datagen.flow_from_directory(validation_dir, target_size=(image_size, image_size), batch_size=val_batchsize, class_mode='categorical', shuffle=False)

# 编译模型
model.compile(loss='categorical_crossentropy', optimizer=optimizers.RMSprop(lr=1e-4), metrics=['acc'])

# 训练模型
# 注意:对 steps_per_epoch 倍增模式,这是因为使用了数据增强。
history = model.fit_generator(
      train_generator,
      steps_per_epoch=2*train_generator.samples/train_generator.batch_size ,
      epochs=40,
      validation_data=validation_generator,
      validation_steps=validation_generator.samples/validation_generator.batch_size,
      verbose=1)

# 保存模型
model.save('da_last4_layers.h5')

# 显示正确率和loss曲线
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'b', label='Training acc')
plt.plot(epochs, val_acc, 'r', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()

plt.figure()

plt.plot(epochs, loss, 'b', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()

结果:

Keras Fine Tuning(微调)(3)_第1张图片

Keras Fine Tuning(微调)(3)_第2张图片

你可能感兴趣的:(Keras,Keras)