手动实现 torchvision.transforms 图像增强(一)

简介

PyTorch 的视觉工具包 torchvision 提供了大量的图像增强操作(torchvision.transforms 模块), 其主要针对 PIL.Image 对象和 torch.Tensor 对象

对于 PIL.Image 对象, transforms 中包含大量的类, 其内部实现调用了 PIL 包中的方法, 使用时先创建特定操作的实例, 然后将该实例视为函数去调用 PIL.Image 对象, 返回增强后的 PIL.Image 对象

对于较新版本的 torchvision(0.8.0 以上), 很多适用于 PIL.Image 对象的 transform 操作同样适用于 torch.Tensor 对象


在本文中, 主要展示如何使用 PIL 包手动实现 PIL.Image 图像的 transforms 操作, 并与官方库的结果比较

以下示例使用 COCO 数据集中的一张图片, 大小为 640*480

Conversion Transforms

ToTensor

PIL.Imagenp.ndarray 转换为 torch.Tensor

首先调整维度顺序, 从 H*W*C 转变为 C*H*W, 然后将图像矩阵的每一个数值都除以 255 转为浮点类型


导入所需模块, 较新版本的 torchvision 使用 torch 模块而不是 random 来产生随机数, 以下程序建议使用 jupyter-notebook 环境下运行

import math

import PIL
import numpy as np
import torch
from PIL import Image, ImageEnhance
from torchvision import transforms
# 下载图片 http://images.cocodataset.org/test-stuff2017/000000000001.jpg, 默认位于当前目录
img_origin = Image.open("000000000001.jpg")
img_width, img_height = img_origin.size

# 调用官方库
img = transforms.ToTensor()(img_origin)

# 手动实现
img_2 = torch.from_numpy(np.array(img_origin))
# 调整维度顺序后, 使 Tensor 连续
img_2 = img_2.permute((2,0,1)).contiguous()
img_2 = img_2.float().div(255)

# tensor(True)
# 两种结果相同
torch.all(img == img_2)

Geometric Distortion

CenterCrop

中心裁剪

CenterCrop 相对于图像中心裁剪出一块区域

参数

torchvision.transforms.CenterCrop(size=(h, w))

size 可以是一个 int, 表示裁剪正方形区域

示例

如果要在原始图像中心裁剪出 200*200 的区域, 需要确定裁剪区域左上角和右下角的坐标, 然后调用 PIL.Image 中的 crop 方法

# 调用库方法
img_cropped = transforms.CenterCrop(200)(img_origin)

# 计算裁剪位置
x1 = (img_width - 200) // 2
y1 = (img_height - 200) // 2
x2 = x1 + 200
y2 = y1 + 200
# 使用 PIL 手动实现裁剪
img_cropped_2 = img_origin.crop((x1, y1, x2, y2))

两种方法的结果一致

# True
img_cropped == img_cropped_2

左侧为原图

手动实现 torchvision.transforms 图像增强(一)_第1张图片

RandomCrop

随机裁剪

RandomCrop 在图像的随机位置裁剪出一块区域

参数

torchvision.transforms.RandomCrop(size=(h, w))

size 可以为一个 int, 表示裁剪正方形区域

实现

随机选取裁剪区域左上角的坐标位置

# 摘自源代码
# h  w  为原始图像的宽高
# th tw 为裁剪区域的宽高

# 低版本 torchvision 实现
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
# 新版本 torchvision 实现
i = torch.randint(0, h - th + 1, size=(1, )).item()
j = torch.randint(0, w - tw + 1, size=(1, )).item()
示例
# 设定随机数种子, 保证随机结果一致
_ = torch.manual_seed(1)
# 调用库方法
img_rand_cropped = transforms.RandomCrop(200)(img_origin)

# 手动实现
_ = torch.manual_seed(1)
# 官方实现中先计算纵坐标, 这里的顺序与官方一致
y1 = torch.randint(0, img_height - 200 + 1, size=(1,)).item()
x1 = torch.randint(0, img_width - 200 + 1, size=(1,)).item()
y2 = y1 + 200
x2 = x1 + 200

img_rand_cropped_2 = img_origin.crop((x1, y1, x2, y2))

# True
img_rand_cropped == img_rand_cropped_2

手动实现 torchvision.transforms 图像增强(一)_第2张图片

填充

RandomCrop 是可以先填充后随机裁剪, 这样裁剪区域可能会包含填充的区域, 其参数为 RandomCrop(size, padding=None, fill=0, padding_mode='constant')

