tensorflow全连接层手写数字识别

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

def full_connection():

    """

    用全连接来对数字进行识别

    :return:

    """

    # 1.准备数据

    mnist=input_data.read_data_sets("./temp",one_hot=True)

    x=tf.placeholder(dtype=tf.float32,shape=(None,784))

    y_true=tf.placeholder(dtype=tf.float32,shape=(None,10))

    # 2.构建模型

    Weights=tf.Variable(initial_value=tf.random_normal(shape=(784,10)))

    bias=tf.Variable(initial_value=tf.random_normal(shape=[10]))

    y_predict=tf.matmul(x,Weights)+bias

    # 3.构造损失函数

    error=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict))

    # 4.优化损失函数

    optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(error)

    # 5.准确率计算

    equal_list=tf.equal(tf.arg_max(y_true,1),

                        tf.arg_max(y_predict,1))

    accuracy=tf.reduce_mean(tf.cast(equal_list,tf.float32))

    # 初始化变量

    init=tf.global_variables_initializer()

    # 开启回话

    with tf.Session() as sess:

        sess.run(init)

        # 创建事件文件

        # 拿真实值

        image,lable=mnist.train.next_batch(100)

        print("训练前,损失为%f"% sess.run(error,feed_dict={x:image,y_true:lable}))

        for i in range(100):

            a,loss,accuracy_value=sess.run([optimizer,error,accuracy],feed_dict={x:image,y_true:lable})

            print("第%d次的训练,损失为%f,准确率为%f,"% (i+1,loss,accuracy_value))

full_connection()

你可能感兴趣的:(tensorflow全连接层手写数字识别)