Pytorch加载图像数据集需要两步,首先需要使用**torchvision.datasets.ImageFolder()读取图像,然后再使用torch.utils.data.DataLoader()**加载数据集。
torchvision.datasets.ImageFolder,一个通用的数据加载器,数据集中的数据以以下方式组织。
root/dog/xxx.png
root/dog/xxy.png
root/dog/[…]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[…]/asd932_.png
ImageFolder类的定义如下:
class ImageFolder(root, transform=None, target_transform=None, loader=default_loader, is_valid_file=None)
Args:
Attributes:
下面代码展示了如何用ImageFolder去加载数据,用Dataloader构建可迭代的数据装载器。
import torch
import torchvision
from torch.utils.data import Dataloader
import torchvision.transforms as transforms
data_transforms ={
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([.5, .5, .5],[.5, .5, .5])
]),
'test': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([.5, .5, .5],[.5, .5, .5])
])
}
# ImageFolder 通用的加载器
dataset = torchvision.datasets.ImageFolder(root, transform=data_trainsforms['train'])
# 构建可迭代的数据装载器
inputs = DataLoader(dataset=dataset, batch_size, shuffle=True, num_workers)
for data, label in inputs:
.......
有时候,不仅仅加载图像数据和label,还需要加载图像的路径,那么需要自定义类,扩展torchvision.datasets.ImageFolder类,代码示例如下所示。
class ImageFolderWithPaths(torchvision.datasets.ImageFolder):
# 扩展torchvision.datasets.ImageFolder,自定义数据集使其包含图像路径
def __getitem__(self, index):
# ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).get__item__(index)
# 图像路径
path = self.imgs[index][0]
# 构造一个新的tuple使其包括origin和图像路径
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
dataset = ImageFolderWithPaths(root, transform=data_trainsforms['train'])
inputs = DataLoader(dataset=dataset, batch_size, shuffle=True, num_workers)
for datas, label, paths in inputs:
.......
torch.utils.data.Dataset, 构建可迭代的数据装载器。组合数据集和采样器,并在数据集上提供单线程或多进程迭代器。
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
参数: