import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from datetime import datetime
import os
batch_size = 100
img_size = 28
num_class = 10
num_epochs = 2
display_step = 2
IMAGE_SIZE = 28
NUM_CHANNELS = 1
X = tf.placeholder(dtype=tf.float32, shape=[None, img_size, img_size, 1], name='input')
Y = tf.placeholder(dtype=tf.float32, shape=[None, num_class])
p_keep = tf.placeholder(tf.float32, name='p_keep_rate')
mnist = input_data.read_data_sets('dataset', one_hot=True)
train_X, train_Y, test_X, test_Y = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
train_X = train_X.reshape(-1, img_size, img_size, 1)
test_X = test_X.reshape(-1, img_size, img_size, 1)
with tf.name_scope('cnn_layer_01') as cnn_01:
w1 = tf.Variable(tf.random_normal(shape=[3, 3, 1, 32], stddev=0.01))
conv1 = tf.nn.conv2d(X, w1, strides=[1, 1, 1, 1], padding="SAME")
conv_y1 = tf.nn.relu(conv1)
with tf.name_scope('pool_layer_01') as pool_01:
pool_y2 = tf.nn.max_pool(conv_y1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
pool_y2 = tf.nn.dropout(pool_y2, p_keep)
with tf.name_scope('cnn_layer_02') as cnn_02:
w2 = tf.Variable(tf.random_normal(shape=[3, 3, 32, 64], stddev=0.01))
conv2 = tf.nn.conv2d(pool_y2, w2, strides=[1, 1, 1, 1], padding="SAME")
conv_y3 = tf.nn.relu(conv2)
with tf.name_scope('pool_layer_02') as pool_02:
pool_y4 = tf.nn.max_pool(conv_y3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
pool_y4 = tf.nn.dropout(pool_y4, p_keep)
with tf.name_scope('cnn_layer_03') as cnn_03:
w3 = tf.Variable(tf.random_normal(shape=[3, 3, 64, 128], stddev=0.01))
conv3 = tf.nn.conv2d(pool_y4, w3, strides=[1, 1, 1, 1], padding="SAME")
conv_y5 = tf.nn.relu(conv3)
with tf.name_scope('pool_layer_03') as pool_03:
pool_y6 = tf.nn.max_pool(conv_y5, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
with tf.name_scope('full_layer_01') as full_01:
w4 = tf.Variable(tf.random_normal(shape=[128*4*4, 625], stddev=0.01))
FC_layer = tf.reshape(pool_y6, [-1, w4.get_shape().as_list()[0]])
FC_layer = tf.nn.dropout(FC_layer, p_keep)
FC_y7 = tf.matmul(FC_layer, w4)
FC_y7 = tf.nn.relu(FC_y7)
FC_y7 = tf.nn.dropout(FC_y7, p_keep)
with tf.name_scope('output_layer') as output_layer:
w5 = tf.Variable(tf.random_normal(shape=[625, num_class]))
model_Y = tf.matmul(FC_y7, w5, name='output')
Y_ = tf.nn.softmax_cross_entropy_with_logits(logits=model_Y, labels=Y)
cost = tf.reduce_mean(Y_)
correct_prediction = tf.equal(tf.argmax(model_Y, axis=1), tf.argmax(Y, axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
optimizer = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
tf.summary.scalar('loss', cost)
tf.summary.scalar('accuracy', accuracy)
tf.summary.histogram('w1', w1)
tf.summary.histogram('w2', w2)
tf.summary.histogram('w3', w3)
tf.summary.histogram('w4', w4)
tf.summary.histogram('w5', w5)
merged_summary = tf.summary.merge_all()
writer = tf.summary.FileWriter('mnist_summary')
train_batches_per_epoch = int(len(train_X)/batch_size)
test_batches_per_epoch = int(len(test_X)/batch_size)
with tf.Session() as sess:
tf.global_variables_initializer().run()
writer.add_graph(sess.graph)
for epoch in range(num_epochs):
print("{}: Epoch number: {} start".format(datetime.now(), epoch+1))
for step in range(train_batches_per_epoch):
img_batch, label_batch = mnist.train.next_batch(batch_size)
reshaped_img_batch = np.reshape(img_batch, (batch_size, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
opti, summary, loss, accu = sess.run([optimizer, merged_summary, cost, accuracy],
feed_dict={X: reshaped_img_batch, Y: label_batch,
p_keep: 0.8})
if step % display_step == 0:
print("{}: step = {} loss = {}".format(datetime.now(), step, loss))
writer.add_summary(summary, epoch * train_batches_per_epoch + step)
print("{}: Start validation".format(datetime.now()))
test_accu = 0.
test_count = 0
for _ in range(test_batches_per_epoch):
img_batch, label_batch = mnist.test.next_batch(batch_size)
reshaped_img_batch = np.reshape(img_batch, (batch_size, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
accu = sess.run(accuracy, feed_dict={X: reshaped_img_batch, Y: label_batch,
p_keep: 0.8})
test_accu += accu
test_count += 1
try:
test_accu /= test_count
except:
print('ZeroDivisionError!')
print("{}: Validation Accuracy = {:.4f}".format(datetime.now(), test_accu))
saver = tf.train.Saver()
checkpoint_name = os.path.join('mnist_model', 'model_epoch' + str(epoch + 1))
path = saver.save(sess, checkpoint_name)
print("{}: Epoch number: {} end".format(datetime.now(), epoch + 1))