一、导入相关库
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Dropout,Conv2D,MaxPool2D,Flatten
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.callbacks import EarlyStopping
import json
from tensorflow.keras.applications.vgg16 import VGG16
定义超参数
num_classes = 2
batch_size = 8
epochs = 10
image_size = 224
数据增加
train_datagen = ImageDataGenerator(
rotation_range = 20,
width_shift_range = 0.1,
height_shift_range = 0.1,
rescale = 1/255,
shear_range = 10,
zoom_range = 0.1,
horizontal_flip = True,
brightness_range=(0.7, 1.3),
fill_mode = 'nearest',
)
test_datagen = ImageDataGenerator(
rescale = 1/255,
)
构造生成器
train_generator = train_datagen.flow_from_directory(
'data/train',
target_size=(image_size,image_size),
batch_size=batch_size,
)
test_generator = test_datagen.flow_from_directory(
'data/test',
target_size=(image_size,image_size),
batch_size=batch_size,
)
做标签
label = train_generator.class_indices
label = dict(zip(label.values(),label.keys()))
with open('label_cat_dog.json','w',encoding='utf-8') as f:
json.dump(label, f)
train_generator.class_indices
查看模型
vgg16 = VGG16(weights='imagenet',include_top=False, input_shape=(image_size,image_size,3))
vgg16.summary()
改变VGG16的全连接层
top_model = Sequential()
top_model.add(Flatten(input_shape=vgg16.output_shape[1:]))
top_model.add(Dense(256,activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(num_classes,activation='softmax'))
model = Sequential()
model.add(vgg16)
model.add(top_model)
model.summary()
构造回调函数
early_stoping = EarlyStopping(monitor='val_accuracy', patience=3, verbose=1)
调节学习率
def adjust_learning_rate(epoch):
if epoch<=10:
lr = 1e-5
else:
lr = 1e-6
return lr
训练模型
adam = Adam(lr=1e-5)
callbacks = []
callbacks.append(LearningRateScheduler(adjust_learning_rate))
model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])
history = model.fit(x=train_generator,
epochs=epochs,
validation_data=test_generator,
callbacks=[early_stoping])
准确率曲线
plt.plot(np.arange(epochs),history.history['accuracy'],c='b',label='train_accuracy')
plt.plot(np.arange(epochs),history.history['val_accuracy'],c='y',label='val_accuracy')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.show()
model.save('cat_dog_model', save_format='tf')