PyTorch如何加载数据集(自定义数据集)

pytorch加载数据集主要分为两种方法:

1、所使用数据集已被集成在pytorch内,如:CIFAR-10,CIFAR-100,MNIST等等。对于这种数据集,可以直接使用pytorch内置函数:torchvision.datasets.CIFAR100来直接加载,比较方便。例程如下:

transform_train = transforms.Compose([
    #transforms.ToPILImage(),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
cifar100_training = torchvision.datasets.CIFAR100(root='./data', train=True,
												  download=True, transform=transform_train)
cifar100_training_loader = DataLoader(
        cifar100_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

2、所使用数据集为被集成,这个类别是本文的主要讲述内容。

加载自定义数据集(即未被集成在pytorch内)

对于自定义数据集pytorch实际上是有一个函数的:torchvision.datasets.ImageFolder(),但是此函数只能加载特定形式的数据集(图片已被分类好,并放在相应文件夹下了,其标签就是其上层目录的名称,在下面会解释为什么)。有时候我们需要使用的数据集可能不是这样的,其标签可能不在目录上,而在一个独立的文件里。这时候,直接使用ImageFolder会导致训练结果与预期结果毫无关系,这就需要我们自己重新构造一个类似于DatasetFolder类(ImageFolder就继承了DatasetFolder类)的新类别来加载数据集。

分析DatasetFolder类

pytorch中DatasetFolder类的官方实现如下:

class DatasetFolder(data.Dataset):
       
def __init__(self, root, loader, extensions, transform=None, target_transform=None):
    classes, class_to_idx = find_classes(root)
    samples = make_dataset(root, class_to_idx, extensions)
    if len(samples) == 0:
        raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
                           "Supported extensions are: " + ",".join(extensions)))

    self.root = root
    self.loader = loader
    self.extensions = extensions

    self.classes = classes
    self.class_to_idx = class_to_idx
    self.samples = samples

    self.transform = transform
    self.target_transform = target_transform

def __getitem__(self, index):
    """
    Args:
        index (int): Index

    Returns:
        tuple: (sample, target) where target is class_index of the target class.
    """
    path, target = self.samples[index]
    sample = self.loader(path)
    if self.transform is not None:
        sample = self.transform(sample)
    if self.target_transform is not None:
        target = self.target_transform(target)

    return sample, target

def __len__(self):
    return len(self.samples)

def __repr__(self):
    fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
    fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
    fmt_str += '    Root Location: {}\n'.format(self.root)
    tmp = '    Transforms (if any): '
    fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
    tmp = '    Target Transforms (if any): '
    fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
    return fmt_str

其中三个地方需要修改:find_classes(root),make_dataset(root, class_to_idx, extensions),和sample = self.loader(path)。
find_classes(root)源码如下:

def find_classes(dir):
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]   
    # 遍历dir目录下的所有子目录名称并将其存在classes中
    classes.sort()
    # 由于Python版本的不同可能需要更换为sorted(classes)
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    # 创建一个字典,将类别与数字对应
    return classes, class_to_idx

该函数的参数:dir = root,调用DatasetFolder类时使用的目录,多为travel,test等。通过上面的程序以及注释,可以看出该函数时默认我们将图片分好类放在相应的文件夹下,如果需要使用其他形式的数据集,首先就要修改该函数得到正确的类别对应表,当我们使用tiny-imagenet-200时,修改为:

def find_classes(class_file):  # 此时class_file要是对应的标签对应表
    with open(class_file) as r:
        classes = map(lambda s : s.strip(), r.readlines())
    
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}

    return classes, class_to_idx

当使用ImageNet时,修改为:

def find_classes(img_label):  # 此时img_label要是对应的标签对应表
    classes = scio.loadmat(img_label)['synsets']
    class_to_idx = {}
    for i in  range(1000):
        class_to_idx[label_array[i][0][1][0]] = i
    return classes, class_to_idx

可以看到只是在读取文件的路径和形式有些许区别,在使用不同的数据集时需要我们自己去调整具体细节。

make_dataset(root, class_to_idx, extensions)源码如下:

def make_dataset(dir, class_to_idx, extensions):
    images = []
    dir = os.path.expanduser(dir)   		# 把path中包含的"~"和"~user"转换成用户目录
    for target in sorted(os.listdir(dir)):  # os.listdir()函数返回一个包含dir目录下所有文件或目录的列表
        d = os.path.join(dir, target)       # 将dir和target连接形成新的路径  d为类别目录
        if not os.path.isdir(d):
            continue

        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                if has_file_allowed_extension(fname, extensions):   # 判断fnames的后缀是否正确(JPEG.JPG等等)
                    path = os.path.join(root, fname)                # 得到文件的路径和文件名
                    item = (path, class_to_idx[target])		        # 依据class_to_idx得到图片类别对应的数字
                    images.append(item)

    return images

sample = self.loader(path)注意对于loader,源码中使用的是default_loader,我在实现时,其相关函数一直无法导入。因此我使用img = Image.open(data).convert(‘RGB’)来替代。

以上就是加载数据集的方法,其他方法大同小异,可以依次类推。

参考链接:
https://github.com/budlbaram/tiny_imagenet/blob/master/tiny_imagenet.py
pytorch官方源码

你可能感兴趣的:(数据集处理)