Transforms是pytorch的图像处理工具包,是torchvision模块下的一个一个类的集合,可以对图像或数据进行格式变换,裁剪,缩放,旋转等,在进行深度学习项目时用途很广泛。下面对Transforms内的常见类的使用进行一个简单的梳理。
from torchvision import transforms
表示将其他图像数据(PIL Image或者 ndarray)类型转化为tensor类型,并归一化至[0-1] 。
trans = transforms.ToTensor()
img_tensor = trans(img)
tensor([[[0.3137, 0.3137, 0.3137, ..., 0.3176, 0.3098, 0.2980],
[0.3176, 0.3176, 0.3176, ..., 0.3176, 0.3098, 0.2980],
[0.3216, 0.3216, 0.3216, ..., 0.3137, 0.3098, 0.3020],
...,
[0.3412, 0.3412, 0.3373, ..., 0.1725, 0.3725, 0.3529],
[0.3412, 0.3412, 0.3373, ..., 0.3294, 0.3529, 0.3294],
[0.3412, 0.3412, 0.3373, ..., 0.3098, 0.3059, 0.3294]],
[[0.5922, 0.5922, 0.5922, ..., 0.5961, 0.5882, 0.5765],
[0.5961, 0.5961, 0.5961, ..., 0.5961, 0.5882, 0.5765],
[0.6000, 0.6000, 0.6000, ..., 0.5922, 0.5882, 0.5804],
...,
[0.6275, 0.6275, 0.6235, ..., 0.3608, 0.6196, 0.6157],
[0.6275, 0.6275, 0.6235, ..., 0.5765, 0.6275, 0.5961],
[0.6275, 0.6275, 0.6235, ..., 0.6275, 0.6235, 0.6314]],
[[0.9137, 0.9137, 0.9137, ..., 0.9176, 0.9098, 0.8980],
[0.9176, 0.9176, 0.9176, ..., 0.9176, 0.9098, 0.8980],
[0.9216, 0.9216, 0.9216, ..., 0.9137, 0.9098, 0.9020],
...,
[0.9294, 0.9294, 0.9255, ..., 0.5529, 0.9216, 0.8941],
[0.9294, 0.9294, 0.9255, ..., 0.8863, 1.0000, 0.9137],
[0.9294, 0.9294, 0.9255, ..., 0.9490, 0.9804, 0.9137]]])
表示用平均值和标准偏差归一化图像。
按照官方文档公式: input[channel] = (input[channel] - mean[channel]) / std[channel];mean: 平均值;std: 标准差。
trans_norm = transforms.Normalize([1, 3, 5], [9, 2, 1])
img_norm = trans_norm(img_tensor)
将PIL图像数据大小变换为指定大小,其定义为:
class Resize(torch.nn.Module):
一般参数的输入有两种方式:
1,指定长宽
2,将图片短边缩放至x
举个例子
trans_resize = transforms.Resize((200, 200))
img_resize = trans_resize(img)
可以将多个transform方法组合,按组合顺序处理数据。
trans_random = transforms.RandomCrop((400, 400))
trans_compose_2 = transforms.Compose([trans_random, tensor_trans])
将图像转换为灰度图。
trans_grayscale = transforms.Grayscale(3)
img_grayscale = trans_grayscale(img_tensor)