pytorch深度学习实战lesson4

目录

解决识别手写数字的方法(理论部分):

解决识别手写数字的方法(实践部分):


参考教材:

pytorch深度学习实战lesson4_第1张图片

课程网站:

https://www.bilibili.com/video/BV1xB4y1m7f4/?spm_id_from=333.1007.top_right_bar_window_custom_collection.content.click

第四课 手写数字体识别问题(举例+实战,对应视频课时6-13)

认识mnist数据集:

pytorch深度学习实战lesson4_第2张图片

Mnist数据集叫做手写数字数据集,它的训练集有6万个,测试集有一万个。每个图片的像素28*28,每个图片的像素是由0,1组成的,0表示白色,1表示黑色。

解决识别手写数字的方法(理论部分):

1、把28*28的图片“打平”,也就是把二维图片弄成一维(784位的一维0,1数组,设为X)。

2、对一维数组进行线性变换,此处要进行多次线性变换,不能只变换一次。使用y=wx+b的形式进行变换。

pytorch深度学习实战lesson4_第3张图片

计算H1的过程:

计算H2的过程:

其中“@”表示矩阵乘法。

计算H3的过程:

以上中括号中,前面的“1”表示维度,后面的数据表示多少个。

3、然后要选择一种合适的编码方式

这里使用的编码方式是“one-hot”编码方式。

进行手写数字识别任务时,我需要对图片的labor进行编码

如上图所示,如果是1的话就对第二个位置写1,如果是3的话就对第4个位置写3。

4、计算loss

使用欧氏距离的算法计算loss,也就是将10维以内的向量先相减在求平方和,结果越小说明差距越少。

5、非线性处理

其实手写字体里面有好多千奇百怪的字体,但是这些对于人脑来讲是很容易就能识别的,其原因就是人脑有很强的非线性能力,因此对于神经网络来说,也不能光进行线性变换,也要有一个非线性的过程。采用下图所示的非线性激活函数——relu函数。

pytorch深度学习实战lesson4_第4张图片

加入激活函数:

pytorch深度学习实战lesson4_第5张图片

6、利用梯度下降算法,计算出三组W和b

pytorch深度学习实战lesson4_第6张图片

7、算好预测值后

使用argmax算出预测值最接近的真实值的labor。

解决识别手写数字的方法(实践部分):

0、准备工作

在另一个.py文件中写入画loss曲线的函数、画图片的函数以及one-hot编码函数。

先导入库

pytorch深度学习实战lesson4_第7张图片

1、加载图片

pytorch深度学习实战lesson4_第8张图片

2、搭建模型

pytorch深度学习实战lesson4_第9张图片

3、训练

pytorch深度学习实战lesson4_第10张图片

pytorch深度学习实战lesson4_第11张图片

到此为止我们已经得到了一组比较不错的【w1,b1,w2,b2,w3,b3】

此时运行程序时,可以看到loss在很稳定的下降。

pytorch深度学习实战lesson4_第12张图片

4、准确度测试

pytorch深度学习实战lesson4_第13张图片

pytorch深度学习实战lesson4_第14张图片

pytorch深度学习实战lesson4_第15张图片

可以看出,预测的结果正确率还是比较可观的。

PS:

由于我发现本课程的课程视频到不完整,再加上没有现成的代码,所以我决定下次更新李沐大神的《动手学深度学习》这门课,也是用的pytorch架构的,我还是非常期待李沐大神的课的!

你可能感兴趣的:(深度学习)