PyTorch学习笔记(二)图像数据增强

Environment

  • OS: macOS Mojave
  • Python version: 3.7
  • PyTorch version: 1.4.0
  • IDE: PyCharm

文章目录

  • 0. 写在前面
  • 1. 基本变换类
    • 1.1 填充
    • 1.2 擦除
    • 1.3 缩放
    • 1.4 裁剪
    • 1.5 旋转
    • 1.6 翻转
    • 1.7 颜色
    • 1.8 仿射变换和线性变换
    • 1.9 归一化和标准化
    • 1.10. Lambda
  • 2. 组合变换类
  • 3. 自定义图像数据增强


0. 写在前面

本文记录了使用 PyTorch 实现图像数据预处理的方法,包括数据增强和标准化。主要的工具为 torchvision.transform 模块中的各种类,要求传入的图片为 PIL.Image 类型的图片。

对图像数据进行变换,可以增加训练样本量。数据变换的方式很多,实际应用时,原则是使手头的训练数据与真实数据接近。例如,如果训练数据和真实数据的物体颜色有差别,那么可以采用色彩变换;如果训练数据和真实数据的物体位置常常差别,那么可以平移变换…

这里以经典的 Lenna 图片为例,选择一块大小为 299 x 299 的部分。

from PIL import Image

image = Image.open('Lenna.jpg')
print(type(image))
# 

print(np.array(img).shape)
# (299, 299, 3)

image.show()

PyTorch学习笔记(二)图像数据增强_第1张图片

1. 基本变换类

1.1 填充

  • Pad 类,对图片边缘进行填充

只有当 padding_mode='constant' 时,参数 fill 才接收用于填充的值

from PIL import Image
from torchvision.transforms import Pad

image = Image.open('Lenna.jpg')

pad = Pad(
    padding=(10, 100),  # 传入整数 a 时,上下左右分别填充 a 个像素;
                        # 传入元组 (a, b) 时,左右填充 a 个像素,上下填充 b 个像素;
                        # 传入元组 (a, b, c, d) 时,左、上、右、下分别填充 a、b、c、d 个像素
    fill=(255, 64, 128),  # 每个通道用于填充的值,默认全为 0,黑色填充
    padding_mode='constant'  # constant, edge, reflect 或 symmetric
)
pad(image).show()

PyTorch学习笔记(二)图像数据增强_第2张图片

padding_mode='edge' 时,使用图像边界的像素值进行填充

PyTorch学习笔记(二)图像数据增强_第3张图片

padding_mode='reflect' 时,使用镜像进行填充,边界的像素值在填充中出现

PyTorch学习笔记(二)图像数据增强_第4张图片

padding_mode='symmetric' 时,使用镜像填充,边界的像素值将作为填充的第一个值

PyTorch学习笔记(二)图像数据增强_第5张图片

1.2 擦除

  • RandomErasing 类,随机对图像区域进行遮挡
from PIL import Image
import numpy as np
import torch
from torchvision.transforms import RandomErasing

img_pil = Image.open('Lenna.jpg')
img_array = np.array(img_pil).transpose(2, 0, 1)  # transpose (H, W, C) -> (C, H, W)
img_tensor = torch.from_numpy(img_array)

random_erasing = RandomErasing(
    p=1.0,  # 概率值,执行该操作的概率,默认为 0.5
    scale=(0.02, 0.33),  # 按均匀分布概率抽样,遮挡区域的面积 = image * scale
    ratio=(0.3, 3.3),  # 遮挡区域的宽高比,按均匀分布概率抽样
    value='123',  # 遮挡区域的像素值,(R, G, B) or (Gray);传入字符串表示用随机彩色像素填充遮挡区域
    inplace=False
)
# 注意,随机遮挡是对 (c, h, w) 形状的 tensor 进行操作,一般在 ToTensor 之后进行

