【Pytorch学习笔记】数据增强

前言

torchvision 模块包含常用的数据集,模型建构,图像变换算法,分别是torchvision.datasets,torchvision.models,torchvision.transforms。本次主要学习torchvision.transforms对数据集进行预处理。

torchvision.transforms.Compose()

transforms.Compose() 用于整合一系列的图像变换函数,将图片按照 Compose() 中的顺序依次处理。torch.nn.Sequential() 与 transforms.Compose() 起到相同的功能。torch.nn.Sequential() 和 torch.jit.script() 结合来导出模型。

'''Compose'''
transform1 = transforms.Compose([
	transforms.CenterCrop(10),
	transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

'''Sequential'''
transform2 = torch.nn.Sequential(
	transforms.CenterCrop(10),
	transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
scripted_transforms = torch.jit.script(transforms)

1. Transforms on PIL Image and torch.Tensor

大多数的图像变换函数图像可以是 PILtensor 类型,部分函数只接受 PIL 或只接受 tensor 。可以使用 transforms.ToPILImage()transforms.ToTensor 进行类型转换。

# 读取图片
from PIL import Image
from torchvision import transforms as T
img = Image.open('C:\\Users\\myt\\Desktop\\test.png')	# (400 * 300)
img.show()

【Pytorch学习笔记】数据增强_第1张图片

1.1 transforms.CenterCrop(size)

以图像中心为中心点,将图片裁剪成指定的大小。如果输入图像尺寸小于指定的输出的大小,则在图像的边界进行”填0“,之后再进行裁剪。如果是 torch.Tensor 类型则大小为 […, H, W]
size(sequence or int): 裁剪后的尺寸大小
输入为一个 int 型,输出 (size, size) 图像;输入为长度1的 sequence 时,输出 (size[0], size[0]) ;输入为 (h, w) ,输出为 (h, w)

# CenterCrop
imgCC = T.CenterCrop((128, 256)).forward(img)
imgCC.show()
imgCC.save('C:\\Users\\myt\\Desktop\\imgcc.png')
print(imgCC.size)

结果:(256, 128)
【Pytorch学习笔记】数据增强_第2张图片

1.2 transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)

随机改变图像的亮度,对比度,饱和度,色调。如果是 torch.Tensor 类型则大小为 […, H, W]
brightness (float or tuple of python:float (min, max)) :[max(0, 1-brightness), 1+brightness][min, max] 随机选择的非负数。
contrast (float or tuple of python:float (min, max)) :[max(0, 1-contrast), 1+contrast][min, max] 随机选择的非负数。
saturation (float or tuple of python:float (min, max)) :[max(0, 1-saturation), 1+saturation][min, max] 随机选择的非负数。
hue (float or tuple of python:float (min, max)) :[-hue, hue][min, max] 随机选择,且 0<=hue<=0.5或-0.5<=min<=max<=0.5

# ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
imgCJ = T.ColorJitter(brightness=0.5, hue= 0.3).forward(img)
imgCJ.show()
imgCJ.save('C:\\Users\\myt\\Desktop\\imgcj.png')
print(imgCJ.size)

结果:(400, 300)
【Pytorch学习笔记】数据增强_第3张图片

1.3 transforms.FiveCrop(size)

分别裁剪图像的四个角和中心。size(sequence or int): 裁剪的尺寸大小,输入为一个 int 型,输出 (size, size) 图像;输入为长度1的 sequence 时,输出 (size[0], size[0]) ;输入为 (h, w) ,输出为 (h, w)

# FiveCrop
imgFC1, imgFC2, imgFC3, imgFC4, imgFC5 = T.FiveCrop(size=(64, 64))(img)
plot([imgFC1, imgFC2, imgFC3, imgFC4, imgFC5])
print(imgFC1.size)

结果:(64, 64)
在这里插入图片描述

1.4 transforms.TenCrop(size,vertical_flip=False)

分别裁剪图像的四个角和中心,再将五张裁剪的图片进行翻转,默认是水平翻转。

  • size(sequence or int) : 裁剪的尺寸大小,输入为一个 int 型,输出 (size, size) 图像;输入为长度1的 sequence 时,输出 (size[0], size[0]) ;输入为 (h, w) ,输出为 (h, w)
  • vertical_flip(bool): True 对应垂直翻转,False 对应水平翻转
# Resize
(TC0, TC1, TC2, TC3, TC4, TC5, TC6, TC7, TC8, TC9) = T.TenCrop(size=64, vertical_flip=True)(img)
plot([TC0, TC1, TC2, TC3, TC4])
plot([TC5, TC6, TC7, TC8, TC9])

TenCrop()
TenCrop()1

1.5 transforms.Grayscale(num_output_channels=1)

将图像转换为灰度图像。如果是 torch.Tensor 类型则大小为 […, 3, H, W]
num_output_channels(int)-(1 or 3): 输出图像通道数,当”3“时则三通道R=G=B。

# Grayscale(num_output_channels)
imgGS = T.Grayscale(1)(img)
imgGS.show()
imgGS.save('C:\\Users\\myt\\Desktop\\imgcj.png')
print(imgGS.size)

结果:(400, 300)
【Pytorch学习笔记】数据增强_第4张图片

1.6 transforms.Pad(padding,fill=0,padding_mode=‘constant’)

将图片所有的边界填充上 padding 像素值。如果是 torch.Tensor 类型则大小为 […, H, W]

  • padding(int or sequence): 表示填充大小,为一个 int 型时,每个边界填充大小都一样;长度为2的 sequence 时,左、右边界填充大小 Padding[0] ,上、下边界填充大小 Padding[1];长度为4的 sequence 时,依次是左、上、右、下边界填充大小。
  • fill (number or str or tuple):表示填充的颜色,只用于常数填充模式,默认是0。如果是一个数则表示填充的灰度值,如果是长度为3的 tuple 类型则分别对应R,G,B。对于 tensor 类型,fiill 只能是 number 类型;对于 PIL 类型 intstrtuple 类型均可。
  • padding_mode(str): 填充模式分为:constant, edge, reflect, symmetric 模式。
    其中 constant: 填充常数;edge: 把边界的像素值复制到填充区域;reflect: 原始 [1,2,3,4],不重复边界 [3, 2, 1, 2, 3, 4, 3, 2]symmetric: 重复边界 [2, 1, 1, 2, 3, 4, 4, 3]
# Pad
imgPd = T.Pad(padding=5, fill=125 ,padding_mode='constant')(img)
imgPd.show()
imgPd.save('C:\\Users\\myt\\Desktop\\imgPd.png')
print(imgPd.size)

结果:(410, 310)
【Pytorch学习笔记】数据增强_第5张图片

1.7 transforms.RandomAffine()

transforms.RandomAffine(degrees, translate=None, scale=None, shear=None, interpolation=, fill=0, fillcolor=None, resample=None)
图像随机仿射变换,保持图像中心不变。如果是 torch.Tensor 类型则大小为 […, H, W]

  • degrees(sequence or number): 旋转角度,输入为 (min, max)(-degree, +degree)
  • translate(tuple, optional): 水平、垂直平移。以 (a, b) 为例,水平方向上 -img_width * a < dx< img_width * a,垂直方向上 -img_high * b < dy
  • scale(tuple, optional): 缩放因子区间。以 (a, b) 为例,a <= scale <= b
  • shear(sequence or number, optional):一个数的情况和一个长度为2的 tuple 时,对 X轴 也就是宽上加上一个 (-shear, +shear)(shear[0], shear[1]) 的随机数;长度为4的tuple 时,X轴 加上**(shear[0], shear[1])** 的随机数,Y轴 加上**(shear[2], shear[3])** 的随机数。
  • interpolation(InterpolationMode): 插值方法:最近邻或双线性插值
  • fill: 变换后图像外的像素点填充值,默认为0,如果是一个数则表示填充的灰度值,如果是长度为3的 tuple 类型则分别对应R,G,B。
  • fillcolor: v0.10.0后移除,使用 fill 即可。
  • resample: v0.10.0后移出,使用 interpolation 即可。
# RandomAffine
affine_transfomer = T.RandomAffine(degrees=(30,70), translate=(0.1, 0.3), scale=(0.5, 0.75) ,shear=30, fill=(255, 255, 0))
imgAF = affine_transfomer(img)
imgAF.show()
imgAF.save('C:\\Users\\myt\\Desktop\\imgAF.png')
imgAF.size

结果:(400, 300)
【Pytorch学习笔记】数据增强_第6张图片

1.8 transforms.RandomApply(transforms, p=0.5)

按照给定的概率,随机应用一系列的图像变换函数。

  • transforms(sequence or torch.nn.Module): 用列表存储一系列的图像变换函数
  • p(float): 进行图像变换的概率。
# RandomApply
applier = T.RandomApply(transforms=[T.RandomCrop(size=(64, 64)), T.ColorJitter(brightness=0.5, hue=0.3), T.CenterCrop((128, 256))], p=0.25)
transformed_imgs = [applier(orig_img) for _ in range(4)]
plot(transformed_imgs)

RandomApply()

1.9 transforms.RandomCrop()

transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode=‘constant’): 随机位置裁剪图像,如果是 torch.Tensor 类型则大小为 […, H, W]

  • size(sequence or int): 裁剪后的尺寸大小,输入为一个 int 型,输出 (size, size) 图像;输入为长度1的 sequence 时,输出 (size[0], size[0]) ;输入为 (h, w) ,输出为 (h, w)
  • padding(int or sequence): 表示填充大小,为一个 int 型时,每个边界填充大小都一样;长度为2的 sequence 时,左、右边界填充大小 Padding[0] ,上、下边界填充大小 Padding[1];长度为4的 sequence 时,依次是左、上、右、下边界填充大小。
  • pad_if_needed(boolean): 如果图像小于所需的尺寸,它就会填充,以避免引发异常。由于裁剪是在填充之后进行的,填充似乎是在一个随机的偏移量上进行的。
  • fill: 变换后图像外的像素点填充值,默认为0,如果是一个数则表示填充的灰度值,如果是长度为3的 tuple 类型则分别对应R,G,B。
  • padding_mode(str): 填充模式分为:constant, edge, reflect, symmetric 模式。
    其中 **constant:**填充常数;**edge:**把边界的像素值复制到填充区域;**reflect:**原始 [1,2,3,4],不重复边界 [3, 2, 1, 2, 3, 4, 3, 2];**symmetric:**重复边界 [2, 1, 1, 2, 3, 4, 4, 3]
# RandomCrop
imgRC = T.RandomCrop(size=(128, 128))(img)
imgRC.show()
imgRC.save('C:\\Users\\myt\\Desktop\\imgRC.png')
imgRC.size

结果:(128, 128)
【Pytorch学习笔记】数据增强_第7张图片

1.10 transforms.RandomGraycale(p=0.1)

按照给定的概率,随机将图片转化为灰度图像。如果是 torch.Tensor 类型则大小为 […, 3, H, W]

# RandomGrayscale
Grayscale = T.RandomGrayscale(p=0.25)
transformed_imgs = [Grayscale(orig_img) for _ in range(4)]
plot(transformed_imgs)
print(Grayscale(orig_img).size)

结果:(400, 300)
RandomGrayscale()

1.11 transforms.RandomHorizontalFlip(p=0.5)

按照给定的概率,随机水平翻转图像。如果是 torch.Tensor 类型则大小为 […, H, W]

# RandomHorizontalFlip
hfilpper = T.RandomHorizontalFlip(p=0.25)
transformed_imgs = [hfilpper(orig_img) for _ in range(4)]
plot(transformed_imgs)

RandomHorizontalFlip()

1.12 transforms.RandomVerticalFlip(p=0.5)

按照给定的概率,随机垂直翻转图像。如果是 torch.Tensor 类型则大小为 […, H, W]

# RandomVerticalFlip
vfilpper = T.RandomVerticalFlip(p=0.25)
transformed_imgs = [vfilpper(orig_img) for _ in range(4)]
plot(transformed_imgs)

RandomVerticalFlip()

1.13 transforms.RandomPerspective()

torchvision.transforms.RandomPerspective(distortion_scale=0.5, p=0.5, interpolation=, fill=0): 按照给定的概率,进行随机透视变换。如果是 torch.Tensor 类型则大小为 […, H, W]

  • distortion_scale(float): 控制图像畸变的参数 (0, 1)。默认是0.5。
  • p(float): 进行透视变换的概率为 0.5
  • interpolation(interpolationMode): 插值方法:最近邻或双线性插值。
  • fill(sequence or number): 变换后图像外的像素点填充值,默认为0,如果是一个数则表示填充的灰度值,如果是长度为3的 tuple 类型则分别对应R,G,B。
# RandomPerspective
perspective_transformer = T.RandomPerspective(distortion_scale=0.6, p=0.5)
perspective_imgs = [perspective_transformer(orig_img) for _ in range(3)]
plot(perspective_imgs)

Rando,Perspective()

1.14 transforms.RandomResizedCrop()

torchvision.transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=): 裁剪图像的一个随机部分,并将其调整到一个给定的尺寸。如果是 torch.Tensor 类型则大小为 […, H, W]。对原始图像进行裁剪:裁剪有一个随机的区域(H*W)和一个随机的长宽比。这个裁剪最后被调整到给定的尺寸。

  • size(sequence or int): 裁剪后的尺寸大小,输入为一个 int 型,输出 (size, size) 图像;输入为长度1的 sequence 时,输出 (size[0], size[0]) ;输入为 (h, w) ,输出为 (h, w)
  • scale(tuple of python:float):resize 之前,指定裁剪的随机区域的下限和上限。 scale 是相对于原始图像的面积来定义的。
  • ratio(tuple of python:float): 调整大小前,裁剪的随机长宽比的上下限。
  • interpolation(interpolationMode): 插值方法:最近邻或双线性插值。