实例化时可传参 pad_if_needed=True, 即便裁剪区域大于图像宽高, 也可自动填充, 如果原始图像尺寸为 180*120, 随机裁剪 200*200 的区域, 实际过程为将宽度和高度填充到 200 (在图像四周填充), 然后裁剪 img.crop((0, 0, 200, 200))

Resize & Scale

在底层的实现上, Resize 直接调用 PIL.Image 的 resize 方法

Scale 与 Resize 功能相同, 实际上 Scale 直接继承 Resize

参数

torchvision.transforms.Resize(size=(h, w))

如果 size 为一个 int, 表示将最短的边缩放到该值, 而另一条边自动缩放 (保持宽高比)

示例
# 最短边(高度)缩放一半(480->240), 为保持宽高比, 宽度也缩放一半(640->320)
img_resized = transforms.Resize(240)(img_origin)

# 手动实现缩放
img_resized_2 = img_origin.resize((320, 240), Image.BILINEAR)

# True
img_resized == img_resized_2

Image.BILINEAR 表示使用双线性插值, 实例化 Resize 时可传入第二个参数 interpolation 指定插值方法, 该参数默认为双线性插值(PIL.Image.BILINEAR)

手动实现 torchvision.transforms 图像增强(一)_第3张图片

RandomResizedCrop

RandomResizedCrop 随机选取一块区域, 然后将该区域缩放到指定的宽高

参数

torchvision.transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR)

scale 表示裁剪区域的面积占原始图形面积的比例, 随机选取面积后, 根据 ratio 随机生成的宽高比得到裁剪区域的宽高, 然后随机选取一个位置进行裁剪, 最后将该区域缩放到指定宽高

interpolation 表示插值方法

示例
# 调用库方法
_ = torch.manual_seed(3)

img_rand_resize_cropped = transforms.RandomResizedCrop(200)(img_origin)

# 手动实现
_ = torch.manual_seed(3)

# 面积随机, 0.08 和 1.0 是 scale 参数的默认值
area = img_width * img_height
random_area = torch.empty(1).uniform_(0.08, 1.0).item() * area
# 宽高比随机
aspect_radio = math.exp(torch.empty(1).uniform_(math.log(3.0/4.0), math.log(4.0/3.0)))
# 裁剪区域的宽高
w = round(math.sqrt(random_area * aspect_radio))
h = round(math.sqrt(random_area / aspect_radio))
# 裁剪位置随机
y1 = torch.randint(0, img_height - h + 1, size=(1,)).item()
x1 = torch.randint(0, img_width - w + 1, size=(1,)).item()
y2 = y1 + h
x2 = x1 + w

# 先裁剪
img = img_origin.crop((x1, y1, x2, y2))
# 后缩放
img_rand_resize_cropped_2 = img.resize((200, 200), Image.BILINEAR)
img_rand_resize_cropped_2

两种方法的结果一致

# True
img_rand_resize_cropped == img_rand_resize_cropped_2

手动实现 torchvision.transforms 图像增强(一)_第4张图片

补充

如果 scale 或 ratio 参数中的两个数值一样(例如 scale=(0.5, 0.5)), 则每一次都固定选取 0.5

center_crop = transforms.CenterCrop(480)(img_origin)

# 裁剪面积等于原图面积, 裁剪时宽度等于高度(宽高比为1), 计算边长 sqrt(640*480), 显然高度超过原始图像的高度
center_crop_2 = transforms.RandomResizedCrop(size=480, scale=(1.0, 1.0), ratio=(1.0, 1.0))(img_origin)

# True
center_crop == center_crop_2

如果裁剪时的宽度或高度超过原始图像, 则改用 CenterCrop 裁剪

手动实现 torchvision.transforms 图像增强(一)_第5张图片

RandomHorizontalFlip & RandomVerticalFlip

水平翻转

torchvision.transforms.RandomHorizontalFlip(p=0.5)

垂直翻转

torchvision.transforms.RandomVerticalFlip(p=0.5)

实现

如果 random.random() < p, 则进行翻转操作

水平翻转 img_origin.transpose(Image.FLIP_LEFT_RIGHT)

垂直翻转 img_origin.transpose(Image.FLIP_TOP_BOTTOM)

flip_h = transforms.RandomHorizontalFlip(p=1.0)(img_origin)
手动实现 torchvision.transforms 图像增强(一)_第6张图片
flip_v = transforms.RandomVerticalFlip(p=1.0)(img_origin)
手动实现 torchvision.transforms 图像增强(一)_第7张图片

RandomRotation

随机旋转

在底层的实现上, RandomRotation 直接调用 PIL.Image 的 rotate 方法

参数

RandomRotation(degrees, resample=False, expand=False, center=None, fill=None)

