TensroFlow学习——第三章(一)

MINIST数字识别问题

全连接层实现手写数字识别

采用了L2正则化、滑动平均模型和指数衰减学习率
训练结果为:训练集93%,验证集95.36%,测试集95.01%
第一部分:前向传播和网络参数

# 定义前向传播和神经网络中的参数
import tensorflow as tf

# 配置神经网络参数
INPUT_NODE=784  # 输入层节点个数
OUTPUT_NODE=10  # 输出层节点个数
LAYER1_NODE=500 # 隐层节点个数

def get_weight_variable(shape,regularizer):
	weights=tf.get_variable('weights',shape,initializer=tf.truncated_normal_initializer(mean=0,stddev=0.1),)
	# 正则化
	if regularizer!=None:
		tf.add_to_collection('losses',regularizer(weights))
	return weights

# 前向传播
def inference(input_tensor,regularizer,avg_class,reuse):
	# 声明第一层神经网络的变量并完成前向传播
	with tf.variable_scope('layer1',reuse=reuse):
		weights=get_weight_variable([INPUT_NODE,LAYER1_NODE],regularizer)
		biases=tf.get_variable('biases',[LAYER1_NODE],initializer=tf.constant_initializer(0.0))
		if avg_class == None:
			layer1=tf.nn.relu(tf.matmul(input_tensor,weights)+biases)
		else:
			layer1=tf.nn.relu(tf.matmul(input_tensor,avg_class.average(weights))+avg_class.average(biases))
	# 声明第二层神经网络的变量并完成前向传播
	with tf.variable_scope('layer2',reuse=reuse):
		weights=get_weight_variable([LAYER1_NODE,OUTPUT_NODE],regularizer)
		biases=tf.get_variable('biases',[OUTPUT_NODE],initializer=tf.constant_initializer(0.0))
		if avg_class == None:
			layer2=tf.matmul(layer1,weights)+biases
		else:
			layer2=tf.matmul(layer1,avg_class.average(weights))+avg_class.average(biases)
	return layer2

第二部分:训练,包括训练集和验证集

# 神经网络训练程序
import os
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference

# 配置神经网络参数
BATCH_SIZE=100
LEARNING_RATE_BASE=0.8
LEARNING_RATE_DECAY=0.99
REGULARAZTION_RATE=0.0001
TRAINING_STEP=200
MOVING_AVERAGE_DECAY=0.99
# 模型保存路径和文件名
MODEL_SAVE_PATH='./'
MODEL_NAME='model.ckpt'
# 训练参数
train_acc,valid_acc=[],[]
train_loss,valid_loss=[],[]
epochs=[]

def train(mnist):
	x=tf.placeholder(tf.float32,[None,mnist_inference.INPUT_NODE],name='x-input')
	y_=tf.placeholder(tf.float32,[None,mnist_inference.OUTPUT_NODE],name='y-input')

	regularizer=tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)

	y=mnist_inference.inference(x,regularizer=regularizer,avg_class=None,reuse=False)
	global_step=tf.Variable(0,trainable=False)

	variable_averages=tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step)
	variable_averages_op=variable_averages.apply(tf.trainable_variables())

	cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1))
	cross_entropy_mean=tf.reduce_mean(cross_entropy)

	loss=cross_entropy_mean+tf.add_n(tf.get_collection('losses'))

	learning_rate=tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples/BATCH_SIZE,LEARNING_RATE_DECAY)
	train_step=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step)

	with tf.control_dependencies([train_step,variable_averages_op]):
		train_op=tf.no_op(name='train')

	# 计算使用滑动平均之后的前向传播结果
	average_y=mnist_inference.inference(x,regularizer=regularizer,avg_class=variable_averages,reuse=tf.AUTO_REUSE)

	correct_prediction=tf.equal(tf.argmax(average_y,1),tf.argmax(y_,1))
	#tf.cast为转化数据格式
	accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

	# 初始化TensorFlow持久化类
	saver=tf.train.Saver()
	# 加载测试数据
	validate_feed={x:mnist.validation.images,y_:mnist.validation.labels}
	
	with tf.Session() as sess:
		tf.initialize_all_variables().run()
		for i in range(TRAINING_STEP):
			xs,ys=mnist.train.next_batch(BATCH_SIZE)
			_,tra_loss,step=sess.run([train_op,loss,global_step],feed_dict={x:xs,y_:ys})
			val_loss=sess.run([loss],feed_dict=validate_feed)

			epochs.append(step)
			train_acc.append(sess.run(accuracy,feed_dict={x:xs,y_:ys}))
			train_loss.append(tra_loss)
			valid_acc.append(sess.run(accuracy,feed_dict=validate_feed))
			valid_loss.append(val_loss)

			# 每100轮
			if (i+1)%10==0:
				print('<==%d==>,loss on training batch is %g.'%(i+1,tra_loss))
		print(train_acc[-1])
		print(valid_acc[-1])
		plt.figure(1)
		plt.grid(True)
		plt.subplot(1,2,1)
		plt.plot(epochs, train_loss, color='red',label='train')
		plt.plot(epochs, valid_loss, color='blue',label='valid')
		plt.legend()
		plt.xlabel('Epochs',fontsize=15)
		plt.ylabel('Y',fontsize=15)
		plt.title('Loss',fontsize=15)
		plt.subplot(1,2,2)
		plt.plot(epochs, train_acc, color='red',label='train')
		plt.plot(epochs, valid_acc, color='blue',label='valid')
		plt.legend()
		plt.xlabel('Epochs',fontsize=15)
		plt.ylabel('Y',fontsize=15)
		plt.title('Acc',fontsize=15)
		plt.show()

		saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME))

def main(argv=None):
	mnist=input_data.read_data_sets('E:/User-Duanduan/python/Deep-Learning/tensorflow/data/MNIST_data/',one_hot=True)
	train(mnist)

if __name__=='__main__':
	tf.app.run()

第三部分:测试集

# 测试模型
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt

import mnist_inference
import mnist_train

def evaluate(mnist):
	with tf.Graph().as_default() as g:
		# 定义输入输出格式
		x=tf.placeholder(tf.float32,[None,mnist_inference.INPUT_NODE],name='x-input')
		y_=tf.placeholder(tf.float32,[None,mnist_inference.OUTPUT_NODE],name='y-input')
		test_feed={x:mnist.test.images,y_:mnist.test.labels}

		show_image=mnist.test.images[1000]
		label=mnist.test.labels[1000]
		one_image={x:[show_image],y_:[label]}
		result_image=label.tolist().index(max(label.tolist()))


		y=mnist_inference.inference(x,None,None,reuse=False)
		one_result=tf.argmax(y,1)
		correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
		accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

		variable_averages=tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)
		variable_to_restore=variable_averages.variables_to_restore()

		saver=tf.train.Saver(variable_to_restore)

		with tf.Session() as sess:
			# 加载模型
			saver.restore(sess,'./model.ckpt')
			accuracy_score=sess.run(accuracy,feed_dict=test_feed)
			print('Test accuracy is %g%%'%(accuracy_score*100))

			result=sess.run(one_result,feed_dict=one_image)
			print('Actual:%g,predtion:%g'%(result_image,result))

			show_image = tf.reshape(show_image, [28, 28])
			plt.figure('Show')
			plt.imshow(show_image.eval())
			plt.show()
def main(argv=None):
	mnist=input_data.read_data_sets('E:/User-Duanduan/python/Deep-Learning/tensorflow/data/MNIST_data/',one_hot=True)
	evaluate(mnist)

if __name__=='__main__':
	tf.app.run()

你可能感兴趣的:(TensroFlow学习——第三章(一))