import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
def creatweight(shape):
return tf.Variable(initial_value=tf.random_normal(shape=shape))
minist=input_data.read_data_sets("./work/data",one_hot=True)
print(minist.train.images.shape,minist.train.labels.shape)
x=tf.placeholder(dtype=tf.float32,shape=(None,784))
y=tf.placeholder(dtype=tf.float32,shape=(None,10))
input_x=tf.reshape(x,shape=[-1,28,28,1])
因为tf.nn.conv2d 要的是4D数据。【batch,高,宽,深度】
读进来后图片居然是一维数据,要用reshape转成【-1,28,28,1】-1代表自动匹配一共又多少图片打包
x接受读入的img。y接受读入的标签。
这里的creatweight函数会返回一个shape形状的参数矩阵,用来初始化用
conv1_weights=creatweight([5,5,1,32])
conv1_bais=creatweight([32])
conv1_x=tf.nn.conv2d(input=input_x,filter=conv1_weights,strides=[1,1,1,1],padding="SAME")+conv1_bais
relu1_x=tf.nn.relu(conv1_x)
pool1_x=tf.nn.max_pool(value=relu1_x,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
conv2_weights=creatweight([5,5,32,64])
conv2_bais=creatweight([64])
conv2_x=tf.nn.conv2d(input=pool1_x,filter=conv2_weights,strides=[1,1,1,1],padding="SAME")+conv2_bais
relu2_x=tf.nn.relu(conv2_x)
pool2_x=tf.nn.max_pool(value=relu2_x,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
这里是两层卷积
讲卷积的在b站很多,这里结构清晰。不多讲述
x_fc=tf.reshape(pool2_x,shape=[-1,7*7*64])
fc_weights=creatweight([7*7*64,10])
fc_bais=creatweight([10])
y_predict=tf.matmul(x_fc,fc_weights)+fc_bais
error=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=y_predict))
optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(error)
init=tf.global_variables_initializer()
全链接层和损失函数。优化器和变量初始化
with tf.Session() as sess:
sess.run(init)
image,label=minist.train.next_batch(20)
k=eval(input("1 ta 2 te 0 out"))
while k!=0:
if k==1:
for i in range(100):
opt,loss=sess.run([optimizer,error],feed_dict={x:image,y:label})
print(f'{loss}')
saver=tf.train.Saver()
saver.save(sess,"./work/model/mymodel")
if k==2:
plt.imshow(minist.train.images[2].reshape(28,28))
plt.show()
opt=sess.run([y_predict],feed_dict={x:minist.train.images[2]})
print(opt)
pass
k=eval(input("1 ta 2 te 0 out"))
开启会话然后1进行训练2进行测试,0退出。这里可以自己改一下。。因为学校服务器跟shi一样。迟迟没有反应。所以2还没有调试好。但可以正常训练。save模型的路径改成自己的。不然报错
这里是全部汇总
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
def creatweight(shape):
return tf.Variable(initial_value=tf.random_normal(shape=shape))
minist=input_data.read_data_sets("./work/data",one_hot=True)
print(minist.train.images.shape,minist.train.labels.shape)
x=tf.placeholder(dtype=tf.float32,shape=(None,784))
y=tf.placeholder(dtype=tf.float32,shape=(None,10))
input_x=tf.reshape(x,shape=[-1,28,28,1])
conv1_weights=creatweight([5,5,1,32])
conv1_bais=creatweight([32])
conv1_x=tf.nn.conv2d(input=input_x,filter=conv1_weights,strides=[1,1,1,1],padding="SAME")+conv1_bais
relu1_x=tf.nn.relu(conv1_x)
pool1_x=tf.nn.max_pool(value=relu1_x,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
conv2_weights=creatweight([5,5,32,64])
conv2_bais=creatweight([64])
conv2_x=tf.nn.conv2d(input=pool1_x,filter=conv2_weights,strides=[1,1,1,1],padding="SAME")+conv2_bais
relu2_x=tf.nn.relu(conv2_x)
pool2_x=tf.nn.max_pool(value=relu2_x,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
x_fc=tf.reshape(pool2_x,shape=[-1,7*7*64])
fc_weights=creatweight([7*7*64,10])
fc_bais=creatweight([10])
y_predict=tf.matmul(x_fc,fc_weights)+fc_bais
error=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=y_predict))
optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(error)
init=tf.global_variables_initializer()
equal_list=tf.equal(tf.argmax(y,1),tf.argmax(y_predict,1))
accuracy=tf.reduce_mean(tf.cast(equal_list,tf.float32))
with tf.Session() as sess:
sess.run(init)
image,label=minist.train.next_batch(20)
k=eval(input("1 ta 2 te 0 out"))
while k!=0:
if k==1:
for i in range(100):
opt,loss=sess.run([optimizer,error],feed_dict={x:image,y:label})
print(f'{loss}')
saver=tf.train.Saver()
saver.save(sess,"./work/model/mymodel")
if k==2:
plt.imshow(minist.train.images[2].reshape(28,28))
plt.show()
opt=sess.run([y_predict],feed_dict={x:minist.train.images[2]})
print(opt)
pass
k=eval(input("1 ta 2 te 0 out"))