记录一下最近看到的各类预处理,均在torchvision.transforms中:
1.翻转
# 水平随机翻转
torchvision.transforms.RandomHorizontalFlip()
# 垂直随机翻转
torchvision.transforms.RandomVertivalFlip()
2.裁剪
# 随机区域裁剪
# 在原图中随机裁剪原图10%~100%区域,且该区域长宽比介于0.5-2之间,并将区域缩放至(224,224)
torchvision.transforms.RandomReisizedCrop(size=224, scale=(0.1, 1), ratio=(0.5, 2))
# 中心区域裁剪
# 在原图中心裁剪大小为(224,224)的区域
torchvision.transforms.CenterCrop(size=224)
3.颜色
# 亮度brightness, 色调hue, 对比度contrast, 饱和度saturation
trorchvision.transforms.ColorJitter(brightness=0.5, hue=0.5, contrast=.0.5, saturation=0.5)
4.其他
缩放至固定尺寸:
torchvision.transforms.Resize(size=224)
格式转换,将小批量图像转换至pytorch需要的格式:(Batch, c, h, w),且值介于(0,1)的32位浮点数
torchvision.transforms.ToTensor()
归一化:
# 对三个通道进行归一化
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
整合所有transform操作:
train_augs = transforms.Compose([
transforms.RandomResizedCrop(size=224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(...)
])
test_augs = transforms.Compose([
transforms.Resize(size=256),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize(...)
])
5.注意:在使用预训练模型时,图像的预处理必须与预训练模型一致。常用的预训练模型保存在torchvison.models中。
6.测试:
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
image = Image.open(r'E:\...\....png')
key_word = ['Original', 'HorizontalFlip', 'VerticalFlip', 'brightness', 'contrast', 'saturation', 'hue', 'size']
trans1 = transforms.RandomHorizontalFlip()
trans2 = transforms.RandomVerticalFlip()
trans3 = transforms.ColorJitter(brightness=0.5)
trans4 = transforms.ColorJitter(contrast=1.5)
trans5 = transforms.ColorJitter(saturation=0.5)
trans6 = transforms.ColorJitter(hue=0.5)
trans7 = transforms.RandomResizedCrop(size=224)
function = [trans1, trans2, trans3, trans4, trans5, trans6, trans7]
# transform = transforms.Compose([trans1, trans2, trans3, trans4, trans5, trans6, trans7])
plt.figure(figsize=(40, 40))
for i in range(len(key_word)):
plt.subplot(4, 2, i+1)
if i == 0:
img = image
else:
img = function[i-1](image)
plt.imshow(img)
plt.title(key_word[i])
plt.show()
输出图像测试结果: