pytorch初学笔记(五):torchvision中数据集的使用

目录

一、torchvision介绍

1. 作用与结构

2. torchvision中常用数据集

二、CIFAR10的介绍

1.  数据集简介

2. 使用该数据集的所需参数 

3. 数据集下载

3.1 pycharm在线下载(下载速度较快时) 

3.2 第三方下载

3.3 数据库的下载总结 

三、 CIFAR10的具体使用

1. 数据集对象的显示(PIL型)

2. 把数据集中的图片对象转换为tensor型

2.1 转换所需transform的定义

2.2 使用tensorboard进行图片显示

四、练习:MNIST数据集的下载和使用

1. 可能的报错和修改 

2. 代码实现

2.1 PIL对象实现

2.2 tensor对象实现

3. 运行结果 


一、torchvision介绍

1. 作用与结构

torchvision — Torchvision main documentation

torchvision是pytorch下的一个包,主要由计算机视觉中的流行数据集、模型体系结构和常见图像转换等模块组成。

 常用的包:

  • Transforming and augmenting images:进行图片变换等。
  • Models and pre-trained weights:提供一些预训练好的神经网络或权重参数等。
  • Dataset :提供常用的数据集。

pytorch初学笔记(五):torchvision中数据集的使用_第1张图片

2. torchvision中常用数据集

Datasets — Torchvision main documentation

 Datasets模块提供了需要常用的数据集以及其具体的使用方法,比如下图所示的图像分类中常用的CIFAR10数据集,图像检测中常用的COCO数据集等。

pytorch初学笔记(五):torchvision中数据集的使用_第2张图片

 pytorch初学笔记(五):torchvision中数据集的使用_第3张图片

 下面具体说明如何对CIFAR10进行下载和使用。

二、CIFAR10的介绍

1.  数据集简介

CIFAR-10 and CIFAR-100 datasets (toronto.edu)

pytorch初学笔记(五):torchvision中数据集的使用_第4张图片

  •     CIFAR-10是一个更接近普适物体的彩色图像的小型数据集
  •     一共包含10 个类别的RGB 彩色图片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。
  •      每个图片的尺寸为32 × 32 ,每个类别有6000个图像,数据集中一共有50000 张训练图片和10000 张测试图片。

2. 使用该数据集的所需参数 

CIFAR10 — Torchvision main documentation

pytorch初学笔记(五):torchvision中数据集的使用_第5张图片

需要设定的5个参数:

1.   root(字符串型):把数据集下载到的位置路径。

2.   train(布尔型):是否把该数据集作为训练数据集使用。

  • True: 作为训练数据集创建
  • False:不作为训练数据集,作为测试数据集创建

3.   transform:图像需要进行的变换操作,一般使用compose把所需的transforms结合起来。

4.   target_transform:对于标签需要做的变换

5.   download(布尔型):是否下载数据集。

  • True:把数据集下载到root指定的对应位置;如果数据集以及进行过下载,则不会再一次下载
  • False:不下载数据集

3. 数据集下载

3.1 pycharm在线下载(下载速度较快时) 

    1. 导入torchvision包,然后依次创建训练数据集和测试数据集。

注意:训练数据集的train参数要设置为True,测试数据集的train设置为False

import torchvision
#创建训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset3",train=True,download=True)
#创建测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset3",train=False,download=True)

    2. 点击运行,等待一段时间后显示下载成功 pytorch初学笔记(五):torchvision中数据集的使用_第6张图片

    3. 观察项目包目录,可以发现自动创建了名为dataset3的文件夹,下载的解压文件和解压好的数据集都在其中。

pytorch初学笔记(五):torchvision中数据集的使用_第7张图片

3.2 第三方下载

    如果在pycharm中下载速度很慢的话,可以找到pycharm所用的下载链接,然后自己使用迅雷等下载软件进行快速下载。

如何找到下载链接?

  1. 把鼠标移动到想要下载的数据集名称上,然后Ctrl+C,进入该数据集的帮助文档。 

pytorch初学笔记(五):torchvision中数据集的使用_第8张图片

      2. 可以看到对应的下载文件名和下载链接。 

pytorch初学笔记(五):torchvision中数据集的使用_第9张图片

    3. 使用迅雷或者浏览器下载,然后把下载过后的压缩文件按照root中定义的路径创建文件夹,然后把文件放入文件夹中,注意,自己创建的文件夹一定要和root中定义的文件夹姓名相同才行,否则后期扫描不到该数据集

    4. 运行上面在线下载中定义的语句,可以发现程序不会再次下载数据集文件,而是会帮你解压好数据集。

3.3 数据库的下载总结 

无论是否需要在线下载数据集,都推荐把download参数值设为True。

因为程序可以帮你自动完成下载解压工作,就算自己下载过文件,也可以提供解压功能,因此更加方便。

三、 CIFAR10的具体使用

1. 数据集对象的显示(PIL型)

import torchvision
#创建训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset3",train=True,download=True)
#创建测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset3",train=False,download=True)

#1. 查看数据集的图片
#输出所有类别
print(test_set.classes)
#输出数据集第一张图片的类型
print(test_set[0])
#输出图片的PIL型格式和标签
img,label = test_set[0]
print(label,test_set.classes[label])
img.show()

        1.  数据集所有类别的查看

        图片有十个类,对应的类别名称存储在dataset.classes列表中。

        2. 数据集中单个具体对象的查看

        想要输出数据集中具体的某一张图片,使用下标调用方式dataset[x]即可显示第x+1张图片;输出的对象类型为一个元组,里面第一项是PIL类型的图片,第二项是图片的标签。

        3. 数据集中图片对象和标签的定义

        可以使用  img,label = dataset[x] 的方式接收对象中的图片和label,然后可以用print进行对label的输出,也可以用 dataset. classes[label]的格式进行对该类别名称的显示。

        4. 数据集中图片的可视化

        使用img.show()方法进行图片的可视化显示。

