pytorch —— transforms图像增强(一)

1、数据增强(data augmentation)

数据增强又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力。

在中学阶段就已经接触过数据增强的概念,看一个例子,高中的五年高考三年模拟,假设学生是一个模型,五年高考真题是一个训练集,当年高考题是一个验证集,用来验证学习模型的学习能力和效果。对于这个例子怎么做数据增强呢?就是对历年的高考题的知识点进行分析和提炼,设计出三年的模拟试题用来给学生进行学习。当做了很多模拟试题的时候,学生的学习能力自然得到了提高,从而在高考的时候分数得到提高,这就是数据增强的一个概念。

如果模拟题的某一些试题恰恰出现在当年高考题当中,这就可以直接提高学生的成绩,即使没有完完整整的题出现在高考中,只要有类似的题型出现在高考题当中,这样也可以提高学生的成绩,这就是数据增强。

看一下图片中的数据增强是怎么样的。下图是一张原始图片,对这张图片进行一系列的操作变换得到64张增强样本。64张图片中的第一张图片是对原始图片进行旋转,第二张图片是对原始图片进行颜色变换,第三张图片是进行镜像操作。对图片进行一系列操作可以得到大量增强样本提供给模型进行训练,让模型见过更多的样本,从而提升模型的泛化能力,使得模型在验证集上的表现更好。下面开始学习具体的数据增强方法。

2、transforms——裁剪(crop)

2.1 transforms.CenterCrop

  • 功能:从图像中心裁剪图片;
  • size:所需裁剪图片尺寸;

看一个例子,如下左图为一张 224 ∗ 224 224*224 224224的图片,对图片进行 196 ∗ 196 196*196 196196的centercrop,图片从中心点开始计算,左右宽196,上下高196的一个裁剪区域,就得到到右边的图片。下面从代码中学习centercrop。

transforms方法的演示还是采用人民币二分类训练的主代码,这里我们只关心数据模块以及训练模块中取出数据那一部分,看一下代码的结构。

import os
import numpy as np
import torch
import random
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tools.my_dataset import RMBDataset
from PIL import Image
from matplotlib import pyplot as plt


def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


set_seed(1)  # 设置随机种子

# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 1
LR = 0.01
log_interval = 10
val_interval = 1
rmb_label = {"1": 0, "100": 1}


def transform_invert(img_, transform_train):
    """
    将data 进行反transfrom操作
    :param img_: tensor
    :param transform_train: torchvision.transforms
    :return: PIL image
    """
    if 'Normalize' in str(transform_train):
        norm_transform = list(filter(lambda x: isinstance(x, transforms.Normalize), transform_train.transforms))
        mean = torch.tensor(norm_transform[0].mean, dtype=img_.dtype, device=img_.device)
        std = torch.tensor(norm_transform[0].std, dtype=img_.dtype, device=img_.device)
        img_.mul_(std[:, None, None]).add_(mean[:, None, None])

    img_ = img_.transpose(0, 2).transpose(0, 1)  # C*H*W --> H*W*C
    img_ = np.array(img_) * 255

    if img_.shape[2] == 3:
        img_ = Image.fromarray(img_.astype('uint8')).convert('RGB')
    elif img_.shape[2] == 1:
        img_ = Image.fromarray(img_.astype('uint8').squeeze())
    else:
        raise Exception("Invalid img shape, expected 1 or 3 in axis 2, but got {}!".format(img_.shape[2]) )

    return img_


