Tensorflow:保存,加载模型

有时常常需要将训练的模型保存下来,用以预测或之后接着训练。

1. 保存模型

Tensorflow:保存,加载模型_第1张图片

主要是通过红色框圈出的两句话来保存模型的,下面是对这两句话的解释:

  • saver = tf.train.Saver()

在定义saver类时,可以指定需要保存的变量,若不指定,则是保存计算图中所有的变量,即像我这样处理。max_to_keeo参数是指最多保存模型的个数,此处设置为5,即最多保存5个模型。

  • saver.save()

图中三个参数一次是:指定需要保存的计算图;模型保存路径及模型前缀名;global_step = epoch_i 以epoch_i为参考量,此处在epoch_i %100 == 0 时保存模型。模型保存后生成以下几个文件:

  (每个文件的具体含义此处不再做具体解释了)

关于 saver = tf.train.Saver()和saver.save()的具体参数含义,可以查看tensorflow 官方网站解析.

2. 加载模型

Tensorflow:保存,加载模型_第2张图片

      前两句话指定加载的模型数据,后几句话表示从加载的模型中获取哪些变量。注意是通过变量名获取的,所以在训练模型时,最好手动的为变量赋名字,若在定义变量时未指定名字,tensorflow也会自动的为变量赋名字,但是变量名规律有点怪,所以最好自己手动赋名字。


以下是我在做保存模型和加载模型的完整代码,重要部分用红色的框圈出来了。

  • 保存模型

Tensorflow:保存,加载模型_第3张图片

  • 加载模型

Tensorflow:保存,加载模型_第4张图片

你可能感兴趣的:(Tensorflow:保存,加载模型)