import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import matplotlib.pyplot as plt
#读入数据----------------------------------------------------------------------
mnist = input_data.read_data_sets('MNIST_data/',one_hot=True)#label为one-hot-encoding
#查看读入的数据格式-------------------------------------------------------------
print('train',mnist.train.num_examples,
',validation',mnist.validation.num_examples,
',test',mnist.test.num_examples)
print('train images:',mnist.train.images.shape,
'labels:',mnist.train.labels.shape)
#查看label的值
mnist.train.labels[0]
#函数:将读入的数据图像显示
def plot_image(image):
#上面读入的图像是一维的,用reshape转成矩阵输出
plt.imshow(image.reshape(28,28),cmap='binary')
plt.show()
plot_image(mnist.train.images[0])
#函数:输出10张图像,带标签label和预测值prediction
def plot_images_labels_prediction(images, labels, prediction, idx, num=10):
fig=plt.gcf()
fig.set_size_inches(12,14)
if num>25: num=25
for i in range(0, num):
ax = plt.subplot(5,5,1+i)
ax.imshow(np.reshape(images[idx],(28,28)),cmap='binary')
title = 'label='+str(np.argmax(labels[idx]))
if len(prediction)>0:
title+=',prediction='+str(prediction[idx])
ax.set_title(title, fontsize=10)
ax.set_xticks([]);ax.set_yticks([])
idx+=1
plt.show()
plot_images_labels_prediction(mnist.validation.images, mnist.validation.labels,[],0)
#构建MLP模型---------------------------------------------------------------------
#自定义layer
def layer(output_dim, input_dim, inputs, activation=None):
W = tf.Variable(tf.random_normal([input_dim, output_dim]))
b = tf.Variable(tf.random_normal([1, output_dim]))
WXb = tf.matmul(inputs,W)+b
if activation is None:
outputs = WXb
else:
outputs = activation(WXb)
return outputs
#两个隐藏层h1、h2,一个输入x一个输出y_predict
x = tf.placeholder('float',[None, 784])
h1 = layer(output_dim=1000,input_dim=784,inputs=x,activation=tf.nn.relu)
h2 = layer(output_dim=1000,input_dim=1000,inputs=h1,activation=tf.nn.relu)
y_predict = layer(output_dim=10,input_dim=1000,inputs=h2,activation=None)
#定义标签值
y_label = tf.placeholder('float',[None,10])
#损失函数
loss_function = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(logits=y_predict,labels=y_label))
#优化器,使loss最小化,学习率为0.001
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss_function)
#预测结果的正确性
correct_prediction = tf.equal(tf.argmax(y_label,1),tf.argmax(y_predict,1))
#计算精度,cast():转换值的类型,reduce_mean():计算平均值
accuracy = tf.reduce_mean(tf.cast(correct_prediction,'float'))
#开始训练---------------------------------------------------------------------
trainEpochs = 15 #epoch
batchSize = 100
totalBatchs = int(len(mnist.train.images)/batchSize) #一个周期总的批次
loss_list=[];epoch_list=[];accuracy_list=[] #记录训练过程的loss,epoch,accuracy
from time import time
startTime = time()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for epoch in range(trainEpochs):
for i in range(totalBatchs):
batch_x, batch_y =mnist.train.next_batch(batchSize) #读取下一个批次的数据,循环读取
sess.run(optimizer, feed_dict={x:batch_x,y_label:batch_y})
loss, acc = sess.run([loss_function,accuracy],feed_dict=\
{x:mnist.validation.images,y_label:mnist.validation.labels})
epoch_list.append(epoch)
loss_list.append(loss)
accuracy_list.append(acc)
print('Train Epoch:', '%02d'%(epoch+1), 'Loss=',\
'{:.9f}'.format(loss), ' Accuary=',acc)
duration = time()-startTime
print('Train Finished takes:', duration)
#训练结果显示--------------------------------------------------------------------
%matplotlib inline
fig = plt.gcf()#获取当前的figure图
fig.set_size_inches(4,2)
plt.plot(epoch_list, loss_list, label='loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['loss'], loc='upper right')
plt.plot(epoch_list,accuracy_list,label='accuracy')
fig = plt.gcf()
fig.set_size_inches(4,2)
plt.ylim(0.8,1)#设置y轴范围
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['accuarcy'], loc='upper right')
#评估------------------------------------------------------------------------------
#测试集test准确率
print('Accuracy:', sess.run(accuracy,feed_dict=\
{x:mnist.test.images,y_label:mnist.test.labels}))
#预测test
prediction_result = sess.run(tf.argmax(y_predict,1),feed_dict={x:mnist.test.images})
#显示真实值和预测值及图像
plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,0)