Pytorch学习笔记(三)——手写数字识别

1. 首先导入所需要的包,其中torchvision包主要实现数据的处理、导入和预览

import torch
from torchvision import datasets, transforms
from torch.autograd import Variable

2.torchvision中的datasets可以实现对数据集的下载,例如MNIST、COCO、ImageNet、CIFCAR,代码如下:

# Download the datasets
data_train = datasets.MNIST(root = "./data/",
                            transform=transform,
                            train = True,
                            download = True)
data_test = datasets.MNIST(root = "./data/",
                            transform=transform,
                            train = False)

其中:root指定了数据集下载后的存放路径,train指定了当前为测试集还是训练集,transfrom指定了所应用的变换,其代码如下:

# Set the transform format
transform = transforms.Compose([transforms.ToTensor(),
                                ])

3.transforms的具体应用:

transforms主要负责对载入的数据进行变换,主要是变为Tensor类型,以及归一化和大小缩放的操作。除此以外,当数据集比较有限时,可以通过变换训练集生成更多的数据进行训练。(数据增强)

上一段的Compose可以看作一种容器,所传入的是一个列表,能够容纳多种数据变换。常用的数据变换操作有:

torchvision.transforms.Resize(h,w):对载入的图片数据按照需求大小进行缩放;

torchvision.transforms.Scale(h,w) : 同上

torchvision.transforms.CenterCrop(h,w) : 以图片中心为参考点,对载入的图片进行裁剪

torchvision.transforms.RondomHorizontalFlip(rate) : 对载入图片随机水平翻转

torchvision.transforms.RondomVerticalFlip(rate) : 对载入图片随机垂直翻转

torchvision.transforms.ToTensor()

torchvision.transforms.ToPILImage()

你可能感兴趣的:(AI,Pytorch,Python,人工智能,人工智能)