erased_img_tensor = random_erasing(img_tensor)
erased_img_array = erased_img_tensor.numpy().transpose(1, 2, 0)  # (C, H, W) -> (H, W, C)
erased_img_pil = Image.fromarray(erased_img_array)
erased_img_pil.show()

PyTorch学习笔记(二)图像数据增强_第6张图片

1.3 缩放

  • Resize 类,将图像缩放到设定的尺寸
from torchvision.transforms import Resize
from PIL import Image

image = Image.open('Lenna.jpg')

resize = Resize(
    size=(299, 224),  # (height, width)
    interpolation=2  # 插值方法,一般保持默认就好
)
resize(image).show()

PyTorch学习笔记(二)图像数据增强_第7张图片

注意 Resize的小坑

1.4 裁剪

  • CenterCrop 类,将从图像中心裁剪出需要的尺寸
from torchvision.transforms import CenterCrop
from PIL import Image

image = Image.open('Lenna.jpg')

center_crop = CenterCrop(size=(224, 224))
center_crop(image).show()

PyTorch学习笔记(二)图像数据增强_第8张图片

  • RandomCrop 类,随机在图片某个位置裁剪出需要的尺寸
from torchvision.transforms import RandomCrop
from PIL import Image

image = Image.open('Lenna.jpg')

random_crop = RandomCrop(
    size=(224, 224),  # 裁剪后图片的尺寸
    padding=(75, 75),  # (左右填充多少,上下填充多少)
    pad_if_needed=False,  # 当 size 大于原始图片的尺寸时,必须将该参数设置为 True,否则会报错
    fill=(255, 0, 0),  # 同 Pad 对象的参数
    padding_mode='constant'  # 同 Pad 对象的参数
)
random_crop(image).show()

PyTorch学习笔记(二)图像数据增强_第9张图片

  • FiveCrop 类,在图像的左上、右上、左下、右下、中心裁出指定尺寸的5张图片
from PIL import Image
import torch
from torchvision.transforms import FiveCrop, Lambda, ToTensor

image = Image.open('Lenna.jpg')

five_crop = FiveCrop(size=(64, 64))
cropped_images = five_crop(image)
cropped_images[4].show()

# FiveCrop之后,往往需要将得到的图像重新拼接为 (B, C, H, W) 的格式
vision_lambda = Lambda(
    lambda crops: torch.stack([(ToTensor()(crop)) for crop in crops])
)

print(vision_lambda(cropped_images).size())
# torch.Size([5, 3, 64, 64])

在这里插入图片描述

  • TenCrop 类,在FiveCrop五张图片的基础上,进行水平(或加上垂直翻转)获得额外的5张图片,共计10张图片
from PIL import Image
from torchvision.transforms import TenCrop

image = Image.open('Lenna.jpg')

ten_crop = TenCrop(
    size=(64, 64),
    vertical_flip=True  # 是否加上垂直翻转(水平翻转是默认需要的)
)

cropped_images = ten_crop(image)
cropped_images[9].show()

在这里插入图片描述


  • RandomResizedCrop 类,先对图片按随机面积比例缩放,再调整到随机的宽高比,最后随机在某个位置按需要的尺寸裁剪图片
from PIL import Image
from torchvision.transforms import RandomResizedCrop

image = Image.open('Lenna.jpg')

random_resized_crop = RandomResizedCrop(
    size=(224, 224),  # 裁剪后图片的尺寸
    scale=(0.08, 1.0),  # 面积比的范围,将在此范围内按均匀分布采样
    ratio=(3/4, 4/3),  # 宽高比的范围,将在此范围内按均匀分布采样,50%概率交换
    interpolation=2  # 插值方法,PIL.Image.NEAREST、PIL.Image.BILINEAR、PIL.Image.BICUBIC
)
random_resized_crop(image).show()

PyTorch学习笔记(二)图像数据增强_第10张图片

1.5 旋转

  • RandomRotation 类,随机角度旋转图片
