最近在学习郑泽宇老师的《Tensorflow实战Google深度学习框架》,书中样例代码写的非常简洁、易懂,而且逻辑性很强,在这里进行记录一下。
以下样例代码使用tensorflow框架构建两层全连接神经网络,识别MNIST手写数字数据集。其中用到了一些优化方法:使用滑动平均模型控制权值参数的变化率、定义学习率的衰减率控制学习率的变化率,使得在模型训练初期,模型参数变化幅度较大,模型向着最优化的方向快速移动,当逐渐接近最优值时,模型参数变化率逐渐降低,逐渐逼近最优解。
#coding:utf-8
"""
Created by cheng star at 2018/9/2 15:57
@email : xxcheng0708@163.com
"""
import os , sys , time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
module_path = os.getcwd()
mnist_data = module_path + "/../mnist_data/"
# MNIST 数据集相关参数
INPUT_NODE = 784 # 输入层神经元个数
OUTPUT_NODE = 10 # 输出层神经元个数
# 配置神经网络参数
LAYER1_NODE = 500 # 隐藏层神经元个数
BATCH_SIZE = 100 # 每个批次训练样本的大小
LEARNING_RATE_BASE = 0.8 # 基础学习率
LEARNING_RATE_DECAY = 0.99 # 学习率的衰减率
REGULARIZATION_RATE = 0.0001 # 正则化项惩罚因子
TRAINING_STEPS = 30000 # 模型训练总次数
MOVING_AVERAGE_DECAY = 0.99 # 滑动平均模型参数衰减率
def inference(input_tensor , avg_class , weights1 , biases1 , weights2 , biases2) :
"""
g构建2层神经网络模型
:param input_tensor: 输入特征向量
:param avg_class: 滑动平均模型函数
:param weights1: 隐藏层权值参数
:param biases1: 隐藏层偏置参数
:param weights2: 输出层权值参数
:param biases2: 输出层偏置参数
:return:
"""
# 没有滑动平均模型
if avg_class == None :
layer1 = tf.nn.relu(tf.matmul(input_tensor , weights1) + biases1)
return tf.matmul(layer1 , weights2) + biases2
else : # 使用滑动平均模型
layer1 = tf.nn.relu(tf.matmul(input_tensor , avg_class.average(weights1)) + avg_class.average(biases1))
return tf.matmul(layer1 , avg_class.average(weights2)) + avg_class.average(biases2)
def train(mnist) :
"""
训练神经网络模型
:param mnist: 输入手写数据集
:return:
"""
x = tf.placeholder(shape=[None , INPUT_NODE] , dtype=tf.float32 , name="x-input")
y_ = tf.placeholder(shape=[None , OUTPUT_NODE] , dtype=tf.float32 , name="y-input")
# 生成隐藏层参数
with tf.variable_scope("layer1") :
weights1 = tf.get_variable(name="weights1" ,
shape=[INPUT_NODE , LAYER1_NODE] ,
initializer=tf.truncated_normal_initializer(stddev=0.1))
biases1 = tf.get_variable(name="biases1" ,
shape=[LAYER1_NODE] ,
initializer=tf.constant_initializer(value=0.1))
# 生成输出层参数
with tf.variable_scope("output-layer") :
weights2 = tf.get_variable(name="weights2" ,
shape=[LAYER1_NODE , OUTPUT_NODE] ,
initializer=tf.truncated_normal_initializer(stddev=0.1))
biases2 = tf.get_variable(name="biases2" ,
shape=[OUTPUT_NODE] ,
initializer=tf.constant_initializer(value=0.1))
# weights1 = tf.Variable(initial_value=tf.truncated_normal(shape=[INPUT_NODE , LAYER1_NODE] , stddev=0.1))
# biases1 = tf.Variable(initial_value=tf.constant(value=0.1 , shape=[LAYER1_NODE]))
#
# weights2 = tf.Variable(initial_value=tf.truncated_normal(shape=[LAYER1_NODE , OUTPUT_NODE] , stddev=0.1))
# biases2 = tf.Variable(initial_value=tf.constant(value=0.1 , shape=[OUTPUT_NODE]))
# 生成网络模型,不适用滑动平均模型
y = inference(x , None , weights1 , biases1 , weights2 , biases2)
# 定义变量global_step,记录训练次数
global_step = tf.Variable(name="global_step" ,initial_value=0 , trainable=False)
# 定义滑动平均模型
variable_average = tf.train.ExponentialMovingAverage(decay=MOVING_AVERAGE_DECAY , num_updates=global_step)
variable_average_op = variable_average.apply(tf.trainable_variables())
# 使用滑动平均模型构造模型
average_y = inference(x , variable_average , weights1 , biases1 , weights2 , biases2)
# 构造交叉熵损失函数
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y , labels=tf.argmax(y_ , 1))
# 计算所有样本的交叉熵平均值
cross_entropy_mean = tf.reduce_mean(cross_entropy)
# 定义L2正则化损失函数
regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
regularization = regularizer(weights1) + regularizer(weights2)
# 定义总损失函数
loss = cross_entropy_mean + regularization
# 设置学习率衰减率
learning_rate = tf.train.exponential_decay(
learning_rate=LEARNING_RATE_BASE , # 基础学习率
global_step=global_step , # 当前迭代次数
decay_steps=mnist.train.num_examples / BATCH_SIZE , # 过完所有的训练数据需要的迭代次数
decay_rate=LEARNING_RATE_DECAY # 学习率衰减速度
)
# 训练模型
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss=loss ,global_step=global_step)
with tf.control_dependencies([train_step , variable_average_op]) :
train_op = tf.no_op(name="train")
# 计算模型准确率
correct_prediction = tf.equal(tf.argmax(average_y , 1) , tf.argmax(y_ , 1))
# correct_prediction = tf.equal(tf.argmax(y , 1) , tf.argmax(y_ , 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction , dtype=tf.float32))
with tf.Session() as sess :
init_op = tf.global_variables_initializer()
sess.run(init_op)
# 验证数据集
validate_feed = {
x : mnist.validation.images ,
y_ : mnist.validation.labels
}
# 测试数据集
test_feed = {
x : mnist.test.images ,
y_ : mnist.test.labels
}
for i in range(TRAINING_STEPS) :
if i % 1000 == 0 :
validate_acc , loss_value , global_step_value= sess.run([accuracy , loss , global_step] , feed_dict=validate_feed)
print("After {0} rounds training , the global step is {1} ,"
"ths loss is {2} , the accuracy on validate dataset is {3}.".format(i , global_step_value ,loss_value , validate_acc))
xs , ys = mnist.train.next_batch(BATCH_SIZE)
sess.run(train_op , feed_dict={x : xs , y_ : ys})
test_acc , loss_value = sess.run([accuracy , loss] , feed_dict=test_feed)
print("After {0} rounds training , the loss is {1} , the accuracy on test dataset is {2}.".format(TRAINING_STEPS , loss_value , test_acc))
def main(argv = None) :
mnist = input_data.read_data_sets(train_dir=mnist_data , one_hot=True)
train(mnist)
if __name__ == "__main__" :
tf.app.run()
模型输出结果如下所示:
C:\ProgramData\Anaconda3\python.exe E:/程序/python代码/LearningAI/learning_tensorflow/test_mnist_zzy_demo.py
WARNING:tensorflow:From E:/程序/python代码/LearningAI/learning_tensorflow/test_mnist_zzy_demo.py:150: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting E:\程序\python代码\LearningAI\learning_tensorflow/../mnist_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting E:\程序\python代码\LearningAI\learning_tensorflow/../mnist_data/train-labels-idx1-ubyte.gz
Extracting E:\程序\python代码\LearningAI\learning_tensorflow/../mnist_data/t10k-images-idx3-ubyte.gz
Extracting E:\程序\python代码\LearningAI\learning_tensorflow/../mnist_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
After 0 rounds training , the global step is 0 ,ths loss is 3.164243459701538 , the accuracy on validate dataset is 0.0957999974489212.
After 1000 rounds training , the global step is 1000 ,ths loss is 0.23361584544181824 , the accuracy on validate dataset is 0.9760000109672546.
After 2000 rounds training , the global step is 2000 ,ths loss is 0.2088349461555481 , the accuracy on validate dataset is 0.9814000129699707.
After 3000 rounds training , the global step is 3000 ,ths loss is 0.19048267602920532 , the accuracy on validate dataset is 0.9833999872207642.
After 4000 rounds training , the global step is 4000 ,ths loss is 0.1719156801700592 , the accuracy on validate dataset is 0.9833999872207642.
After 5000 rounds training , the global step is 5000 ,ths loss is 0.16165196895599365 , the accuracy on validate dataset is 0.9855999946594238.
After 6000 rounds training , the global step is 6000 ,ths loss is 0.1522310972213745 , the accuracy on validate dataset is 0.9850000143051147.
After 7000 rounds training , the global step is 7000 ,ths loss is 0.14527904987335205 , the accuracy on validate dataset is 0.9847999811172485.
After 8000 rounds training , the global step is 8000 ,ths loss is 0.13154302537441254 , the accuracy on validate dataset is 0.9850000143051147.
After 9000 rounds training , the global step is 9000 ,ths loss is 0.12546882033348083 , the accuracy on validate dataset is 0.9851999878883362.
After 10000 rounds training , the global step is 10000 ,ths loss is 0.1196407675743103 , the accuracy on validate dataset is 0.9854000210762024.
After 11000 rounds training , the global step is 11000 ,ths loss is 0.11399652063846588 , the accuracy on validate dataset is 0.9864000082015991.
After 12000 rounds training , the global step is 12000 ,ths loss is 0.109804168343544 , the accuracy on validate dataset is 0.9850000143051147.
After 13000 rounds training , the global step is 13000 ,ths loss is 0.11330102384090424 , the accuracy on validate dataset is 0.9854000210762024.
After 14000 rounds training , the global step is 14000 ,ths loss is 0.10213477909564972 , the accuracy on validate dataset is 0.9843999743461609.
After 15000 rounds training , the global step is 15000 ,ths loss is 0.09962807595729828 , the accuracy on validate dataset is 0.9854000210762024.
After 16000 rounds training , the global step is 16000 ,ths loss is 0.09647612273693085 , the accuracy on validate dataset is 0.9851999878883362.
After 17000 rounds training , the global step is 17000 ,ths loss is 0.0948617160320282 , the accuracy on validate dataset is 0.9854000210762024.
After 18000 rounds training , the global step is 18000 ,ths loss is 0.09350297600030899 , the accuracy on validate dataset is 0.98580002784729.
After 19000 rounds training , the global step is 19000 ,ths loss is 0.09059648215770721 , the accuracy on validate dataset is 0.9855999946594238.
After 20000 rounds training , the global step is 20000 ,ths loss is 0.08989834785461426 , the accuracy on validate dataset is 0.9851999878883362.
After 21000 rounds training , the global step is 21000 ,ths loss is 0.08760837465524673 , the accuracy on validate dataset is 0.98580002784729.
After 22000 rounds training , the global step is 22000 ,ths loss is 0.08716955780982971 , the accuracy on validate dataset is 0.9854000210762024.
After 23000 rounds training , the global step is 23000 ,ths loss is 0.08485446870326996 , the accuracy on validate dataset is 0.98580002784729.
After 24000 rounds training , the global step is 24000 ,ths loss is 0.08533652126789093 , the accuracy on validate dataset is 0.9855999946594238.
After 25000 rounds training , the global step is 25000 ,ths loss is 0.08394122123718262 , the accuracy on validate dataset is 0.9851999878883362.
After 26000 rounds training , the global step is 26000 ,ths loss is 0.08382612466812134 , the accuracy on validate dataset is 0.9851999878883362.
After 27000 rounds training , the global step is 27000 ,ths loss is 0.08189772069454193 , the accuracy on validate dataset is 0.9861999750137329.
After 28000 rounds training , the global step is 28000 ,ths loss is 0.08306366205215454 , the accuracy on validate dataset is 0.98580002784729.
After 29000 rounds training , the global step is 29000 ,ths loss is 0.08183949440717697 , the accuracy on validate dataset is 0.98580002784729.
After 30000 rounds training , the loss is 0.08004538714885712 , the accuracy on test dataset is 0.9836000204086304.
Process finished with exit code 0
参考文献:《Tensorflowshiz实战Google深度学习框架》 郑泽宇、顾思宇 等著