Environment
本文记录了使用 PyTorch 实现图像数据预处理的方法,包括数据增强和标准化。主要的工具为 torchvision.transform
模块中的各种类,要求传入的图片为 PIL.Image
类型的图片。
对图像数据进行变换,可以增加训练样本量。数据变换的方式很多,实际应用时,原则是使手头的训练数据与真实数据接近。例如,如果训练数据和真实数据的物体颜色有差别,那么可以采用色彩变换;如果训练数据和真实数据的物体位置常常差别,那么可以平移变换…
这里以经典的 Lenna 图片为例,选择一块大小为 299 x 299 的部分。
from PIL import Image
image = Image.open('Lenna.jpg')
print(type(image))
#
print(np.array(img).shape)
# (299, 299, 3)
image.show()
Pad
类,对图片边缘进行填充只有当 padding_mode='constant'
时,参数 fill
才接收用于填充的值
from PIL import Image
from torchvision.transforms import Pad
image = Image.open('Lenna.jpg')
pad = Pad(
padding=(10, 100), # 传入整数 a 时,上下左右分别填充 a 个像素;
# 传入元组 (a, b) 时,左右填充 a 个像素,上下填充 b 个像素;
# 传入元组 (a, b, c, d) 时,左、上、右、下分别填充 a、b、c、d 个像素
fill=(255, 64, 128), # 每个通道用于填充的值,默认全为 0,黑色填充
padding_mode='constant' # constant, edge, reflect 或 symmetric
)
pad(image).show()
padding_mode='edge'
时,使用图像边界的像素值进行填充
padding_mode='reflect'
时,使用镜像进行填充,边界的像素值在填充中出现
padding_mode='symmetric'
时,使用镜像填充,边界的像素值将作为填充的第一个值
RandomErasing
类,随机对图像区域进行遮挡from PIL import Image
import numpy as np
import torch
from torchvision.transforms import RandomErasing
img_pil = Image.open('Lenna.jpg')
img_array = np.array(img_pil).transpose(2, 0, 1) # transpose (H, W, C) -> (C, H, W)
img_tensor = torch.from_numpy(img_array)
random_erasing = RandomErasing(
p=1.0, # 概率值,执行该操作的概率,默认为 0.5
scale=(0.02, 0.33), # 按均匀分布概率抽样,遮挡区域的面积 = image * scale
ratio=(0.3, 3.3), # 遮挡区域的宽高比,按均匀分布概率抽样
value='123', # 遮挡区域的像素值,(R, G, B) or (Gray);传入字符串表示用随机彩色像素填充遮挡区域
inplace=False
)
# 注意,随机遮挡是对 (c, h, w) 形状的 tensor 进行操作,一般在 ToTensor 之后进行
erased_img_tensor = random_erasing(img_tensor)
erased_img_array = erased_img_tensor.numpy().transpose(1, 2, 0) # (C, H, W) -> (H, W, C)
erased_img_pil = Image.fromarray(erased_img_array)
erased_img_pil.show()
Resize
类,将图像缩放到设定的尺寸from torchvision.transforms import Resize
from PIL import Image
image = Image.open('Lenna.jpg')
resize = Resize(
size=(299, 224), # (height, width)
interpolation=2 # 插值方法,一般保持默认就好
)
resize(image).show()
注意 Resize的小坑
CenterCrop
类,将从图像中心裁剪出需要的尺寸from torchvision.transforms import CenterCrop
from PIL import Image
image = Image.open('Lenna.jpg')
center_crop = CenterCrop(size=(224, 224))
center_crop(image).show()
RandomCrop
类,随机在图片某个位置裁剪出需要的尺寸from torchvision.transforms import RandomCrop
from PIL import Image
image = Image.open('Lenna.jpg')
random_crop = RandomCrop(
size=(224, 224), # 裁剪后图片的尺寸
padding=(75, 75), # (左右填充多少,上下填充多少)
pad_if_needed=False, # 当 size 大于原始图片的尺寸时,必须将该参数设置为 True,否则会报错
fill=(255, 0, 0), # 同 Pad 对象的参数
padding_mode='constant' # 同 Pad 对象的参数
)
random_crop(image).show()
FiveCrop
类,在图像的左上、右上、左下、右下、中心裁出指定尺寸的5张图片from PIL import Image
import torch
from torchvision.transforms import FiveCrop, Lambda, ToTensor
image = Image.open('Lenna.jpg')
five_crop = FiveCrop(size=(64, 64))
cropped_images = five_crop(image)
cropped_images[4].show()
# FiveCrop之后,往往需要将得到的图像重新拼接为 (B, C, H, W) 的格式
vision_lambda = Lambda(
lambda crops: torch.stack([(ToTensor()(crop)) for crop in crops])
)
print(vision_lambda(cropped_images).size())
# torch.Size([5, 3, 64, 64])
TenCrop
类,在FiveCrop五张图片的基础上,进行水平(或加上垂直翻转)获得额外的5张图片,共计10张图片from PIL import Image
from torchvision.transforms import TenCrop
image = Image.open('Lenna.jpg')
ten_crop = TenCrop(
size=(64, 64),
vertical_flip=True # 是否加上垂直翻转(水平翻转是默认需要的)
)
cropped_images = ten_crop(image)
cropped_images[9].show()
RandomResizedCrop
类,先对图片按随机面积比例缩放,再调整到随机的宽高比,最后随机在某个位置按需要的尺寸裁剪图片from PIL import Image
from torchvision.transforms import RandomResizedCrop
image = Image.open('Lenna.jpg')
random_resized_crop = RandomResizedCrop(
size=(224, 224), # 裁剪后图片的尺寸
scale=(0.08, 1.0), # 面积比的范围,将在此范围内按均匀分布采样
ratio=(3/4, 4/3), # 宽高比的范围,将在此范围内按均匀分布采样,50%概率交换
interpolation=2 # 插值方法,PIL.Image.NEAREST、PIL.Image.BILINEAR、PIL.Image.BICUBIC
)
random_resized_crop(image).show()
RandomRotation
类,随机角度旋转图片from PIL import Image
from torchvision.transforms import RandomRotation
image = Image.open('Lenna.jpg')
random_rotation = RandomRotation(
degrees=(-45, 45), # 传入为整数 d 时,旋转角度在 (-d, d);传入整数元组 (d1, d2)时,角度在 (d1, d2)
resample=False, # 旋转之后会要重采样,一般设置为默认就好
expand=True, # 传入 True 时,会 padding 图片,保证旋转后四个角的信息不丢失
center=(0, 0) # 以什么点为轴旋转,默认为中心,可传入 tuple 更改为任意坐标
)
# 注意,若对不止一张 image 进行处理,在传入 expand=True 时,每张图片的 size 应该要相同,否则报错
# expand=True 是针对以中心为轴旋转所需要 padding 的量的,因此当center=(0, 0),信息依旧会丢失
random_rotation(image).show()
包括水平翻转和垂直翻转
RandomHorizontalFlip
类,按概率水平(即左右)翻转图片from PIL import Image
from torchvision.transforms import RandomHorizontalFlip
image = Image.open('Lenna.jpg')
random_horizontal_flip = RandomHorizontalFlip(p=1.0) # p 为翻转的概率,默认为 0.5
random_horizontal_flip(image).show()
RandomVerticalFlip
类,按概率垂直(即上下)翻转图片from PIL import Image
from torchvision.transforms import RandomVerticalFlip
image = Image.open('Lenna.jpg')
random_vertical_flip = RandomVerticalFlip(p=1.0) # p 为翻转的概率,默认为 0.5
random_vertical_flip(image).show()
RandomGrayscale
类,按概率将图像转换为灰度图from PIL import Image
from torchvision.transforms import RandomGrayscale
image = Image.open('Lenna.jpg')
random_grayscale = RandomGrayscale(p=1.0)
random_grayscale(image).show()
Grayscale
类,将图片转换为灰度图,即 RandomGrayscale(p=0.1) 的情况from PIL import Image
from torchvision.transforms import Grayscale
image = Image.open('Lenna.jpg')
grayscale = Grayscale(
num_output_channels=3 # num_output_channels should be either 1 or 3
)
grayscale(image).show()
ColorJitter
类,调整亮度、对比度、饱和度、色相调整亮度
from PIL import Image
from torchvision.transforms import ColorJitter
image = Image.open('Lenna.jpg')
color_jitter = ColorJitter(
brightness=(1.5, 1.5), # 亮度,当传入 a 时,从 [max(0, 1-a), 1+a] 中按均匀分布抽样
# 当传入 (a, b) 时,从 [a, b] 中按均匀分布抽样
contrast=0, # 对比度,传入格式同 brightness
saturation=0, # 饱和度,传入格式同 brightness
hue=0 # 色相,当传入 a 时,从 [-a, a] 中按均匀分布抽样,注意,0 <= a <= 0.5
# 当传入 (a, b) 时,从 [a, b] 中按均匀分布抽样,注意,-0.5 <= a <= b <= 0.5
)
color_jitter(image).show()
调整对比度
from PIL import Image
from torchvision.transforms import ColorJitter
image = Image.open('Lenna.jpg')
color_jitter = ColorJitter(
brightness=0,
contrast=(1.5, 1.5),
saturation=0,
hue=0
)
color_jitter(image).show()
调整饱和度
from PIL import Image
from torchvision.transforms import ColorJitter
image = Image.open('Lenna.jpg')
color_jitter = ColorJitter(
brightness=0,
contrast=0,
saturation=(1.5, 1.5),
hue=0
)
color_jitter(image).show()
调整色相
from PIL import Image
from torchvision.transforms import ColorJitter
image = Image.open('Lenna.jpg')
color_jitter = ColorJitter(
brightness=0,
contrast=0,
saturation=0,
hue=0.5
)
color_jitter(image).show()
RandomAffine
类,对图像进行放射变换放射变换是二维的线性变换,由旋转、平移、缩放、错切、翻转五种基本变换构成
from PIL import Image
from torchvision.transforms import RandomAffine
image = Image.open('Lenna.jpg')
random_affine = RandomAffine(
degrees=45, # 旋转的角度,同 RandomRotation 中的 degrees 设置
translate=(0.5, 0.5), # 平移区间设置,传入 (a, b) ,a 设置宽,b 设置高,
# 图像在宽的维度平移的区间为 -img.width * a < dx < img_width * a
scale=None, # 缩放比例
shear=None, # 设置错切角度,有水平错切和垂直错切,若为 a,则仅在 x 轴错切,错切角度在 (-a, a);
# 若为 (a, b),则 a 设置 x 轴角度,b 设置 y 轴角度;
# 若为 (a, b, c, d),则 x 轴角度为 (a, b),y 轴角度为 (c, d)
resample=False, # PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC
fillcolor=0 # 设置填充颜色
)
# RandomAffine 中并没有翻转相关的设置
random_affine(image).show()
应用错切
from PIL import Image
from torchvision.transforms import RandomAffine
image = Image.open('Lenna.jpg')
random_affine = RandomAffine(
degrees=0,
translate=None,
scale=None,
shear=45, # 设置错切角度,有水平错切和垂直错切,若为 a,则仅在 x 轴错切,错切角度在 (-a, a);
# 若为 (a, b),则 a 设置 x 轴角度,b 设置 y 轴角度;
# 若为 (a, b, c, d),则 x 轴角度为 (a, b),y 轴角度为 (c, d)
resample=False,
fillcolor=0
)
random_affine(image).show()
LinearTransformation
类,对图像进行线性变换,可对图像数据进行白化处理ToTensor
类,将 PIL.Image 转换为 torch.Tensor 类型并归一化到 [ 0.0 , 1.0 ] [0.0, 1.0] [0.0,1.0] 之间from PIL import Image
from torchvision.transforms import ToTensor
image_pil = Image.open('Lenna.jpg')
to_tensor = ToTensor()
image_tensor = to_tensor(image_pil)
print(image_tensor)
Normalize
类,对 torch.Tensor 类型图像数据进行逐个通道的标准化(实际上调用了 torch.nn.functional.normalize
函数)from PIL import Image
from torchvision.transforms import ToTensor, Normalize
image = Image.open('Lenna.jpg')
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
normalize = Normalize(
mean=norm_mean, # 各 channel 的均值
std=norm_std, # 各 channel 的标准差
inplace=False
)
image = normalize(ToTensor()(image))
print(image.mean(), image.std())
Lambda
类,类似于 Python 中的 lambda
匿名函数
# FiveCrop之后,将得到的图像拼接成 N x C x H x W 的格式
from PIL import Image
import torch
from torchvision.transforms import FiveCrop, ToTensor, Lambda
image = Image.open('Lenna.jpg')
five_crop = FiveCrop(size=(224, 224))
cropped_images = five_crop(image) # list
vision_lambda = Lambda(
lambda crops: torch.stack([(ToTensor()(crop)) for crop in crops])
)
print(vision_lambda(cropped_images).size())
# torch.Size([5, 3, 224, 224])
Compose
类,组合一系列的 transform 操作from PIL import Image
from torchvision.transforms import CenterCrop, Resize, Compose
image = Image.open('Lenna.jpg')
compose = Compose([
CenterCrop(224),
Resize(128)
])
compose(image).show()
RandomChoice
类,从一组 transforms 中随机选择一个from PIL import Image
from torchvision.transforms import RandomHorizontalFlip, RandomVerticalFlip, RandomChoice
image = Image.open('Lenna.jpg')
# 从左右翻转和上下翻转中随机选择一个
random_choice = RandomChoice([
RandomHorizontalFlip(p=1.),
RandomVerticalFlip(p=1.)
])
random_choice(image).show()
RandomApply
类,按一定概率执行一组 transformsfrom PIL import Image
from torchvision.transforms import CenterCrop, RandomHorizontalFlip, RandomApply
image = Image.open('Lenna.jpg')
random_apply = RandomApply([
CenterCrop(224),
RandomHorizontalFlip(p=1.)
], p=.5) # 0.5 的概率执行该系列 transforms
random_apply(image).show()
RandomOrder
类,打乱一系列 transforms 的执行顺序from PIL import Image
from torchvision.transforms import RandomRotation, Pad, RandomOrder
image = Image.open('Lenna.jpg')
random_order = RandomOrder([
RandomRotation(45),
Pad(32)
]) # 有可能先旋转,亦有可能先填充
random_order(image).show()
通过重写类对象的 __call__
魔法方法实现
以为图像添加噪声为例,使用 skimage.util.random_noise
函数
import random
from PIL import Image
import skimage.io
import skimage.util
import numpy as np
class RandomNoise:
"""按概率为图像添加噪声"""
def __init__(self, modes, p=0.5):
"""
Params:
modes: list or tuple of strings
添加噪声的类型,如 'gaussian', 'localvar', 'poisson', 'salt', 'pepper', 's&p', 'speckle'
p: float
执行该操作的概率
"""
self.modes = modes
self.p = p
def __call__(self, image):
"""
Param:
image, PIL.Image
image of PIL data type
Returns:
a PIL Image
"""
if random.uniform(0, 1) < self.p: # 按概率执行该操作
img_arr = np.array(image)
for mode in self.modes:
img_arr = skimage.util.random_noise(img_arr, mode)
img_pil = Image.fromarray((img_arr * 255).astype(np.uint8))
return img_pil
else:
return image
image = Image.open('./Lenna.jpg')
# 添加一波噪音
modes = ['gaussian', 'pepper', 'speckle']
random_noise = RandomNoise(modes, p=1.)
noisy_image = random_noise(image)
noisy_image.show()
# 添加很多噪音
modes = ['gaussian', 'pepper', 'speckle'] * 10
random_noise = RandomNoise(modes, p=1.)
noisy_image = random_noise(image)
noisy_image.show()