PyTorch学习笔记(17)--torchvision.transforms用法介绍

PyTorch学习笔记(17)–torchvision.transforms用法介绍

    本博文是PyTorch的学习笔记,第17次内容记录,主要记录了torchvision.transforms的使用方法。

目录

  • PyTorch学习笔记(17)--torchvision.transforms用法介绍
  • 1.问题来源
  • 2.torchvision.transforms具体用法
  • 3.torchvision.transforms其他的用法
  • 4.补充torchvision模块的其他功能
  • 5.运行错误解决

1.问题来源

    在读ResNet的应用代码时,遇到下面这一小段代码,这段代码出现在读取图片信息之前,这段代码的具体功能是什么呢?对于初学者来说很有必要弄清楚这段代码的具体含义

    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

2.torchvision.transforms具体用法

    PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:torchvision.datasets、torchvision.models、torchvision.transforms。而上面这段代码就用到了torchvision.transforms这个包。

    这里用到的 torchvision 工具库是 pytorch 框架下常用的图像处理包,可以用来生成图片和视频数据集(torchvision.datasets),做一些图像预处理(torchvision.transforms),导入预训练模型(torchvision.models),以及生成图和保存图像(torchvision.utils)。
    其中,transforms函数对图像做预处理可以是:归一化(normalize)尺寸剪裁(resize)翻转(flip) 等。
    上面的这些步骤实际操作起来往往是一系列的,此时可以用compose将这些图像预处理操作连起来。
    如上面的代码,这里做的操作是:
    transforms.ToTensor() ,将一个PIL图像转换为tensor。即, ( H ∗ W ∗ C ) (H\ast W\ast C) (HWC)范围在[0,255]的PIL图像 转换为 ( C ∗ H ∗ W ) (C\ast H\ast W) (CHW)范围在[0,1]的torch.tensor。
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ,用均值[0.485, 0.456, 0.406]和标准差[0.229, 0.224, 0.225]对图像做归一化处理。

3.torchvision.transforms其他的用法

    transforms函数另外的功能还包括:

    Resize:把给定的图片resize到给定的尺寸。

    ToPILImage: 将torch.tensor 转换为PIL图像。

    CenterCrop:以输入图的中心点为中心做指定size的裁剪操作。

    RandomCrop:以输入图的随机位置为中心做指定size的裁剪操作。

    RandomHorizontalFlip:以0.5概率水平翻转给定的PIL图像。

    RandomVerticalFlip:以0.5概率竖直翻转给定的PIL图像。

    RandomResizedCrop:将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为制定的大小(有一个参数n)。

    Grayscale:将给定图像转换为灰度图像。

    RandomGrayscale:将图像以指定的概率转换为灰度图像。

    FiveCrop: 从一张输入图像中裁剪出5张指定size的图像,包括4个角的图像和一个中心。

    TenCrop:剪出10张指定size的图像。做法是在FiveCrop的基础上,再将输入图像进行水平或竖直翻转,然后进行FiveCrop操作,这样一张图像可得到10张crop图像。

    Pad:对给定图像的所有边用的“padding”个像素用“fill”值填充。

    ColorJitter:修改图像的亮度,对比度,饱和度和色度。

    Lambda:做其参数指定的变换。

    上述四个包及其具体函数的详细介绍参考Pytorch的中文文档。

    代码实现可以参考github的代码实现。

4.补充torchvision模块的其他功能

    torchvision 是独立于 PyTorch 的关于图像操作的一个工具库,目前包括六个模块:

    1)torchvision.datasets:几个常用视觉数据集,可以下载和加载,以及如何编写自己的 Dataset。

     2)torchvision.models:经典模型,例如 AlexNet、VGG、ResNet 等,以及训练好的参数。

     3)torchvision.transforms:常用的图像操作,例随机切割、旋转、数据类型转换、tensor 与 numpy 和 PIL Image 的互换等。

     4)torchvision.ops:提供 CV 中常用的一些操作,比如 NMS、ROI_Align、ROI_Pool 等。

     5)torchvision.io:提供输入输出的一些操作,目前针对的是视频的写入写出。

     6)torchvision.utils:其他工具,比如产生一个图像网格等。

5.运行错误解决

    问题1:数据集为彩色图像,通道数为3,但是模型中输入通道数为1,也就是接收灰色图像,这时,训练模型时会报错,具体错误为:

RuntimeError: Given groups=1, weight of size 32 3 3 3, expected input[1, 4, 416, 416] to have 3 channels

    解决输入通道数的问题,也就是要将3通道的彩色图像修改成1通道的灰色图像,这时的修改方式为:

修改前:
train_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=True,
                                          transform=torchvision.transforms.torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=False,
                                         transform=torchvision.transforms.ToTensor(),
                                         download=True)
修改后:
train_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=True,
                                          transform=torchvision.transforms.Compose([
                                              torchvision.transforms.Grayscale(),
                                              torchvision.transforms.ToTensor()]),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=False,
                                         transform=torchvision.transforms.Compose([
                                             torchvision.transforms.Grayscale(),
                                             torchvision.transforms.ToTensor()]),
                                         download=True)

    也就是增加一个torchvision.transforms.Grayscale()的操作。
    问题1:

你可能感兴趣的:(PyTorch学习笔记,pytorch,深度学习,人工智能)