论文链接:https://arxiv.org/abs/1610.02357
Xception是 Google 在 Inception 的基础上提出的对 Inception-v3 的另一种改进,用Separable Convolution(“极致”的 Inception 模块)来替换Inception中的卷积操作(详细了解请参考原论文)
1.1 Separable Convolution模块:
即将上一层的输出先进行普通1×1卷积操作,再对卷积后的每个channel分别进行 3×3卷积操作,最后将结果连接起来输入下一层
1.2 Xception整体网络结构:
tensorflow.keras.applications模块内置了许多模型,包括MobileNet、InceptionV3、VGG等。我们可以使用内置的Xception模型,只需修改最后的全连接层输出类别即可。
数据集使用的是kaggle上的数据集:200 Bird Species
这个数据集包含了200种鸟类,每一类训练集包含100多张图片,验证集和测试集均包含五张图片,每张图片的大小为224x224
3.1 导入相应的库
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models, Model, Sequential
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import ReduceLROnPlateau,EarlyStopping
import tensorflow as tf
import os
3.2 设置
图片的宽和高都设置为299(Xception默认图片输入大小为299x299),batch_size设置为128,总共训练10个epoch,还有设置训练集,验证集,测试集路径
im_height = 299
im_width = 299
batch_size = 128
epochs = 10
image_path = "./data/100-bird-species/"
train_dir = image_path + "train"
validation_dir = image_path + "valid"
test_dir = image_path + "test"
3.3 数据预处理
对训练集图片做数据增强,验证集和测试集只做归一化处理
# data generator with data augmentation
train_image_generator = ImageDataGenerator( rescale=1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
validation_image_generator = ImageDataGenerator(rescale=1./255)
test_image_generator = ImageDataGenerator(rescale=1./255)
3.4 生成数据
将训练集数据打乱,训练集和验证集数据不打乱,均采用one-hot编码模式
train_data_gen = train_image_generator.flow_from_directory(directory=train_dir,
batch_size=batch_size,
shuffle=True,
target_size=(im_height, im_width),
class_mode='categorical')
total_train = train_data_gen.n
val_data_gen = validation_image_generator.flow_from_directory(directory=validation_dir,
batch_size=batch_size,
shuffle=False,
target_size=(im_height, im_width),
class_mode='categorical')
total_val = val_data_gen.n
test_data_gen = test_image_generator.flow_from_directory( directory=test_dir,
target_size=(im_height, im_width))
total_test = test_data_gen.n
3.5 构建模型
这里使用的是tensorflow内置的Xception模型,然后将预训练模型的前100层冻结,训练后32层,加入全局平均池化层和输出层
Xception预训练模型下载:网址
covn_base = tf.keras.applications.xception.Xception(weights='imagenet',include_top=False)
covn_base.trainable = True
for layers in covn_base.layers[:-32]:
layers.trainable = False
model = tf.keras.Sequential()
model.add(covn_base)
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Dense(200))
model.summary()
可以看出,Xception模型总共有21,271,280个参数,然而我们只需要训练9,888,144个参数即可
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
xception (Model) (None, None, None, 2048) 20861480
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048) 0
_________________________________________________________________
dense (Dense) (None, 200) 409800
=================================================================
Total params: 21,271,280
Trainable params: 9,888,144
Non-trainable params: 11,383,136
_________________________________________________________________
3.6 编译模型
选用的是adam优化器,初始学习率设置为0.0001,损失函数为交叉熵损失函数,因为构建模型的时候最后一层没有用softmax激活,所以from_logits设置为True
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=["accuracy"])
3.7 开始训练
Early_sp:监视’val_accuracy’的变化,如果连续五个轮次不变,停止训练
reduce_lr:监视’val_loss’的变化,如果两个轮次不变学习率衰减为原来 的1/10
Early_sp = EarlyStopping(monitor = 'val_accuracy', patience = 5,restore_best_weights = True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=2, verbose=1)
history = model.fit(x=train_data_gen,
steps_per_epoch=total_train // batch_size,
epochs=epochs,
validation_data=val_data_gen,
validation_steps=total_val // batch_size,
callbacks=[Early_sp,reduce_lr])
从训练过程可以看出,当训练了10个epoch之后已经达到了97%的准确率
Train for 214 steps, validate for 7 steps
Epoch 1/10
214/214 [==============================] - 733s 3s/step - loss: 3.1320 - accuracy: 0.5123 - val_loss: 0.9767 - val_accuracy: 0.8460
Epoch 2/10
214/214 [==============================] - 688s 3s/step - loss: 0.8930 - accuracy: 0.8639 - val_loss: 0.4742 - val_accuracy: 0.9353
Epoch 3/10
214/214 [==============================] - 703s 3s/step - loss: 0.4492 - accuracy: 0.9197 - val_loss: 0.3091 - val_accuracy: 0.9554
Epoch 4/10
214/214 [==============================] - 710s 3s/step - loss: 0.2958 - accuracy: 0.9449 - val_loss: 0.2153 - val_accuracy: 0.9587
Epoch 5/10
214/214 [==============================] - 694s 3s/step - loss: 0.2137 - accuracy: 0.9582 - val_loss: 0.1709 - val_accuracy: 0.9688
Epoch 6/10
214/214 [==============================] - 681s 3s/step - loss: 0.1621 - accuracy: 0.9689 - val_loss: 0.1370 - val_accuracy: 0.9732
Epoch 7/10
214/214 [==============================] - 684s 3s/step - loss: 0.1268 - accuracy: 0.9749 - val_loss: 0.1182 - val_accuracy: 0.9665
Epoch 8/10
214/214 [==============================] - 688s 3s/step - loss: 0.1029 - accuracy: 0.9802 - val_loss: 0.1132 - val_accuracy: 0.9777
Epoch 9/10
214/214 [==============================] - 709s 3s/step - loss: 0.0850 - accuracy: 0.9827 - val_loss: 0.1043 - val_accuracy: 0.9699
Epoch 10/10
214/214 [==============================] - 718s 3s/step - loss: 0.0702 - accuracy: 0.9853 - val_loss: 0.1057 - val_accuracy: 0.9732
3.8 绘制学习率和损失值曲线
# plot loss and accuracy image
history_dict = history.history
train_loss = history_dict["loss"]
train_accuracy = history_dict["accuracy"]
val_loss = history_dict["val_loss"]
val_accuracy = history_dict["val_accuracy"]
# figure 1
plt.figure()
plt.plot(range(epochs), train_loss, label='train_loss')
plt.plot(range(epochs), val_loss, label='val_loss')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('loss')
# figure 2
plt.figure()
plt.plot(range(epochs), train_accuracy, label='train_accuracy')
plt.plot(range(epochs), val_accuracy, label='val_accuracy')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.show()
3.9 在测试集上测试
scores = model.evaluate(test_data_gen, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
在测试集上的准确率为98.9%,损失值为0.0658,效果还是不错的
Found 1000 images belonging to 200 classes.
32/32 [==============================] - 6s 190ms/step - loss: 0.0658 - accuracy: 0.9890
Test loss: 0.06583867644076236
Test accuracy: 0.989