PyTorch学习笔记-Transform

1. Transform的概念与基本用法

transforms 在计算机视觉工具包 torchvision 下,包含了很多种对图像数据进行变换的类,这些都是在我们进行图像数据读入步骤中必不可少的。

transforms 主要使用的类为:transforms.ToTensor,该类能够将 PIL Image 或者 ndarray 转换为 tensor,并且归一化至[0-1]。注意归一化至[0-1]是直接除以255,若自己的 ndarray 数据尺度有变化,则需要自行修改。

为什么需要 tensor 数据类型?因为它是一个包装了反向传播神经网络所需要的一些基础的参数,因此在神经网络中需要将图片类型转换为 tensor 类型进行训练。

例如:

from PIL import Image
from torchvision import transforms
import cv2

img_path = 'dataset/hymenoptera_data/train/ants_image/0013035.jpg'
img_PIL = Image.open(img_path)  # 

tensor_trans = transforms.ToTensor()  # 创建 ToTensor 的实例对象
img_tensor1 = tensor_trans(img_PIL)  # 将 PIL Image 转换成 tensor

print(type(img_tensor1))  # 

img_cv = cv2.imread(img_path)  # 
img_tensor2 = tensor_trans(img_cv)  # 将 OpenCV Image 转换成 tensor
print(type(img_tensor2))

2. Transform的常用类

  • transforms.Compose:Compose 能够将多种变换组合在一起。例如下面的代码可以先将 PIL Image 中心裁切,然后再转换成 tensor:
img_path = 'dataset/hymenoptera_data/train/ants_image/0013035.jpg'
img_PIL = Image.open(img_path)

trans = transforms.Compose([
    transforms.CenterCrop(100),
    transforms.ToTensor()
])

img_trans = trans(img_PIL)
  • transforms.CenterCrop:需要传入参数 size,表示以 (size, size) 的大小从中心裁剪,参数也可以为 (height, width)。例如:
img_PIL.show()

trans_centercrop = transforms.CenterCrop((100, 150))
img_centercrop = trans_centercrop(img_PIL)
img_centercrop.show()
  • transforms.RandomCrop:需要传入参数 size,表示以 (size, size) 的大小随机裁剪,参数也可以为 (height, width)
  • transforms.Normalize(mean, std):对数据按通道进行标准化,即先减均值 mean,再除以标准差 std,注意是 HWC 格式,处理公式为:output[channel] = (input[channel] - mean[channel]) / std[channel],例如:
trans_tensor = transforms.ToTensor()
img_tensor = trans_tensor(img_PIL)

# 如果 input 的范围是[0, 1],那么用该参数归一化后的范围就变为[-1, 1]
trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
img_norm = trans_norm(img_tensor)
print(img_norm)
  • transforms.Resize:需要传入参数 (height, width) interpolation,表示重置图像的分辨率为 (h, w),也可以传入一个整数 size,这样会将较短的那条边缩放至 size,另一条边按原图大小等比例缩放。interpolation 为插值方法选择,默认为 PIL.Image.BILINEAR,例如:
trans_tensor = transforms.ToTensor()
img_tensor = trans_tensor(img_PIL)

print(img_tensor.size())  # torch.Size([3, 512, 768]),tensor 图像使用 size() 获取大小,PIL 图像使用 size

trans_resize = transforms.Resize((256, 300))
img_resize = trans_resize(img_tensor)
print(img_resize.size())  # torch.Size([3, 256, 300]),修改比例

trans_resize = transforms.Resize(30)
img_resize = trans_resize(img_tensor)
print(img_resize.size())  # torch.Size([3, 30, 45]),与原图等比例
  • transforms.ToPILImage::将 tensor 或者 ndarray 的数据转换为 PIL Image 类型数据,参数 mode 默认为 None,表示1通道, mode=3 表示3通道,默认转换为 RGB,4通道默认转换为 RGBA。

你可能感兴趣的:(Artificial,Intelligence,pytorch,学习,计算机视觉,深度学习,opencv)