前几天在看代码时遇到制作数据集的一条代码:
train_datasets = datasets.ImageFolder(train_dir, transform = train_transforms)
train_dataloader = torch.utils.data.DataLoader(train_datasets, batch_size = batch_size, shuffle = True)
我的路径是dir=r’D:\PycharmProjects\DenseNet\image\raw_img\train’ train下有五个文件夹[‘black’, ‘break’, ‘hide’, ‘shade’, ‘snap’],代表五种类别,每种类别下是图片
主要是对Torchvision.datasets中的ImageFolder函数的不理解通过查该函数的源代码,慢慢理解,以下是自己的探索历程。
首先ImageFolder函数的源代码是:
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
classes, class_to_idx = find_classes(root) #classes返回[‘black’, ‘break’, ‘hide’, ‘shade’, ‘snap’] class_to_idx返回{‘black’: 0, ‘break’: 1, ‘hide’: 2, ‘shade’: 3, ‘snap’: 4}
imgs = make_dataset(root, class_to_idx)#生成一个列表,该列表中保存着每一张图片和对应的标签形成的元组
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.classes = classes
self.class_to_idx = class_to_idx
self.transform = transform
self.target_transform = target_transform
self.loader = loader
当然,在我的代码中root=train_dir,是一个你所放数据集的绝对路径,transforms=none,target_transform=None,说明默认情况下这两个功能不执行,那么它的作用我们一会再进行探索,loader=default_loader,我们再来看default_loader的源代码:
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
get_image_backend用于获取加载图像的包的名称,有pil和accimage两种,accimage包使用Intel IPP库。速度比 PIL快,但不支持多操作,此段代码也就是选择选择加载图像的工具,于是,我查到三个加载图像的函数:
1、pil_loader
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
该函数旨在打开路径中的图片并转化成‘RGB’格式
2、accimage_loader
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
3、default_loader
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
我给出的代码中选择的默认方式对图片进行加载
对于classes, class_to_idx = find_classes(root)中的find_classes函数源代码是:
def find_classes(dir):
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
返回的两个值classes是一个列表,包括你的dir打开后的几种类别,我的是[‘black’, ‘break’, ‘hide’, ‘shade’, ‘snap’],class_to_idx返回一个字典,我的是{‘black’: 0, ‘break’: 1, ‘hide’: 2, ‘shade’: 3, ‘snap’: 4}即给你的几种类别加上标签
再来讲讲源代码中的make_dataset函数,源代码为:
def make_dataset(dir, class_to_idx): #dir=r'D:\PycharmProjects\DenseNet\image\raw_img\train'
images = []
dir = os.path.expanduser(dir) #D:\PycharmProjects\DenseNet\image\raw_img\train
for target in sorted(os.listdir(dir)): #target=black break hide shade sna五种类别
d = os.path.join(dir, target) #此时d输出包含五种类别的五个绝对路径
if not os.path.isdir(d): #此句在判断d是否是一个路径
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
if is_image_file(fname):
path = os.path.join(root, fname)#该路径已经具体到每张图片
item = (path, class_to_idx[target])#将每张图片与对应的标签放到一个元组
images.append(item)#将每张图片与对应的标签形成的元组,加入到image列表中
return images
对于其中的os.path.expanduser,查询如下:
os.path.expanduser(path)
在Unix和Windows平台上,返回参数,参数中开头的或者user被替换成user的主/家目录。
在Unix上,开头的~被替换成环境变量HOME,如果它被设置的话;否则,通过内建模块pwd在密码目录查询当前用户的家目录。如果开头是 ~user ,则直接在密码目录中查询(user的家目录)。
在Windows上,将使用HOME和USERPROFILE,如果它们被设置的话;否则使用HOMEPATH和HOMEDRIVE的组合。如果开头是~user,首先按上述方式得到user路径,然后移除最后的目录部分。
把把path中包含的""和"user"转换成用户目录
如果扩展失败或者参数path不是以~打头,则直接返回参数(path)。
对于多种os.path的常用路径名操作可见
os.walk的函数声明为:
walk(top, topdown=True, οnerrοr=None, followlinks=False)
top为所要遍历的地址
topdowm为真,优先遍历top目录,否则优先遍历top的子目录,(默认为开启)
onerror需要一个callable的对象,当walk需要异常时会调用
followlinks为真,则会遍历目录下的快捷方式,实际所指的目录(默认关闭)
os.walk的返回值是一个生成器,也就是说我们需要不断地遍历它,来获得所有内容。
每次遍历都返回一个三元组(root, dirs,files)
root指的是当前正在遍历的文件夹所在地址;
dirs是一个列表,内容是该文件夹所有的目录的名字(不包含子目录)
files也是一个列表,内容是该文件夹中所有的文件(不包含子目录)
为了方便理解,尝试打印如下:
root:
D:\PycharmProjects\DenseNet_pytorch\venv\Scripts\python.exe D:/PycharmProjects/DenseNet_pytorch/test.py
D:\PycharmProjects\DenseNet\image\raw_img\train\black
D:\PycharmProjects\DenseNet\image\raw_img\train\break
D:\PycharmProjects\DenseNet\image\raw_img\train\hide
D:\PycharmProjects\DenseNet\image\raw_img\train\shade
D:\PycharmProjects\DenseNet\image\raw_img\train\snap
Process finished with exit code 0
dirs:
D:\PycharmProjects\DenseNet_pytorch\venv\Scripts\python.exe D:/PycharmProjects/DenseNet_pytorch/test.py
[]
[]
[]
[]
[]
Process finished with exit code 0
files:
生成五个列表,每个列表存放root五个地址打开后的多个图片