在pytorch中,数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset,并实现相应的方法。
在学习Pytorch的教程时,加载数据许多时候都是直接调用torchvision.datasets
里面集成的数据集,直接在线下载,然后使用torch.utils.data.DataLoader
进行加载。
那么,我们怎么使用我们自己的数据集,然后用DataLoader
进行加载呢?
常见的两种形式的导入:
1.1、一种是整个数据集都在一个文件下,内部再另附一个label文件,说明每个文件的状态。这种存放数据的方式可能更时候在非分类问题上得到应用。下面就是我们经常使用的数据存放方式。
1.2、一种则是更适合在分类问题上,即把不同种类的数据分为不同的文件夹存放起来。这样,我们可以从文件夹或文件名得到label。使用torchvision.datasets.imageFolder函数生成数据集。这种方式没有用过,暂时不介绍了
官方:torch.utils.data.Dataset
是一个抽象类,
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
用户想要加载自定义的数据只需要继承这个类(torch.util.data.Dataset),并且覆写__len__ 和 __getitem__两个方法, 不覆写这两个方法会直接返回错误。因此步骤如下:
len(dataset),
返回整个数据集的大小建立的自定义类如下:
# 加载数据集,自己重写DataSet类
class dataset(Dataset):
# image_dir为数据目录,label_file,为标签文件
def __init__(self, image_dir, label_file, transform=None):
super(dataset, self).__init__() # 添加对父类的初始化
self.image_dir = image_dir # 图像文件所在路径
self.labels = read(label_file) # 图像对应的标签文件, read label_file之后的结果
self.transform = transform # 数据转换操作
self.images = os.listdir(self.image_dir )#目录里的所有img文件
# 加载每一项数据
def __getitem__(self, idx):
image_index = self.images[index] #根据索引index获取该图片
img_path = os.path.join(self.image_dir, image_index) #获取索引为index的图片的路径名
labels = self.labels[index] # 对应标签
image = Image.open(img_name)
if self.transform:
image = self.transform(image)
# 返回一张照片,一个标签
return image, labels
# 数据集大小
def __len__(self):
return (len(self.images))
设置好数据类之后,我们就可以将其用torch.utils.data.DataLoader加载,并访问它。
if __name__=='__main__':
data = AnimalData(img_dir_path, label_file, transform=None)#初始化类,设置数据集所在路径以及变换
dataloader = DataLoader(data,batch_size=128,shuffle=True)#使用DataLoader加载数据
for i_batch,batch_data in enumerate(dataloader):
print(i_batch)#打印batch编号
print(batch_data['image'].size())#打印该batch里面图片的大小
print(batch_data['label'])#打印该batch里面图片的标签
其实Dataset类不局限于这么写,它可以实现多种数据读取方法,只需要把读取数据以及数据处理逻辑写在__getitem__方法中即可,然后将处理好后的数据以及标签返回即可。
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
参数解释:
Dataloader的处理逻辑是先通过Dataset类里面的 __getitem__
函数获取单个的数据,然后组合成batch,再使用collate_fn所指定的函数对这个batch做一些操作,比如padding啊之类的。
因为dataloader
是有batch_size
参数的,我们可以通过自定义collate_fn=myfunction
来设计数据收集的方式,意思是已经通过上面的Dataset
类中的__getitem__
函数采样了batch_size
数据,以一个包的形式传递给collate_fn
所指定的函数。
dataloader 对于数据的读取延迟主要取决于num_workers和pin_memory这两个参数。首先,我先介绍一下比较简单的 pin_memory 参数。
所谓的 pin_memory 就是锁页内存的意思。
计算机为了运行进程会先将进程和数据读到内存里。一般来说,计算机的内存都是比较小的,很难存的下太多的数据。但是,某个进程在某个时间段所需的进程和数据往往是比较少的,也就是说在某个时间点我们不需要将一个进程所需要的所有资源都放在内存里。我们可以将这些暂时用不到的数据或进程存放在硬盘一个被称为虚拟内存的地方。在进程运行的时候,我们可以不断交换内存和虚拟内存的数据以减少内存所需存储的数据。而且这些交换往往是通过某些规律预测下个时刻进程会用到的数据和代码并提前交换至内存的,这些规律的使用以及预测的准确性将会影响到进程的速度。
所谓的锁页内存就是说,我们不允许系统将某些内存里的数据交换至虚拟内存,毋庸置疑这将会提升进程的运行速度。但是也会是内存的存储占用消耗很多。
pin_memory 为 true 的时候速度的提升会有多大
Dataloader 多进程读取数据的参数是通过num_workers指定的,num_workers 为 0 的话就用主进程去读取数据,num_workers 为 N 的话就会多开 N 个进程去读取数据。这里的多进程是通过 python 的 multiprocessing module 实现的(其实 pytorch 在 multiprocessing 又加了一个 wraper 以实现shared memory)。
关于 num_workers的工作原理:
所以:
Dataloader 读数据的整个流程:
可以看出,dataloader 只会在每次迭代成功的时候才会放入新的 index 到 index_queue 里面。因为上面写了在初始化 dataloader 的时候,我们一共放了 2 x self.num_workers 个 batch 的 index 到 index_queue。读了一个 batch 才会放新的 batch,所以这所有的 worker 进程最多缓存的 batch 数量就是 2 x self.num_workers 个。
以上流程的如果想看代码可以参考:Pytorch Dataloader 学习笔记 · 大专栏