本博文是PyTorch的学习笔记,第17次内容记录,主要记录了torchvision.transforms的使用方法。
在读ResNet的应用代码时,遇到下面这一小段代码,这段代码出现在读取图片信息之前,这段代码的具体功能是什么呢?对于初学者来说很有必要弄清楚这段代码的具体含义
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:torchvision.datasets、torchvision.models、torchvision.transforms。而上面这段代码就用到了torchvision.transforms
这个包。
这里用到的 torchvision 工具库是 pytorch 框架下常用的图像处理包,可以用来生成图片和视频数据集(torchvision.datasets),做一些图像预处理(torchvision.transforms),导入预训练模型(torchvision.models),以及生成图和保存图像(torchvision.utils)。
其中,transforms函数对图像做预处理可以是:归一化(normalize)
,尺寸剪裁(resize)
,翻转(flip)
等。
上面的这些步骤实际操作起来往往是一系列的,此时可以用compose将这些图像预处理操作连起来。
如上面的代码,这里做的操作是:
transforms.ToTensor() ,将一个PIL图像转换为tensor。即, ( H ∗ W ∗ C ) (H\ast W\ast C) (H∗W∗C)范围在[0,255]的PIL图像 转换为 ( C ∗ H ∗ W ) (C\ast H\ast W) (C∗H∗W)范围在[0,1]的torch.tensor。
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ,用均值[0.485, 0.456, 0.406]和标准差[0.229, 0.224, 0.225]对图像做归一化处理。
transforms函数另外的功能还包括:
Resize:把给定的图片resize到给定的尺寸。
ToPILImage: 将torch.tensor 转换为PIL图像。
CenterCrop:以输入图的中心点为中心做指定size的裁剪操作。
RandomCrop:以输入图的随机位置为中心做指定size的裁剪操作。
RandomHorizontalFlip:以0.5概率水平翻转给定的PIL图像。
RandomVerticalFlip:以0.5概率竖直翻转给定的PIL图像。
RandomResizedCrop:将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为制定的大小(有一个参数n)。
Grayscale:将给定图像转换为灰度图像。
RandomGrayscale:将图像以指定的概率转换为灰度图像。
FiveCrop: 从一张输入图像中裁剪出5张指定size的图像,包括4个角的图像和一个中心。
TenCrop:剪出10张指定size的图像。做法是在FiveCrop的基础上,再将输入图像进行水平或竖直翻转,然后进行FiveCrop操作,这样一张图像可得到10张crop图像。
Pad:对给定图像的所有边用的“padding”个像素用“fill”值填充。
ColorJitter:修改图像的亮度,对比度,饱和度和色度。
Lambda:做其参数指定的变换。
上述四个包及其具体函数的详细介绍参考Pytorch的中文文档。
代码实现可以参考github的代码实现。
torchvision 是独立于 PyTorch 的关于图像操作的一个工具库,目前包括六个模块:
1)torchvision.datasets:几个常用视觉数据集,可以下载和加载,以及如何编写自己的 Dataset。
2)torchvision.models:经典模型,例如 AlexNet、VGG、ResNet 等,以及训练好的参数。
3)torchvision.transforms:常用的图像操作,例随机切割、旋转、数据类型转换、tensor 与 numpy 和 PIL Image 的互换等。
4)torchvision.ops:提供 CV 中常用的一些操作,比如 NMS、ROI_Align、ROI_Pool 等。
5)torchvision.io:提供输入输出的一些操作,目前针对的是视频的写入写出。
6)torchvision.utils:其他工具,比如产生一个图像网格等。
问题1:数据集为彩色图像,通道数为3,但是模型中输入通道数为1,也就是接收灰色图像,这时,训练模型时会报错,具体错误为:
RuntimeError: Given groups=1, weight of size 32 3 3 3, expected input[1, 4, 416, 416] to have 3 channels
解决输入通道数的问题,也就是要将3通道的彩色图像修改成1通道的灰色图像,这时的修改方式为:
修改前:
train_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=True,
transform=torchvision.transforms.torchvision.transforms.ToTensor(),
download=True)
test_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=False,
transform=torchvision.transforms.ToTensor(),
download=True)
修改后:
train_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.Grayscale(),
torchvision.transforms.ToTensor()]),
download=True)
test_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=False,
transform=torchvision.transforms.Compose([
torchvision.transforms.Grayscale(),
torchvision.transforms.ToTensor()]),
download=True)
也就是增加一个torchvision.transforms.Grayscale()
的操作。
问题1: