PyTorch深度学习笔记(十)torchvision中的数据集使用

课程学习笔记,课程链接

目录

一、torchvision

二、CIFAR数据集

下载数据集

数据集的使用

transforms的使用


目的:

如何把数据集和 Transforms 结合在一起;

介绍科研中使用的一些标准数据集和下载、查看、使用方法

一、torchvision

PyTorch官网

PyTorch深度学习笔记(十)torchvision中的数据集使用_第1张图片

 点进 torchvision

PyTorch深度学习笔记(十)torchvision中的数据集使用_第2张图片

torchvision 文档列出了很多科研或者毕设常用的一些数据集,如入门数据集 MNIST,用于手写文字。这些数据集位于 torchvision.datasets 模块,可以通过该模块对数据集进行下载,转换等操作。

torchvision 还有 io模块,但不常用;torchvision.models 会提供一些训练好的神经网络模型,在之后会用到;torchvision.transforms 之前已经学习过了,主要提供一些数据处理的工具。

接下来主要讲解如何联合使用 torchvision.datasets 和 torchvision.transforms

二、CIFAR数据集

用到的数据集是 CIFIAR,点击官网文档进行查看。

PyTorch深度学习笔记(十)torchvision中的数据集使用_第3张图片

数据集官网,该数据集包含了 60000 张 32*32 像素的 10 个类别的彩色图片,每个种类 6000 张图片,其中 60000 张中 50000 张是训练图片,10000 张是测试图片。

下载路径,非代码下载

下载数据集

import torchvision
'''
若有 ssl 报错,添加如下代码
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
'''
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)

然后运行,torchvision 就会自动进行 CIFAR10 数据集的下载。

这里分别下载训练集和测试集,下载好后会放到所设置的路径下,这里下载的数据集会被放带当前目录的 dataset目录下。

数据集的使用

查看下测试数据集中每个数据包含什么:

import torchvision
import ssl
​
ssl._create_default_https_context = ssl._create_unverified_context
​
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
print(test_set[0])
print(test_set.classes)
​
img, target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])
img.show()

PyTorch深度学习笔记(十)torchvision中的数据集使用_第4张图片

 输出:

PyTorch深度学习笔记(十)torchvision中的数据集使用_第5张图片

 即一个数据单元里包含输入图片和对应的 tag,这里用数字进行映射,数字 3 也就是表示 cat ,可用 img.show() 查看下图片。因为这个数据集比较小,只有百 MB,图片像素只有 32*32,所以模糊,这里是将其分类为猫。

transforms的使用

由于数据集的图片类型是 PIL Image,torch 无法直接使用,所以要先转为 tensor,通过 transforms 实现。这里用 transforms 将图片转为 tensor 类型后,并用 tensorboard 进行查看。

import torchvision
import ssl
from torch.utils.tensorboard import SummaryWriter
​
ssl._create_default_https_context = ssl._create_unverified_context
​
# 目的将设置成 ToTensor 的 transforms 应用到数据集的每张图片,都转成 tensor 数据类型
dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
​
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
print(test_set[0])
​
writer = SummaryWriter("logs")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)
writer.close()

结果,这里查看了索引为 [0]~[9] 的十张图片

PyTorch深度学习笔记(十)torchvision中的数据集使用_第6张图片

你可能感兴趣的:(PyTorch,pytorch,深度学习,python)