# RandomResizedCrop
resize_cropper = T.RandomResizedCrop(size=(64,64), scale=(0.08,1.0), ratio=(0.75, 1.33))
resize_cropp_imgs = [resize_cropper(orig_img) for _ in range(4)]
plot(resize_cropp_imgs)

RandomResizedCrop()

1.15 transforms.RandomRotation()

torchvision.transforms.RandomRotation(degrees, interpolation=, expand=False, center=None, fill=0, resample=None): 根据 degress 旋转图像,如果是 torch.Tensor 类型则大小为 […, H, W]

  • degrees(sequence or number): 旋转角度,输入为 (min, max)(-degree, +degree)
  • interpolation(interpolationMode): 插值方法:最近邻或双线性插值。
  • expand(bool,optional): 旋转后是否修改图像大小,True 则修改大小保证可以容纳图像;False 则维持原大小,对图像进行适当的舍弃。
  • center(sequence,optional):旋转的中心(x, y),左上角为坐标原点。默认为图像中心。
  • fill(sequence or number): 变换后图像外的像素点填充值,默认为0,如果是一个数则表示填充的灰度值,如果是长度为3的 tuple 类型则分别对应R,G,B。
  • resample(int,optional): v0.10.0后会移除,使用 interpolation 即可。
# RandomRotation
imgRR = T.RandomRotation(degrees=(0, 30), center=((int(img.size[0]/3), int(img.size[1]/3))), expand=True)(img)
imgRR.show()
imgRR.save('C:\\Users\\myt\\Desktop\\imgRR.png')
print(imgRR.size)

