Pytorch深度学习------torchvision中dataset数据集的使用(CIFAR10)

文章目录

      • 一、什么是TorchVision
      • 二、以torchvision.datasets子模块下的CIFAR10数据集为例
        • 1、CIFAR10数据集参数
        • 2、返回参数
        • 3、代码中使用

一、什么是TorchVision

torchvision是pytorch的一个图形库,用来处理图像,主要用来构建计算机视觉模型。

从下面的官网截图可以看到torchvision有很多模块,下面以dataset模块进行举例。
Pytorch深度学习------torchvision中dataset数据集的使用(CIFAR10)_第1张图片

torchvision中datasets包:用来进行数据加载,主要有以下几个模块

CelebA
CIFAR
Cityscapes
COCO
Captions
Detection
DatasetFolder
EMNIST
FakeData
Fashion-MNIST
Flickr
HMDB51
ImageFolder
ImageNet
Kinetics-400
KMNIST
LSUN
MNIST
Omniglot
PhotoTour
Places365
QMNIST
SBD
SBU
STL10
SVHN
UCF101
USPS
VOC

二、以torchvision.datasets子模块下的CIFAR10数据集为例

Pytorch深度学习------torchvision中dataset数据集的使用(CIFAR10)_第2张图片
从上图可知:

CIFAR-10数据集由60000张32 × 32彩色图像组成,分为10个类,每个类有6000张图像。有50000张训练图像和10000张测试图像。
数据集分为5个训练批次和1个测试批次,每个批次有10000张图像。测试批包含从每个类随机选择的1000个图像。训练批次以随机顺序包含剩余的图像,但一些训练批次可能包含来自一个类的更多图像。在它们之间,训练批次包含来自每个类的5000张图像。

1、CIFAR10数据集参数

class CIFAR10(VisionDataset):
    """`CIFAR10 `_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """

由上述代码可知,有如下5个参数
root :即指定数据集要下载在哪一个文件夹里面,如:root=“./dataset” 即将数据集下载到当前目录的dataset文件夹下
train :是否为训练集,布尔类型,如果train=True即为训练集,否则train=False则为非训练集。
transform :进行图像变换的各种操作,如RandomCrop、Compose等。
target_transform :对于标签进行transform 操作。
download :是否下载数据集,download = True表示下载数据集,download = False表示不下载数据集。(如果当前文件夹已经有需要下载的数据集,但是在程序编写中又把download属性值设定为True,此时不会再下载。)

2、返回参数

返回一个元组,其元组代表img, target

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

3、代码中使用

import torchvision
from torch.utils.tensorboard import SummaryWriter

# 创建训练数据集
# step1 准备创建数据集需要的各种参数
trans_tool = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()  # 转为Tensor类型
    # torchvision.transforms.Resize((5, 5))  # 进行大小裁剪
])
# 第一个参数root表示下载的数据集需要放在哪一个文件夹里面,第二个参数tran表示是否是训练数据集,第三个参数transform表示进行变换操作,第四个参数download表示是否在线下载
tran_dataset = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=trans_tool,download=True)
# 创建测试数据集
test_dataset = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=trans_tool,download=True)
print(tran_dataset[0])  # 此时显示的是(, 6),即元组的形式,显示图片类别和标签
# step2 在tensorboard中显示

writer = SummaryWriter("logs")
for i in range(10):
    img, laber = tran_dataset[i]
    writer.add_image("CIFAR10",img,i)
writer.close()

Pytorch深度学习------torchvision中dataset数据集的使用(CIFAR10)_第3张图片

你可能感兴趣的:(pytorch深度学习,深度学习,pytorch,人工智能)