CNN在mnist数据集上实现

这次我们使用CNN中最经典的Lenet网络在mnist数据集上进行训练和预测。

  • 卷积NN
    主要有两部分组成,一部分是对输入图片特征提取,一部分是全连接网络,主要组成操作包括卷积、池化、激活等。

  • Lenet网络模型
    Lenet是提出比较早,能有效解决手写数字图片识别的卷积模型,模型结构如下:


    CNN在mnist数据集上实现_第1张图片
    0.PNG

其中,padding=valid代表非全0填充,输出图片尺寸=(输入尺寸-卷积核尺寸+1)/步长;padding=same代表全0填充,输出尺寸=输入尺寸/步长;pooling不改变深度。
对Lenet进行调整使其使用于mnist数据集,结构如下:


CNN在mnist数据集上实现_第2张图片
Lenet_on_mnist.PNG

实现还是分三模块:forward,backwa,test,主要改变是在forward:


CNN在mnist数据集上实现_第3张图片
lenet1.png

定义获得权重、偏执,增加对卷积,池化的函数。
CNN在mnist数据集上实现_第4张图片
lenet2.png

按上层结构前向传播,返回预测值。

backward和test跟上一篇中改动不大,主要是要注意输入的大小:

CNN在mnist数据集上实现_第5张图片
leb1.png

输入占位大小改变
CNN在mnist数据集上实现_第6张图片
leb2.png

喂入的barch_size大小改变
同理,在test文件中,测试数据的大小也相应改变。


新手学习,欢迎指教!

你可能感兴趣的:(CNN在mnist数据集上实现)