PyTorch之torchvision.transforms

参考:【pytorch】图像基本操作

PIL
  • transforms.ToTensor()
from torchvision import transforms

transform_tensor = transforms.Compose([
    # PIL.Image/numpy.ndarray -> Tensor, including range [0, 255] -> [0.0,1.0]
    transforms.ToTensor(),
)

注意:transforms.ToTensor() 内含归一至 [0.0,1.0] 操作

from PIL import Image

# preproccess
img_path = "snorlax.png"
img = Image.open(img_path).convert('RGB')  # 读取图像
img_tensor = transform_tensor(img)  # 归一化到 [0.0,1.0]
print(img_tensor.shape)  # torch.Size([3, 959, 959])

# convert to PILImage, and show
img_PIL = transforms.ToPILImage()(img_tensor).convert('RGB')
print(img_PIL.size, img_PIL.mode)  # (959, 959) RGB
img_PIL.show()
  • transforms.Normalize()
from torchvision import transforms

# transforms.ToTensor()
transform_norm = transforms.Compose([
    # PIL.Image/numpy.ndarray -> Tensor, including range [0, 255] -> [0.0,1.0]
    transforms.ToTensor(),
    # range [0.0,1.0] -> [-1.0, 1.0]  # channel=(channel-mean)/std
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ]
)

将 [0.0,1.0] 归一至 [-1.0, 1.0] ,图像色域会出现偏差 ⤵️

  • transforms.RandomCrop()
from torchvision import transforms

transform_rc = transforms.Compose([
    transforms.ToTensor(),
    transforms.ToPILImage(),
    transforms.RandomCrop((450, 450)),
    ]
)
from PIL import Image

img_path = "snorlax.png"
img = Image.open(img_path).convert('RGB')
img_PIL = transform_rc(img)
img_PIL.show()
numpy.ndarray
from torchvision import transforms
import cv2
import numpy as np

img = cv2.imread(img_path)  # 读取图像  # img:  (959, 959, 3)  H*W*C
img_tensor = transform_tensor(img)  # 归一化到 [0.0,1.0]
print(img_tensor.shape)  # torch.Size([3, 959, 959])  C*H*W

# conert to numpy.ndarray, and show
img_arr = img_tensor.numpy()*255
img_arr = img_arr.astype('uint8')
img_arr = np.transpose(img_arr, (1, 2, 0))
print(img_arr.shape)  # (959, 959, 3)  H*W*C
cv2.imshow('img_arr', img_arr)
cv2.waitKey()

你可能感兴趣的:(PyTorch之torchvision.transforms)