Pytorch框架学习记录4——数据集的使用(torchvision.dataset)

Pytorch框架学习记录4——数据集的使用(torchvision.dataset)

1. 数据集

在pytorch官网中我们可以看到pytorch自身所配有的数据集的情况,以及该数据集的类型、使用方法等。在这里,我们选择数据集较小的CIFAR10作为我们的示例数据集。

该数据集的调用和使用使用代码如下:

torchvision.datasets.CIFAR10(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)

参数说明:

  • root ( string ) – 数据集的根目录, cifar-10-batches-py如果下载设置为 True,则该目录存在或将保存到该目录。
  • train ( bool , optional ) – 如果为真,则从训练集创建数据集,否则从测试集创建。
  • transform ( callable , optional ) – 一个函数/转换,它接受 PIL 图像并返回转换后的版本。例如,transforms.RandomCrop
  • target_transform ( callable , optional ) – 接收目标并对其进行转换的函数/转换。
  • download ( bool , optional ) – 如果为 true,则从 Internet 下载数据集并将其放在根目录中。如果数据集已经下载,则不会再次下载。

2. 使用实例

下载CIFAR10数据集后,将其类型转换为tensor类型,并在tensorboard中进行展示。

import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import transforms

dataset_transform = transforms.Compose([
    transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=dataset_transform, download=True)

writer = SummaryWriter('logs')

for i in range(10):
    img, label = train_set[i]
    writer.add_image('train10', img, i)

writer.close()

此外,还可以直接通过链接使用浏览器下载,下载完毕后,在当前目录下也命名一个dataset文件夹并放入,上述代码不做任何改变,会自动将手动下载的数据集进行解压和修正。

Pytorch框架学习记录4——数据集的使用(torchvision.dataset)_第1张图片

你可能感兴趣的:(Pytorch学习笔记,pytorch,学习,深度学习)