Python-Pytorch框架-CIFAR10数据集加载与可视化

01

今天我们来一起认识一下CIFAR10数据集,使用Pytorch数据加载器进行加载,并对其可视化展示出来。

本篇文章的目标:

  1. 认识CIFAR10数据集;
  2. 能够使用Pytorch加载器加载数据集
  3. 能够可视化加载的数据集

line

01
认识CIFAR10数据集

CIFAR-10是一个更接近普适物体的彩色图像数据集。
Cifar-10 是由 Hinton 的学生 Alex Krizhevsky、Ilya Sutskever 收集的一个用于普适物体识别的计算机视觉数据集,一共包含10 个类别的RGB 彩色图片:
飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。
它包含 60000 张 32 X 32 的 RGB 彩色图片,总共 10 个分类。
其中,包括 50000 张用于训练集,10000 张用于测试集。CIFAR10数据集的内容,如图所示:
Python-Pytorch框架-CIFAR10数据集加载与可视化_第1张图片

02
加载数据集

使用torchvision.datasets模块可以加载cifar10数据集,涉及函数为torchvision.datasets.CIFAR10(root, train, download)
root: cifar10数据集存放目录

train:
True,表示加载训练数据集
False,表示加载验证数据集

download:
True,表示cifar10数据集在root指定的文件夹不存在时,会自动下载
False,表示不管root指定文件夹是否存在cifar10数据集,都不会自动下载cifar10数据集。

加载数据,记得设置download = False。如果上一步不知道该把数据集放到哪里,可以先设置为True,然后看下载位置在哪,之后替换掉。
Dataset是一个包装类,可对数据进行张量(tensor)的封装,其可作为DataLoader的参数传入,进一步实现基于tensor的数据预处理。

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)

然后需要使用数据加载器加载创建好的数据集。深度学习是由数据支撑起来的,所以我们一般在做深度学习的时候往往伴随着大量、复杂的数据。
如果把所有的数据全部加载到内存上,容易把电脑的内存“撑爆”,所以要分批次一点点加载数据。每一种深度学习的框架都有自己所规定的数据格式,数据加载器就有了必要的作用。
数据加载器就是把大量的数据,分批次加载和处理成框架所需要的数据格式数据分批次加载.

使用PyTorch内置的模块 torch.utils.data.DataLoader()数据加载器:
dataset:数据集
batch_size: 每一批数据的总量
shuffle: True or False 为True的时候会将数据打乱再分批

trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True, num_workers=2)

03
可视化数据集

通过数据加载器后即可对数据集进行查看

可视化方法一:

def show(img):
    img = img / 2 + 0.5
    npimg = img
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
images, labels = iter(trainloader).next()
show(torchvision.utils.make_grid(images))

可视化方法二:

def imshow(trainloader):
    # 可视化数据集图片
    i = 0
    for batch in trainloader:
        if (i == 0):
            images, labels = batch
            print("[INFO] {} labels classes".format(BATCHSIZE * NUM_WORKERS),
                  [classes[label] for label in (labels.cpu().numpy())])
            i += 1
        else:
            continue
    print(images.shape)
    grid = torchvision.utils.make_grid(images, nrow=4)
    plt.imshow(np.transpose(grid.cpu().numpy().astype("uint8"), (1, 2, 0)))  # 交换维度,从GBR换成RGB
    plt.show()
imshow(trainloader)

Python-Pytorch框架-CIFAR10数据集加载与可视化_第2张图片

在可视化过程之中可能会报如下错:

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

如果使用上述代码遇到上方的错误,可以忽略此报错提示,这里不会影响我们的显示。

04
总结

以上就是今天要讲的所有内容了,相信大家在使用数据集时经常遇到很多不同的问题,所以在写代码的过程中应细心一点,对于CIFAR10数据集就讲这么多,下期见。

line
end

点个关注不迷路
觉得孔哥写的对你有帮助?请分享给更多的人
欢迎一起学习!博客平台同步发布,请搜索——和孔哥一起学

dianzan

版权声明:
作者: 和孔哥一起学
导师:Fu Xianjun
本文版权归作者导师共有,欢迎转载,但未经作者同意必须在文章页面注明来源及原作者或原文链接,否则保留追究法律责任的权利。

你可能感兴趣的:(Pytorch,python,pytorch,python,深度学习)