2019独角兽企业重金招聘Python工程师标准>>>
训练完成以后我们就可以直接使用训练好的模板进行预测了
但是每次在预测之前都要进行训练,不是一个常规操作,毕竟有些复杂的模型需要训练好几天甚至更久
所以将训练好的模型进行保存,当有需要的时候重新加载这个模型进行预测或者继续训练,这才是一个常规操作
我们依然使用最简单的例子进行说明,这里沿用Tensorflow入门——实现最简单的线性回归模型的预测 这个例子进行
====================================================
模型的保存
在tensorflow中保存模型使用的是tf.train.Saver对象,我们需要在保存之前先实例化这个对象
saver = tf.train.Saver()
对于模型的保存,其实就是保存整个session对象,再给定一个path就实现了模型的保存(对应的path需要存在,如果不存在会报错)
saver.save(sess, SAVE_PATH + 'model')
保存完成以后,可以看到对应的目录下面生成了4个文件
model.meta中保存的是模型,而这个模型仅仅是计算流和参数的定义,可以认为是一个未经训练的模型
model.index和model.data-00000-of-00001中保存的是参数值,也就是真正训练的结果
checkpoint中保存的是最后几次保存的信息,从文件名就可以看出它是一个检查点,记录了其他几个文件之间的关系,这是一个txt文件,我们可以打开看一下(在这个例子中我们只保存了一次,如果保存多次的话这个文件中会记录多次保存结果的信息)
下面是运行的log
epoch= 0 _loss= 6029.333 _w= [0.005] _n= [0.005] epoch= 5000 _loss= 10.897877 _w= [4.2031364] _n= [-1.905781] epoch= 10000 _loss= 112.455055 _w= [4.7837024] _n= [-11.81817] epoch= 15000 _loss= 6.2376847 _w= [5.1548934] _n= [-19.740992] epoch= 20000 _loss= 2.9357195 _w= [5.2787647] _n= [-22.662355] epoch= 25000 _loss= 0.022824269 _w= [5.3112087] _n= [-23.141117] epoch= 30000 _loss= 1.3711997 _w= [5.326612] _n= [-23.255548] epoch= 35000 _loss= 0.005477888 _w= [5.3088646] _n= [-23.289743] epoch= 40000 _loss= 2.8727396 _w= [5.315157] _n= [-23.191956] epoch= 45000 _loss= 0.009563584 _w= [5.300157] _n= [-23.18857] 训练完成,开始预测。。。 x= 0.1610020536371326 y预测= [-22.44688] y实际= -22.401859054114084 x= 7.379937860774309 y预测= [16.030691] y实际= 16.075068797927063 x= 5.1744928042152685 y预测= [4.2754745] y实际= 4.320046646467379 x= 10.26990231423617 y预测= [31.434462] y实际= 31.478579334878784 x= 23.219346463697207 y预测= [100.45616] y实际= 100.49911665150611 x= 7.101197776563807 y预测= [14.544985] y实际= 14.589384149085088 x= 3.097841295090581 y预测= [-6.7932644] y实际= -6.7485058971672025 x= 6.474682013005717 y预测= [11.205599] y实际= 11.250055129320469 x= 13.811264369891983 y预测= [50.310234] y实际= 50.35403909152427 x= 29.260954830177415 y预测= [132.65846] y实际= 132.70088924484563
====================================================
模型的加载
因为保存时分成了模型和参数值两部分进行保存,所以在加载模型的时候也需要将模型和参数值(训练结果)两步分开进行加载
上面讲到了meta文件是模型,checkpoint是参数值,这里分别使用tf.train下的import_meta_graph和latest_checkpoint方法来加载
saver = tf.train.import_meta_graph(SAVE_PATH + 'model.meta') saver.restore(sess, tf.train.latest_checkpoint(SAVE_PATH))
这样,之前保存起来的模型就被我们重新加载成功了,但是在预测或者继续训练之前,我们需要重新定义相关的变量
但是也不是凭空的重新定义,因为这些参数已经在之前保存的模型中定义过了,我们只需要从已经加载的模型中将相关参数的定义给找出来就可以了
为了找回参数的定义,我们需要稍微修改一下模型,将这些需要在重新加载阶段找回的参数定义给上命名(如果是用来预测,我们需要找回X和OUT,如果是用来继续训练,我们需要找回X、OUT、loss),所以这里我们将模型中相关的参数都给上命名
X = tf.placeholder(tf.float32, name='X') Y = tf.placeholder(tf.float32, name='Y') W = tf.Variable(tf.zeros([1]),