最后一节作业是水果分类的任务,一共6类,使用之前学习的知识在代码段上进行填空。
加载ImageNet预训练的基础模型
from tensorflow import keras
base_model = keras.applications.VGG16(
weights="imagenet",
input_shape=(224, 224, 3),
include_top=False)
冻结基础模型
# Freeze base model
base_model.trainable = False
向模型添加新层
# Create inputs with correct shape
inputs = keras.Input(shape=(224, 224, 3))
x = base_model(inputs, training=False)
# Add pooling layer or flatten layer
x = keras.layers.GlobalAveragePooling2D()(x)
# Add final dense layer
outputs = keras.layers.Dense(6, activation = 'softmax')(x)
# Combine inputs and outputs to create model
model = keras.Model(inputs, outputs)
model.summary()
编译模型
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
扩充数据
from tensorflow.keras.preprocessing.image import ImageDataGenerator
datagen_train = ImageDataGenerator(featurewise_center=True, # set input mean to 0 over the dataset
samplewise_center=True, # set each sample mean to 0
rotation_range=10, # randomly rotate images in the range (degrees, 0 to 180)
zoom_range = 0.1, # Randomly zoom image
width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)
height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
horizontal_flip=True, # randomly flip images
vertical_flip=False)
datagen_valid = ImageDataGenerator(featurewise_center=True, # set input mean to 0 over the dataset
samplewise_center=True, # set each sample mean to 0
rotation_range=10, # randomly rotate images in the range (degrees, 0 to 180)
zoom_range = 0.1, # Randomly zoom image
width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)
height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
horizontal_flip=True, # randomly flip images
vertical_flip=False)
加载数据集
# load and iterate training dataset
train_it = datagen_train.flow_from_directory(
"data/fruits/train",
target_size=(224, 224),
color_mode="rgb",
class_mode="categorical",
)
# load and iterate validation dataset
valid_it = datagen_valid.flow_from_directory(
"data/fruits/valid",
target_size=(224, 224),
color_mode="rgb",
class_mode="categorical",
)
训练模型
现在开始训练模型!将训练和测试数据集传递给fit
函数,并设置所需的训练次数(epochs)
model.fit(train_it,
validation_data=valid_it,
steps_per_epoch=train_it.samples/train_it.batch_size,
validation_steps=valid_it.samples/valid_it.batch_size,
epochs=10)