这里提供一个简单的利用tensorflow训练的例子。
调用keras的mobilenetv2 api,去掉其顶层,使用其特征提取部分,然后重新设定顶层部分。选择使用预训练权重可以降低我们的训练难度。
有一个比较细节的地方是,这里将归一化直接集成在网络上,这样的好处是:一方面不需要在art上进行设定,另一方面,在制作数据集时也不用进行归一化的预处理(如果忘了归一化会导致模型不收敛,预训练权重对归一化也会有要求,这样相当于减少出问题时需要排错的事项)。
def Mobilenet_v2(input_size,weights,Dropout_rate,Trainable,alpha = 0.35):
base_model = keras.applications.MobileNetV2(
input_shape=(input_size, input_size, 3),
alpha=alpha,
weights=weights,
include_top=False
)
inputs = keras.Input(shape=(input_size, input_size,3))
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(inputs)
x = base_model(x, training=False)
if Trainable:
base_model.trainable =True
else:
base_model.trainable = False
print("特征层已冻结")
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(Dropout_rate, name='Dropout')(x)
outputs = keras.layers.Dense(15, activation='softmax')(x)
model = keras.Model(inputs, outputs)
return model
这里的数据集直接使用的是官方数据集,没有独立的验证集,直接从训练集中划分。
增强使用的是tensorflow自带的,这里仅进行了旋转,缩放等增强。
这样的增强显然是不够的,有关增强的方法会在后边单独讲。
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_root = "./train/"
train_generator = ImageDataGenerator(rotation_range=360,
zoom_range =0.2,
horizontal_flip = True,
validation_split =0.2
)
train_dataset = train_generator.flow_from_directory(batch_size=batch_size,
directory=train_root,
shuffle=True,
target_size=(input_size,input_size),
subset='training')
valid_dataset = train_generator.flow_from_directory(batch_size=batch_size,
directory=train_root,
shuffle=True,
target_size=(input_size,input_size),
subset='validation')
print(train_dataset.class_indices)
需要注意,数据集的目录结构需如下所示,不要像逐飞那样搞分大类小类的二级目录:
在这里我设置了3个回调,其中学习降低回调是比较重要的。一开始的学习率在训练到后期可能是不太合适的,会引起震荡,可以通过降低学习率来避免这种情况,通常通过降低学习率会使得模型的精度有少量提升。
你也可以使用LearningRateScheduler回调对学习率进行动态的调整。这也是一个常用的方法。
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2,patience=10,verbose=1)
early_stop =keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=10,verbose=1)
save_weights = keras.callbacks.ModelCheckpoint(save_path + "/model_{epoch:02d}_{val_accuracy:.4f}.h5",
save_best_only=True, monitor='val_accuracy')
在学习深度学习框架时,看官方文档是非常有用的,这里列举两个经常会用到的,或者说,看这两个就够了。
Keras API reference
tensorflow
其中也提供了一些入门的教程,跟着走一遍会对框架有一定的了解。