python Tensorflow三层全连接神经网络实现手写数字识别

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from PIL import Image
import cv2
import os

x = tf.placeholder(dtype=tf.float32, shape=[None, 28*28],name='x')
y = tf.placeholder(dtype=tf.float32, shape=[None, 10],name='y')
# 读取数据
mnist = input_data.read_data_sets('./MNIST',one_hot=True)
batch_size = 1000
def add_layer(input_data, input_num, output_num, activation_function=None):
    w = tf.Variable(initial_value=tf.random_normal(shape=[input_num, output_num]))
    b = tf.Variable(initial_value=tf.random_normal(shape=[1, output_num]))
    output = tf.add(tf.matmul(input_data, w),b)
    if activation_function:
        output = activation_function(output)
    return output


def build_nn(data):
    hidden_layer1 = add_layer(data, 28*28, 100, activation_function= tf.sigmoid)
    hidden_layer2 = add_layer(hidden_layer1, 100, 50, activation_function=tf.sigmoid)
    output_layer = add_layer(hidden_layer2, 50, 10)
    return output_layer


def train_nn(data):
    output = build_nn(data)
    # 比较y和输出代价函数
    loss = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=output)
    # 向量转化为数,取平均
    loss = tf.reduce_mean(loss)
    # 定义优化器
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=1).minimize(loss)

    # 模型保存下来
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        if not os.path.exists('checkpoint'):
            for i in range(50):
                each_cost = 0
                for j in range(int(mnist.train.num_examples) // batch_size):
                    x_data, y_data = mnist.train.next_batch(batch_size)
                    cost, _ = sess.run([loss, optimizer], feed_dict={x: x_data, y: y_data})
                    each_cost += cost
                print('Epoch', i, ': ', each_cost)
            accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y, 1), tf.argmax(output, 1)), tf.float32))
            acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
            print(acc)
            saver.save(sess, './mnist.ckpt')
        else:
            saver.restore(sess, './mnist.ckpt')
            predict('./1.jpg', sess, output)





def reconstuct_image():
    """将图片数据变为图片形式分组"""
    for i in range(10):
        if not os.path.exists('./{}'.format(i)):
            os.makedirs('./{}'.format(i))
    batch_size = 1
    for j in range(int(mnist.train.num_examples)//batch_size):
        # 注意x_data的格式为[[28*28个数]]
        x_data, y_data = mnist.train.next_batch(batch_size)
        img = Image.fromarray(np.reshape(np.array(x_data[0]*255, dtype = 'uint8'), newshape=(28, 28)))
        dir = np.argmax(y_data[0])
        img.save('./{}/{}.bmp'.format(dir,j))
        if j%1000==0:
            print("已完成", j, "/", mnist.train.num_examples)


def read_data(path):
    """利用opencv读取图片"""
    # 读取黑白图片
    image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    processed_image = cv2.resize(image, dsize=(28, 28))
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    processed_image = np.resize(processed_image, new_shape=(1, 28*28))
    return image, processed_image

def predict(image_path, sess, output):
    """预测自己手写的图片"""
    # 读取图片
    image, processed_image = read_data(image_path)
    result = sess.run(output, feed_dict={x:processed_image})
    result = np.argmax(result,1)
    print('the prediction is ',result)


if __name__ == '__main__':
    # reconstuct_image()
    train_nn(x)

你可能感兴趣的:(算法)