Tensorflow学习笔记:CNN篇(7)——Finetuning,模型的保存与恢复

Tensorflow学习笔记:CNN篇(7)——Finetuning,模型的保存与恢复


前序

— 前文,我们介绍了VGGNet的组成结构,并在CIFAR-10数据集上进行了实现,本文着重介绍Finetuning的能力。Finetuning的意思是在已有模型之后进行参数和训练模型复用的缩写,也是真实工程应用中最常用的的是由既有模型的手段。


实战Tensorflow模型的存储与恢复

— 随着模型形式的越来越复杂,对模型存储的要求和格式也越来越重要。借鉴敏捷开发的模型,首先对于常用变量的定义,笔者建议使用全部变量进行存储;而对于模型专用的类,也建议创建专门的模型控制。工程文件的分类如图所示:
Tensorflow学习笔记:CNN篇(7)——Finetuning,模型的保存与恢复_第1张图片


代码示例

Step 1: 全局数据与模型类的定义

首先是对全局数据的定义,对模型的保存和读取来说,存储文件夹是个通用的变量,故在工程目录下新建一个名为global_variable .py文件,定义了文件的存储位置,其内容如下:

save_path = '.\\model\\'

Step 2: 模型的定义

这里定义了一个线性回归模型,在工程目录下新建一个名为lineRegulation_model .py文件,将其定义为类使用,这样做的好处是使用相同的创建方法将类的定义放在不同文件中,也就是在训练模型和恢复模型中保存和重新加载,而不会因为定义或输入错误而产生不好的结果。

import tensorflow as tf


class LineRegModel:
    def __init__(self):
        self.a_val = tf.Variable(tf.random_normal([1]))
        self.b_val = tf.Variable(tf.random_normal([1]))
        self.x_input = tf.placeholder(tf.float32)
        self.y_label = tf.placeholder(tf.float32)
        self.y_output = tf.add(tf.multiply(self.x_input, self.a_val), self.b_val)
        self.loss = tf.reduce_mean(tf.pow(self.y_output - self.y_label, 2))

    def get_op(self):
        return tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)

Step 3: 模型的训练

在工程目录下新建一个名为model_train .py文件,代码如下:

import tensorflow as tf
import numpy as np
import global_variable
import lineRegulation_model as model

train_x = np.random.rand(5)
train_y = 5 * train_x + 3.2   # y = 5 * x + 3
model = model.LineRegModel()

a_val = model.a_val
b_val = model.b_val

x_input = model.x_input
y_label = model.y_label

y_output = model.y_output

loss = model.loss
optimize = model.get_op()
saver = tf.train.Saver()
if __name__ == "__main__":
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    flag = True
    epoch = 0
    while flag:
        epoch += 1
        _ , loss_val = sess.run([optimize,loss],feed_dict={x_input:train_x,y_label:train_y})
        if loss_val < 1e-6:
            flag = False
    print(a_val.eval(sess) , "   ", b_val.eval(sess))
    print("-----------%d-----------"%epoch)

    saver.save(sess,global_variable.save_path)
    print("model save finished")
    sess.close()

将训练好的模型进行保存,使用Saver函数的save方法将对话保存到相应的路径中,保存结果如图所示:
Tensorflow学习笔记:CNN篇(7)——Finetuning,模型的保存与恢复_第2张图片

Step 4: 模型的恢复

对于模型的恢复所使用的方法为Saver类中的restore函数,恢复指定保存文件夹中的训练模型,在工程目录下新建一个名为model_restore.py文件,代码如下:

import tensorflow as tf
import global_variable
import lineRegulation_model as model

model = model.LineRegModel()

x_input = model.x_input
y_output = model.y_output

saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess,global_variable.save_path)

result = sess.run(y_output,feed_dict={x_input:[1]})
print(result)

首先是对模型的使用,这里采用的是类的实现,可以在最大程度上复用已有参数的定义,在对话内部Saver函数中对模型进行了恢复,之后通过对话输入待计算的数值后打印结果。
这里写图片描述

你可能感兴趣的:(Tensorflow学习笔记)