结果:(445, 363)
【Pytorch学习笔记】数据增强_第8张图片

1.16 transforms.Resize()

torchvision.transforms.Resize(size, interpolation=, max_size=None, antialias=None) 将输入图像的尺寸变换为指定尺寸。如果是 torch.Tensor 类型则大小为 […, H, W]

  • size(sequence or int):裁剪后的尺寸大小,输入为一个 int 型,则size对应短边,当height>weight(size*height/weight, size) 图像;输入为长度1的 sequence 时,输出 (size[0], size[0]) ;输入为 (h, w) ,输出为 (h, w)
  • interpolation(interpolationMode): 插值方法:最近邻或双线性插值。
  • max_size(int,optional): 调整后的图像的长边允许的最大值:如果图像的长边在根据尺寸调整后大于max_size,那么图像将再次调整,使长边等于max_size。
# Resize
imgRs = T.Resize(size=150, max_size=160)(img)
imgRs.show()
imgRs.save('C:\\Users\\myt\\Desktop\\imgRs.png')
imgRs.size

结果:(160, 120)
原图像 400×300,根据 size 变为 (150×400/300, 150) = (200, 150),其中长边 200 超过max_size,所以按比例修改为 (160, 160×150/200) = (160, 120)。
【Pytorch学习笔记】数据增强_第9张图片

