Keras - 加载预训练模型并冻结网络的层

目录

加载预训练模型:

冻结网络层:

冻结预训练模型中的层

加载所有预训练模型的层


    在解决一个任务时,我会选择加载预训练模型并逐步fine-tune。比如,分类任务中,优异的深度学习网络有很多。ResNet, VGG, Xception等等... 并且这些模型参数已经在imagenet数据集中训练的很好了,可以直接拿过来用。根据自己的任务,训练一下最后的分类层即可得到比较好的结果。此时,就需要“冻结”预训练模型的所有层,即这些层的权重永不会更新。以Xception为例:

加载预训练模型:

from tensorflow.python.keras.applications import Xception

model = Sequential()

model.add(Xception(include_top=False, pooling='avg', weights='imagenet'))

model.add(Dense(NUM_CLASS, activation='softmax'))

include_top = False :  不包含顶层的3个全链接网络

weights : 加载预训练权重

随后,根据自己的分类任务加一层网络即可。

网络具体参数:

model.summary

得到两个网络层,第一层是xception层,第二层为分类层。

由于未冻结任何层,trainable params为:20, 811, 050

Keras - 加载预训练模型并冻结网络的层_第1张图片

 

冻结网络层:

    由于第一层为xception,不想更新xception层的参数,可以加以下代码:

model.layers[0].trainable = False

Keras - 加载预训练模型并冻结网络的层_第2张图片

 

冻结预训练模型中的层

    如果想冻结xception中的部分层,可以如下操作:

from tensorflow.python.keras.applications import Xception

model = Sequential()

model.add(Xception(include_top=False, pooling='avg', weights='imagenet'))

model.add(Dense(NUM_CLASS, activation='softmax'))

for i, layer in enumerate(model.layers[0].layers):
    if  i > 115:
        layer.trainable = True
    else:
        layer.trainable = False
    print(i, layer.name, layer.trainable)

Keras - 加载预训练模型并冻结网络的层_第3张图片

Keras - 加载预训练模型并冻结网络的层_第4张图片

 

加载所有预训练模型的层

若想把xeption的所有层应用在训练自己的数据,并改变分类数。可以如下操作:

model = Sequential()
model.add(Xception(include_top=True,  weights=None, classes=NUM_CLASS))

  • * 如果想指定classes,有两个条件:include_top:True, weights:None。否则无法指定classes

你可能感兴趣的:(keras,python)