Tensorflow 如何存取网络模型

    当我训练完网络模型之后,会想到如何去保存训练好的weightsbias等网络参数,并在将来进行分类或者识别的任务中重新载入(restore)这个训练好的网络。那么在tensorflow中是如何实现对网络模型的保存的呢?
    在tensorflow中,变量存储在二进制文件中,主要包含从变量名到tensor值的映射关系。当创建一个Saver对象时,可以选择性地为检查点文件中的变量设置变量名。
    具体的,首先,给变量赋值,不过要在其后加上参数name=“”,注意,这里的name即要保存到网络模型的变量名称,未来在进行网络模型的载入时需要通过该变量值进行数据读取,类似字典的感觉。
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2”)
    之后,创建一个saver对象,来进行保存,同时不要忘记设定保存的路径。
saver = tf.train.Saver()
save_path = saver.save(sess, "./MNISTmodel/model.ckpt")
print ("Model saved in file: ", save_path)
    模型保存好之后,在需要再次使用这个模型时,同样需要再创建一个saver对象。不要忘记,要将模型中之前保存好的变量名称再赋给需要载入的模型,即
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name=“v2”)
不过此时不需要对这些变量进行初始化了
saver = tf.train.Saver()
......
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "./MNISTmodel/model.ckpt")
  print "Model restored."
    这样就可以直接恢复之前训练好的模型了。经过我的验证,准确度与之前训练好的时刻准确度一致。证明网络模型确实被成功恢复了。
    模型的保存不仅为了将来再次使用它进行分类等任务,也可以用来做fine-tuning。

你可能感兴趣的:(深度学习)