日常填坑之TF模型加载“Key Variable_xxx not found in checkpoint”

保存模型的时候一切正常,但是加载的时候就会出现“Key Variable_xxx not found in checkpoint”错误。首先要分析错误原因,一般情况下model.ckpt文件肯定都有的,都是加载的时候出的问题。所以先把ckpt文件中的变量打印出来看看。这里有个前提条件,定义变量的时候需要指定name参数,不然打印出来的都是“Variable_xxx:0”之类的!

import os
from tensorflow.python import pywrap_tensorflow

current_path = os.getcwd()
model_dir = os.path.join(current_path, 'model')
checkpoint_path = os.path.join(model_dir,'embedding.ckpt-0') # 保存的ckpt文件名,不一定是这个
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and values
for key in var_to_shape_map:
    print("tensor_name: ", key)
    # print(reader.get_tensor(key)) # 打印变量的值,对我们查找问题没啥影响,打印出来反而影响找问题

我的输出:

tensor_name:  w_1_1/Adam_1
tensor_name:  w_2/Adam_1
tensor_name:  b_2
tensor_name:  w_1_1
tensor_name:  w_out/Adam_1
tensor_name:  b_1_1/Adam_1
tensor_name:  w_out
tensor_name:  w_1
tensor_name:  b_out
tensor_name:  b_2/Adam
tensor_name:  b_1
tensor_name:  b_out/Adam_1
tensor_name:  b_1_1/Adam
tensor_name:  w_1_1/Adam
tensor_name:  b_1_1
tensor_name:  w_2/Adam
tensor_name:  w_2
tensor_name:  w_out/Adam
tensor_name:  beta1_power
tensor_name:  b_out/Adam
tensor_name:  b_2/Adam_1
tensor_name:  beta2_power

这就很明显了,我的网络里只有”b_1,b_2,w_1,w_2”这种变量,由于使用了tf.train.AdamOptimizer()来更新梯度,所以在保存检查点的时候如果不指定则是全局保存,把优化的变量“w_out/Adam”这种命名规则的变量也一并保存了,自然在恢复的时候就会出现找不到XX变量。解决办法,在声明 saver = tf.train.Saver()的时候带上参数,即需要保存的变量

def ann_net(w_alpha=0.01, b_alpha=0.1):
    # 隐藏层_1
    w_1 = tf.Variable(w_alpha * tf.random_normal(shape=(input_size, hidden1_size)), name='w_1')
    b_1 = tf.Variable(b_alpha * tf.random_normal(shape=[hidden1_size]),name='b_1')
    hidden1_output = tf.nn.tanh(tf.add(tf.matmul(X, w_1), b_1))
    hidden1_output = tf.nn.dropout(hidden1_output, keep_prob)

    # 隐藏层_2
    shp1 = hidden1_output.get_shape()
    w_2 = tf.Variable(w_alpha * tf.random_normal(shape=(shp1[1].value, hidden2_size)), name='w_2')
    b_2 = tf.Variable(b_alpha * tf.random_normal(shape=[hidden2_size]),name='b_2')
    hidden2_output = tf.nn.tanh(tf.add(tf.matmul(hidden1_output, w_2), b_2))
    hidden2_output = tf.nn.dropout(hidden2_output, keep_prob)

    # 输出层
    shp2 = hidden2_output.get_shape()
    w_output = tf.Variable(w_alpha * tf.random_normal(shape=(shp2[1].value, embeding_size)), name='w_out')
    b_output = tf.Variable(b_alpha * tf.random_normal(shape=[embeding_size]),name='b_out')
    output = tf.add(tf.matmul(hidden2_output, w_output), b_output)

    variables_dict = {'b_2': b_2, 'w_out': w_output, 'w_1': w_1, 'b_out': b_output, 'b_1': b_1, 'w_2': w_2}
    return output,variables_dict

在train()函数里,使用variables_dict初始化saver

with tf.device('/cpu:0'):
    saver = tf.train.Saver(var_dict)
    with tf.Session(config=tf.ConfigProto(device_count={'cpu': 0})) as sess:
        sess.run(tf.global_variables_initializer())
        step = 0

        ckpt = tf.train.get_checkpoint_state('model/')
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            step = int(ckpt.model_checkpoint_path.rsplit('-',1)[1])
            print("Model restored.")
    # 训练代码
    # ... ...
    saver.save(sess, 'model/embedding.model', global_step=step)

如果是从网上down的模型比如vgg-16之类的,只想加载前面的几层,而且用自己定义的变量,方法一样,指定一个变量列表或者字典,传给tf.train.Saver()。
如果是LSTM,道理也一样,不过系统存储的时候有tf自己的规则,LSTM默认的variable_scope叫做“bidirectional_rnn”如果没额外操作过的话变量前会自动带上这个名字,所以保存的模型里的名字就类似于下面这样:

tensor_name:  train/train_1/fc_b/Adam
tensor_name:  train_1/fc_b
tensor_name:  train/fc_b
tensor_name:  train/train/bidirectional_rnn/fw/basic_lstm_cell/kernel/Adam
tensor_name:  train/bidirectional_rnn/fw/basic_lstm_cell/kernel
tensor_name:  train/train/bidirectional_rnn/bw/basic_lstm_cell/bias/Adam
tensor_name:  train/beta2_power
tensor_name:  train/train/fc_w/Adam
tensor_name:  train_1/beta1_power
tensor_name:  train/train/bidirectional_rnn/bw/basic_lstm_cell/bias/Adam_1
tensor_name:  train/train/bidirectional_rnn/fw/basic_lstm_cell/bias/Adam_1
tensor_name:  train/beta1_power
tensor_name:  train/train_1/fc_w/Adam_1
tensor_name:  train/train/bidirectional_rnn/bw/basic_lstm_cell/kernel/Adam_1
tensor_name:  train_1/beta2_power
tensor_name:  train/train/fc_w/Adam_1
tensor_name:  train/bidirectional_rnn/bw/basic_lstm_cell/kernel
tensor_name:  train/train/fc_b/Adam
tensor_name:  train/bidirectional_rnn/bw/basic_lstm_cell/bias
tensor_name:  train/fc_w
tensor_name:  train_1/fc_w
tensor_name:  train/bidirectional_rnn/fw/basic_lstm_cell/bias
tensor_name:  train/train/fc_b/Adam_1
tensor_name:  train/train/bidirectional_rnn/fw/basic_lstm_cell/kernel/Adam_1
tensor_name:  train/train/bidirectional_rnn/bw/basic_lstm_cell/kernel/Adam
tensor_name:  train/train/bidirectional_rnn/fw/basic_lstm_cell/bias/Adam
tensor_name:  train/train_1/fc_b/Adam_1
tensor_name:  train/train_1/fc_w/Adam

前面的”train”是我添加的variable_scope,所以恢复的时候可以这样:

include = ['train/fc_b', 'train/fc_w',
           'train/bidirectional_rnn/bw/basic_lstm_cell/bias',
           'train/bidirectional_rnn/bw/basic_lstm_cell/kernel',
           'train/bidirectional_rnn/fw/basic_lstm_cell/bias',
           'train/bidirectional_rnn/fw/basic_lstm_cell/kernel']
variables_to_restore = tf.contrib.slim.get_variables_to_restore(include=include)
saver = tf.train.Saver(variables_to_restore)
with tf.Session(config=tf.ConfigProto(device_count={'cpu': 0})) as sess:
    sess.run(tf.global_variables_initializer())
    # ... ...

你可能感兴趣的:(日常填坑之TF模型加载“Key Variable_xxx not found in checkpoint”)