2019/11/3
今天才正式开始看pytorch的官网教程,把一些基础的操作先搞明白吧,虽然之前跑过简单的demo,总是感觉有些地方不能解释得很好,所以这次记录一下新的收获:
1 torch tensor
pytorch里面最基础格式的量就是torch tensor了,但这么基础的量在初始化的时候,还是可以指定很多参数的,例如torch.tensor(data, dtype, device, requires_grad),一个tensor还有.grad和.grad_fn属性。
2 in-place operation
原文:
Any operation that mutates a tensor in-place is post-fixed with an_. For example: x.copy_(y), x.t_(), will change x.
3 numpy vs tensor
原文:numpy的数据和torch tensor格式是可以相互转化的,用tensor.numpy()和torch.from_numpy()就可以互相转化,不过得注意下面一点:
The Torch Tensor and NumPy array will share their underlying memory locations (if the Torch Tensor is on CPU), and changing one will change the other.
4 关于torch.nn建立网络
原文:以前没有留意这问题,不过一般也都符合这个输入要求,就是必须是有batch的那个维度
torch.nn only supports mini-batches. The entire torch.nn package only supports inputs that are a mini-batch of samples, and not a single sample.
If you have a single sample, just use input.unsqueeze(0) to add a fake batch dimension.
虽然以前知道可以print(net)来查看网络的结构,但是不知道还可以用net.conv1来对第一个卷积核进行操作(当然前提是你在init函数里面已经定义了conv1)
5 torch.utils.data
Dataset
torch.utils.data.Dataset,附上原文的一段介绍(因为感觉翻译没有原来的味道):
All datasets that represent a map from keys to data samples should subclass it. All subclasses should overrite
__getitem__()
, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite__len__()
, which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader.
主要就是有自己训练集的时候,可以构建一个类来继承Dataset类,并且重写__getitem__()
方法和__len__()
方法,这样也就可以用后面的的Dataloader来加载数据集了。
DataLoader
torch.utils.data.DataLoader是pytorch里面核心的数据加载的模块,他的接口是这样的,看到可选的参数很多,常用的有dataset, batch_size, shuffle, num_workers:
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
sampler=None, batch_sampler=None, num_workers=0,
collate_fn=None, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, multiprocessing_context=None)
6 torchvision
datasets
torchvision.datasets里面有很多的可以加载的数据集,如MNIST、Fashion-MNIST、KMNIST、EMNIST、QMNIST、FakeData、COCO(Captions\Detection)、LSUN、ImageNet、CIFAR、STL10、SVHN、PhotoTour、SBU、Flickr、VOC、Cityscapes、SBD、USPS、Kinetics-400、HMDB51、UCF101。这些都继承了torch.utils.data.Dataset这个类,所以这些数据集都可以用torch.utils.data.DataLoader的多线程来进行快速的加载(如果我们自己构建自己的dataset,去重写len和getitem方法,也可以调用torch.utils.data.DataLoader来对数据进行加载),而且他们的API接口都很像,差不多都有下面几个参数(以ImageNet为例,不用解释,一看能猜出来):
imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)
当然里面也有一个通用的接口,可以让你自己构建数据集:
torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=
, is_valid_file=None)
只要你按照下面的文件目录结构存放自己的图像就行
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
这里面还有一些是比较常用的数量,一般情况下我们会构建两个dataset,一个是训练的,一个是测试的,如trainset,valset;与之对应就有两个DataLoader,一个是trainloader,一个是valloader:
-
len(trainset)
就是训练图片的数量,len(valset)
验证图片的数量,其实len(trainloader.dataset)
是和len(trainset)
一样的,都是返回训练图片总的数量 -
len(trainloader)
就是一次喂入所有的训练集要多少个batch,len(valloader)
就是一次喂入所有的验证集要多少个batch。所以就有, - trainset是直接可以索引的,如:image, target = trainset[0];而trainloader就不行,但是trainloader是iterable,可以通过循环来进行获取(很多时候会看到别人用enumerate的方法),如
for data in trainloader:
batch_img, batch_target = data
# 或者有的地方也这样写
dataiter = iter(trainloader)
images, labels = dataiter.next()
transforms
torchvision.transforms主要用于对图像进行变换,也就是图像增强data augmentation。
-
Transforms on PIL Image
这个意思就是说,对还不是torch tensor类型的图像的变换操作,一般指通过一些图片导入包如PIL导入的图片:CenterCrop、FiveCrop、TenCrop、ColorJitter、Grayscale、Pad、RandomAffine、RandomApply、RandomChoice、RandomCrop、RandomGrayscale、RandomHorizontalFlip、RandomOrder、RandomPerspective、RandomResizedCrop、RandomRotation、RandomSizedCrop、RandomVerticalFlip、Resize、Scale -
Transforms on torch.*Tensor
这里就是指对torch tensor类型的图像的变换操作,所以一般在这之前都会有一个transforms.ToTensor()的操作:LinearTransformation、Normalize、RandomErasing -
Conversion Transforms
这个就有点像是类型转化的变换操作吧:ToPILImage、ToTensor -
Generic Transforms
这个就是一般性的变换操作:Lambda -
Functional Transforms
函数式的变换操作更加的细粒度,对于一些要求高的任务比较好,如图像分割:adjust_brightness、adjust_contrast、adjust_gamma、adjust_hue、adjust_saturation、affine、crop、erase、five_crop、hflip、normalize、pad、perspective、resize、resized_crop、rotate、ten_crop、to_grayscale、to_pil_image、to_tensor、vflip
最后可以用 torchvision.transforms.Compose(transforms)
把想要对图像做的变换都集中起来,transforms是各种之前特定变换操作的列表
small summary
所以结合6中两个就可以差不多这样写
import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
batch_size=4, shuffle=True,
num_workers=4)
Common import
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
Note
The output of torchvision datasets are PILImage images of range [0, 1].
More
See here for more details on saving PyTorch models.
If you want to see even more MASSIVE speedup using all of your GPUs, please check out Optional: Data Parallelism.