Pytorch源码(一)—— 简析torchvision的ImageFolder

一、所使用的函数介绍

1. find_classes

def find_classes(dir):
    # 得到指定目录下的所有文件,并将其名字和指定目录的路径合并
    # 以数组的形式存在classes中
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    # 使用sort()进行简单的排序
    classes.sort()
    # 将其保存的路径排序后简单地映射到 0 ~ [ len(classes)-1] 的数字上
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    # 返回存放路径的数组和存放其映射后的序号的数组
    return classes, class_to_idx

2. has_file_allowed_extension

def has_file_allowed_extension(filename, extensions):
    """Checks if a file is an allowed extension.

    Args:
        filename (string): path to a file

    Returns:
        bool: True if the filename ends with a known image extension
    """
    # 将文件的名变成小写
    filename_lower = filename.lower()

    # endswith() 方法用于判断字符串是否以指定后缀结尾
    # 如果以指定后缀结尾返回True,否则返回False
    return any(filename_lower.endswith(ext) for ext in extensions)

3. make_dataset

def make_dataset(dir, class_to_idx, extensions):
    images = []
    # expanduser把path中包含的"~"和"~user"转换成用户目录
    # 主要还是在Linux之类的系统中使用,在不包含"~"和"~user"时
    # dir不变
    dir = os.path.expanduser(dir)
    # 排序后按顺序通过for循环dir路径下的所有文件名
    for target in sorted(os.listdir(dir)):
        # 将路径拼合
        d = os.path.join(dir, target)
        # 如果拼接后不是文件目录,则跳出这次循环
        if not os.path.isdir(d):
            continue
        # os.walk(d) 返回的fnames是当前d目录下所有的文件名
        # 注意:第一个for其实就只循环一次,返回的fnames 是一个数组
        for root, _, fnames in sorted(os.walk(d)):
            # 循环每一个文件名
            for fname in sorted(fnames):
                # 文件的后缀名是否符合给定
                if has_file_allowed_extension(fname, extensions):
                    # 组合路径
                    path = os.path.join(root, fname)
                    # 将组合后的路径和该文件位于哪一个序号的文件夹下的序号
                    # 组成元祖
                    item = (path, class_to_idx[target])
                    # 将其存入数组中
                    images.append(item)

    return images


注意:

下面三个函数都是加载图像的函数,用于ImageFolder类中

4. pil_loader

def pil_loader(path):
# open path as file to avoid ResourceWarning
#  (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

5. accimage_loader

def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)

6. default_loader

def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)

用于定义读入文件的格式

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']


二、关键类

1. DatasetFolder

class DatasetFolder(data.Dataset):
    """A generic data loader where the samples are arranged in this way: ::

        root/class_x/xxx.ext
        root/class_x/xxy.ext
        root/class_x/xxz.ext

        root/class_y/123.ext
        root/class_y/nsdf3.ext
        root/class_y/asd932_.ext

    Args:
        root (string): Root directory path.
        loader (callable): A function to load a sample given its path.
        extensions (list[string]): A list of allowed extensions.
        transform (callable, optional): A function/transform that takes in
            a sample and returns a transformed version.
            E.g, ``transforms.RandomCrop`` for images.
        target_transform (callable, optional): A function/transform that takes
            in the target and transforms it.

     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        samples (list): List of (sample path, class_index) tuples
    """

    def __init__(self, root, loader, extensions, transform=None, target_transform=None):
        # 得到root下的文件路径数组和文件映射后的序号数组
        classes, class_to_idx = find_classes(root)
        # 得到所有文件的路径和其所在文件夹的序号所组成的集合的数组
        samples = make_dataset(root, class_to_idx, extensions)
        # 如果在指定的路径上没有得到文件那么便抛出一个异常
        if len(samples) == 0:
            raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
                               "Supported extensions are: " + ",".join(extensions)))
        
        self.root = root
        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples

        self.transform = transform
        self.target_transform = target_transform

     # 如果在类中定义了__getitem__()方法,那么其实例对象(假设为P)
     # 就可以这样P[key]取值。
     # 当实例对象做P[key]运算时,就会调用类中的__getitem__(key)方法
    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        # 得到文件路径和其所属的文件夹的序号
        path, target = self.samples[index]
        # 加载数据
        sample = self.loader(path)
        # 是否对读入的数据进行处理
        # 主要包括转化张量和一些数据增强的方法
        if self.transform is not None:
            sample = self.transform(sample)
        # 是否对所属的文件夹的序号进行处理
        # 由于torchvision是按文件夹来得到数据的标签值
        # 所以这里的序号其实就是分类的标签
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    # 数据数量
    def __len__(self):
        return len(self.samples)

    # 生成报告,报告一些必要信息
    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str


2. ImageFolder

# 继承自DatasetFolder,只是在DatasetFolder基础上将加载的文件格式
# 数据加载函数给定义了
class ImageFolder(DatasetFolder):
    """A generic data loader where the images are arranged in this way: ::

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png

    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """
    # default_loader 数据加载函数
    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader):
        # root 路径
        # IMG_EXTENSIONS 定义了读取的文件类型
        # transform和target_transform 数据和标签处理
        super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
                                          transform=transform,
                                          target_transform=target_transform)
        self.imgs = self.samples

你可能感兴趣的:(Pytorch源码(一)—— 简析torchvision的ImageFolder)