from PIL import Image
from torchvision.transforms import RandomRotation

image = Image.open('Lenna.jpg')

random_rotation = RandomRotation(
    degrees=(-45, 45),  # 传入为整数 d 时,旋转角度在 (-d, d);传入整数元组 (d1, d2)时,角度在 (d1, d2)
    resample=False,  # 旋转之后会要重采样,一般设置为默认就好
    expand=True,  # 传入 True 时,会 padding 图片,保证旋转后四个角的信息不丢失
    center=(0, 0)  # 以什么点为轴旋转,默认为中心,可传入 tuple 更改为任意坐标
)
# 注意,若对不止一张 image 进行处理,在传入 expand=True 时,每张图片的 size 应该要相同,否则报错
# expand=True 是针对以中心为轴旋转所需要 padding 的量的,因此当center=(0, 0),信息依旧会丢失

random_rotation(image).show()

PyTorch学习笔记(二)图像数据增强_第11张图片

1.6 翻转

包括水平翻转和垂直翻转

  • RandomHorizontalFlip 类,按概率水平(即左右)翻转图片
from PIL import Image
from torchvision.transforms import RandomHorizontalFlip

image = Image.open('Lenna.jpg')

random_horizontal_flip = RandomHorizontalFlip(p=1.0)  # p 为翻转的概率,默认为 0.5
random_horizontal_flip(image).show()

PyTorch学习笔记(二)图像数据增强_第12张图片

  • RandomVerticalFlip 类,按概率垂直(即上下)翻转图片
from PIL import Image
from torchvision.transforms import RandomVerticalFlip

image = Image.open('Lenna.jpg')

random_vertical_flip = RandomVerticalFlip(p=1.0) # p 为翻转的概率,默认为 0.5
random_vertical_flip(image).show()

PyTorch学习笔记(二)图像数据增强_第13张图片

1.7 颜色

  • RandomGrayscale 类,按概率将图像转换为灰度图
from PIL import Image
from torchvision.transforms import RandomGrayscale

image = Image.open('Lenna.jpg')

random_grayscale = RandomGrayscale(p=1.0)
random_grayscale(image).show()

PyTorch学习笔记(二)图像数据增强_第14张图片

  • Grayscale 类,将图片转换为灰度图,即 RandomGrayscale(p=0.1) 的情况
from PIL import Image
from torchvision.transforms import Grayscale

image = Image.open('Lenna.jpg')

grayscale = Grayscale(
    num_output_channels=3  # num_output_channels should be either 1 or 3
)
grayscale(image).show()

PyTorch学习笔记(二)图像数据增强_第15张图片

  • ColorJitter 类,调整亮度、对比度、饱和度、色相

调整亮度

from PIL import Image
from torchvision.transforms import ColorJitter

image = Image.open('Lenna.jpg')

color_jitter = ColorJitter(
    brightness=(1.5, 1.5),  # 亮度,当传入 a 时,从 [max(0, 1-a), 1+a] 中按均匀分布抽样
                            # 当传入 (a, b) 时,从 [a, b] 中按均匀分布抽样
    contrast=0,  # 对比度,传入格式同 brightness
    saturation=0,  # 饱和度,传入格式同 brightness
    hue=0  # 色相,当传入 a 时,从 [-a, a] 中按均匀分布抽样,注意,0 <= a <= 0.5
           # 当传入 (a, b) 时,从 [a, b] 中按均匀分布抽样,注意,-0.5  <= a <= b <= 0.5
)
color_jitter(image).show()

PyTorch学习笔记(二)图像数据增强_第16张图片

调整对比度

from PIL import Image
from torchvision.transforms import ColorJitter

image = Image.open('Lenna.jpg')

color_jitter = ColorJitter(
    brightness=0,
    contrast=(1.5, 1.5),
    saturation=0,
    hue=0
)
color_jitter(image).show()

PyTorch学习笔记(二)图像数据增强_第17张图片

