tensorflow报NotFoundError (see above for traceback): Key G_b0 not found in checkpoint

错误提示

在使用TensorFlow加载ckpt文件的时候报NotFoundError (see above for traceback): Key G_b0 not found in checkpoint错误,详细错误如下:

tensorflow报NotFoundError (see above for traceback): Key G_b0 not found in checkpoint_第1张图片

加载ckpt文件的代码如下:

saver = tf.train.Saver()
saver.restore(sess,ckpt_path)

分析错误

从错误的提示来看就是指模型文件中没有找到G_b0这个变量,因为在网络中需要用到这个变量,所以导致无法给这个变量赋值,导致加载模型的时候报错。首先,我们应该确认我们的ckpt文件中是否包含G_b0这个变量的信息,通过下面的代码可以查看模型文件中的变量信息

from tensorflow.contrib.framework.python.framework import checkpoint_utils
var_list = checkpoint_utils.list_variables("log/GAN.ckpt")
for v in var_list:
    print(v)

tensorflow报NotFoundError (see above for traceback): Key G_b0 not found in checkpoint_第2张图片

通过输出模型文件中的信息可以发现,的确是存在G_b0变量的信息,所以是不是因为名字不同导致它无法找到的呢所以,这时候应该确认训练时候变量定义的代码,细心的同学可能已经发现了这个的G_b0前面还多了一个G的scope,的确问题就在这

修正错误

通过分析错误我们已经发现了是由于scope导致变量找不到的,所以在加载模型的时候,给定义变量加一个scope就好了,代码如下:

改正前:

z = tf.placeholder(dtype=tf.float32,shape=(None,100))
G_z = GAN.Generator(z)

改正后:

with tf.variable_scope('G'):
      z = tf.placeholder(dtype=tf.float32,shape=(None,100))
      G_z = GAN.Generator(z)

 

你可能感兴趣的:(tensorflow修炼之路)