保存和提取使用tensorflow训练的模型,以mnist数据集和lenet5为例

  • 如果你想保存在tensorflow上辛苦训练了很久的模型,随时地去使用它
  • tensorflow将上述过程分成了两个部分:
    训练和保存
    提取和使用
  • 训练部分为:加载训练数据,前向传播计算,代价函数评估,反向传播更新,保存
  • 提取部分为:
    加载测试数据(格式与训练数据保持一致)
    前向传播计算(框架与训练部分一致)
  • 我将以经典的手写数字识别数据集和lenet5框架举例

代码:训练与保存篇

# 导入tensorflow框架
import tensorflow as tf
import numpy as np
# 导入mnist数据集输入方法
from tensorflow.examples.tutorials.mnist import input_data

# 随机种子
tf.set_random_seed(1)

# 读取数据 如果数据存在则直接读取,如果不存在,则联网下载
mnist = input_data.read_data_sets(r'MNIST_data')

# 构建模型
# 数据样本本身是一行784个特征的像素矩阵
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.int64,[None])
# 将像素矩阵转为卷积网络识别的图片格式
x_img = tf.reshape(x,[-1,1,28,28])
x_img = tf.transpose(x_img,perm=[0,2,3,1])

# 前向传播计算
conv1_1 = tf.layers.conv2d(x_img,8,(3,3),padding='same',activation=tf.nn.relu,name='conv1_1')
pool1 = tf.layers.max_pooling2d(conv1_1,(2,2),(2,2),name='pool1')

conv2_1 = tf.layers.conv2d(pool1,16,(3,3),padding='same',activation=tf.nn.relu,name='conv2_1')
pool2 = tf.layers.max_pooling2d(conv2_1,(2,2),(2,2),name='pool2')

flatten = tf.layers.flatten(pool2,name='flatten')

fc1 = tf.layers.dense(flatten,120,activation=tf.nn.relu,name='fc1')
fc2 = tf.layers.dense(fc1,84,activation=tf.nn.relu,name='fc2')

y_ = tf.layers.dense(fc2,10)

# 计算代价 使用这个代价函数不需要独热编码,函数内部自动处理
loss = tf.losses.sparse_softmax_cross_entropy(logits=y_,labels=y)

# 训练
train_op = tf.train.AdamOptimizer(0.003).minimize(loss)

# 预测和准确率
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y_,1),y),dtype=tf.float32))

# 创建模型的保存和提取对象
saver = tf.train.Saver()
# 开启会话
config = tf.ConfigProto()
config.gpu_options.allow_growth = True # 这两行是tensorflow-gpu版的运行指令,如果是cpu的话,要注释掉
with tf.Session(config=config) as sess:
	# 激活所有变量
    sess.run(tf.global_variables_initializer())
    # 执行训练
    for i in range(10000):
    	# 批次训练,每次100张
        x_train,y_train = mnist.train.next_batch(100)
        los_val,acc_val,_ = sess.run([loss,accuracy,train_op],feed_dict={
            x:x_train,y:y_train
        })
        # 每100次打印一次准确率等参数
        if (i+1) % 100 == 0:
            print('批次',i+1)
            print('代价',los_val)
            print('测试集准确率',acc_val)
		# 每训练1000次保存一个模型,模型以ckpt结尾
        if (i+1) % 1000 == 0:
            saver.save(sess,r'model/mnist%d'%(i+1)+'.ckpt')

保存和提取使用tensorflow训练的模型,以mnist数据集和lenet5为例_第1张图片
如上图所示
保存的模型3个文件为一组,分别以data,index,meta:

  • meta file保存了graph结构,包括 GraphDef, SaverDef等,当存在meta
    file,我们可以不在文件中定义模型,也可以运行,而如果没有meta file,我们需要定义好模型,再加载data file,得到变量值

  • index file为一个 string-string table,table的key值为tensor名,value为BundleEntryProto, BundleEntryProto

  • data file保存了模型的所有变量的值

  • 默认是最多存储5个模型,当第六个模型生成时,会覆盖掉第一个模型,以此类推,保存最后5个模型

代码:模型提取使用

‘’‘

	模型提取篇,要拥有与训练模型时相同的前向计算结构,测试集的数据也要与训练时相同
	本篇是训练,保存和提取
	如果我们在日常工作中,要复用已经训练好的大型的网络结构,那么必要的步骤就是先去了解这个模型
	的前向计算结构,以及适应的数据格式,然后处理自己的数据,保证一致
	
’‘’
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

# 随机种子
tf.set_random_seed(1)

# 读取数据
mnist = input_data.read_data_sets(r'MNIST_data')

# 构建模型
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.int64,[None])
x_img = tf.reshape(x,[-1,1,28,28])
x_img = tf.transpose(x_img,perm=[0,2,3,1])

# 前向传播
conv1_1 = tf.layers.conv2d(x_img,8,(3,3),padding='same',activation=tf.nn.relu,name='conv1_1')
pool1 = tf.layers.max_pooling2d(conv1_1,(2,2),(2,2),name='pool1')

conv2_1 = tf.layers.conv2d(pool1,16,(3,3),padding='same',activation=tf.nn.relu,name='conv2_1')
pool2 = tf.layers.max_pooling2d(conv2_1,(2,2),(2,2),name='pool2')

flatten = tf.layers.flatten(pool2,name='flatten')

fc1 = tf.layers.dense(flatten,120,activation=tf.nn.relu,name='fc1')
fc2 = tf.layers.dense(fc1,84,activation=tf.nn.relu,name='fc2')

y_ = tf.layers.dense(fc2,10)

# 预测和准确率
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y_,1),y),dtype=tf.float32))

# 创建保存模型器
saver = tf.train.Saver()
# 开启会话
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    sess.run(tf.global_variables_initializer())
    # 要提取的模型的路径
    saver.restore(sess,r'model/mnist10000.ckpt')
    acc_val = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
    print('测试集预测准确率为:',acc_val)
‘’‘
	测试集预测准确率为: 0.9886
’‘’

你可能感兴趣的:(保存和提取使用tensorflow训练的模型,以mnist数据集和lenet5为例)