官网相关内容的链接如下:http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#loading-and-normalizing-cifar10
我的系列博文
Pytorch打怪路(一)pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理(本文)
一般来说,使用深度学习框架我们会经过下面几个流程:
模型定义(包括损失函数的选择) --->数据处理和加载 ---> 训练(可能包含训练过程可视化) ---> 测试
所以我们在自己写代码的时候也基本上就按照这四个大模块四步走就ok了
官方给的这个例子呢,是先进行的第二步数据处理和加载,然后定义网络,这其实没什么关系。
所以本篇博文讲解的是 数据处理和加载 这一步的内容,当然会接着在后续博文写其他步骤。
此例的步骤:
A、Load and normalizing the CIFAR10 training and test datasets using torchvision
B、Define a Convolution Neural Network
C、Define a loss function
D、Train the network on the training data
E、Test the network on the test data
下面我就直接上程序,并且添加我自己的一些注解,觉得有问题的欢迎提出,希望和大家多交流。
细节见torchvision的官方文档链接:http://pytorch.org/docs/0.3.0/torchvision/index.html
-
# 首先当然肯定要导入torch和torchvision,至于第三个是用于进行数据预处理的模块
-
import torch
-
import torchvision
-
import torchvision.transforms
as transforms
-
-
# **由于torchvision的datasets的输出是[0,1]的PILImage,所以我们先先归一化为[-1,1]的Tensor**
-
# 首先定义了一个变换transform,利用的是上面提到的transforms模块中的Compose( )
-
# 把多个变换组合在一起,可以看到这里面组合了ToTensor和Normalize这两个变换
-
transform = transforms.Compose(
-
[transforms.ToTensor(),
-
transforms.Normalize((
0.5,
0.5,
0.5), (
0.5,
0.5,
0.5))])
-
-
# 定义了我们的训练集,名字就叫trainset,至于后面这一堆,其实就是一个类:
-
# torchvision.datasets.CIFAR10( )也是封装好了的,就在我前面提到的torchvision.datasets
-
# 模块中,不必深究,如果想深究就看我这段代码后面贴的图1,其实就是在下载数据
-
#(不可能会慢一点吧)然后进行变换,可以看到transform就是我们上面定义的transform
-
trainset = torchvision.datasets.CIFAR10(root=
'./data', train=
True,
-
download=
True, transform=transform)
-
# trainloader其实是一个比较重要的东西,我们后面就是通过trainloader把数据传入网
-
# 络,当然这里的trainloader其实是个变量名,可以随便取,重点是他是由后面的
-
# torch.utils.data.DataLoader()定义的,这个东西来源于torch.utils.data模块,
-
# 网页链接http://pytorch.org/docs/0.3.0/data.html,这个类可见我后面图2
-
trainloader = torch.utils.data.DataLoader(trainset, batch_size=
4,
-
shuffle=
True, num_workers=
2)
-
# 对于测试集的操作和训练集一样,我就不赘述了
-
testset = torchvision.datasets.CIFAR10(root=
'./data', train=
False,
-
download=
True, transform=transform)
-
testloader = torch.utils.data.DataLoader(testset, batch_size=
4,
-
shuffle=
False, num_workers=
2)
-
# 类别信息也是需要我们给定的
-
classes = (
'plane',
'car',
'bird',
'cat',
-
'deer',
'dog',
'frog',
'horse',
'ship',
'truck')
mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 参考官网链接:https://pytorch.org/docs/stable/torchvision/models.html 用ctrl+f 搜索mean 第一个位置就有说明
图1
root表示存放dataset的位置,本例就是' ./data'
train,如果为True,就创建的是trainning set,可以看到我们的trainset调用它时用的是True
而testset调用它时,参数里填的是False
transform,这个transform是形参名,由于我们定义的变换也叫transform,所以就有transform = transform,
看起来可能有点怪,其实我们的之前的变换可以随便命名
download,如果为True,就从网上下载,如果已经有下载好的数据就不会重复下载了
------------------------------------------------------------------------------------------------------------------------------------------
图2
数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。
dataset:就是数据的来源,比如训练集就添入我们定义的trainset
batch_size:每批次进入多少数据,本例中填的是4
shuffle:如果为真,就打乱数据的顺序,本例为True
num_workers:用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
本例中为2。这个值是什么意思呢,就是数据读入的速度到底有多快,你选的用来加载数据的
子进程越多,那么显然数据读的就越快,这样的话消耗CPU的资源也就越多,所以这个值在自己
跑实验的时候,可以自己试一试,既不要让花在加载数据上的时间太多,也不要占用太多电脑资源
所以这第一步----数据加载和处理,要注意的就是这些内容,如果程序运行完毕,会显示:
这里我提个小建议,就是下载数据的那个root参数,官网代码给的是'./data',这个其实可以改成自己的位置
而且,建议改成 绝对路径 要好一点。然后由于代码可以直接从官网复制粘帖,所以这部分程序的运行的快慢,
基本就取决于下载数据的网速了,建议,可能也不见得很快,不如晚上睡觉前开始下,说不定第二天
醒来就下好了呢.......