【torchvision】 torchvision.datasets.ImageFolder

torchvision.datasets.ImageFolder(root, transform, target_transform, loader)

参数:

  • root:图片存储的根目录,即各类别文件夹所在目录的上一级目录,在下面的例子中是 “…/input/data/”
  • transform:对图片进行预处理操作(函数),原始图片作为输入,返回一个转换后的图片。
  • target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
  • loader:表示数据集加载方式,通常默认加载方式即可

另外,该 API 有以下成员变量:

  • self.classes:用一个 list 保存类别名称
  • self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
  • self.imgs:保存(img-path, class) tuple的 list,与我们自定义 Dataset类的 def getitem(self, index): 返回值类似。注意看下面实例中 dataset.imgs 的返回值

举例:

数据存储结构如下

【torchvision】 torchvision.datasets.ImageFolder_第1张图片

import torchvision
import torchvision.transforms as transforms
from torch.utils import data

trans = transforms.Compose([transforms.RandomCrop(224), transforms.ToTensor()])
dataset = torchvision.datasets.ImageFolder('../input/data', transform=trans)
print(dataset.classes)
print(dataset.class_to_idx)
print(dataset.imgs)
print('\n')

train_loader = data.DataLoader(dataset, batch_size=2, shuffle=True)
for (img, label) in train_loader:
    print(img.shape)
    print(label)
    break

【torchvision】 torchvision.datasets.ImageFolder_第2张图片

你可能感兴趣的:(torchvision,python,深度学习,人工智能)