1.17 transforms.GaussianBlur(kernel_size,sigma=(0.1, 2.0))

使用高斯滤波器进行图像模糊,高斯滤波器的参数随机。如果是 torch.Tensor 类型则大小为 […, C, H, W]

  • kernel_size(int or sequence): 高斯滤波器的大小。
  • sigma(float or tuple of python:float(min, max)):(min, max) 随机数作为标准差。
# GaussianBlur
imgGB = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))
imgGB_imgs = [imgGB(orig_img) for _ in range(4)]
plot(imgGB_imgs)

GaussianBlur()

1.18 transforms.RandomInvert(p=0.5)

以给定的概率随机地插入给定图像的颜色。如果是 torch.Tensor 类型则大小为 […, 1or3, H, W]

  • p(float): 图像转换的概率
# RandomInvert
imgRI = T.RandomInvert(p=0.3)
imgRI_imgs = [imgRI(orig_img) for _ in range(3)]
plot(imgRI_imgs)

RandomInvert()

1.19 transforms.RandomPosterize(bits, p=0.5)

通过减少每个颜色通道的比特数,以给定的概率随机地对图像进行贴图。如果图像是torch Tensor,它应该是 torch.uint8 类型的大小为 […, 1or3, H, W],如果img是PIL图像,预计它的模式是 "L "或 “RGB”。

  • bits(int): 每个通道保留的比特数。
  • p(float): 修该通道比特数的概率,默认0.5。
