torchvision 是 PyTorch 的一个图形图像库(专门用来处理图像和视觉的),主要用于构建计算机视觉模型。
torchvision 包含四个大类:
torchvision.transforms 是 PyTorch 中的图像预处理包,包含了很多对图像数据进行变换的函数,主要用于常见的一些图形变换。
torchvision.transforms.Compose() 类,这个类的主要作用是串联多个图形变换的操作,它会对列表里面的变换操作进行遍历。
torchvision.transforms.ToTensor() 类,把 shape=(H*W*C) 的像素值范围为 [0, 255] 的 PIL.Image 或者 numpy.ndarray 转换成 shape=(C*H*W) 的像素值范围为 [0.0, 1.0] 的 torch.FloatTensor。
torchvision.transforms.ToPILImage() 类,把 shape=(C*H*W) 的 Tensor 或者 shape=(H*W*C) 的 numpy.ndarray 转换成 shape=(H*W*C) 的 PIL.Image,值不变。
torchvision.transforms.Normalize(mean, std) 类,用给定的均值和标准差分别对每个通道的数据进行规范化。
具体来说,给定均值 ( M 1 , M 2 , . . . , M n ) (M_1, M_2, ..., M_n) (M1,M2,...,Mn) 和标准差 ( S 1 , S 2 , . . . , S n ) (S_1, S_2, ..., S_n) (S1,S2,...,Sn),其中 n(一般为 3 (R, G, B))为通道数。用公式 channel = (channel - mean) / std 来进行规范化。
对每个通道进行如下操作:output[channel] = (input[channel] - mean[channel]) / std[channel]。
比如原来的 tensor 是三个维度的,数值在 [0, 1] 之间,经过变换之后数值范围就扩展到 [-1, 1] 之间。计算如下:((0, 1) - 0.5) / 0.5 = (-1, 1)。
torchvision.transforms.CenterCrop(size) 类,将给定的 PIL.Image 进行中心裁剪,得到指定的 size。参数 size 可以是一个整数,裁剪出来的是一个正方形图像;size 也可以是一个 tuple(target_height, target_width)。
torchvision.transforms.RandomCrop(size, padding=0) 类,对 PIL.Image 进行随机裁剪,即裁剪中心点的位置随机选取。参数 size 可以是一个整数,也可以是一个 tuple。
torchvision.transforms.RandomResizedCrop(size) 类,先对 PIL.Image 进行随机裁剪,然后再将其 resize 成给定 size 大小。
torchvision.transforms.RandomHorizontalFlip(p=0.5) 类,将给定的 PIL.Image 随机水平翻转,翻转的概率默认为 0.5。
torchvision.transforms.RandomVerticalFlip(p=0.5) 类,将给定的 PIL.Image 随机垂直翻转,翻转的概率默认为 0.5。
torchvision.transforms.Pad(padding, fill=0) 类,用给定的值填充 PIL.Image 的所有边。padding–各边要填充多少个像素,fill–用什么值填充。
import torchvision.transforms as transforms
# 图像预处理
transform = transforms.Compose([
transforms.Resize(96), # 将图像缩放到 96*96 大小
transforms.ToTensor(), # 将图像数据转换成张量
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 对张量数据进行归一化
])
torchvision.datasets 是用于数据加载的包,PyTorch 团队在这个包中帮我们提前处理好了很多图像数据集。如 MNIST、CocoCaptions、CocoDetection、LSUN、ImageFolder、ImageNet、CIFAR10、STL10、SVHN、PhotoTour 等。
所有数据集都是 torch.utils.data.Dataset 的子类,它们实现了 _getitem_ 和 __len__ 方法,因此它们都可以传递给 torch.utils.data.DataLoader。
import torchvision.datasets as datasets
# 数据集准备
trainset = datasets.MNIST(
root='./data', # 根目录
train=True, # 用于指定需要载入数据集的哪个部分,这里载入的是训练集;如果为 False,则载入测试集
transform=transform, # 用于指定导入数据集时需要对数据进行哪些变换操作,需提前定义这些变换操作
download=True # 用于指定是否需要网上下载;如果为 True,则从网上下载数据集并将其放在根目录中;如果已下载数据集,则不会再次下载
)
# 数据集加载
trainloader = torch.utils.data.DataLoader(
trainset, # 准备的数据集
batch_size=4, # 设定图像数据的批次大小
shuffle=True, # 如果为 True,则每个 epoch 都会将数据集打乱
num_workers=2, # 设定加载数据时的线程数目;默认为 0,主线程加载数据
collate_fn=<function default_collate>, # 指定取样本的方式,可以自己定义函数来实现想要的功能
pin_memory=False, # 指定是否为锁页内存
drop_last=False # 用于指定对 len(trainset)/batch_size 余下的数据的处理方式;如果为 True,则将最后不够一个 batch_size 的数据抛弃;如果为 False,则保留
)
'''
主机中的内存有两种存在方式,一是锁页,二是不锁页。锁页内存存放的内容在任何情况下都不会与主机的虚拟内存(虚拟内存就是硬盘)进行交换;而不锁页内存在主机内存不足时,数据会存放在虚拟内存中。显卡中的显存全部是锁页内存。
当计算机的内存充足时,可以设置 pin_memory=True;当系统卡住,或者交换内存使用过多的时候,设置 pin_memory=False。因为 pin_memory 与电脑硬件性能有关,PyTorch 开发者不能确保每一个炼丹玩家都有高端设备,因此 pin_memory 默认为 False。
'''
torchvision.models 中包含如 AlexNet、VGG、ResNet、SqueezeNet、DenseNet 等模型结构,同时为我们提供已经预训练好的模型,我们加载之后可以直接使用。
可以通过以下代码快速创建一个随机初始化权重的模型:
import torchvision.models as models
alexnet = models.alexnet()
vgg16 = models.vgg16()
resnet18 = models.resnet18()
也可以通过 pretrained=True 来加载一个预训练好的模型:
import torchvision.models as models
alexnet = models.alexnet(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
resnet18 = models.resnet18(pretrained=True)
我们把 torchvision 中的多个类组合起来使用:
# 初始的 MNIST 数据集图像大小为 28*28,我们把它们处理成 96*96 的 torch.Tensor 的格式
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
# 图像预处理
transform = transforms.Compose([
transforms.Resize(96), # 缩放到 96*96 大小
transforms.ToTensor(), # 将图像转换成 Tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化
])
# 数据集准备
train_dataset = datasets.MNIST(root='./data/',
train=True,
transform=transform,
download=True)
# 数据集加载
train_loader = DataLoader(dataset=train_dataset,
batch_size=8,
shuffle=True)
print(len(train_dataset))
print(len(train_loader))
---------
60000
7500