【Tensorflow 2.0 正式版教程】模型的保存、加载与迁移

模型的保存和加载可以直接通过Model类的save_weightsload_weights实现。默认的保存格式为tensorflow的checkpoint格式,也可以手动设置保存为h5文件。

model.save_weights('weights', save_format='h5')

加载时默认为根据网络的拓扑结构进行加载,这适用于不对网络进行更改,直接进行测试的情况。但如果只希望加载部分权重,可以更改为根据变量名进行加载。

model.load_weights('weights', by_name=True)

有个很坑的点是,加载checkpoint格式时源码中似乎没有实现by_name,所以尽管设置了by_name=True,他仍然会按照拓扑结构加载,然后报错提示部分变量不匹配,所以还是尽量都存成h5文件。

有了保存和读取模型的方法后,就可以在大型数据库上先进行预训练,然后将权重迁移到小数据库或其他任务上。举例i来说,实际应用场景如VGG16先在ImageNet上进行训练,再将除最后一层全连接以外的参数迁移到SSD完成目标检测任务。

实现时就涉及到两个问题:部分网络层不同(不同的分类任务最后一个全连接层的输出维度不同)和调用网络时的输出不同(目标检测任务需要提取网络的中间层的特征图输出),我们可以通过继承Model类来解决上述问题。

继承Model类需要实现两个函数,__init__()call(),下面以ResNet为例。

class ResNet(models.Model):
    def __init__(self, layer_num, **kwargs):
        super(ResNet, self).__init__(**kwargs)
        if block_type[layer_num] == 'basic block':
            self.block = BasicBlock
        else:
            self.block = BottleneckBlock

        self.conv0 = Conv2D(64, (7, 7), strides=(2, 2), name='conv0', padding='same', use_bias=False)

        self.block_collector = []
        for layer_index, (b, f) in enumerate(zip(block_num[layer_num], filter_num), start=1):
            if layer_index == 1:
                if block_type[layer_num] == 'basic block':
                    self.block_collector.append(self.block(f, name='conv1_0'))
                else:
                    self.block_collector.append(self.block(f, projection=True, name='conv1_0'))
            else:
                self.block_collector.append(self.block(f, strides=(2, 2), name='conv{}_0'.format(layer_index)))

            for block_index in range(1, b):
                self.block_collector.append(self.block(f, name='conv{}_{}'.format(layer_index, block_index)))

        self.bn = BatchNormalization(name='bn', momentum=0.9, epsilon=1e-5)
        self.global_average_pooling = GlobalAvgPool2D()
        self.fc = Dense(1000, name='fully_connected', activation='softmax', use_bias=False)

    def call(self, inputs, training):
        net = self.conv0(inputs)
        print('input', inputs.shape)
        print('conv0', net.shape)
        net = tf.nn.max_pool2d(net, ksize=(3, 3), strides=(2, 2), padding='SAME')
        print('max-pooling', net.shape)

        for block in self.block_collector:
            net = block(net, training)
            print(block.name, net.shape)
        net = self.bn(net, training)
        net = tf.nn.relu(net)

        net = self.global_average_pooling(net)
        print('global average-pooling', net.shape)
        net = self.fc(net)
        print('fully connected', net.shape)
        return net

__init__中实例化网络所需的各个层,call中定义网络的运算。在迁移到其他任务时,改写call即可。

实例化各个层时尽量自定义name,因为变量名是网络层的名字和变量名本身共同决定的,举例来说最后的全连接层中权重名为fully_connected/kernel:0,自定义各个层的名称能保证采用by_name方式加载模型不会出现问题。在只迁移部分权重时只需要设定model.load_weights()中的by_name参数即可。

ResNet的完整代码可以在我的github找到
https://github.com/Apm5/tensorflow_2.0_tutorial/blob/master/CNN/ResNet.py

你可能感兴趣的:(tensorflow)