# RandomPosterize(bits,p=0.5)
imgRP = T.RandomPosterize(bits=2, p=0.6)
imgRP_imgs = [imgRP(orig_img) for _ in range(3)]
plot(imgRP_imgs)

RandomPosterize

1.20 transforms.RandomSolarize(threshold,p=0.5)

通过反转所有高于阈值的像素值,以给定的概率随机地晒出图像。如果图像是torch Tensor,它的大小为 […, 1or3, H, W],如果img是PIL图像,预计它的模式是 "L "或 “RGB”。

  • threshold(float): 阈值,所有大于等于阈值的像素点进行反转。
  • p(float): 发生反转的概率,默认为0.5。
# RandomSolarize
imgRS = T.RandomSolarize(threshold = 192.0)
imgRS_imgs = [imgRS(orig_img) for _ in range(2)]
plot(imgRS_imgs)

RandomSolarize()

1.21 transforms.RandomAdjustSharpness(sharpness_factor,p=0.5)

以给定的概率随机调整图像的锐度。如果图像是torch Tensor,它的大小为 […, 1or3, H, W]

  • sharpness_factor(float): 调整锐度的程度。可以是任何非负数。0给出一个模糊的图像,1给出原始图像,而2将清晰度提高2倍。
  • p(float): 调整锐度的概率,默认为0.5。
# RandomAdjustSharpness
imgRAS = T.RandomAdjustSharpness(sharpness_factor=2,p=1)(img)
imgRAS.show()
imgRAS.save('C:\\Users\\myt\\Desktop\\imgRAS.png')

【Pytorch学习笔记】数据增强_第10张图片

1.22 transforms.RandomAutocontrast(p=0.5)

以一个给定的概率随机地对给定图像的像素进行最大化对比度操作。如果图像是torch Tensor,它的大小为 […, 1or3, H, W]

  • p(float): 调整对比度的概率,默认为0.5。
# RandomAutocontrast
imgRA = T.RandomAutocontrast(p=1)(img)
imgRA.show()
imgRA.save('C:\\Users\\myt\\Desktop\\imgRA.png')

【Pytorch学习笔记】数据增强_第11张图片

1.23 transforms.RandomEqualize(p=0.5)

以给定的概率随机对图像进行直方图均衡化。如果图像是torch Tensor,它的大小为 […, 1or3, H, W]

# RandomEqualize
orig_img = Image.open('C:\\Users\\myt\\Desktop\\test3.jpg')
imgRE = T.RandomEqualize(p=1)
imgRE_imgs = [imgRE(orig_img) for _ in range(1)]
plot(imgRE_imgs)

【Pytorch学习笔记】数据增强_第12张图片

1.24 transforms.Normalize(mean,std,inplace=False)

使用均值和标准差对 torch.tensor 图像进行归一化。这个图像变换不支持 PIL

  • mean(sequence):(mean[1], mean[2], …, mean[n]) 分别表示 n 个通道的均值。
  • std(sequence):(std[1], std[2], …, std[n]) 分别表示 n 个通道的标准差。
  • inplace(bool, optional): 是否原地操作。
