tensorflow,对mnist的手写数据集进行卷积神经网络分类

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
def creatweight(shape):
    return tf.Variable(initial_value=tf.random_normal(shape=shape))
minist=input_data.read_data_sets("./work/data",one_hot=True)
print(minist.train.images.shape,minist.train.labels.shape)
x=tf.placeholder(dtype=tf.float32,shape=(None,784))
y=tf.placeholder(dtype=tf.float32,shape=(None,10))
input_x=tf.reshape(x,shape=[-1,28,28,1])

因为tf.nn.conv2d 要的是4D数据。【batch,高,宽,深度】

读进来后图片居然是一维数据,要用reshape转成【-1,28,28,1】-1代表自动匹配一共又多少图片打包

x接受读入的img。y接受读入的标签。

这里的creatweight函数会返回一个shape形状的参数矩阵,用来初始化用

conv1_weights=creatweight([5,5,1,32])
conv1_bais=creatweight([32])
conv1_x=tf.nn.conv2d(input=input_x,filter=conv1_weights,strides=[1,1,1,1],padding="SAME")+conv1_bais
relu1_x=tf.nn.relu(conv1_x)
pool1_x=tf.nn.max_pool(value=relu1_x,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")

conv2_weights=creatweight([5,5,32,64])
conv2_bais=creatweight([64])
conv2_x=tf.nn.conv2d(input=pool1_x,filter=conv2_weights,strides=[1,1,1,1],padding="SAME")+conv2_bais
relu2_x=tf.nn.relu(conv2_x)
pool2_x=tf.nn.max_pool(value=relu2_x,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")

这里是两层卷积

讲卷积的在b站很多,这里结构清晰。不多讲述

x_fc=tf.reshape(pool2_x,shape=[-1,7*7*64])
fc_weights=creatweight([7*7*64,10])
fc_bais=creatweight([10])
y_predict=tf.matmul(x_fc,fc_weights)+fc_bais

error=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=y_predict))
optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(error)
init=tf.global_variables_initializer()

全链接层和损失函数。优化器和变量初始化

with tf.Session() as sess:
    sess.run(init)
    image,label=minist.train.next_batch(20)
    k=eval(input("1 ta 2 te 0 out"))
    while k!=0:
        if k==1:
            for i in range(100):
                opt,loss=sess.run([optimizer,error],feed_dict={x:image,y:label})
                print(f'{loss}')
            saver=tf.train.Saver()
            saver.save(sess,"./work/model/mymodel")
        if k==2:
            plt.imshow(minist.train.images[2].reshape(28,28))
            plt.show()
            opt=sess.run([y_predict],feed_dict={x:minist.train.images[2]})
            print(opt)
            pass
        k=eval(input("1 ta 2 te 0 out"))

开启会话然后1进行训练2进行测试,0退出。这里可以自己改一下。。因为学校服务器跟shi一样。迟迟没有反应。所以2还没有调试好。但可以正常训练。save模型的路径改成自己的。不然报错

这里是全部汇总



import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
def creatweight(shape):
    return tf.Variable(initial_value=tf.random_normal(shape=shape))
minist=input_data.read_data_sets("./work/data",one_hot=True)
print(minist.train.images.shape,minist.train.labels.shape)
x=tf.placeholder(dtype=tf.float32,shape=(None,784))
y=tf.placeholder(dtype=tf.float32,shape=(None,10))
input_x=tf.reshape(x,shape=[-1,28,28,1])

conv1_weights=creatweight([5,5,1,32])
conv1_bais=creatweight([32])
conv1_x=tf.nn.conv2d(input=input_x,filter=conv1_weights,strides=[1,1,1,1],padding="SAME")+conv1_bais
relu1_x=tf.nn.relu(conv1_x)
pool1_x=tf.nn.max_pool(value=relu1_x,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")

conv2_weights=creatweight([5,5,32,64])
conv2_bais=creatweight([64])
conv2_x=tf.nn.conv2d(input=pool1_x,filter=conv2_weights,strides=[1,1,1,1],padding="SAME")+conv2_bais
relu2_x=tf.nn.relu(conv2_x)
pool2_x=tf.nn.max_pool(value=relu2_x,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")

x_fc=tf.reshape(pool2_x,shape=[-1,7*7*64])
fc_weights=creatweight([7*7*64,10])
fc_bais=creatweight([10])
y_predict=tf.matmul(x_fc,fc_weights)+fc_bais

error=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=y_predict))
optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(error)
init=tf.global_variables_initializer()

equal_list=tf.equal(tf.argmax(y,1),tf.argmax(y_predict,1))
accuracy=tf.reduce_mean(tf.cast(equal_list,tf.float32))
with tf.Session() as sess:
    sess.run(init)
    image,label=minist.train.next_batch(20)
    k=eval(input("1 ta 2 te 0 out"))
    while k!=0:
        if k==1:
            for i in range(100):
                opt,loss=sess.run([optimizer,error],feed_dict={x:image,y:label})
                print(f'{loss}')
            saver=tf.train.Saver()
            saver.save(sess,"./work/model/mymodel")
        if k==2:
            plt.imshow(minist.train.images[2].reshape(28,28))
            plt.show()
            opt=sess.run([y_predict],feed_dict={x:minist.train.images[2]})
            print(opt)
            pass
        k=eval(input("1 ta 2 te 0 out"))

你可能感兴趣的:(tensorflow,分类,python)