Tensorflow 中模型持久化方法

tensorflow 实现模型持久化

本文主要介绍如何用tensorflow来实现训练好的模型的持久化以及模型的引用。

import tensorflow as tf
saver=tf.train.Saver()   #用来创建一个持久化类

在训练的时候,可以设置迭代固定的次数然后保存模型

save.save(sess,'mnist_fenlei_model/',global_step=global_step)

‘mnist_fenlei_model/’,是你保存模型的位置,这里 global_step 一定要写,不然后面引用模型的时候会有错误。
在测试程序中,用以下代码实现模型引用:

ckpt=tf.train.get_checkpoint_state('mnist_fenlei_model')     #自动找到最新的模型。
saver=tf.train.Saver()
if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
else:
        print('no ckpt ')

这样就可以用训练好的模型实现测试或者其他应用。

你可能感兴趣的:(Tensorflow)