# ============================ step 1/5 数据 ============================
split_dir = os.path.join("E:/pytorch/rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]


train_transform = transforms.Compose([
    transforms.Resize((224, 224)),

    # 1 CenterCrop
    # transforms.CenterCrop(512),     # 512

    # 2 RandomCrop
    # transforms.RandomCrop(224, padding=16),
    # transforms.RandomCrop(224, padding=(16, 64)),
    # transforms.RandomCrop(224, padding=16, fill=(255, 0, 0)),
    # transforms.RandomCrop(512, pad_if_needed=True),   # pad_if_needed=True
    # transforms.RandomCrop(224, padding=64, padding_mode='edge'),
    # transforms.RandomCrop(224, padding=64, padding_mode='reflect'),
    # transforms.RandomCrop(1024, padding=1024, padding_mode='symmetric'),

    # 3 RandomResizedCrop
    # transforms.RandomResizedCrop(size=224, scale=(0.5, 0.5)),

    # 4 FiveCrop
    # transforms.FiveCrop(112),
    # transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])),

    # 5 TenCrop
    # transforms.TenCrop(112, vertical_flip=False),
    # transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])),

    # 1 Horizontal Flip
    # transforms.RandomHorizontalFlip(p=1),

    # 2 Vertical Flip
    # transforms.RandomVerticalFlip(p=0.5),

    # 3 RandomRotation
    # transforms.RandomRotation(90),
    # transforms.RandomRotation((90), expand=True),
    # transforms.RandomRotation(30, center=(0, 0)),
    # transforms.RandomRotation(30, center=(0, 0), expand=True),   # expand only for center rotation

    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std)
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)


# ============================ step 5/5 训练 ============================
for epoch in range(MAX_EPOCH):
    for i, data in enumerate(train_loader):

        inputs, labels = data   # B C H W

        img_tensor = inputs[0, ...]     # C H W
        img = transform_invert(img_tensor, train_transform)
        plt.imshow(img)
        plt.show()
        plt.pause(0.5)
        plt.close()

        # bs, ncrops, c, h, w = inputs.shape
        # for n in range(ncrops):
        #     img_tensor = inputs[0, n, ...]  # C H W
        #     img = transform_invert(img_tensor, train_transform)
        #     plt.imshow(img)
        #     plt.show()
        #     plt.pause(1)

上面代码第五部分“训练”中有一个函数transform_invert(),这个函数是用来对transform进行逆操作,使得我们可以观察到模型输入的数据是长什么样的。因为数据经过transfrom,转换为张量的形式,可能是一些浮点的数据,没有办法将这些数据进行可视化,因此需要一个transform_invert()函数,对transform进行逆操作,将张量的数据变换成img,这样就可以进行可视化。

现在看一下transform_invert()函数中有什么操作。按ctrl键,鼠标左键点击该函数名就可以跳转到函数实现位置。

def transform_invert(img_, transform_train):
    """
    将data 进行反transfrom操作
    :param img_: tensor
    :param transform_train: torchvision.transforms
    :return: PIL image
    """

可以看到,这个函数接受一个img_和transform_train,返回PIL image,也就是可以直接plot将其格式化。

if 'Normalize' in str(transform_train):
    norm_transform = list(filter(lambda x: isinstance(x, transforms.Normalize), transform_train.transforms))
    mean = torch.tensor(norm_transform[0].mean, dtype=img_.dtype, device=img_.device)
    std = torch.tensor(norm_transform[0].std, dtype=img_.dtype, device=img_.device)
    img_.mul_(std[:, None, None]).add_(mean[:, None, None])

在这个函数中,对normalize进行反操作,normalize是减去均值除于方差,因此反操作就是乘于方差再加上均值。

img_ = img_.transpose(0, 2).transpose(0, 1)  # C*H*W --> H*W*C
img_ = np.array(img_) * 255

之后需要对通道进行变换,采用transpose,将通道的 C ∗ H ∗ W C*H*W CHW格式转换为 H ∗ W ∗ C H*W*C HWC,也就是将channel放到最后面,然后将0-1尺度上的数据转换到0-255。

if img_.shape[2] == 3:
    img_ = Image.fromarray(img_.astype('uint8')).convert('RGB')
elif img_.shape[2] == 1:
    img_ = Image.fromarray(img_.astype('uint8').squeeze())
else:
    raise Exception("Invalid img shape, expected 1 or 3 in axis 2, but got {}!".format(img_.shape[2]) )

return img_

最后是将np_array的形式转换成PIL image,这里的代码会针对channel是3通道还是1通道,分别转换成“RGB”彩色图像和灰度图像,最后返回图像就可以对图像进行plot,对图像进行可视化。

下面看代码中的transform.CenterCrop()函数,经过裁剪之后图像会变成什么样。首先在第五部分训练中设置断点,观察input是什么样的数据形式,如下图所示:
pytorch —— transforms图像增强(一)_第1张图片
在transforms中,为了统一图片的尺寸,一开始会执行transforms.Resize((224,224)),把图片统一地缩放到 224 ∗ 224 224*224 224224的尺寸大小,然后执行transforms.CenterCrop(196)操作,裁剪出来一个196大小的图片。
pytorch —— transforms图像增强(一)_第2张图片
对程序进行debug,代码停在之前打断点的位置,如下图所示。观察一下代码中data的形式。
pytorch —— transforms图像增强(一)_第3张图片
将断点取消,点击step over功能键,到达代码img_tensor = inputs[0, …] 位置,点击console就会打开一个命令窗,如下图所示,这个命令窗的环境与当前代码调试的环境是完全一致的,可以在这个命令窗对变量进行更改或者查看。
pytorch —— transforms图像增强(一)_第4张图片
现在查看inputs的形状,如下图,inputs的形状是一个[1,3,196,196]的形式。第一个维度是size,因为在代码开始设置了BATCH_SIZE=1,所以inputs中的第一个维度为1,代表BITCH_SIZE;第二个维度是channel,也就是通道,由于是rgb图像,通道的长度为3;第三维和第四位分别是图像的高和宽。
pytorch —— transforms图像增强(一)_第5张图片
由于可视化图片是一个三通道的三维张量,所以需要对inputs进行操作,进行索引,索引出第一块区域,也就是接下来的一句代码“img_tensor = inputs[0, …]”,这段代码的意思是取四维张量中的第一个三维张量,这样就把四维张量变为三维张量了,其顺序为 C ∗ H ∗ W C*H*W CHW。将得到的三维张量img输入到函数transform_invert()函数中进行逆变换,就返回可以可视化的img,然后将img进行plt操作,得到裁剪图片如下所示:
pytorch —— transforms图像增强(一)_第6张图片
这个图片就是 196 ∗ 196 196*196 196196尺寸大小的图片,由于代码中transforms.CenterCrop设定的size是196,小于transforms.Resize((224,224))的尺寸。假如设定的size为大于(224,224)的,那么代码是否能够正确执行?下面观察一下,把代码中的196改为512,代码如下所示:
pytorch —— transforms图像增强(一)_第7张图片
修改代码之后,执行debug操作,代码并没有报错,输出图片为(512,512)大小的图片,对超出224的区域会自动填充为零的像素,也就是全黑的区域,如下所示:
pytorch —— transforms图像增强(一)_第8张图片

2.2 transforms.RandomCrop

  • 功能:从图片中随机裁剪出尺寸为size的图片(位置随机裁剪);
  • size:所需裁剪图片尺寸;
  • padding:设置填充大小(有三种模式);
    • 当为a时,上下左右均填充a个像素;
    • 当为(a,b)时,上下填充b个像素,左右填充a个像素;
    • 当为(a,b,c,d)时,左,上,右,下分别填充a,b,c,d;
  • pad_if_need:若图像小于设定size,则填充 ;
  • padding_mode:填充模式,有4种模式;
    • 1、constant:像素值由fill设定;
    • 2、edge:像素值由图像边缘像素决定;
    • 3、reflect:镜像填充,最后一个像素不镜像,eg:[1,2,3,4]->[3,2,1,2,3,4,3,2,](由于最后一个像素不镜像,所以跳过1和4,分别从2和3开始进行镜像填充);
    • 4、symmetric:镜像填充,最后一个像素镜像, eg:[1,2,3,4]->[2,1,1,2,3,4,4,3](最后一个像素镜像,所以不会跳过1和4,分别从1和4开始进行镜像填充);
  • fill:constant时,设置填充的像素值;
transforms.RandomCrop(size,
                      padding=None,
                      pad_if_needed=False,
                      fill=0,
                      padding_mode='constant')

下面通过代码观察RandomCrop是怎样对图像进行裁剪的。和前面一样,对图像进行统一的尺寸变换,缩放为(224,224)。

第一步,对上下左右均进行16像素的padding,图片如下所示,裁剪出来的图片左边和上边都有一块黑色的填充区域。为什么右边和下边没有呢?这是因为经过填充之后的图片的尺寸应该是224+16+16,比224大32个像素。在这个大的图片上进行(224,224)的随机选取,由于图像选取左上角的这一部分,所以右边和下边是没有黑色的填充区域的。
pytorch —— transforms图像增强(一)_第9张图片
padding的第二种模型,分别对左右、上下设置不同的填充,其图片如下,可以看到左右的填充区域相比于上下是更小的。
pytorch —— transforms图像增强(一)_第10张图片
可以看到填充的区域都是黑色,默认填充的像素是0,如果想设置的填充区域是红色,或者是其它的彩色图,就可以对fill这个参数进行设置,代码中对fill设置一个长度为3的tuple,3个元素分别对应的是rgb通道,设定第一个红色通道为255,其它两个通道为0,可以看一下其padding出来的颜色是红色的,如下所示。当然也可以设定其它自定义的颜色,这就是fill参数的使用。
pytorch —— transforms图像增强(一)_第11张图片
接下来看一下pad_if_needed参数,当size大于图片尺寸的时候,pad_if_needed参数必须打开,否则会报错。可以看到在超出图片的范围全部填充上像素值为0的像素点,也就是黑色的。
pytorch —— transforms图像增强(一)_第12张图片
观察参数padding_mode的几种模式,padding_mode默认采用constant模式,在采用constant的时候,采用fill参数去设置填充的像素点的像素值。接下里看padding_mode的第二种模式,padding_mode=‘edge’,这种模式是采用图片的边界值对图片进行填充,设置padding的值大一点,padding=64,以便于更好地观察填充的效果,其图片如下所示:
pytorch —— transforms图像增强(一)_第13张图片
从上面这个图片可以看到,填充的区域是左边和上边,左边的每一个像素值,都是用图片的最边缘的像素点进行填充,上边也是。可以看一下下一张图片的效果,打开软件的debug功能区,使用run to sursor功能将代码运行到断点位置,代码如下图所示:
pytorch —— transforms图像增强(一)_第14张图片
点击功能键三次,得到三张不同的暑促图片,得到的输出图片如下所示:
pytorch —— transforms图像增强(一)_第15张图片
pytorch —— transforms图像增强(一)_第16张图片
pytorch —— transforms图像增强(一)_第17张图片
从图片中可以看到,填充区域都是采用边缘像素点的值进行填充的,这是padding_mode='edge’模型的作用。

接着看一下镜像模式,镜像模式像一个印钞机,其输出图片如下所示:
pytorch —— transforms图像增强(一)_第18张图片
从图片可以看出,padding_mode='reflect’就是对图片进行镜像操作,填充区域是对原始图片的边缘区域进行镜像。padding_mode='symmetric’和padding_mode='reflect’功能相差不多,只是相差一个像素值点。

把代码修改一下,RandomCrop()函数的参数size=1024,padding=1024,观察更大区域上的镜像。
pytorch —— transforms图像增强(一)_第19张图片
以上就是RandomCrop()函数的使用简介。

2.3 transforms.RandomResizedCrop

  • 功能:随机大小、长宽比裁剪图片;
  • size:所需裁剪图片尺寸;
  • scale:随机裁剪面积比例,默认(0.08,1)
  • ratio:随机长宽比,默认(3/4,4/3)
  • interpolation:插值方法(裁剪出来的图片尺寸可能小于size,所以需要进行插值处理,插值方法有三种)
    • PIL.Image.NEAREST
    • PIL.Image.BILINEAR
    • PIL.Image.BICUBIC
RandomResizedCrop(size,
                  scale=(0.08,1.0),
                  ratio=(3/4,4/3),
                  interpolation)

通过代码理解RandomResizedCrop()函数的操作,首先设置代码为

transforms.RandomResizedCrop(size=224, scale=(0.08, 0.1))

pytorch —— transforms图像增强(一)_第20张图片
输出结果如上图所示,所得图片比原始图片小得多,这个比例是在(0.08,1)之间随机选取得到的,选取得到一个比例之后,再根据ratio长宽比设定图像的长和宽,裁剪得到一个图片。裁剪得到图片之后,再resize到设定的size大小尺寸。

修改代码如下所示,意思是采取一半的面积,然后再进行长宽比的缩放,得到图片如下所示。

 transforms.RandomResizedCrop(size=224, scale=(0.5, 0.5)),

pytorch —— transforms图像增强(一)_第21张图片
这个图片保持了原始图片的50%的面积,可以根据需求设置scale参数值。

2.4 FiveCrop

  • 功能:在图像的上下左右以及中心裁剪出尺寸为size的5张图片;
transforms.FiveCrop(size)

2.5 TenCrop

  • 功能:在图像的上下左右以及中心裁剪出尺寸为size的5张图片,在这五张图片上进行水平或者垂直镜像获得10张图片;
  • size:所需裁剪图片尺寸大小;
  • vertical_flip:是否垂直翻转;
transforms.TenCrop(size,
                   vertical_flip=False)

下面通过代码学习这两个函数。看一下代码

transforms.FiveCrop(112),

由于FiveCrop()裁剪出来的是五张图片,返回的是一个tuple(元组),当尝试运行代码时,会报错,报错信息如下所示:
pytorch —— transforms图像增强(一)_第22张图片
报错为:pic should be PIL Image or ndarray. Got 。意思是pic这个参数应该是一个PIL Image或者是ndarray的,但是却得到了一个tuple。所以直接使用是不行的,需要对FiveCrop返回的tuple进行一定的操作,将tuple变换为张量的形式或者是PIL Image的形式。这里使用到Lambda方法,Lambda是匿名函数,可以对FiveCrop()的输出进行一系列的变换,使其输出可以变换为代码可以执行的数据格式。看一下lambda匿名函数的功能:

transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])),

