The torchvision package consists of popular datasets, model architectures, and common image transformations for computer vision.
torchvision包含了很多通用的数据集,模型架构,以及图像转换方法,配合pytorch使用更好搭建训练模型。
图像转换和数据增强方法介绍
为了方便展示,首先定义一个画图函数
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as T
plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('assets') / '24colormap.jpg')
print(np.array(orig_img).shape)
# if you change the seed, make sure that the randomly-applied transforms
# properly show that the image can be both transformed and *not* transformed!
torch.manual_seed(0)
def plot(imgs):
num = len(imgs)
if num > 2:
num_rows = 2
else:
num_rows = 1
num_cols = (num + 1) // num_rows
# print('row, col:', num_rows, num_cols)
fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
i = 0
for row_idx in range(num_rows):
for col_idx in range(num_cols):
ax = axs[row_idx, col_idx]
ax.imshow(np.asarray(imgs[i]))
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
print(i, np.asarray(imgs[i]).shape)
i += 1
if i == num:
break
plt.tight_layout()
plt.imshow(orig_img)
原图如下:
resized_imgs = [T.Resize(size=size)(orig_img) for size in ((30,30), 50, 100, orig_img.size[::-1])]
plot(resized_imgs)
center_crops = [T.CenterCrop(size=size)(orig_img) for size in (30, 50, 200, orig_img.size[::-1])]
plot(center_crops)
rotater = T.RandomRotation(degrees=(0, 180))
rotated_imgs = [rotater(orig_img) for _ in range(4)]
plot(rotated_imgs)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ldQXW6Xs-1666165747915)(20221019144813.png)]
affine_transfomer = T.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))
affine_imgs = [affine_transfomer(orig_img) for _ in range(4)]
plot(affine_imgs)
perspective_transformer = T.RandomPerspective(distortion_scale=0.6, p=1.0)
perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)]
plot(perspective_imgs)
cropper = T.RandomCrop(size=(128, 128))
crops = [cropper(orig_img) for _ in range(4)]
plot(crops)
结果展示:
resize_cropper = T.RandomResizedCrop(size=(32, 32))
resized_crops = [resize_cropper(orig_img) for _ in range(4)]
plot(resized_crops)
perspective_transformer = T.RandomPerspective(distortion_scale=0.6, p=1.0)
perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)]
plot(perspective_imgs)
数据变换
torchvision.transforms.LinearTransformation(transformation_matrix,mean_vector)
torchvision.transforms.Normalize(mean,std,inplace=False)
格式转换
torchvision.transforms.ToTensor
图像翻转
torchvision.transforms.RandomHorizontalFlip(p=0.5)
torchvision.transforms.RandomVerticalFlip(p=0.5)
更多数据增强方法,请参看
[1]https://pytorch.org/vision/0.13/transforms.html
[2]https://zhuanlan.zhihu.com/p/519919904
transforms.Compose() 用于整合一系列的图像变换函数,将图片按照 Compose() 中的顺序依次处理。torch.nn.Sequential() 与 transforms.Compose() 起到相同的功能。torch.nn.Sequential() 可以和 torch.jit.script() 结合来导出模型。
#Compose
transform1 = transforms.Compose([
transforms.CenterCrop(10),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
#Sequential
transform2 = torch.nn.Sequential(
transforms.CenterCrop(10),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
scripted_transforms = torch.jit.script(transforms)
关于torchscript 可以参看文章 和 官网, 官网jit
除了以上的转换方法,一般都有对应的函数进行数据增强
比如:
import torchvision.transforms.functional as TF
TF.adjust_brightness(orig_img, 0.2)
TF.adjust_contrast(orig_img, 0.6)
TF.adjust_hue(orig_img, -0.4)
TF.adjust_saturation(orig_img, 0)
TF.adjust_sharpness(orig_img, 2)
TF.affine(orig_img, angle=0,translate=[150,150],scale=1, shear=0)
TF.crop(orig_img, 300, 300, 500, 600)
TF.erase(orig_img, 100, 200, 800, 600,0)
TF.gaussian_blur(orig_img, 21, 5)
TF.resize(orig_img, [400,800])
TF.rotate(orig_img, 60)
TF.vflip(orig_img)
TF.hflip(orig_img)
TF.crop(orig_img, 300, 300, 500, 600)
TF.erase(orig_img, 100, 200, 800, 600,0)
TF.gaussian_blur(orig_img, 21, 5)
TF.resize(orig_img, [400,800])
TF.rotate(orig_img, 60)
TF.vflip(orig_img)
TF.hflip(orig_img)
转换函数,可以处理多张图像
import torchvision.transforms.functional as TF
import random
def my_segmentation_transforms(image, segmentation):
if random.random() > 0.5:
angle = random.randint(-30, 30)
image = TF.rotate(image, angle)
segmentation = TF.rotate(segmentation, angle)
# more transforms ...
return image, segmentation
转换类
import torchvision.transforms.functional as TF
import random
class MyRotationTransform:
"""Rotate by one of the given angles."""
def __init__(self, angles):
self.angles = angles
def __call__(self, x):
angle = random.choice(self.angles)
return TF.rotate(x, angle)
rotation_transform = MyRotationTransform(angles=[-30, -15, 0, 15, 30])
https://pytorch.org/vision/0.13/transforms.html#functional-transforms