pytorch学习之路一:识别MINIST手写数字数据集

1.首先,参考pytorch官网安装pytorch,我选择的是windows系统+python3.7+pip安装,所以选择好后,直接在cmd中运行图中的命令行,pytorch官网:https://pytorch.org/

pytorch学习之路一:识别MINIST手写数字数据集_第1张图片

2.打开pycharm。新建工程,有的童鞋可能会找不到下载的pytorch包,请参考:https://blog.csdn.net/weixin_39954922/article/details/105606956

3.开始撸代码:(代码见文末网址)

3.1首先导入需要的包,

pytorch学习之路一:识别MINIST手写数字数据集_第2张图片

3.2 导入MINIST数据集,如果没有的话可以通过代码下载,但是真的灰常慢~,还是科学上网好(手动滑稽),这里说明一下,pycharm下载MINIST数据集后,会自动放在当前工程文件夹下,链接里已下载好共各位看官食用

pytorch学习之路一:识别MINIST手写数字数据集_第3张图片

3.3开始构件卷积神经网络。这里使用class类来建立

pytorch学习之路一:识别MINIST手写数字数据集_第4张图片

pytorch学习之路一:识别MINIST手写数字数据集_第5张图片

3.4选择优化器和损失函数

pytorch学习之路一:识别MINIST手写数字数据集_第6张图片

3.5开始训练,并保存模型

pytorch学习之路一:识别MINIST手写数字数据集_第7张图片

3.6训练一次得到模型之后,没有必要重复训练,只需要调用模型即可,所以可以将3.5节的代码注释掉,但是这里需要注意的是,调用模型时,仍然需要用到前面定义的cnn,因此,若重新打开一个.py文件调用模型,需要重新声明一下 CNN class并定义cnn=CNN(),,为了看清图片,本代码中用了opencv-python的imshow()函数来显示要识别的图片

pytorch学习之路一:识别MINIST手写数字数据集_第8张图片

3.6预测结果:
pytorch学习之路一:识别MINIST手写数字数据集_第9张图片

 

代码GITHUB传送门:https://github.com/1240117300/MINIST

你可能感兴趣的:(日常,python,深度学习,神经网络)