tensorflow保存和重载模型

1.保存一个模型

注意要给变量命名,要把op用tf.add_to_collection()添加
saver = tf.train.Saver(max_to_keep=20)  #注意要在初始化变量以后,迭代以前,参数为最大保存数量
saver.save(sess, 'my_test_model',global_step)#在迭代中保存模型
第一个参数是保存模型,第二个参数是保存路径,第三个参数是保存多次模型时用来命名
    注意:要把导入模型时需要用到的(例如需要feed_dict的变量,以及op命名(name=''))
2.导入一个训练好的模型

new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')#在创建会话以后
new_saver.restore(sess,'../model/model_LR_test')#加载模型中各种变量的值,注意这里不用文件的后缀
op=tf.get_collection('op_name')[0]
graph = tf.get_default_graph() 
X = graph.get_operation_by_name('X').outputs[0]#为了将placeholder加载出来

你可能感兴趣的:(tensorflow)