pytorch 数据预处理(填充、随机裁剪、随机遮挡、随机左右翻转)

核心代码

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),#每边填充4,把32^*32填充至40*40,再随机裁剪
    Cutout(0.5),#参数是遮挡的概率
    transforms.RandomHorizontalFlip(),#随机左右翻转
    transforms.ToTensor()#必不可少的数据转换
])

效果图

pytorch 数据预处理(填充、随机裁剪、随机遮挡、随机左右翻转)_第1张图片
填充、随机裁剪效果
pytorch 数据预处理(填充、随机裁剪、随机遮挡、随机左右翻转)_第2张图片
pytorch 数据预处理(填充、随机裁剪、随机遮挡、随机左右翻转)_第3张图片

Cutout() 的遮挡效果
pytorch 数据预处理(填充、随机裁剪、随机遮挡、随机左右翻转)_第4张图片
pytorch 数据预处理(填充、随机裁剪、随机遮挡、随机左右翻转)_第5张图片

完整代码

import torch as t
import numpy as np
import torchvision as tv
import matplotlib.pyplot as plt
from torchvision import transforms
from torchtoolbox.transform import Cutout
 
ROOT = '../pytorch/cifar-10'
BATCH_SIZE=128

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),#每边填充4,把32^*32填充至40*40,再随机裁剪
    Cutout(0.5),#参数是遮挡的概率
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=train_transform)
train_load = t.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')

for i,data in enumerate(train_load):
    img,label = data
    if i<10:
        image = img[i].numpy()
        tag = label[i].numpy()
        #print(tag)
        print('img is a ', classes[tag])
        show = np.zeros((32,32,3))
        show[:,:,0]=image[0,:,:]
        show[:,:,1]=image[1,:,:]
        show[:,:,2]=image[2,:,:]
        #print(show.shape)
        #print(img[i])
        plt.figure()
        plt.imshow(show)
        plt.show()
        print("------------------------------------")
    else:
        break

你可能感兴趣的:(机器学习,深度学习,python,pytorch,数据预处理,数据增强)