# Normalize
img = Image.open('C:\\Users\\myt\\Desktop\\test.png')
img_tensorf = T.ToTensor()(img)
img_Norm = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img_tensorf)

2. Conversion Transforms

2.1 transforms.ToPILImage(mode=None)

torch.tensorC×H×W 图像或 ndarrayH×W×C 图像转换为 PIL图像。

  • mode(PIL.Image mode): 输入数据的颜色空间和像素深度(可选)。如果模式为无(默认),则会对输入数据做一些假设。- 如果输入的数据有4个通道,模式就被假定为 RGBA。- 如果输入有3个通道,则假设模式为 RGB。- 如果输入有2个通道,则假设模式为 LA。- 如果输入有1个通道,模式由数据类型决定(如int, float, short)。
import torchvision.io as TI
img = TI.read_image('C:\\Users\\myt\\Desktop\\test.jpg')
print(img.shape)
imgI = T.ToPILImage()(img)
imgI.show()
imgI.save('C:\\Users\\myt\\Desktop\\imgI.png')

结果:torch.Size([3, 300, 400])
【Pytorch学习笔记】数据增强_第13张图片

2.2 transforms.ToTensor

PIL 图像或 ndarrayH×W×C 图像转换为 torch.tensorC×H×W 图像。PIL 模式为 **(L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) **,ndarraynp.unit8

# ToTensor
img = Image.open('C:\\Users\\myt\\Desktop\\test.png')
print(img.size)
imgT = T.ToTensor()(img)
print(imgT.shape)

结果:(400, 300);torch.Size([3, 300, 400])

3. Automatic Augmentation Transforms

AutoAugment 是一种常见的数据增强技术,可以提高图像分类模型的准确性。虽然数据增强策略与其训练的数据集直接相关,但经验研究表明,ImageNet 策略在应用于其他数据集时能提供显著的改进。在 TorchVision 中,我们实现了在以下数据集上学习的3个策略。ImageNetCIFAR10SVHN。新的变换可以独立使用,也可以与现有的 transform 混合使用。

3.1 torchvision.transforms.AutoAugmentPolicy(value)

其中主要有三个成员 AutoAugmentPolicy.CIFAR10、AutoAugmentPolicy.IMAGENET、AutoAugmentPolicy.SVHNAutoAugment 联用。

# AutoAugmentPolicy
policies = [T.AutoAugmentPolicy.CIFAR10, T.AutoAugmentPolicy.IMAGENET, T.AutoAugmentPolicy.SVHN]
augmenters = [T.AutoAugment(policy) for policy in policies]
imgs = [
    [augmenter(orig_img) for _ in range(4)]
    for augmenter in augmenters
]
row_title = [str(policy).split('.')[-1] for policy in policies]
plot(imgs, row_title=row_title)

【Pytorch学习笔记】数据增强_第14张图片

3.2 transforms.RandAugment()

RandAugment 是一种简单的高性能数据增强技术,可以提高图像分类模型的准确性。

# RandAugment
augmenter = T.RandAugment()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)				

【Pytorch学习笔记】数据增强_第15张图片

3.3 transforms.TrivalAugmentWide()

TrivialAugmentWide 是一种独立于数据集的数据增强技术,可以提高图像分类模型的准确性。

# TrivalAugmentWide
augmenter = T.TrivialAugmentWide()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)

【Pytorch学习笔记】数据增强_第16张图片

4. Functional Transforms

torchvision.transforms.functrional 模块提供的函数形式的图像变化,可以让用户自定图像变换类。

# Functional Transform
import torchvision.transforms.functional as TF
import random

class MyRotationTransform:
    """Rotate by one of the given angles."""

    def __init__(self, angles):
        self.angles = angles

    def __call__(self, x):
        angle = random.choice(self.angles)
        return TF.rotate(x, angle)

rotation_transform = MyRotationTransform(angles=[-30, -15, 0, 15, 30])

transforms = T.Compose(
    [MyRotationTransform((30,60)),
     T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
     T.CenterCrop(size=(256, 128))
    ])

imgMT = transforms(img)
imgMT.show()
imgMT.save('C:\\Users\\myt\\Desktop\\imgMT.png')

【Pytorch学习笔记】数据增强_第17张图片

你可能感兴趣的:(PyTorch学习,pytorch,python,深度学习)