import tensorflow as tf
import matplotlib.pyplot as plt #画图对应的库
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("data/", one_hot=True)
X = tf.placeholder(dtype=tf.float32, shape=[None, 784])
y = tf.placeholder(dtype=tf.float32, shape=[None, 10])
# 如果将权重初始化为0,则准确率非常低,10%左右
# W = tf.Variable(tf.zeros(shape=[784, 256]))
# 使用标准正态分布,标准差0.05,准确率为97%左右。
# W = tf.Variable(tf.random_normal(shape=[784,256], stddev=0.1))
# 也可以使用截断正态分布,准确率与标准正态分布差不多
W = tf.Variable(tf.truncated_normal(shape=[784,256], stddev=0.1)) #stddev:标准差 truncated_normal:截断正态分布:只要2倍标准差之内的数据
b = tf.Variable(tf.zeros(shape=[1,256]))
print("tf.truncated_normal(shape=[784,256], stddev=0.1)):",tf.truncated_normal(shape=[784,256], stddev=0.1))
#进行矩阵的点乘运算
z = tf.matmul(X, W) + b
print("z\n", z)
# 使用relu激活函数
a = tf.nn.relu(z)
# W2 = tf.Variable(tf.random_normal(shape=[256, 10], stddev=0.05))
W2 = tf.Variable(tf.truncated_normal(shape=[256,10], stddev=0.1))
b2 = tf.Variable(tf.zeros(shape=[1,10]))
z2 = tf.matmul(a,W2) + b2
print("z2\n", z2)
a2 = tf.nn.softmax(z2)
loss = -tf.reduce_sum(y*tf.log(a2))
# 这里不再计算softmax,在计算交叉熵,而是直接使用tf.nn.softmax_cross_entropy_with_logits直接计算。
# 但是,使用该方法后,准确率有所下降。
# loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=z2, labels=y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
correct = tf.equal(tf.argmax(y, axis=1), tf.argmax(a2, axis=1))
# correct = tf.equal(tf.argmax(y, axis=1),tf.argmax(z2, axis=1))
#tensorflow对类型要求非常严格,我们不能使用布尔类型的张量进行数学上的运算
#如果需要运算,则必须进行类型转换
rate = tf.reduce_mean(tf.cast(correct, tf.float32))
#定义会话,来进行求值。
with tf.Session() as sess: #with:上下文管理器好处可以不用关闭
#定义全局初始化器,用来对变量进行初始化。tensorflow中,变量必须进行初始化。
sess.run(tf.global_variables_initializer()) #进行真正的运行
for i in range(3000):
#获取下一个批次样本的数量。返回的是一个元组
batch_X, batch_y = mnist.train.next_batch(100)
#print("batch_X.shape",batch_X.shape)
#使用feed_dict来填充占位符张量的数据
sess.run(train_step, feed_dict={X: batch_X, y: batch_y})
if i % 500 == 0:
print("训练数量为:",i)
print("准确率为:")
print(sess.run(rate, feed_dict={X: mnist.test.images, y:mnist.test.labels}))
#获取下一张image值
for i in range(5):
batch_X, batch_y = mnist.train.next_batch(1)
#以图片形式展现
print("这张图片是:")
plt.imshow(batch_X.reshape((28,28)), cmap="gray")
plt.show()
#打印真实值的索引
print("这张图片的标签是:")
print(sess.run(tf.argmax(y, axis=1), feed_dict={X: batch_X, y:batch_y}))
print("这张图片的预测值是:")
#print(sess.run(tf.argmax(a, axis=1), feed_dict={X: batch_X, y:batch_y}))
print(sess.run(tf.argmax(a2, axis=1), feed_dict={X: batch_X, y:batch_y}))
#tf.argmax(a2, axis=1)