degrees 指定旋转范围, 如果是单个 int 或 float, 视为 (-int, int) 或 (-float, float)

resample 指定采样方法, 例如, 最近邻PIL.Image.NEAREST, 双线性 PIL.Image.BILINEAR

expand 旋转后的图像尺寸是否与原始图像一致

center 指定旋转中心的坐标, 默认是图像中心

fill 旋转后空白区域的填充色

示例

逆时针旋转 30 度, 使用紫色填充多余空白:

# 库方法
img_rotate = transforms.RandomRotation((30, 30), resample=Image.BILINEAR, expand=True, fill=(255, 0, 255))(img_origin)

# 手动实现
img_rotate_2 = img_origin.rotate(30, resample=Image.BILINEAR, expand=True, fillcolor=(255, 0, 255))

# True
img_rotate == img_rotate_2

指定 expand=True 后, 图像尺寸变大了, 但原始图像的信息被完整保存

手动实现 torchvision.transforms 图像增强(一)_第8张图片

顺时针旋转 30 度, 采样方式使用 PIL.Image.NEAREST(False 视为 0)

# 库方法
img_rotate = transforms.RandomRotation((-30, -30), resample=False, expand=False, fill=0)(img_origin)

# 手动实现
img_rotate_2 = img_origin.rotate(-30, resample=False, expand=False, fillcolor=0)

# True
img_rotate == img_rotate_2

这一次, 旋转后的图像尺寸没有扩张, 但原始图像的四个角落被裁剪

手动实现 torchvision.transforms 图像增强(一)_第9张图片

Photometric Distortion

ColorJitter

参数

torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)

亮度

brightness 范围 [0.0, 1.0], 1.0 为原图, 0.0 为纯黑

参数形式 [min, max], 也可以是 float, 等价于 [max(0, 1 - brightness), 1 + brightness]

使用 PIL.ImageEnhance.Brightness 实现

对比度

contrast 范围 [0.0, 1.0], 1.0 为原图, 0.0 为纯灰

参数形式 [min, max], 也可以是 float, 等价于 [max(0, 1 - contrast), 1 + contrast]

使用 PIL.ImageEnhance.Contrast 实现

饱和度

saturation 范围 [0.0, 1.0], 1.0 为原图, 0.0 为黑白图片

参数形式 [min, max], 也可以是 float, 等价于 [max(0, 1 - saturation), 1 + saturation]

使用 PIL.ImageEnhance.Color 实现

色调

范围 [-0.5, 0.5]

参数形式 [min, max], 如果是 float, 等价于 [-hue, hue]

随机

每一次对图片调用 ColorJitter 时均使用 torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() 的形式进行随机选取

如果设置了多个参数(例如设置了 brightness 和 contrast), 每一次调用时, 顺序也是随机的, 可能先调整亮度后调整对比度, 也可能先调整对比度后调整亮度

示例
亮度
# 调用官方库
img_brightness_00 = transforms.ColorJitter(brightness = (0.00, 0.00))(img_origin)

img_brightness_25 = transforms.ColorJitter(brightness = (0.25, 0.25))(img_origin)

img_brightness_50 = transforms.ColorJitter(brightness = (0.50, 0.50))(img_origin)

img_brightness_75 = transforms.ColorJitter(brightness = (0.75, 0.75))(img_origin)

# 手动实现
enhancer = ImageEnhance.Brightness(img_origin)
img_brightness_50_2 = enhancer.enhance(0.50)

# True
# 结果一致
img_brightness_50 == img_brightness_50_2

对比度
img_contrast_00 = transforms.ColorJitter(contrast = (0.00, 0.00))(img_origin)

img_contrast_25 = transforms.ColorJitter(contrast = (0.25, 0.25))(img_origin)

img_contrast_50 = transforms.ColorJitter(contrast = (0.50, 0.50))(img_origin)

img_contrast_75 = transforms.ColorJitter(contrast = (0.75, 0.75))(img_origin)

饱和度
img_saturation_00 = transforms.ColorJitter(saturation = (0.00, 0.00))(img_origin)

img_saturation_25 = transforms.ColorJitter(saturation = (0.25, 0.25))(img_origin)

img_saturation_50 = transforms.ColorJitter(saturation = (0.50, 0.50))(img_origin)

img_saturation_75 = transforms.ColorJitter(saturation = (0.75, 0.75))(img_origin)

手动实现 torchvision.transforms 图像增强(一)_第10张图片手动实现 torchvision.transforms 图像增强(一)_第11张图片


最后更新时间: 2021-01-14

你可能感兴趣的:(PyTorch,图像识别,计算机视觉)