可视化学习笔记10-pytorch cifar10批量数据预处理结果可视化

使用cifar10数据集,概率为0.5的随机遮挡批量结果可视化。

import torch as t
import numpy as np
import torchvision as tv
import matplotlib.pyplot as plt
from torchvision import transforms
from torchtoolbox.transform import Cutout

ROOT = './data/cifar-10'
BATCH_SIZE = 128

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # 每边填充4,把32^*32填充至40*40,再随机裁剪
    Cutout(0.5),  # 参数是遮挡的概率
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=train_transform)
train_load = t.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

for i, data in enumerate(train_load):
    img, label = data
    if i < 10:
        image = img[i].numpy()
        tag = label[i].numpy()
        # print(tag)
        print('img is a ', classes[tag])
        show = np.zeros((32, 32, 3))
        show[:, :, 0] = image[0, :, :]
        show[:, :, 1] = image[1, :, :]
        show[:, :, 2] = image[2, :, :]
        # print(show.shape)
        # print(img[i])
        plt.figure()
        plt.imshow(show)
        plt.show()
        print("------------------------------------")
    else:
        break


可视化学习笔记10-pytorch cifar10批量数据预处理结果可视化_第1张图片
可视化学习笔记10-pytorch cifar10批量数据预处理结果可视化_第2张图片

你可能感兴趣的:(可视化学习,pytorch,学习,深度学习)