模型的保存和加载可以直接通过Model
类的save_weights
和load_weights
实现。默认的保存格式为tensorflow的checkpoint格式,也可以手动设置保存为h5文件。
model.save_weights('weights', save_format='h5')
加载时默认为根据网络的拓扑结构进行加载,这适用于不对网络进行更改,直接进行测试的情况。但如果只希望加载部分权重,可以更改为根据变量名进行加载。
model.load_weights('weights', by_name=True)
有个很坑的点是,加载checkpoint格式时源码中似乎没有实现by_name
,所以尽管设置了by_name=True
,他仍然会按照拓扑结构加载,然后报错提示部分变量不匹配,所以还是尽量都存成h5文件。
有了保存和读取模型的方法后,就可以在大型数据库上先进行预训练,然后将权重迁移到小数据库或其他任务上。举例i来说,实际应用场景如VGG16先在ImageNet上进行训练,再将除最后一层全连接以外的参数迁移到SSD完成目标检测任务。
实现时就涉及到两个问题:部分网络层不同(不同的分类任务最后一个全连接层的输出维度不同)和调用网络时的输出不同(目标检测任务需要提取网络的中间层的特征图输出),我们可以通过继承Model
类来解决上述问题。
继承Model
类需要实现两个函数,__init__()
与call()
,下面以ResNet为例。
class ResNet(models.Model):
def __init__(self, layer_num, **kwargs):
super(ResNet, self).__init__(**kwargs)
if block_type[layer_num] == 'basic block':
self.block = BasicBlock
else:
self.block = BottleneckBlock
self.conv0 = Conv2D(64, (7, 7), strides=(2, 2), name='conv0', padding='same', use_bias=False)
self.block_collector = []
for layer_index, (b, f) in enumerate(zip(block_num[layer_num], filter_num), start=1):
if layer_index == 1:
if block_type[layer_num] == 'basic block':
self.block_collector.append(self.block(f, name='conv1_0'))
else:
self.block_collector.append(self.block(f, projection=True, name='conv1_0'))
else:
self.block_collector.append(self.block(f, strides=(2, 2), name='conv{}_0'.format(layer_index)))
for block_index in range(1, b):
self.block_collector.append(self.block(f, name='conv{}_{}'.format(layer_index, block_index)))
self.bn = BatchNormalization(name='bn', momentum=0.9, epsilon=1e-5)
self.global_average_pooling = GlobalAvgPool2D()
self.fc = Dense(1000, name='fully_connected', activation='softmax', use_bias=False)
def call(self, inputs, training):
net = self.conv0(inputs)
print('input', inputs.shape)
print('conv0', net.shape)
net = tf.nn.max_pool2d(net, ksize=(3, 3), strides=(2, 2), padding='SAME')
print('max-pooling', net.shape)
for block in self.block_collector:
net = block(net, training)
print(block.name, net.shape)
net = self.bn(net, training)
net = tf.nn.relu(net)
net = self.global_average_pooling(net)
print('global average-pooling', net.shape)
net = self.fc(net)
print('fully connected', net.shape)
return net
__init__
中实例化网络所需的各个层,call
中定义网络的运算。在迁移到其他任务时,改写call
即可。
实例化各个层时尽量自定义name
,因为变量名是网络层的名字和变量名本身共同决定的,举例来说最后的全连接层中权重名为fully_connected/kernel:0
,自定义各个层的名称能保证采用by_name
方式加载模型不会出现问题。在只迁移部分权重时只需要设定model.load_weights()
中的by_name
参数即可。
ResNet的完整代码可以在我的github找到
https://github.com/Apm5/tensorflow_2.0_tutorial/blob/master/CNN/ResNet.py