【神经网络学习笔记2】简单的CNN网络识别手写图案

本篇文章还是基于tensorflow给的官方样例,教会大家如何构建简单的CNN网络以下是官方代码

tensorflow官方样例 CNN网络

conv2d定义的是卷积层

maxpool2d定义的是池化层

conv_net定义的是具体的网络运算过程,其中fc定义的是全连接层

可以很方便地修改各层的参数,如深度,广度等

我用来解决的问题是来识别手写的O和X,MNIST不知道为啥在我的电脑上装不上去。

训练集与测试集是我自己手写出来的,总共120张,因为样本比较小,所以采取的是标准梯度下降法。

以下是我的代码。前面加的num_step和num_test是用来定义训练集和测试集的数量的,这里选取的是100和20。同时我的程序为了方便选择直接读取图片。最后加上了loss_op和acc关于时间的变化。

from __future__ import division, print_function, absolute_import
import numpy as np
import tensorflow as tf
import string
import matplotlib.pyplot as plt 
import matplotlib.image as mpimg 

# Training Parameters
learning_rate = 0.001
num_steps = 100
num_test = 20
display_step = 4

# Network Parameters
num_input = 81 # data input 
num_classes = 2 # total classes 
dropout = 0.75 # Dropout, probability to keep units

fd=open('inputimage.txt','w')
imga=np.zeros((num_steps,num_input))
try_imga=np.zeros((num_test,num_input))
for k in range(num_steps+num_test):
    img=mpimg.imread(str(k+1)+'.png')
    img=(img[:,:,2]).reshape(1,81)
    if k
z
最后运行的结果如下

【神经网络学习笔记2】简单的CNN网络识别手写图案_第1张图片【神经网络学习笔记2】简单的CNN网络识别手写图案_第2张图片

分别是loss_op和accuracy随训练时间的变化图

最终用20幅图片的测试准确率为95%,但其实每次运行的结果都不一样,但都基本稳定在了90%以上,可能是各个参数取的初始值不太一样。

可以看到关于图片的神经网络学习中,loss一开始都是很大的,因为图片信息较为复杂一些,但利用adam的优化器最终也取到了比较好的结果。

接下来是我自己的训练集加测试集,有兴趣的童鞋可以用来玩一玩

手写ox训练集+测试集

你可能感兴趣的:(AI)