pytorch源码 folder

在使用pytorch构建数据库时,会使用到ImageFolder这个模块便于数据加载,了解其源码便于快速开发。

import torch.utils.data as data
#PIL: Python Image Library缩写,图像处理模块
#     Image,ImageFont,ImageDraw,ImageFilter
from PIL import Image    
import os
import os.path

# 图片扩展(图片格式)
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

# 判断是不是图片文件
def is_image_file(filename):
    # 只要文件以IMG_EXTENSIONS结尾,就是图片
    # 注意any的使用
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

# 结果:classes:['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
# classes_to_idx:{'1': 1, '0': 0, '3': 3, '2': 2, '5': 5, '4': 4, '7': 7, '6': 6, '9': 9, '8': 8}
def find_classes(dir):
    '''
     返回dir下的类别名,classes:所有的类别,class_to_idx:将文件中str的类别名转化为int类别
     classes为目录下所有文件夹名字的集合
    '''
    # os.listdir:以列表的形式显示当前目录下的所有文件名和目录名,但不会区分文件和目录。
    # os.path.isdir:判定对象是否是目录,是则返回True,否则返回False
    # os.path.join:连接目录和文件名

    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    # sort:排序
    classes.sort()
    # 将文件名中得到的类别转化为数字class_to_idx['3'] = 3
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx
    # class_to_idx :{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9}


# 如果文件是图片文件,则保留它的路径,和索引至images(path,class_to_idx)
def make_dataset(dir, class_to_idx):
    # 返回(图片的路径,图片的类别)
    # 打开文件夹,一个个索引
    images = []
    # os.path.expanduser(path):把path中包含的"~"和"~user"转换成用户目录
    dir = os.path.expanduser(dir)
    for target in sorted(os.listdir(dir)):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue

        # os.walk:遍历目录下所有内容,产生三元组
        # (dirpath, dirnames, filenames)【文件夹路径, 文件夹名字, 文件名】
        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                if is_image_file(fname):
                    path = os.path.join(root, fname)   # 图片的路径
                    item = (path, class_to_idx[target])  # (图片的路径,图片类别)
                    images.append(item)

    return images

# 打开路径下的图片,并转化为RGB模式
def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    # with as : 安全方面,可替换:try,finally
    # 'r':以读方式打开文件,可读取文件信息
    # 'b':以二进制模式打开文件,而不是文本
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            # convert:,用于图像不同模式图像之间的转换,这里转换为‘RGB’
            return img.convert('RGB')


def accimage_loader(path):
    # accimge:高性能图像加载和增强程序模拟的程序。
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path):
    # get_image_backend:获取加载图像的包的名称
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)


class ImageFolder(data.Dataset):
    """A generic data loader where the images are arranged in this way: ::

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png

    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """

    # 初始化,继承参数
    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader):
        # TODO
        # 1. Initialize file path or list of file names.
        # 找到root的文件和索引
        classes, class_to_idx = find_classes(root)
        # 保存路径下图片文件路径和索引至imgs
        imgs = make_dataset(root, class_to_idx)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        #这里需要注意的是,第一步:read one data,是一个data

        path, target = self.imgs[index] 
        # 这里返回的是图片路径,而需要的是图片格式
        img = self.loader(path) # 将图片路径加载成所需图片格式
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        # return the total size of your dataset.
        return len(self.imgs)

你可能感兴趣的:(pytorch源码,caffe,pytorch,源码共读,项目分享)