代码中冒号之前的是函数的输入,冒号之后的整个语句是函数的返回值。由于输入是一个tuple格式的数据,需要将tuple中每一张图片,将其拼接为张量的形式,所以代码中采用了torch.stack()的形式,在讲常量的操作的时候,stack是对张量在某一维度上进行拼接,这里采用默认维度,也就是第0个维度。stack()函数中传入的是一个list,代码中采用了python的列表解析式,列表生成器。它的功能是对参数crops进行for循环,每一次提取出一个元素crop,每一次对这个元素crops进行一些操作得到列表的元素。

crops是FiveCrop()函数输出的一个tuple,然后对tuple的每一个元素进行for循环,每一次取出一个crop,也就是一张图片,对每一张图片进行一个ToTensor()的操作,将其转换为张量的形式,将其变为list的一个元素。通过不断的循环,把五张图片都转为张量的形式,然后得到一个长度为5的list,把这个list放到stack()当中,stack()就把这个长度为5的list拼接成一个张量。这样,通过lambda(),就把tuple转为张量的形式,这样就可以输入到模型中。

点击运行之后还是会报错,报错如下:
pytorch —— transforms图像增强(一)_第23张图片
由于图片的维度和代码不匹配,不能用原始方法可视化。因为得到的input不再是一个四维的张量,是一个五维的张量。这个五维张量的各个维度分别为batchs,ncrops,c,h,w,通过下面这个新的代码对每个crop进行可视化。

