import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("data/", one_hot=True)
learning_rate = 0.001
training_iters = 2000
batch_size = 100
n_input = 28 * 28
n_classes = 10
n_batch = mnist.train.num_examples // batch_size
print("n_batch == ", n_batch)
weights = {
'wc1': tf.Variable(tf.truncated_normal([5, 5, 1, 32], stddev=0.1)),
'wc2': tf.Variable(tf.truncated_normal([5, 5, 32, 64], stddev=0.1)),
'wd1': tf.Variable(tf.truncated_normal([7 * 7 * 64, 1024], stddev=0.1)),
'out': tf.Variable(tf.truncated_normal([1024, 10], stddev=0.1))
}
baises = {
'bc1': tf.Variable(tf.random_normal([32])),
'bc2': tf.Variable(tf.random_normal([64])),
'bd1': tf.Variable(tf.random_normal([1024])),
'out': tf.Variable(tf.random_normal([10])),
}
x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_classes])
keep_prob = tf.placeholder(tf.float32)
def conv_2d(x, w, b, strides=1):
conv = tf.nn.conv2d(x, w, strides=[1, strides, strides, 1], padding='SAME')
conv = tf.nn.bias_add(conv, b)
return tf.nn.relu(conv)
def max_pool(x, k=2):
return tf.nn.max_pool(x, ksize=[1, k, k, 1], strides=[1, k, k, 1], padding='SAME')
def conv_net(x, weights, baises, dropout):
x = tf.reshape(x, [-1, 28, 28, 1])
conv1 = conv_2d(x, weights['wc1'], baises['bc1'])
conv1 = max_pool(conv1)
conv2 = conv_2d(conv1, weights['wc2'], baises['bc2'])
conv2 = max_pool(conv2)
conv2 = tf.reshape(conv2, [-1, 7*7*64])
fc1 = tf.add(tf.matmul(conv2, weights['wd1']), baises['bd1'])
fc1 = tf.nn.relu(fc1)
fc1 = tf.nn.dropout(fc1, dropout)
out = tf.add(tf.matmul(fc1, weights['out']), baises['out'])
return out
logits = conv_net(x, weights, baises, keep_prob)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
prediction = tf.nn.softmax(tf.nn.sigmoid(logits))
correct_prediction = (tf.equal(tf.argmax(prediction,1), tf.argmax(y,1)))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(21):
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.7})
acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels, keep_prob:1.0})
print("Iter: " + str(epoch) + ", acc: " + str(acc))