import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets(r"D:\Jupyter\MNIST-data",one_hot=True)
def get_weight(shape,name):
return tf.Variable(tf.truncated_normal(shape,stddev=0.1),name=name)
def get_bias(shape,name):
return tf.Variable(tf.zeros(shape),name=name)
def conv2d(X,W):
return tf.nn.conv2d(X,W,strides=[1,1,1,1],padding="SAME")
def pool2d(X):
return tf.nn.max_pool(X,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
with tf.name_scope("input"):
x=tf.placeholder('float',shape=[None,784],name="x")
y=tf.placeholder('float',shape=[None,10],name="y")
x_images=tf.reshape(x,[-1,28,28,1],name="x_images")
tf.summary.image("pic",x_images,9)
with tf.name_scope("conv1"):
W1=get_weight([5,5,1,64],"W1")
tf.summary.histogram("W1",W1)
b1=get_bias([64],"b1")
tf.summary.histogram("b1",b1)
conv1=tf.nn.relu(conv2d(x_images,W1)+b1,name="conv1")
pool1=pool2d(conv1)
with tf.name_scope("conv2"):
W2=get_weight([3,3,64,32],"W2")
b2=get_bias([32],"b2")
conv2=tf.nn.relu(conv2d(pool1,W2)+b2,name="conv2")
pool2=pool2d(conv2)
with tf.name_scope("FC1"):
flatten=tf.reshape(pool2,[-1,7*7*32],name="flatten")
W3=get_weight([7*7*32,128],"W3")
b3=get_bias([128],"b3")
FC1=tf.nn.relu(tf.matmul(flatten,W3)+b3,name="FC1")
with tf.name_scope("FC2"):
W4=get_weight([128,10],"W4")
b4=get_bias([10],"b4")
out=tf.add(tf.matmul(FC1,W4),b4,name="out")
with tf.name_scope("loss"):
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=out),name="loss")
tf.summary.scalar("loss",loss)
tf.summary.histogram("loss",loss)
with tf.name_scope("train"):
optimizer=tf.train.AdamOptimizer(0.001)
train_step=optimizer.minimize(loss,name="train")
with tf.name_scope("acc"):
pre_correct=tf.equal(tf.argmax(out,1),tf.argmax(y,1))
acc=tf.reduce_mean(tf.cast(pre_correct,'float'))
tf.summary.scalar("acc",acc)
merged=tf.summary.merge_all()
init=tf.global_variables_initializer()
batch_size=128
with tf.Session() as sess:
sess.run(init)
writer=tf.summary.FileWriter("./log",sess.graph)
for i in range(1000):
xx,yy=mnist.train.next_batch(batch_size)
feed_dict={x:xx,y:yy}
_,merge=sess.run([train_step,merged],feed_dict=feed_dict)
writer.add_summary(merge,i)