PyTorch 的视觉工具包 torchvision 提供了大量的图像增强操作(torchvision.transforms 模块), 其主要针对 PIL.Image 对象和 torch.Tensor 对象
对于 PIL.Image 对象, transforms 中包含大量的类, 其内部实现调用了 PIL 包中的方法, 使用时先创建特定操作的实例, 然后将该实例视为函数去调用 PIL.Image 对象, 返回增强后的 PIL.Image 对象
对于较新版本的 torchvision(0.8.0 以上), 很多适用于 PIL.Image 对象的 transform 操作同样适用于 torch.Tensor 对象
在本文中, 主要展示如何使用 PIL 包手动实现 PIL.Image 图像的 transforms 操作, 并与官方库的结果比较
以下示例使用 COCO 数据集中的一张图片, 大小为 640*480
将 PIL.Image
或 np.ndarray
转换为 torch.Tensor
首先调整维度顺序, 从 H*W*C 转变为 C*H*W, 然后将图像矩阵的每一个数值都除以 255 转为浮点类型
导入所需模块, 较新版本的 torchvision 使用 torch 模块而不是 random 来产生随机数, 以下程序建议使用 jupyter-notebook 环境下运行
import math
import PIL
import numpy as np
import torch
from PIL import Image, ImageEnhance
from torchvision import transforms
# 下载图片 http://images.cocodataset.org/test-stuff2017/000000000001.jpg, 默认位于当前目录
img_origin = Image.open("000000000001.jpg")
img_width, img_height = img_origin.size
# 调用官方库
img = transforms.ToTensor()(img_origin)
# 手动实现
img_2 = torch.from_numpy(np.array(img_origin))
# 调整维度顺序后, 使 Tensor 连续
img_2 = img_2.permute((2,0,1)).contiguous()
img_2 = img_2.float().div(255)
# tensor(True)
# 两种结果相同
torch.all(img == img_2)
CenterCrop
相对于图像中心裁剪出一块区域
torchvision.transforms.CenterCrop(size=(h, w))
size
可以是一个 int, 表示裁剪正方形区域
如果要在原始图像中心裁剪出 200*200 的区域, 需要确定裁剪区域左上角和右下角的坐标, 然后调用 PIL.Image
中的 crop
方法
# 调用库方法
img_cropped = transforms.CenterCrop(200)(img_origin)
# 计算裁剪位置
x1 = (img_width - 200) // 2
y1 = (img_height - 200) // 2
x2 = x1 + 200
y2 = y1 + 200
# 使用 PIL 手动实现裁剪
img_cropped_2 = img_origin.crop((x1, y1, x2, y2))
两种方法的结果一致
# True
img_cropped == img_cropped_2
左侧为原图
RandomCrop
在图像的随机位置裁剪出一块区域
torchvision.transforms.RandomCrop(size=(h, w))
size
可以为一个 int, 表示裁剪正方形区域
随机选取裁剪区域左上角的坐标位置
# 摘自源代码
# h w 为原始图像的宽高
# th tw 为裁剪区域的宽高
# 低版本 torchvision 实现
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
# 新版本 torchvision 实现
i = torch.randint(0, h - th + 1, size=(1, )).item()
j = torch.randint(0, w - tw + 1, size=(1, )).item()
# 设定随机数种子, 保证随机结果一致
_ = torch.manual_seed(1)
# 调用库方法
img_rand_cropped = transforms.RandomCrop(200)(img_origin)
# 手动实现
_ = torch.manual_seed(1)
# 官方实现中先计算纵坐标, 这里的顺序与官方一致
y1 = torch.randint(0, img_height - 200 + 1, size=(1,)).item()
x1 = torch.randint(0, img_width - 200 + 1, size=(1,)).item()
y2 = y1 + 200
x2 = x1 + 200
img_rand_cropped_2 = img_origin.crop((x1, y1, x2, y2))
# True
img_rand_cropped == img_rand_cropped_2
RandomCrop
是可以先填充后随机裁剪, 这样裁剪区域可能会包含填充的区域, 其参数为 RandomCrop(size, padding=None, fill=0, padding_mode='constant')
实例化时可传参 pad_if_needed=True
, 即便裁剪区域大于图像宽高, 也可自动填充, 如果原始图像尺寸为 180*120, 随机裁剪 200*200 的区域, 实际过程为将宽度和高度填充到 200 (在图像四周填充), 然后裁剪 img.crop((0, 0, 200, 200))
在底层的实现上, Resize
直接调用 PIL.Image
的 resize 方法
Scale 与 Resize 功能相同, 实际上 Scale 直接继承 Resize
torchvision.transforms.Resize(size=(h, w))
如果 size
为一个 int, 表示将最短的边缩放到该值, 而另一条边自动缩放 (保持宽高比)
# 最短边(高度)缩放一半(480->240), 为保持宽高比, 宽度也缩放一半(640->320)
img_resized = transforms.Resize(240)(img_origin)
# 手动实现缩放
img_resized_2 = img_origin.resize((320, 240), Image.BILINEAR)
# True
img_resized == img_resized_2
Image.BILINEAR
表示使用双线性插值, 实例化 Resize
时可传入第二个参数 interpolation
指定插值方法, 该参数默认为双线性插值(PIL.Image.BILINEAR)
RandomResizedCrop
随机选取一块区域, 然后将该区域缩放到指定的宽高
torchvision.transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR)
scale
表示裁剪区域的面积占原始图形面积的比例, 随机选取面积后, 根据 ratio
随机生成的宽高比得到裁剪区域的宽高, 然后随机选取一个位置进行裁剪, 最后将该区域缩放到指定宽高
interpolation
表示插值方法
# 调用库方法
_ = torch.manual_seed(3)
img_rand_resize_cropped = transforms.RandomResizedCrop(200)(img_origin)
# 手动实现
_ = torch.manual_seed(3)
# 面积随机, 0.08 和 1.0 是 scale 参数的默认值
area = img_width * img_height
random_area = torch.empty(1).uniform_(0.08, 1.0).item() * area
# 宽高比随机
aspect_radio = math.exp(torch.empty(1).uniform_(math.log(3.0/4.0), math.log(4.0/3.0)))
# 裁剪区域的宽高
w = round(math.sqrt(random_area * aspect_radio))
h = round(math.sqrt(random_area / aspect_radio))
# 裁剪位置随机
y1 = torch.randint(0, img_height - h + 1, size=(1,)).item()
x1 = torch.randint(0, img_width - w + 1, size=(1,)).item()
y2 = y1 + h
x2 = x1 + w
# 先裁剪
img = img_origin.crop((x1, y1, x2, y2))
# 后缩放
img_rand_resize_cropped_2 = img.resize((200, 200), Image.BILINEAR)
img_rand_resize_cropped_2
两种方法的结果一致
# True
img_rand_resize_cropped == img_rand_resize_cropped_2
如果 scale 或 ratio 参数中的两个数值一样(例如 scale=(0.5, 0.5)), 则每一次都固定选取 0.5
center_crop = transforms.CenterCrop(480)(img_origin)
# 裁剪面积等于原图面积, 裁剪时宽度等于高度(宽高比为1), 计算边长 sqrt(640*480), 显然高度超过原始图像的高度
center_crop_2 = transforms.RandomResizedCrop(size=480, scale=(1.0, 1.0), ratio=(1.0, 1.0))(img_origin)
# True
center_crop == center_crop_2
如果裁剪时的宽度或高度超过原始图像, 则改用 CenterCrop
裁剪
torchvision.transforms.RandomHorizontalFlip(p=0.5)
torchvision.transforms.RandomVerticalFlip(p=0.5)
如果 random.random() < p
, 则进行翻转操作
水平翻转 img_origin.transpose(Image.FLIP_LEFT_RIGHT)
垂直翻转 img_origin.transpose(Image.FLIP_TOP_BOTTOM)
flip_h = transforms.RandomHorizontalFlip(p=1.0)(img_origin)
flip_v = transforms.RandomVerticalFlip(p=1.0)(img_origin)
在底层的实现上, RandomRotation
直接调用 PIL.Image
的 rotate 方法
RandomRotation(degrees, resample=False, expand=False, center=None, fill=None)
degrees
指定旋转范围, 如果是单个 int 或 float, 视为 (-int, int) 或 (-float, float)
resample
指定采样方法, 例如, 最近邻PIL.Image.NEAREST
, 双线性 PIL.Image.BILINEAR
expand
旋转后的图像尺寸是否与原始图像一致
center
指定旋转中心的坐标, 默认是图像中心
fill
旋转后空白区域的填充色
逆时针旋转 30 度, 使用紫色填充多余空白:
# 库方法
img_rotate = transforms.RandomRotation((30, 30), resample=Image.BILINEAR, expand=True, fill=(255, 0, 255))(img_origin)
# 手动实现
img_rotate_2 = img_origin.rotate(30, resample=Image.BILINEAR, expand=True, fillcolor=(255, 0, 255))
# True
img_rotate == img_rotate_2
指定 expand=True
后, 图像尺寸变大了, 但原始图像的信息被完整保存
顺时针旋转 30 度, 采样方式使用 PIL.Image.NEAREST
(False 视为 0)
# 库方法
img_rotate = transforms.RandomRotation((-30, -30), resample=False, expand=False, fill=0)(img_origin)
# 手动实现
img_rotate_2 = img_origin.rotate(-30, resample=False, expand=False, fillcolor=0)
# True
img_rotate == img_rotate_2
这一次, 旋转后的图像尺寸没有扩张, 但原始图像的四个角落被裁剪
torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
brightness 范围 [0.0, 1.0], 1.0 为原图, 0.0 为纯黑
参数形式 [min, max], 也可以是 float, 等价于 [max(0, 1 - brightness), 1 + brightness]
使用 PIL.ImageEnhance.Brightness
实现
contrast 范围 [0.0, 1.0], 1.0 为原图, 0.0 为纯灰
参数形式 [min, max], 也可以是 float, 等价于 [max(0, 1 - contrast), 1 + contrast]
使用 PIL.ImageEnhance.Contrast
实现
saturation 范围 [0.0, 1.0], 1.0 为原图, 0.0 为黑白图片
参数形式 [min, max], 也可以是 float, 等价于 [max(0, 1 - saturation), 1 + saturation]
使用 PIL.ImageEnhance.Color
实现
范围 [-0.5, 0.5]
参数形式 [min, max], 如果是 float, 等价于 [-hue, hue]
每一次对图片调用 ColorJitter 时均使用 torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
的形式进行随机选取
如果设置了多个参数(例如设置了 brightness 和 contrast), 每一次调用时, 顺序也是随机的, 可能先调整亮度后调整对比度, 也可能先调整对比度后调整亮度
# 调用官方库
img_brightness_00 = transforms.ColorJitter(brightness = (0.00, 0.00))(img_origin)
img_brightness_25 = transforms.ColorJitter(brightness = (0.25, 0.25))(img_origin)
img_brightness_50 = transforms.ColorJitter(brightness = (0.50, 0.50))(img_origin)
img_brightness_75 = transforms.ColorJitter(brightness = (0.75, 0.75))(img_origin)
# 手动实现
enhancer = ImageEnhance.Brightness(img_origin)
img_brightness_50_2 = enhancer.enhance(0.50)
# True
# 结果一致
img_brightness_50 == img_brightness_50_2
img_contrast_00 = transforms.ColorJitter(contrast = (0.00, 0.00))(img_origin)
img_contrast_25 = transforms.ColorJitter(contrast = (0.25, 0.25))(img_origin)
img_contrast_50 = transforms.ColorJitter(contrast = (0.50, 0.50))(img_origin)
img_contrast_75 = transforms.ColorJitter(contrast = (0.75, 0.75))(img_origin)
img_saturation_00 = transforms.ColorJitter(saturation = (0.00, 0.00))(img_origin)
img_saturation_25 = transforms.ColorJitter(saturation = (0.25, 0.25))(img_origin)
img_saturation_50 = transforms.ColorJitter(saturation = (0.50, 0.50))(img_origin)
img_saturation_75 = transforms.ColorJitter(saturation = (0.75, 0.75))(img_origin)
最后更新时间: 2021-01-14