调整饱和度

from PIL import Image
from torchvision.transforms import ColorJitter

image = Image.open('Lenna.jpg')

color_jitter = ColorJitter(
    brightness=0,
    contrast=0,
    saturation=(1.5, 1.5),
    hue=0
)
color_jitter(image).show()

PyTorch学习笔记(二)图像数据增强_第18张图片

调整色相

from PIL import Image
from torchvision.transforms import ColorJitter

image = Image.open('Lenna.jpg')

color_jitter = ColorJitter(
    brightness=0,
    contrast=0,
    saturation=0,
    hue=0.5
)
color_jitter(image).show()

PyTorch学习笔记(二)图像数据增强_第19张图片

1.8 仿射变换和线性变换

  • RandomAffine 类,对图像进行放射变换

放射变换是二维的线性变换,由旋转、平移、缩放、错切、翻转五种基本变换构成

from PIL import Image
from torchvision.transforms import RandomAffine

image = Image.open('Lenna.jpg')
random_affine = RandomAffine(
    degrees=45,  # 旋转的角度,同 RandomRotation 中的 degrees 设置
    translate=(0.5, 0.5),  # 平移区间设置,传入 (a, b) ,a 设置宽,b 设置高,
                           # 图像在宽的维度平移的区间为 -img.width * a < dx < img_width * a
    scale=None,  # 缩放比例
    shear=None,  # 设置错切角度,有水平错切和垂直错切,若为 a,则仅在 x 轴错切,错切角度在 (-a, a);
                 # 若为 (a, b),则 a 设置 x 轴角度,b 设置 y 轴角度;
                 # 若为 (a, b, c, d),则 x 轴角度为 (a, b),y 轴角度为 (c, d)
    resample=False,  # PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC
    fillcolor=0  # 设置填充颜色
)
# RandomAffine 中并没有翻转相关的设置

random_affine(image).show()

PyTorch学习笔记(二)图像数据增强_第20张图片

应用错切

from PIL import Image
from torchvision.transforms import RandomAffine

image = Image.open('Lenna.jpg')
random_affine = RandomAffine(
    degrees=0,
    translate=None,
    scale=None,
    shear=45,  # 设置错切角度,有水平错切和垂直错切,若为 a,则仅在 x 轴错切,错切角度在 (-a, a);
               # 若为 (a, b),则 a 设置 x 轴角度,b 设置 y 轴角度;
               # 若为 (a, b, c, d),则 x 轴角度为 (a, b),y 轴角度为 (c, d)
    resample=False,
    fillcolor=0
)

random_affine(image).show()

PyTorch学习笔记(二)图像数据增强_第21张图片

  • LinearTransformation 类,对图像进行线性变换,可对图像数据进行白化处理

1.9 归一化和标准化

  • ToTensor 类,将 PIL.Image 转换为 torch.Tensor 类型并归一化到 [ 0.0 , 1.0 ] [0.0, 1.0] [0.0,1.0] 之间
from PIL import Image
from torchvision.transforms import ToTensor

image_pil = Image.open('Lenna.jpg')
to_tensor = ToTensor()

image_tensor = to_tensor(image_pil)
print(image_tensor)


  • Normalize 类,对 torch.Tensor 类型图像数据进行逐个通道的标准化(实际上调用了 torch.nn.functional.normalize 函数)
from PIL import Image
from torchvision.transforms import ToTensor, Normalize

image = Image.open('Lenna.jpg')

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

normalize = Normalize(
    mean=norm_mean,  # 各 channel 的均值
    std=norm_std,  # 各 channel 的标准差
    inplace=False
)

image = normalize(ToTensor()(image))
print(image.mean(), image.std())


1.10. Lambda

Lambda 类,类似于 Python 中的 lambda 匿名函数

# FiveCrop之后,将得到的图像拼接成 N x C x H x W 的格式
from PIL import Image
import torch
from torchvision.transforms import FiveCrop, ToTensor, Lambda

