torch_vision(一):数据增强和转换模块torchvision.transforms

文章目录

  • torchvision.transforms 学习笔记
    • 1. torchvision介绍
    • 2. TRANSFORMING AND AUGMENTING IMAGES
    • 3. Resize transform
    • 4. CenterCrop transform 中心裁剪
    • 5. RandomRotation 随机旋转
    • 6.RandomAffine 随机仿射变换
    • 7.RandomPerspective 随机透视变换
    • 8.RandomCrop 随机crop固定尺寸
    • 9.RandomResizedCrop 随机crop之后,再 resize到固定尺寸
    • 10.RandomPerspective 随机透视变换
    • 11. 其他常见操作
    • 12. transforms.Compose() 和 torch.nn.Sequential()
    • 13. 函数转换
    • 14. 利用函数定义转换
    • 15. 参考

torchvision.transforms 学习笔记

1. torchvision介绍

The torchvision package consists of popular datasets, model architectures, and common image transformations for computer vision.

torchvision包含了很多通用的数据集,模型架构,以及图像转换方法,配合pytorch使用更好搭建训练模型。

2. TRANSFORMING AND AUGMENTING IMAGES

图像转换和数据增强方法介绍

  1. 它们可以使用Compose链接在一起。
  2. 大多数转换类都有一个等效的函数:函数转换提供对转换的细粒度控制。
  3. 大多数变换同时接受PIL图像和张量图像,尽管有些变换只接受PIL图像,有些则只接受张量图像。
  4. 可以通过transform模块用于tensor与PIL图像之间的转换。

3. Resize transform

为了方便展示,首先定义一个画图函数

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision.transforms as T


plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('assets') / '24colormap.jpg')
print(np.array(orig_img).shape)
# if you change the seed, make sure that the randomly-applied transforms
# properly show that the image can be both transformed and *not* transformed!
torch.manual_seed(0)

def plot(imgs):
    num = len(imgs)
    if num > 2:
        num_rows = 2
    else:
        num_rows = 1

    
    num_cols = (num + 1) // num_rows
    # print('row, col:', num_rows, num_cols)
    fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    i = 0
    for row_idx in range(num_rows):
        for col_idx in range(num_cols):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(imgs[i]))
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
            print(i, np.asarray(imgs[i]).shape)
            i += 1
            if i == num:
                break
            
    plt.tight_layout()

plt.imshow(orig_img)

原图如下:

torch_vision(一):数据增强和转换模块torchvision.transforms_第1张图片

resized_imgs = [T.Resize(size=size)(orig_img) for size in ((30,30), 50, 100, orig_img.size[::-1])]
plot(resized_imgs)

结果展示
torch_vision(一):数据增强和转换模块torchvision.transforms_第2张图片

4. CenterCrop transform 中心裁剪

center_crops = [T.CenterCrop(size=size)(orig_img) for size in (30, 50, 200, orig_img.size[::-1])]
plot(center_crops)

结果展示
torch_vision(一):数据增强和转换模块torchvision.transforms_第3张图片

5. RandomRotation 随机旋转

rotater = T.RandomRotation(degrees=(0, 180))
rotated_imgs = [rotater(orig_img) for _ in range(4)]
plot(rotated_imgs)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ldQXW6Xs-1666165747915)(20221019144813.png)] torch_vision(一):数据增强和转换模块torchvision.transforms_第4张图片

6.RandomAffine 随机仿射变换

affine_transfomer = T.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))
affine_imgs = [affine_transfomer(orig_img) for _ in range(4)]
plot(affine_imgs)

结果展示:
torch_vision(一):数据增强和转换模块torchvision.transforms_第5张图片

7.RandomPerspective 随机透视变换

perspective_transformer = T.RandomPerspective(distortion_scale=0.6, p=1.0)
perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)]
plot(perspective_imgs)

结果展示:
torch_vision(一):数据增强和转换模块torchvision.transforms_第6张图片

8.RandomCrop 随机crop固定尺寸

cropper = T.RandomCrop(size=(128, 128))
crops = [cropper(orig_img) for _ in range(4)]
plot(crops)

结果展示:

