代码系列:pytorch——torchvision.datasets.ImageFolder

参考博客:https://blog.csdn.net/TH_NUM/article/details/80877435

用于从文件夹中读取数据,源码链接:https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py

ImageFolder是DatasetFolder的子类,有以下属性:

Attributes:

        classes (list): List of the class names.

        class_to_idx (dict): Dict with items (class_name, class_index).

        samples (list): List of (sample path, class_index) tuples

        targets (list): The class_index value for each image in the dataset

self.classes, self.class_to_idx是由方法find_classes得到,self.samples, self.targets由方法make_dataset得到:

classes, class_to_idx = self._find_classes(self.root)
samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
targets = [s[1] for s in samples]

find_classes(), make_dataset()的代码如下:

def find_classes(dir):
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx

def make_dataset(dir, class_to_idx):
    images = []
    dir = os.path.expanduser(dir)
    for target in sorted(os.listdir(dir)):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue

        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                if is_image_file(fname):
                    path = os.path.join(root, fname)
                    item = (path, class_to_idx[target])
                    images.append(item)

    return images

self.classes是所有dir路径下的排序后文件名list(也就是所有类别的list),self.class_to_idx将类别编码为int类型,得到一个dict;self.samples是包含了每个类别下的图片文件名和其类别的tuple的list(在ImageFolder中self.imgs=self.samples),self.targets是self.samples中的类别编码。

具体例子如下:

import torchvision
import torchvision.datasets as dset

training_dir = r"F:\Face-Reco-master\data\faces\training"
folder_dataset = dset.ImageFolder(root=training_dir)

其中training_dir目录下文件如下图所示:

代码系列:pytorch——torchvision.datasets.ImageFolder_第1张图片

每个文件夹s##表示一个类别,其中保存了该类下的图像。输出结果如下:

代码系列:pytorch——torchvision.datasets.ImageFolder_第2张图片

你可能感兴趣的:(Machine,Learning,python,Deep,Learning)