上一节我们学习了Pytorch优化网络的基本方法,本节我们将以MNIST数据集为例,通过搭建一个完整的神经网络,来加深对Pytorch的理解。
一、数据集
MNIST是一个非常经典的数据集,下载链接:http://yann.lecun.com/exdb/mnist/
下载下来的文件如下:
该手写数字数据库具有60,000个示例的训练集和10,000个示例的测试集。它是NIST提供的更大集合的子集。数字已经过尺寸标准化,并以固定尺寸的图像为中心。
手写数字识别是一个比较简单的任务,它是一个10分类问题,(0-9),之所以选这个数据集,是因为识别难度低,计算量小,数据容易获得。
二、模型搭建
1、网络节点的确定
对于不同的目的,网络的选择也是不一样的。一般来说,网络容量和数据集大小是对应的。一个小型数据集也只需要一个小型的网络。
这里有一个经验值:
1)model_size=sqrt(in_size*out_size)
2)model_size=log(in_size)
3) model_size=sqrt(in_size*out_size)
model_size:网络的节点量
in_size:输入的节点量
out_size输出的节点量
2、导入pytorch包
import torch import torchvision import trochvision import datasets import trochvision import transforms from torch.autograd import Variable
3、获取训练集和测试集
#root用于指定数据集下载后的存放路径
#transform用于指定导入数据集需要对数据进行变换操作
#train指定在数据集下载后需要载入哪部分数据,true为训练集,false为测试集
data_train=datasets.MNIST(root="./data/",transform=transform,train=True,download=True) data_test=datasets.MNIST(root='./data/',transform=transform,train=False)
4、数据预览和装载
#数据装载,可以理解为对图片的处理 #处理完成后,将图片送给模型训练,装载就是打包的过程 #dataset 用于指定载入的数据集名称 #batch_size设置了每个包的图片数据数据个数 #shuffle 装载过程将数据随机打乱并打包 data_loader_train=torch.utils.data.DataLoader(dataset=data_train,batch_size=64,shuffle=True) data_loader_test=torch.utils.data.DataLoader(dataset=data_test,batch_size=64,shuffle=True)