bs, ncrops, c, h, w = inputs.shape
for n in range(ncrops):
    img_tensor = inputs[0, n, ...]  # C H W
    img = transform_invert(img_tensor, train_transform)
    plt.imshow(img)
    plt.show()
    plt.pause(1)

设置断点,调试代码,打开命令输入窗,单击运行,得到一张图片的五维表示,代码要在五维张量中获取每一张图片,每一张图片应该是一个3维的张量,对ncrops进行循环,分别将五张图片进行可视化。看一下img_tensor的形状,通过命令输入窗,可以看到img_shape的形状为(3,112,112),可以直接进行可视化。
pytorch —— transforms图像增强(一)_第24张图片
其输出五张照片如下所示:
pytorch —— transforms图像增强(一)_第25张图片
pytorch —— transforms图像增强(一)_第26张图片
pytorch —— transforms图像增强(一)_第27张图片
pytorch —— transforms图像增强(一)_第28张图片
pytorch —— transforms图像增强(一)_第29张图片
下面看一下TenCrop()函数的使用方法,它是在FiveCrop()函数的基础上进行翻转得到的十张图片。设置vertical_flip=True,也就是进行垂直的翻转

transforms.TenCrop(112, vertical_flip=True),
    transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])),

3、transforms——翻转和旋转(flip and rotation)

3.1 transforms——Flip

3.1.1 RandomHorizontalFlip(水平)

3.1.2 RandomVerticalFlip(垂直)

  • 功能:依概率水平(左右)或垂直(上下)翻转图片
  • P:翻转概率(即有多大的概率将图片进行翻转)
RandomHorizontalFlip(p=0.5)
RandomVerticalFlip(p=0.5)

3.2 transforms——Rotation

3.3 RandomRotation

  • 功能:随机旋转图片;
  • degrees:旋转角度;
    • 当为a时,在(-a,a)之间选择旋转角度;
    • 当为(a,b)时,在(a,b)之间选择旋转角度;
  • resample:重采样方法;
  • expand:是否扩大图片,以保持原图信息;
  • center:旋转点设置,默认中心旋转;
RandomRotation(degrees,
                resample=False,
                expand=False,
                center=None

当使用expand扩大图片时,因为每张图片旋转的角度不同,最后得到的图片的大小是不一样的,最后拼接的时候可能出现报错的问题,所以在使用expand的时候,需要注意对图片进行缩放,将所有照片缩放到统一的尺寸。

你可能感兴趣的:(pytorch)