kaggle实战(三)——纯手搓CNN做图像识别

题目介绍

这次做的还是新手题,是一道图像识别的题目,而数据集也是机器学习中非常经典的MNIST手写数字集。虽然接触ML和DL的同学应该都了解这个数据集,但是为了让大家明白我们要干什么还是简单介绍一下。
在这里插入图片描述
原题的链接(kaggle会被墙,所以请大家各显神通)

数据集介绍

MNIST数据集其实就是0-9这10个数字的手写版图片,我们要通过训练集对模型进行训练,让模型能够识别各个数字的不同样子,之后再对测试集中的数据进行预测,得到我们最终的预测标签。
我一开始以为题目的数据集会直接给图片,但是题目给的是CSV文件。
kaggle实战(三)——纯手搓CNN做图像识别_第1张图片
文件的大小是10×785,那么我们除去第一行,也就是label标签行,大小就是10×784。这个10行相信不难理解,因为一共就10个数字,也就是10个类别。
那么784是什么意思?这里的784是用28×28得来的,因为我们的MNIST数据集中的每一张照片的大小都是1×28×28这里的1表示的是图片只有一个通道,之所以只有一个通道是因为MNIST数据集中的图片都是黑白的
在上图的部分数据中我们看到pixel的数据好像都是0,这是为什么,因为单通道的黑白图片,只有白色的地方是有灰度值数值的,而黑色区域的灰度值都是0。又因为一张图片其实大部分区域还是黑的,只有写字的地方是白色,所以其实大部分的像素的灰度值都是0

数据清洗

因为这次的数据集较特殊,都是代表一张图片的像素,并且没有缺失,所以这题并没有数据清洗的环节

模型搭建

第一种 纯卷积网络+全连接层

首先我们定义两层卷积,第一层卷积输入通道是1,输出通道是32(输入通道为1是因为我们的图片是单通道图片,只有一个通道,而输出通道是超参数,由我们自己设定。)
第二层卷积的输入通道就是上一层卷积的输出通道,输出通道依旧是我们自己定义。
并且在每次卷积之后,我们都用了LeakyRuLU作为激活函数对数据进行处理,并使用最大池化对数据降维。
kaggle实战(三)——纯手搓CNN做图像识别_第2张图片
之后我们做完了卷积之后用全连接层对卷积得到的所有特征进行处理,这里我圈起来的是第一层全连接的输入特征数,这个特征数是需要我们自己进行计算的,计算方法是我们将卷积的每张图的像素数目再乘以最后的总的通道数就得到了总的特征个数(这个原理的理解需要大家对卷积网络有一定的基础)
kaggle实战(三)——纯手搓CNN做图像识别_第3张图片

第二种 残差卷积网络 + 全连接层

如果了解残差网络的同学应该也知道从卷积网络改到残差网络只要对输入再进行一个1×1的全局卷积来改变输入的通道数就可以了,所以代码改起来非常简单。
kaggle实战(三)——纯手搓CNN做图像识别_第4张图片
kaggle实战(三)——纯手搓CNN做图像识别_第5张图片
搭建好神经网络之后,我们就要把数据投到网络去训练网络。那么在此之前,我们要先把训练集分为训练集和验证集,这个方法已经在之前的题目涉及过,所以直接放代码。
kaggle实战(三)——纯手搓CNN做图像识别_第6张图片
我们利用utils.data.DataLoader方法生成了DataLoader的迭代器,我们可以通过对这个迭代器迭代得到我们的图片数据和每个图片对应的标签
kaggle实战(三)——纯手搓CNN做图像识别_第7张图片
我们在整理好训练数据的迭代器之后就可以开始写神经网络的训练函数。在这里博主想强调一行代码,就是用红笔画出来的第一条线那里。博主一开始并没有加这行代码就直接训练网络导致报错,报错写的是我的数据和是在GPU上,但是我网络的权重仍在CPU上,导致出错,后来去翻了别人的
notebook,发现少了这一行代码。
所以如果大家有条件使用GPU的话,记得不光把数据放在GPU上,还要把神经网络的权重也放在GPU上。
那么第二条红线是什么意思呢,这一行代码的作用是用来预热学习率,使得学习率根据迭代的进程自己进行调整。(具体不展开了,大家想了解的话自己去搜一下)
kaggle实战(三)——纯手搓CNN做图像识别_第8张图片
接下啦我们定义验证集的函数。因为是验证集,所以我们是不需要在函数中进行优化算法的,直接算出来计算个准确率就行了。
kaggle实战(三)——纯手搓CNN做图像识别_第9张图片
最后就是我们的主函数
kaggle实战(三)——纯手搓CNN做图像识别_第10张图片
最最后就是我跑出来的结果了。
kaggle实战(三)——纯手搓CNN做图像识别_第11张图片
上面的结果是博主加了残差块之后重新跑出来的数据,可以看到最高的准确率是在99.357,一开始我只是用了纯卷积的方式,正确率是99.26,实在没想到在这么高的基础之上竟然残差块还能让结果更好,只能说一句恺明nb。

复盘

这种图像识别的题目个人觉得就是拼网络和硬件了,现在pytorch也已经有很多现成的网络可以直接impor,大家也不妨自己去试一试。
这次的题目就介绍到这里了,下一道题是房价预测题,已经在做了,但是由于数据维度实在太高而且相关性有点复杂所以特征工程很花时间,应该还要个几天吧!

转载请标明出处!

你可能感兴趣的:(深度学习(pytorch),机器学习,pytorch,python,深度学习)