Pytorch 加载图像数据(ImageFolder和Dataloader)

Pytorch加载图像数据集需要两步,首先需要使用**torchvision.datasets.ImageFolder()读取图像,然后再使用torch.utils.data.DataLoader()**加载数据集。

ImageFolder

torchvision.datasets.ImageFolder,一个通用的数据加载器,数据集中的数据以以下方式组织。

root/dog/xxx.png
root/dog/xxy.png
root/dog/[…]/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/[…]/asd932_.png

ImageFolder类的定义如下:

class ImageFolder(root, transform=None, target_transform=None, loader=default_loader, is_valid_file=None)

Args:

  • root(string) :Root directory path.
  • transform(callable, optional):A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform(callable, optional):A function/transform that takes in the target and transforms it.
  • loader(callable, optional):A function to load an image given its path.
  • is_valid_file (callable, optional): A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files)

Attributes:

  • classes (list): List of the class names sorted alphabetically.
  • class_to_idx (dict): Dict with items (class_name, class_index).
  • imgs (list): List of (image path, class_index) tuples

下面代码展示了如何用ImageFolder去加载数据,用Dataloader构建可迭代的数据装载器。

import torch
import torchvision
from torch.utils.data import Dataloader
import torchvision.transforms as transforms
data_transforms ={
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([.5, .5, .5],[.5, .5, .5])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([.5, .5, .5],[.5, .5, .5])
    ])
}
# ImageFolder 通用的加载器
dataset = torchvision.datasets.ImageFolder(root, transform=data_trainsforms['train'])
# 构建可迭代的数据装载器
inputs = DataLoader(dataset=dataset, batch_size, shuffle=True, num_workers)
for data, label in inputs:
    .......

有时候,不仅仅加载图像数据和label,还需要加载图像的路径,那么需要自定义类,扩展torchvision.datasets.ImageFolder类,代码示例如下所示。

class ImageFolderWithPaths(torchvision.datasets.ImageFolder):
    # 扩展torchvision.datasets.ImageFolder,自定义数据集使其包含图像路径
    def __getitem__(self, index):
        # ImageFolder normally returns
        original_tuple = super(ImageFolderWithPaths, self).get__item__(index)
        # 图像路径
        path = self.imgs[index][0]
        # 构造一个新的tuple使其包括origin和图像路径
        tuple_with_path = (original_tuple + (path,))
        return tuple_with_path

dataset = ImageFolderWithPaths(root, transform=data_trainsforms['train'])
inputs = DataLoader(dataset=dataset, batch_size, shuffle=True, num_workers)
for datas, label, paths in inputs:
    .......

Dataloader

torch.utils.data.Dataset, 构建可迭代的数据装载器。组合数据集和采样器,并在数据集上提供单线程或多进程迭代器。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

参数:

  • dataset (Dataset) – 加载数据的数据集。
  • batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。
  • shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).
  • sampler (Sampler, optional) – 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
  • num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
  • collate_fn (callable, optional) – 自定义处理数据并返回
  • pin_memory (bool, optional) – True 代表将数据Tensor放入CUDA的pin存储
  • drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)

参考

  1. pytorch中文文档
  2. SOURCE CODE FOR TORCHVISION.DATASETS.FOLDER

你可能感兴趣的:(机器学习和深度学习之旅,pytorch,深度学习,python)