torchvision在pypi上的文档介绍
PyTorch 0.3.0 中文文档
简介: torchvision包是服务于pytorch深度学习框架的,用来生成图片,视频数据集,和一些流行的模型类和预训练模型.
torchvision由以下四个部分组成:
下面分别介绍
#第一部分: torchvision.datasets
torchvision.datasets是继承torch.utils.data.Dataset的子类. 因此,可以使用torch.utils.data.DataLoader对它们进行多线程处理(python multiprocessing)
比如:
torch.utils.data.DataLoader(coco_cap, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)
torchvision.datasets可能需要transform和target_transform参数,关于二者的解释如下:
torchvision.datasets包括以下内容:
MNIST
COCO (Captioning and Detection)
LSUN Classification
ImageFolder
Imagenet-12
CIFAR10 and CIFAR100
STL10
SVHN
PhotoTour
其中,ImageFolder是一种data loader.图片以下面的方式存放:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png # 不同类别的图片放在各自的文件夹下
dset.ImageFolder(root=“root folder path”, [transform, target_transform])
然后,ImageFolder类有下面三个成员属性:
(1) self.classes - The class names as a list (类别名字列表)
(2) self.class_to_idx - Corresponding class indices (类别对应的序号)
(3) self.imgs - The list of (image path, class-index) tuples (图片路径+类别序号组成的元组)
#第二部分:torchvision.models
torchvision.models包含下列模型的定义:
使用方式1:构建一个模型,随机初始化参数
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
使用方式2:构建一个模型,使用预训练的模型进行参数初始化.
We provide pre-trained models for the ResNet variants, SqueezeNet 1.0 and 1.1, and AlexNet, using the PyTorch model zoo. These can be constructed by passing pretrained=True.
有预训练模型的网络有:ResNet variants, SqueezeNet 1.0 and 1.1, and AlexNet.构建预训练模型使用了torch.utils.model_zoo.设置pretrained=True
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
注:这些pre-trained models要求输入图片格式如下:
imagenet推荐的normalization例子:
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(traindir, transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
#第三部分: torchvision.transforms
torchvision.transforms包含了常见的图像变化(预处理)操作.这些变化可以用torchvision.transforms.Compose链接在一起.
torchvision.transforms中的变化, 可以分为以下几类:
一: Transforms on PIL.Image
*二: Transforms on torch.Tensor
三: Conversion Transforms 数据格式转换操作
四: Generic Transforms 一般的变化操作
#第四部分: torchvision.utils
utils嘛, 就是一些工具. 好像目前只有两个.
一个例子:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
import random
%matplotlib inline
def show(img):
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
import scipy.misc
lena = scipy.misc.face()
img = transforms.ToTensor()(lena)
print(img.size())
torch.Size([3, 768, 1024])
imglist = [img, img, img, img.clone().fill_(-10)]
show(make_grid(imglist, padding=100))
show(make_grid(imglist, padding=100, normalize=True))
show(make_grid(imglist, padding=100, normalize=True, range=(0, 1)))
show(make_grid(imglist, padding=100, normalize=True, range=(0, 0.5)))
show(make_grid(imglist, padding=100, normalize=True, scale_each=True))
show(make_grid(imglist, padding=100, normalize=True, range=(0, 0.5), scale_each=True))