ImageFolder---合并dataset

ImageFolder 用于读取文件夹内的图片与类别,生成Map-style datasets形式的数据集,以便DataLoader迭代。ImageFolder使用格式:

root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png

但是由于各种原因,如原生数据存储格式,某些操作系统(如centos)中单个目录中最大文件数量存在限制(hdf5 yyds)等等,使得数据的存储是分开的,如:

sub_root1/dog/xxx.png
sub_root1/dog/xxy.png
sub_root1/dog/[...]/xxz.png

sub_root2/cat/123.png
sub_root2/cat/nsdf3.png
sub_root2/cat/[...]/asd932_.png

需要分别读取,再合并ImageFolder,功能代码如下:

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms

def merge_datasets(dataset, sub_dataset):
    '''
        需要合并的Attributes:
            classes (list): List of the class names sorted alphabetically.
            class_to_idx (dict): Dict with items (class_name, class_index).
            samples (list): List of (sample path, class_index) tuples
            targets (list): The class_index value for each image in the dataset
    '''
    # 合并 classes
    dataset.classes.extend(sub_dataset.classes)
    dataset.classes = sorted(list(set(dataset.classes)))
    # 合并 class_to_idx
    dataset.class_to_idx.update(sub_dataset.class_to_idx)
    # 合并 samples
    dataset.samples.extend(sub_dataset.samples)
    # 合并 targets
    dataset.targets.extend(sub_dataset.targets)

验证是否可行:


transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

paths = ["E:\\datasets\\office31\\amazon\\asub1", "E:\\datasets\\office31\\amazon\\asub2", "E:\datasets\\office31\\amazon\\asub3"]

dataset = ImageFolder(root=paths[0], transform=transform)
for i in range(len(paths) - 1):
    sub_dataset = ImageFolder(paths[i + 1])
    merge_datasets(dataset, sub_dataset)

dataloader = DataLoader(dataset=dataset, batch_size=64, shuffle=False)
for data in dataloader:
    images, targets = data
    print(targets)

你可能感兴趣的:(深度学习,python,pytorch,合并dataset,超出文件数目限制)