输出结果如下: 

pytorch初学笔记(五):torchvision中数据集的使用_第10张图片

        打开的对应图片如下图所示,由于数据集中的图片较小,所以不清晰,但是可以看出来是一只小猫的图片。 

pytorch初学笔记(五):torchvision中数据集的使用_第11张图片

2. 把数据集中的图片对象转换为tensor型

2.1 转换所需transform的定义

        因为需要完成数据集中所有图片类型从PIL到tensor的转换,我们需要用到transforms工具,也需要设定数据集中的transform参数。

        我们在数据集定义的语句之前定义我们需要的transform,由于一般需要对图像做的变换不止一个,所以我们使用compose来对多个transforms进行组合在这里我们只需要一个ToTensor即可。

        下面代码给出使用compose定义transform和不使用compose的两个版本,都可以完成成功运行。

  •  使用compose:
import torchvision
#定义transforms
dataset_transform = torchvision.transforms.Compose([
    #定义totensor
    torchvision.transforms.ToTensor()
])
#创建训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset3",train=True,transform=dataset_transform,download=True)
#创建测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset3",train=False,transform=dataset_transform,download=True)

  • 不使用compose:
import torchvision
#定义transforms
from torch.utils.tensorboard import SummaryWriter

trans_totensor_tool = torchvision.transforms.ToTensor()
#创建训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset3",train=True,transform=trans_totensor_tool,download=True)
#创建测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset3",train=False,transform=trans_totensor_tool,download=True)
、

2.2 使用tensorboard进行图片显示

        完成了transform和数据集的定义后,即可使用add_image()方法完成图片显示。在这里我们使用for循环进行10张图片的显示。

import torchvision
#定义transforms
from torch.utils.tensorboard import SummaryWriter

trans_totensor_tool = torchvision.transforms.ToTensor()
#创建训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset3",train=True,transform=trans_totensor_tool,download=True)
#创建测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset3",train=False,transform=trans_totensor_tool,download=True)

#使用tensorboard进行显示
writer = SummaryWriter("logs")
#for循环完成10张图片的显示
for i in range(10):
    img,label=test_set[i]
    writer.add_image("dataset",img,i)

writer.close()

        结果如下所示。可以看到一共step=9,成功显示了数据集中第1-10张图片。

pytorch初学笔记(五):torchvision中数据集的使用_第12张图片          pytorch初学笔记(五):torchvision中数据集的使用_第13张图片

 pytorch初学笔记(五):torchvision中数据集的使用_第14张图片         pytorch初学笔记(五):torchvision中数据集的使用_第15张图片

四、练习:MNIST数据集的下载和使用

1. 可能的报错和修改 

        使用上面做过的练习对MNIST数据集进行相同的操作,注意在下载数据集后可能会爆“UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors.” 的错误,按照博文的方法修改即可。 

(4条消息) Pytorch | 报错The given NumPy array is not writeable,and PyTorch does not support non-writeable tensor_软耳朵DONG的博客-CSDN博客

2. 代码实现

对于PIL对象:

  • 完成数据集所有类别的输出(classes)
  • 输出数据集中的第一个对象
  • 完成前10张图片对应类别的输出
  • 完成第10张图片的显示(show方法)

对于tensor对象:

  • 把数据集中所有图片类型从PIL型转换为tensor型,重定义图片大小为10*10(使用Compose,ToTensor和Resize)
  • 输出前10张图片

2.1 PIL对象实现

import torchvision
from torch.utils.tensorboard import SummaryWriter


train_set = torchvision.datasets.MNIST(root="./MNIST_test",train=True,download=True)
test_set = torchvision.datasets.MNIST(root="./MNIST_test",train=False,download=True)

#pil型对象显示
print(test_set.classes)
print(test_set[0])
for i in range(10):
    img,label=test_set[i]
    print(test_set.classes[label])
img.show()

2.2 tensor对象实现

import torchvision
from torch.utils.tensorboard import SummaryWriter

trans_tool = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((10,10))
])

train_set = torchvision.datasets.MNIST(root="./MNIST_test",train=True,transform=trans_tool,download=True)
test_set = torchvision.datasets.MNIST(root="./MNIST_test",train=False,transform=trans_tool,download=True)

#tensor型对象显示
writer = SummaryWriter("logs")
for i in range(10):
    img,label=test_set[i]
    writer.add_image("MNIST",img,i)
print(img.shape)
writer.close()

3. 运行结果 

 数据集下载并创建成功:

pytorch初学笔记(五):torchvision中数据集的使用_第16张图片 

 显示第10张图片:

pytorch初学笔记(五):torchvision中数据集的使用_第17张图片

 print的显示结果:pytorch初学笔记(五):torchvision中数据集的使用_第18张图片

在未改变大小之前的维度是(1,28,28),resize后可见tensor的维度变成了(1,10,10 )

,pytorch初学笔记(五):torchvision中数据集的使用_第19张图片

 tensoeboard显示结果: 

pytorch初学笔记(五):torchvision中数据集的使用_第20张图片

你可能感兴趣的:(pytorch,人工智能,python,transformer,计算机视觉)