PyTorch深度学习入门笔记(四)torchvision 中的DataSet使用

torchvision

官网:https://pytorch.org/Python官网.
PyTorch深度学习入门笔记(四)torchvision 中的DataSet使用_第1张图片
点进 torchvision:
PyTorch深度学习入门笔记(四)torchvision 中的DataSet使用_第2张图片

  1. torchvision 文档列出了很多科研或者毕设常用的一些数据集,如入门数据集MNIST,用于手写文字。这些数据集位于 torchvision.datasets模块,可以通过该模块对数据集进行下载,转换等操作。
  2. torchvision还有 io模块,但不常用
  3. torchvision.models会提供一些训练好的神经网络模型,在之后会用到。
  4. torchvision.transforms之前已经学习过了,主要提供一些数据处理的工具。
    接下来主要讲解如何联合使用 torchvision.datasets 和 torchvision.transforms

CIFAR数据集

PyTorch深度学习入门笔记(四)torchvision 中的DataSet使用_第3张图片

具有一些参数,如root,设置数据集所在目录路径等。
数据集官网:https://www.cs.toronto.edu/~kriz/cifar.html
PyTorch深度学习入门笔记(四)torchvision 中的DataSet使用_第4张图片
这个数据集包含了 60000张32*32像素的10个类别的彩色图片,每个种类6000张图片,其中60000张中50000张是训练图片,10000张是测试图片。

2.1下载数据集

import torchvision

train_set = torchvision.datasets.CIFAR10(root="../dataset", train=True, download=True)
test_set  = torchvision.datasets.CIFAR10(root="../dataset", train=False, download=True)

2.2图片都转为ToTensor

# 每张图片都转为ToTensor类型
dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

train_set = torchvision.datasets.CIFAR10(root="../dataset", train=True, transform=dataset_transform, download=True) # train参数表示是训练集还是测试集
test_set  = torchvision.datasets.CIFAR10(root="../dataset", train=False, transform=dataset_transform, download=True)

2.3一些测试

# print(test_set.classes) # ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# img, target = test_set[0]
# print(img) # 
# print(target) # 3
# print(test_set.classes[target]) # cat
# img.show()

# print(test_set[0])

2.4tensorboard进行查看

由于数据集的图片类型是PIL Image,torch无法直接使用,所以要先转为tensor,通过 transforms 实现。
这里用transforms将图片转为 tensor类型后,用tensorboard进行查看:

writer = SummaryWriter("../logs/P14")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)

writer.close()

(pytorch) E:\CodeCodeCodeCode\AI\Pytorch-study>tensorboard --logdir="logs/P14"
TensorFlow installation not found - running with reduced feature set.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.7.0 at http://localhost:6006/ (Press CTRL+C to quit)

PyTorch深度学习入门笔记(四)torchvision 中的DataSet使用_第5张图片

2.5其他数据集的使用

使用其他数据集的方法也很简单,设置好所需参数就可以了:
PyTorch深度学习入门笔记(四)torchvision 中的DataSet使用_第6张图片
以COCO为例,一样的,设置存储路径,json文件存储路径,transform,target_transform,transforms 就可以了。
注:如果下载速度比较慢,可以通过将下载路径复制到迅雷中进行数据集下载,下载好后设置好存放目录的路径就行。
在这里插入图片描述
在这里插入图片描述

你可能感兴趣的:(pytorch,pytorch,深度学习,人工智能,python)