image = Image.open('Lenna.jpg')

five_crop = FiveCrop(size=(224, 224))
cropped_images = five_crop(image)  # list
vision_lambda = Lambda(
    lambda crops: torch.stack([(ToTensor()(crop)) for crop in crops])
)

print(vision_lambda(cropped_images).size())
# torch.Size([5, 3, 224, 224])

2. 组合变换类

  • Compose 类,组合一系列的 transform 操作
from PIL import Image
from torchvision.transforms import CenterCrop, Resize, Compose

image = Image.open('Lenna.jpg')

compose = Compose([
    CenterCrop(224),
    Resize(128)
])

compose(image).show()

PyTorch学习笔记(二)图像数据增强_第22张图片

  • RandomChoice 类,从一组 transforms 中随机选择一个
from PIL import Image
from torchvision.transforms import RandomHorizontalFlip, RandomVerticalFlip, RandomChoice

image = Image.open('Lenna.jpg')

# 从左右翻转和上下翻转中随机选择一个
random_choice = RandomChoice([
    RandomHorizontalFlip(p=1.),
    RandomVerticalFlip(p=1.)
])

random_choice(image).show()

PyTorch学习笔记(二)图像数据增强_第23张图片

  • RandomApply 类,按一定概率执行一组 transforms
from PIL import Image
from torchvision.transforms import CenterCrop, RandomHorizontalFlip, RandomApply

image = Image.open('Lenna.jpg')

random_apply = RandomApply([
    CenterCrop(224),
    RandomHorizontalFlip(p=1.)
], p=.5)  # 0.5 的概率执行该系列 transforms

random_apply(image).show()

PyTorch学习笔记(二)图像数据增强_第24张图片

  • RandomOrder 类,打乱一系列 transforms 的执行顺序
from PIL import Image
from torchvision.transforms import RandomRotation, Pad, RandomOrder

image = Image.open('Lenna.jpg')

random_order = RandomOrder([
    RandomRotation(45),
    Pad(32)
])  # 有可能先旋转,亦有可能先填充

random_order(image).show()

PyTorch学习笔记(二)图像数据增强_第25张图片

3. 自定义图像数据增强

通过重写类对象的 __call__ 魔法方法实现

以为图像添加噪声为例,使用 skimage.util.random_noise 函数

import random
from PIL import Image
import skimage.io
import skimage.util
import numpy as np


class RandomNoise:
    """按概率为图像添加噪声"""
    def __init__(self, modes, p=0.5):
        """
        Params:
        modes: list or tuple of strings
            添加噪声的类型,如 'gaussian', 'localvar', 'poisson', 'salt', 'pepper', 's&p', 'speckle'
        p: float
            执行该操作的概率
        """
        self.modes = modes
        self.p = p

    def __call__(self, image):
        """
        Param:
        image, PIL.Image
            image of PIL data type

        Returns:
        a PIL Image
        """
        if random.uniform(0, 1) < self.p:  # 按概率执行该操作
            img_arr = np.array(image)
            for mode in self.modes:
                img_arr = skimage.util.random_noise(img_arr, mode)

            img_pil = Image.fromarray((img_arr * 255).astype(np.uint8))

            return img_pil
        else:
            return image


image = Image.open('./Lenna.jpg')

# 添加一波噪音
modes = ['gaussian', 'pepper', 'speckle']
random_noise = RandomNoise(modes, p=1.)
noisy_image = random_noise(image)
noisy_image.show()

# 添加很多噪音
modes = ['gaussian', 'pepper', 'speckle'] * 10
random_noise = RandomNoise(modes, p=1.)
noisy_image = random_noise(image)
noisy_image.show()

PyTorch学习笔记(二)图像数据增强_第26张图片

PyTorch学习笔记(二)图像数据增强_第27张图片

你可能感兴趣的:(PyTorch学习笔记)