pytorch图像预处理

记录一下最近看到的各类预处理,均在torchvision.transforms中:

1.翻转

# 水平随机翻转
torchvision.transforms.RandomHorizontalFlip()

# 垂直随机翻转
torchvision.transforms.RandomVertivalFlip()

2.裁剪

# 随机区域裁剪
# 在原图中随机裁剪原图10%~100%区域,且该区域长宽比介于0.5-2之间,并将区域缩放至(224,224)
torchvision.transforms.RandomReisizedCrop(size=224, scale=(0.1, 1), ratio=(0.5, 2))

# 中心区域裁剪
# 在原图中心裁剪大小为(224,224)的区域
torchvision.transforms.CenterCrop(size=224)

3.颜色

# 亮度brightness, 色调hue, 对比度contrast, 饱和度saturation
trorchvision.transforms.ColorJitter(brightness=0.5, hue=0.5, contrast=.0.5, saturation=0.5)

4.其他

缩放至固定尺寸:

torchvision.transforms.Resize(size=224)

格式转换,将小批量图像转换至pytorch需要的格式:(Batch, c, h, w),且值介于(0,1)的32位浮点数

torchvision.transforms.ToTensor()

归一化:

# 对三个通道进行归一化
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

整合所有transform操作:

train_augs = transforms.Compose([
                                 transforms.RandomResizedCrop(size=224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize(...)
                                 ])

test_augs = transforms.Compose([
                                 transforms.Resize(size=256),
                                 transforms.CenterCrop(size=224),
                                 transforms.ToTensor(),
                                 transforms.Normalize(...)
 ])

5.注意:在使用预训练模型时,图像的预处理必须与预训练模型一致。常用的预训练模型保存在torchvison.models中。

6.测试:

import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image

image = Image.open(r'E:\...\....png')

key_word = ['Original', 'HorizontalFlip', 'VerticalFlip', 'brightness', 'contrast', 'saturation', 'hue', 'size']

trans1 = transforms.RandomHorizontalFlip()
trans2 = transforms.RandomVerticalFlip()
trans3 = transforms.ColorJitter(brightness=0.5)
trans4 = transforms.ColorJitter(contrast=1.5)
trans5 = transforms.ColorJitter(saturation=0.5)
trans6 = transforms.ColorJitter(hue=0.5)
trans7 = transforms.RandomResizedCrop(size=224)

function = [trans1, trans2, trans3, trans4, trans5, trans6, trans7]
# transform = transforms.Compose([trans1, trans2, trans3, trans4, trans5, trans6, trans7])

plt.figure(figsize=(40, 40))

for i in range(len(key_word)):
    plt.subplot(4, 2, i+1)
    if i == 0:
        img = image
    else:
        img = function[i-1](image)
    plt.imshow(img)
    plt.title(key_word[i])

plt.show()

输出图像测试结果:

pytorch图像预处理_第1张图片

你可能感兴趣的:(pytorch图像预处理)