看pytorch official tutorials的新收获(持续更新)

pytorch标志

2019/11/3
今天才正式开始看pytorch的官网教程,把一些基础的操作先搞明白吧,虽然之前跑过简单的demo,总是感觉有些地方不能解释得很好,所以这次记录一下新的收获:

1 torch tensor

pytorch里面最基础格式的量就是torch tensor了,但这么基础的量在初始化的时候,还是可以指定很多参数的,例如torch.tensor(data, dtype, device, requires_grad),一个tensor还有.grad和.grad_fn属性。

2 in-place operation

原文:

Any operation that mutates a tensor in-place is post-fixed with an_. For example: x.copy_(y), x.t_(), will change x.

3 numpy vs tensor

原文:numpy的数据和torch tensor格式是可以相互转化的,用tensor.numpy()和torch.from_numpy()就可以互相转化,不过得注意下面一点:

The Torch Tensor and NumPy array will share their underlying memory locations (if the Torch Tensor is on CPU), and changing one will change the other.

4 关于torch.nn建立网络

原文:以前没有留意这问题,不过一般也都符合这个输入要求,就是必须是有batch的那个维度

torch.nn only supports mini-batches. The entire torch.nn package only supports inputs that are a mini-batch of samples, and not a single sample.
If you have a single sample, just use input.unsqueeze(0) to add a fake batch dimension.

虽然以前知道可以print(net)来查看网络的结构,但是不知道还可以用net.conv1来对第一个卷积核进行操作(当然前提是你在init函数里面已经定义了conv1)

5 torch.utils.data

Dataset

torch.utils.data.Dataset,附上原文的一段介绍(因为感觉翻译没有原来的味道):

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite__len__(), which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader.

主要就是有自己训练集的时候,可以构建一个类来继承Dataset类,并且重写__getitem__()方法和__len__()方法,这样也就可以用后面的的Dataloader来加载数据集了。

DataLoader

torch.utils.data.DataLoader是pytorch里面核心的数据加载的模块,他的接口是这样的,看到可选的参数很多,常用的有dataset, batch_size, shuffle, num_workers

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, 
                                                           sampler=None, batch_sampler=None, num_workers=0, 
                                                           collate_fn=None, pin_memory=False, drop_last=False, 
                                                           timeout=0, worker_init_fn=None, multiprocessing_context=None)

6 torchvision

datasets

torchvision.datasets里面有很多的可以加载的数据集,如MNIST、Fashion-MNIST、KMNIST、EMNIST、QMNIST、FakeData、COCO(Captions\Detection)、LSUN、ImageNet、CIFAR、STL10、SVHN、PhotoTour、SBU、Flickr、VOC、Cityscapes、SBD、USPS、Kinetics-400、HMDB51、UCF101。这些都继承了torch.utils.data.Dataset这个类,所以这些数据集都可以用torch.utils.data.DataLoader的多线程来进行快速的加载(如果我们自己构建自己的dataset,去重写lengetitem方法,也可以调用torch.utils.data.DataLoader来对数据进行加载),而且他们的API接口都很像,差不多都有下面几个参数(以ImageNet为例,不用解释,一看能猜出来):

imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)

当然里面也有一个通用的接口,可以让你自己构建数据集:

torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=, is_valid_file=None)
只要你按照下面的文件目录结构存放自己的图像就行
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

这里面还有一些是比较常用的数量,一般情况下我们会构建两个dataset,一个是训练的,一个是测试的,如trainset,valset;与之对应就有两个DataLoader,一个是trainloader,一个是valloader:

  • len(trainset)就是训练图片的数量,len(valset)验证图片的数量,其实len(trainloader.dataset)是和len(trainset)一样的,都是返回训练图片总的数量
  • len(trainloader)就是一次喂入所有的训练集要多少个batch,len(valloader)就是一次喂入所有的验证集要多少个batch。所以就有,
  • trainset是直接可以索引的,如:image, target = trainset[0];而trainloader就不行,但是trainloader是iterable,可以通过循环来进行获取(很多时候会看到别人用enumerate的方法),如
for data in trainloader:
    batch_img, batch_target = data
# 或者有的地方也这样写
dataiter = iter(trainloader)
images, labels = dataiter.next()

transforms

torchvision.transforms主要用于对图像进行变换,也就是图像增强data augmentation。

  • Transforms on PIL Image
    这个意思就是说,对还不是torch tensor类型的图像的变换操作,一般指通过一些图片导入包如PIL导入的图片:CenterCrop、FiveCrop、TenCrop、ColorJitter、Grayscale、Pad、RandomAffine、RandomApply、RandomChoice、RandomCrop、RandomGrayscale、RandomHorizontalFlip、RandomOrder、RandomPerspective、RandomResizedCrop、RandomRotation、RandomSizedCrop、RandomVerticalFlip、Resize、Scale
  • Transforms on torch.*Tensor
    这里就是指对torch tensor类型的图像的变换操作,所以一般在这之前都会有一个transforms.ToTensor()的操作:LinearTransformation、Normalize、RandomErasing
  • Conversion Transforms
    这个就有点像是类型转化的变换操作吧:ToPILImage、ToTensor
  • Generic Transforms
    这个就是一般性的变换操作:Lambda
  • Functional Transforms
    函数式的变换操作更加的细粒度,对于一些要求高的任务比较好,如图像分割:adjust_brightness、adjust_contrast、adjust_gamma、adjust_hue、adjust_saturation、affine、crop、erase、five_crop、hflipnormalizepad、perspective、resize、resized_crop、rotate、ten_crop、to_grayscale、to_pil_image、to_tensor、vflip

最后可以用 torchvision.transforms.Compose(transforms)把想要对图像做的变换都集中起来,transforms是各种之前特定变换操作的列表

small summary

所以结合6中两个就可以差不多这样写

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

Common import

import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets

Note

The output of torchvision datasets are PILImage images of range [0, 1].

More

See here for more details on saving PyTorch models.
If you want to see even more MASSIVE speedup using all of your GPUs, please check out Optional: Data Parallelism.

你可能感兴趣的:(看pytorch official tutorials的新收获(持续更新))