torchvision.datasets.ImageFolder

这篇博客讲解了如何自定义一个 Dataset类 返回训练数据与标签,但是对于简单的图像分类任务,并不需要自己定义一个 Dataset类,可以直接调用 torchvision.datasets.ImageFolder 返回训练数据与标签。

1. 数据集组织方式

既然是调用API,那么你的数据集必然得按照API的要求去组织, torchvision.datasets.ImageFolder 要求数据集按照如下方式组织:

A generic data loader where the images are arranged in this way:

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png

注意:根目录 root 下存储的是类别文件夹(如cat,dog),每个类别文件夹下存储相应类别的图像(如xxx.png)。

2. torchvision.datasets.ImageFolder 介绍

torchvision.datasets.ImageFolder_第1张图片
可以从源码看出,torchvision.datasets.ImageFolder 有 root, transform, target_transform, loader四个参数,现在依次介绍这四个参数。

  1. root:图片存储的根目录,即各类别文件夹所在目录的上一级目录,在下面的例子中是’./data/train/’。
  2. transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。
  3. target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
  4. loader:表示数据集加载方式,通常默认加载方式即可。

另外,该 API 有以下成员变量:

  1. self.classes:用一个 list 保存类别名称
  2. self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
  3. self.imgs:保存(img-path, class) tuple的 list,与我们自定义 Dataset类的 def __getitem__(self, index): 返回值类似。注意看下面实例中 dataset.imgs 的返回值

3. torchvision.datasets.ImageFolder 实例

先看数据集组织结构:
torchvision.datasets.ImageFolder_第2张图片即根目录为 “./data/train/”,根目录下有三个类别文件夹,即Snowdrop、LilyValley、Daffodil,每个类别文件夹下有80个训练样本。

import torchvision

dataset = torchvision.datasets.ImageFolder('./data/train/') # 不做transform
print(dataset.classes)
print(dataset.class_to_idx)
print(dataset.imgs)

torchvision.datasets.ImageFolder_第3张图片
那么如何取一个图片数据呢?

# dataset[0] 表示取第一个训练样本,即(path, class_index)。
print(dataset[0][0]) # 返回的数据是PIL Image对象

在这里插入图片描述

你可能感兴趣的:(DeepLearning)