tensorflow模型restore问题

训练好的模型在恢复时候遇到如下问题:

Attempting to use uninitialized value Variable
Caused by op u'Variable/read'

但是restore的代码如下:

#load model
def load_model(model, ckpt_path, session):
    start_time = time.time()
    try:
        model.saver.restore(session, ckpt_path)
    except tf.errors.NotFoundError as e:
        print("Can't load checkpoint")
        print("%s" % str(e))

    session.run(tf.tables_initializer())
    print("loaded model parameters from %s, time %.2fs" %(ckpt_path, time.time() - start_time))
    return model

def create_or_load_model(model, model_dir, session):
    latest_ckpt = tf.train.latest_checkpoint(model_dir)
    if latest_ckpt:
        model = load_model(model, latest_ckpt, session)
    else:
        start_time = time.time()
        session.run(tf.global_variables_initializer())
        session.run(tf.tables_initializer())
        print("created model with fresh parameters, time %.2fs" %(time.time() - start_time))

    global_step = model.global_step.eval(session=session)
    return model, global_step

这段代码的意思就是存在训练好的模型就进行恢复,否则就初始化全局变量,从头训练。
这块的问题就出在这个初始化全局变量这里。再看关系调用情况代码:

with tf.Graph().as_default() as graph:
        with tf.Session(graph=graph) as sess:
            model = Model('train')
            model,curent_global_step = \
            _mh.create_or_load_model(model,_hp.model_path,sess)

模型创建后再进行全局初始化,这就回导致上面的问题。模型中参数创建了但是并没有初始化。
所以修改代码如下:

    with tf.Graph().as_default() as graph:
        with tf.Session(graph=graph) as sess:
            sess.run(tf.global_variables_initializer())
            model = Model('train')
            model,curent_global_step = \
            _mh.create_or_load_model(model,_hp.model_path,sess)

模型恢复代码修改如下:

#load model
def load_model(model, ckpt_path, session):
    start_time = time.time()
    try:
        model.saver.restore(session, ckpt_path)
    except tf.errors.NotFoundError as e:
        print("Can't load checkpoint")
        print("%s" % str(e))

    session.run(tf.tables_initializer())
    print("loaded model parameters from %s, time %.2fs" %(ckpt_path, time.time() - start_time))
    return model

def create_or_load_model(model, model_dir, session):
    latest_ckpt = tf.train.latest_checkpoint(model_dir)
    if latest_ckpt:
        model = load_model(model, latest_ckpt, session)
    else:
        start_time = time.time()
        session.run(tf.global_variables_initializer())
        session.run(tf.tables_initializer())
        print("created model with fresh parameters, time %.2fs" %(time.time() - start_time))

    global_step = model.global_step.eval(session=session)
    return model, global_step

经过以上修改发现可以正常恢复训练模型

你可能感兴趣的:(Python机器学习)