Torchvision.datasets中的ImageFolder函数详解

前几天在看代码时遇到制作数据集的一条代码:

train_datasets = datasets.ImageFolder(train_dir, transform = train_transforms)
train_dataloader = torch.utils.data.DataLoader(train_datasets, batch_size = batch_size, shuffle = True)  

我的路径是dir=r’D:\PycharmProjects\DenseNet\image\raw_img\train’ train下有五个文件夹[‘black’, ‘break’, ‘hide’, ‘shade’, ‘snap’],代表五种类别,每种类别下是图片
主要是对Torchvision.datasets中的ImageFolder函数的不理解通过查该函数的源代码,慢慢理解,以下是自己的探索历程。
首先ImageFolder函数的源代码是:

    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader):
        classes, class_to_idx = find_classes(root) #classes返回[‘black’, ‘break’, ‘hide’, ‘shade’, ‘snap’]  class_to_idx返回{‘black’: 0, ‘break’: 1, ‘hide’: 2, ‘shade’: 3, ‘snap’: 4}
        imgs = make_dataset(root, class_to_idx)#生成一个列表,该列表中保存着每一张图片和对应的标签形成的元组
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

当然,在我的代码中root=train_dir,是一个你所放数据集的绝对路径,transforms=none,target_transform=None,说明默认情况下这两个功能不执行,那么它的作用我们一会再进行探索,loader=default_loader,我们再来看default_loader的源代码:

def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)

get_image_backend用于获取加载图像的包的名称,有pil和accimage两种,accimage包使用Intel IPP库。速度比 PIL快,但不支持多操作,此段代码也就是选择选择加载图像的工具,于是,我查到三个加载图像的函数:
1、pil_loader

def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

该函数旨在打开路径中的图片并转化成‘RGB’格式
2、accimage_loader

def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)

3、default_loader

def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)

我给出的代码中选择的默认方式对图片进行加载
对于classes, class_to_idx = find_classes(root)中的find_classes函数源代码是:

def find_classes(dir):
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx

返回的两个值classes是一个列表,包括你的dir打开后的几种类别,我的是[‘black’, ‘break’, ‘hide’, ‘shade’, ‘snap’],class_to_idx返回一个字典,我的是{‘black’: 0, ‘break’: 1, ‘hide’: 2, ‘shade’: 3, ‘snap’: 4}即给你的几种类别加上标签
再来讲讲源代码中的make_dataset函数,源代码为:

def make_dataset(dir, class_to_idx):  #dir=r'D:\PycharmProjects\DenseNet\image\raw_img\train'
    images = []
    dir = os.path.expanduser(dir)  #D:\PycharmProjects\DenseNet\image\raw_img\train
    for target in sorted(os.listdir(dir)):  #target=black break hide shade sna五种类别
        d = os.path.join(dir, target)    #此时d输出包含五种类别的五个绝对路径
        if not os.path.isdir(d):   #此句在判断d是否是一个路径
            continue

        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                if is_image_file(fname):
                    path = os.path.join(root, fname)#该路径已经具体到每张图片
                    item = (path, class_to_idx[target])#将每张图片与对应的标签放到一个元组
                    images.append(item)#将每张图片与对应的标签形成的元组,加入到image列表中

    return images

对于其中的os.path.expanduser,查询如下:
os.path.expanduser(path)
在Unix和Windows平台上,返回参数,参数中开头的或者user被替换成user的主/家目录。

在Unix上,开头的~被替换成环境变量HOME,如果它被设置的话;否则,通过内建模块pwd在密码目录查询当前用户的家目录。如果开头是 ~user ,则直接在密码目录中查询(user的家目录)。

在Windows上,将使用HOME和USERPROFILE,如果它们被设置的话;否则使用HOMEPATH和HOMEDRIVE的组合。如果开头是~user,首先按上述方式得到user路径,然后移除最后的目录部分。
把把path中包含的""和"user"转换成用户目录
如果扩展失败或者参数path不是以~打头,则直接返回参数(path)。
对于多种os.path的常用路径名操作可见

os.walk的函数声明为:
walk(top, topdown=True, οnerrοr=None, followlinks=False)
top为所要遍历的地址
topdowm为真,优先遍历top目录,否则优先遍历top的子目录,(默认为开启)
onerror需要一个callable的对象,当walk需要异常时会调用
followlinks为真,则会遍历目录下的快捷方式,实际所指的目录(默认关闭)
os.walk的返回值是一个生成器,也就是说我们需要不断地遍历它,来获得所有内容。
每次遍历都返回一个三元组(root, dirs,files)
root指的是当前正在遍历的文件夹所在地址;
dirs是一个列表,内容是该文件夹所有的目录的名字(不包含子目录)
files也是一个列表,内容是该文件夹中所有的文件(不包含子目录)
为了方便理解,尝试打印如下:
root:

D:\PycharmProjects\DenseNet_pytorch\venv\Scripts\python.exe D:/PycharmProjects/DenseNet_pytorch/test.py
D:\PycharmProjects\DenseNet\image\raw_img\train\black
D:\PycharmProjects\DenseNet\image\raw_img\train\break
D:\PycharmProjects\DenseNet\image\raw_img\train\hide
D:\PycharmProjects\DenseNet\image\raw_img\train\shade
D:\PycharmProjects\DenseNet\image\raw_img\train\snap

Process finished with exit code 0

dirs:

D:\PycharmProjects\DenseNet_pytorch\venv\Scripts\python.exe D:/PycharmProjects/DenseNet_pytorch/test.py
[]
[]
[]
[]
[]

Process finished with exit code 0

files:
生成五个列表,每个列表存放root五个地址打开后的多个图片

你可能感兴趣的:(p'y'torch,python,pytorch)