pytorch ImageFolder

参考官方文档:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/#imagefolder

简单使用

ImageFolder是一个很好用的数据加载器
所需要的文件结构如下所示,每一类的图片都在各自类的文件夹下(狗的图片在dog文件夹下,猫的图片在cat文件夹的,而这些类的文件夹都在同一个根目录下)

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

调用起来也十分简单
root:数据集根文件夹路径
transform:一个函数,对输入图片的转换,可不写
target_transform:一个函数,对输出图片的转换,可不写

import torchvision.datasets as dset
import torchvision.transforms as transforms
dataSet = dset.ImageFolder(root="root folder path", [transform, target_transform])

他有以下成员变量:

self.classes - 用一个list保存 类名
self.class_to_idx - 类名对应的 索引
self.imgs - 保存(img-path, class) tuple的list

上述都内容都可以在官方文档上找到,接下来是一些常用语法

通过下标获取图片的数据和标签

要想遍历图片十分简单,写个for循环就好了

dataSet[i][0]#返回第i张图片的PIL Image对象
dataSet[i][0]#返回第i张图片的标签(整数)

注意到那个标签是一个整数,如果我们想要获得标签的字符串,我们可以用如下语法

dataSet.classes[dataSet[i][0]]]

通过下标获取图片的名称

ImageFolder通过下标获取图片十分简单,那么如果我们相用下标获取图片的名字呢?
查了好久,官方文档上面好像也没有具体说,通过查找ImageFolder这个类的方法(例如print(dir(dataSet)),其中dataSet是一个ImageFolder对象),我找到了samples这个方法,说明如下

dataSet.samples[i][0]#返回第i张图片的名称
dataSet.samples[i][0]#返回第i张图片的标签(整数)

你可能感兴趣的:(pytorch)