tensorflow加载resnet v2 chekpoint方法

版本:
python = 3.6
tensorflow-gpu = 1.11

代码及模型下载地址:https://github.com/tensorflow/models/tree/master/research/slim

问题:在使用tf.train.Saver(tf.global_variables())加载预训练模型时,出现加载的chekpoint中的graph key和代码构建的网络模型不对应的情况,显示缺失一个biases。错误如下:

NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Tensor name "resnet_v2_152/block1/unit_1/bottleneck_v2/conv1/biases" not found in checkpoint files resnet_v2_152.ckpt

问题原因:使用了错误的scope

解决方法:在初始化resnet时,使用resnet_arg_scope()(参数可选)

import tensorflow as tf
# resnet_v2是官方代码
import resnet_v2

slim = tf.contrib.slim

def main():
	
	# 输入图片大小为224*224
	batch_size = 1
	height, width = 224, 224
    image = tf.random_uniform((batch_size, height, width, 3))
	
	# 错误方法:
	# net, end_points = resnet_v2.resnet_v2_152(inputs, is_training=False)
    # 正确方法
    with slim.arg_scope(resnet_v2.resnet_arg_scope()):
        net, end_points = resnet_v2.resnet_v2_152(inputs=image, is_training=False, num_classes=1001)
    # print(net)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver(tf.global_variables())
    checkpoint_path = 'resnet_v2_152.ckpt'
    with tf.Session() as sess:
        sess.run(init)
        # 加载模型
        saver.restore(sess, checkpoint_path)
        
main()

解决问题参考:https://github.com/tensorflow/models/issues/2527
https://github.com/tensorflow/tensorflow/issues/4249

你可能感兴趣的:(tensorflow)