torch_vision(一):数据增强和转换模块torchvision.transforms_第7张图片

9.RandomResizedCrop 随机crop之后,再 resize到固定尺寸

resize_cropper = T.RandomResizedCrop(size=(32, 32))
resized_crops = [resize_cropper(orig_img) for _ in range(4)]
plot(resized_crops)

结果展示:
torch_vision(一):数据增强和转换模块torchvision.transforms_第8张图片

10.RandomPerspective 随机透视变换

perspective_transformer = T.RandomPerspective(distortion_scale=0.6, p=1.0)
perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)]
plot(perspective_imgs)

结果展示:
torch_vision(一):数据增强和转换模块torchvision.transforms_第9张图片

11. 其他常见操作

数据变换

  1. 线性变换
torchvision.transforms.LinearTransformation(transformation_matrix,mean_vector)
  1. 标准化:减去均值,除以标准差
torchvision.transforms.Normalize(mean,std,inplace=False)

格式转换

  1. 最常用的就是 pil image 或者 np.ndarray 转换为tensor
torchvision.transforms.ToTensor

图像翻转

  1. 随机水平和随机垂直翻转
torchvision.transforms.RandomHorizontalFlip(p=0.5)
torchvision.transforms.RandomVerticalFlip(p=0.5)

更多数据增强方法,请参看
[1]https://pytorch.org/vision/0.13/transforms.html
[2]https://zhuanlan.zhihu.com/p/519919904

12. transforms.Compose() 和 torch.nn.Sequential()

transforms.Compose() 用于整合一系列的图像变换函数,将图片按照 Compose() 中的顺序依次处理。torch.nn.Sequential() 与 transforms.Compose() 起到相同的功能。torch.nn.Sequential() 可以和 torch.jit.script() 结合来导出模型。

#Compose
transform1 = transforms.Compose([
	transforms.CenterCrop(10),
	transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

#Sequential
transform2 = torch.nn.Sequential(
	transforms.CenterCrop(10),
	transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
scripted_transforms = torch.jit.script(transforms)

关于torchscript 可以参看文章 和 官网, 官网jit

13. 函数转换

除了以上的转换方法,一般都有对应的函数进行数据增强
比如:

import torchvision.transforms.functional as TF

TF.adjust_brightness(orig_img, 0.2)
TF.adjust_contrast(orig_img, 0.6)
TF.adjust_hue(orig_img, -0.4)
TF.adjust_saturation(orig_img, 0)
TF.adjust_sharpness(orig_img, 2)
TF.affine(orig_img, angle=0,translate=[150,150],scale=1, shear=0)
TF.crop(orig_img, 300, 300, 500, 600)
TF.erase(orig_img, 100, 200, 800, 600,0)
TF.gaussian_blur(orig_img, 21, 5)
TF.resize(orig_img, [400,800])
TF.rotate(orig_img, 60)
TF.vflip(orig_img)
TF.hflip(orig_img)
TF.crop(orig_img, 300, 300, 500, 600)
TF.erase(orig_img, 100, 200, 800, 600,0)
TF.gaussian_blur(orig_img, 21, 5)
TF.resize(orig_img, [400,800])
TF.rotate(orig_img, 60)
TF.vflip(orig_img)
TF.hflip(orig_img)

14. 利用函数定义转换

转换函数,可以处理多张图像

import torchvision.transforms.functional as TF
import random

def my_segmentation_transforms(image, segmentation):
    if random.random() > 0.5:
        angle = random.randint(-30, 30)
        image = TF.rotate(image, angle)
        segmentation = TF.rotate(segmentation, angle)
    # more transforms ...
    return image, segmentation

转换类

import torchvision.transforms.functional as TF
import random

class MyRotationTransform:
    """Rotate by one of the given angles."""

    def __init__(self, angles):
        self.angles = angles

    def __call__(self, x):
        angle = random.choice(self.angles)
        return TF.rotate(x, angle)

rotation_transform = MyRotationTransform(angles=[-30, -15, 0, 15, 30])

15. 参考

https://pytorch.org/vision/0.13/transforms.html#functional-transforms

你可能感兴趣的:(图像处理算法,python,深度学习,人工智能)