官方数据集的下载方法

目录

PyTorch官网介绍

下载数据集

Transform变换操作


PyTorch官网介绍

首先进入PyTorch官网,

其界面如下:

官方数据集的下载方法_第1张图片

点击torchvision进入,我这里以0.9.0的版本为例,若要改变版本号,则点击左上角的版本号选择想要的版本即可。里面常用模块包含torchvision.datasets、torchvision,models和torchvision.transforms,这里说明一下models模块,该模块表示已经训练好的一部分神经网络模型,可直接调用,其结构如下:

官方数据集的下载方法_第2张图片

  • Classification:分类模型
  • Semantic Segmentation:语义分割模型
  • Object Detection:目标检测模型
  • Video classification:视频分类模型

下载数据集

本次数据集选用  作为训练样本,该数据集用于图像分类,里面包含各种事物的图像。通过官方文档查看数据集调用方法如下:

官方数据集的下载方法_第3张图片

  • root:数据集存放路径
  • train:如果布尔值为True,创建训练集;若为False,创建测试集
  • transform:对数据集中的数据进行变换
  • target_transform:对于指定目标数据进行变换
  • download:如果布尔值为True,自动下载该数据集

在pycharm中运行下面代码会自动下载 CIFAR 数据集并在该.py文件目录下新建dataset文件夹用于存放训练集和测试集。若下载速度过慢,可复制链接到迅雷进行下载,下载完毕后,在.py文件目录下手动建立一个同名的文件夹,将压缩包复制到该文件夹下,再次运行该程序会自动将压缩包解压并存放到当前目录下。

当想要下载的数据集运行后没有显示出链接时,可通过Ctrl+点击鼠标左键进入其.py文件,向上拉动即可找到该数据集的下载链接。

train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False,  download=True)

Transform变换操作

下载的数据集为PIL.Image类型,而需要传入进行训练的图像应为tensor格式,因此需要使用transform遍历数据集使图像类型转换为tensor格式后才能进行相应的图像预处理工作。转换后我们使用Tensorboard看下显示的结果,能正常显示则说明转换成功。Tensorboard 的使用方法:PyTorch深度学习快速入门

相关代码如下:

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

train_set = torchvision.datasets.CIFAR10(root="./dataset",transform = dataset_transform, train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)

print(test_set[0])
print(test_set.classes)
img, target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])
img.show()
print(test_set[0])
writer = SummaryWriter("p10")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)

writer.close()

显示结果如下:通过拉动 step 可以查看各阶段时期的图像

官方数据集的下载方法_第4张图片

你可能感兴趣的:(PyTorch深度学习快速入门,pytorch,深度学习,python)