【深度学习笔记整理-5.3】使用预训练模型

1.为什么神经网络可以使用预训练模型?

所谓使用预训练的模型,就是使用别人已经训练好的模型及参数,虽然利用到的数据集不同,但是前几层几乎做的事情是相同的,就拿CNN举例,前几层都是在找一些简单的线条,对于足够大的网络和足够多的数据,前几层找到的线条是普适的,这也就是为什么我们可以使用预训练模型。

2.如何使用预训练模型?

使用预训练模型,我们就需要固定住前几层的参数,而后几层的参数(尤其是全连接的部分)需要我们根据手头的任务进行训练得出,以CNN为例,我们在使用这个预训练模型时,就要固定住(frozen)其卷积基,训练全连接的部分,也就是所谓分类器的部分。

【深度学习笔记整理-5.3】使用预训练模型_第1张图片

以keras内置的模型VGG16为例(此模型是由两个连续的卷积层加一个池化层重复两次和三个连续的卷积层加一个池化层重复三次构成),其引入的代码如下

from keras.applications import VGG16

conv_base = VGG16(weights='imagenet',include_top=False,input_shape=(150,150,3))

其中weights指定其参数由imagnet这个数据集训练获得,inlude_top指明是否引入全连接部分,input_shape用于修改为自己输入图片的大小,此参数可以不填。

from keras import models,layers

network = models.Squential()

network.add(conv_base)

network.add(layers.Flatten())

network.add(layers.Dense(512,activation='relu'))

network.add(layers.Dropout(0.5))

network.add(layers.Dense(1,activation='sigmoid'))

conv_base.trainable = False

构建好神经网络后,训练方式与前面类似

from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers 
train_datagen = ImageDataGenerator(       
rescale=1./255,       
rotation_range=40,       
width_shift_range=0.2,      
height_shift_range=0.2,       
shear_range=0.2,       
zoom_range=0.2,       
horizontal_flip=True,       
fill_mode='nearest') 
test_datagen = ImageDataGenerator(rescale=1./255)   
train_generator = train_datagen.flow_from_directory(         
train_dir,           
target_size=(150, 150),           
batch_size=20,         
class_mode='binary')   
validation_generator = test_datagen.flow_from_directory(         
validation_dir,         
target_size=(150, 150),         
batch_size=20,         
class_mode='binary') 
model.compile(loss='binary_crossentropy',               
optimizer=optimizers.RMSprop(lr=2e-5),               
metrics=['acc']) 
history = model.fit_generator(       
train_generator,       
steps_per_epoch=100,       
epochs=30,       
validation_data=validation_generator,       
validation_steps=50)

3.Fine-tuning

使用fine-tuning时,我们要力求不对模型做太大的改变,并且即使改变也尽量只改变模型的最后几层,因为这几层更贴近我们自己的任务,此外,我们需要先将自己的全连接层部分的权重参数训练好后,采用较小的学习率再进行微调,这样才可以避免由于预测类标与实际类标相差过大而导致模型后几层权重参数也改变过大的情况发生。

你可能感兴趣的:(深度学习,机器学习,python)