Pytorch、MxNet 的DataLoader,DataSet设计

Pytorch 和MxNet(gluon) 的DataLoader以及DataSet设计得比较巧妙,简单记录一下,怕自己忘了。

以MxNet为例介绍,因为我现在屏幕上的代码是MxNet的代码;但是Pytorch里好像是一样的,因为我发现gluon很多东西都和pytorch一样。

主要是理解一下设计思路,以及一些以前我不常用的python小技巧。

 

DataLoader和DataSet都是可迭代的对象。

DataSet每次返回一个元素(比如一张图片),而DataLoader每次返回一个batch。

python中创建可迭代对象大概有如下几种方法(我只写了目前我记得的):

1, 创建一个Class, 这个class中必须要实现  如下两个函数。这样这个class调用起来就类似一个list.

def __len__(self):
   pass

def __getitem__(self, idx):
   pass

2, 创建一个Class, 这个class要有如下函数;这个class就变成了一个迭代器,每次迭代先调用__iter__,然后内部调用next()。

def __iter__(self):
    return self

def next(self):
    return ...

3, 创建一个Calss,这个class可以没有 next()函数,但是可以在__iter__()函数内部采用yield关键字返回。yield关键字一百度就知道了。这样class就变成了一个迭代器。

 

言归正传,MxNet中DataSet中被设计成了第1种(__getitem__)形式,因为我们经常采用索引的方式来取image之类的数据;

DataLoader被设计成了第3众(yield)的形式,因为我们在训练的时候往往也不需要知道具体是哪个batch。

 

DataSet这个类中还可以传入transformer这个callable的类,transformer可以采用compose来堆叠,很方便。

这些transformer一般包括将图片颜色抖动,裁剪为固定大小等。

 

DataLoader这个类就将DataSet的每个item形成batch,形成batch的过程可以采用了多进程的方式,加快速度,多进程的代码也在这里实现。在形成batch的时候也可以采用shuffle参数,shuffle整个Dataset的id,然后通过id来调用DataSet的__getitem__(id)函数,这也从侧面体现了DataSet被设计成__getitem__形式的迭代器的合理性。

此外,对于一般的分类任务,图片大小已经一样,形成batch很直接。但是对于Faster-RCNN等模型,其getitem出来的img大小不一样,需要对batch中的某个图片进行pad之类的操作,这个操作也是在DataLoader中完成的,MxNet中采用的是_batchify_fn这个函数来实现的。很方便。

 

简而言之,DataSet是一个可迭代的类,主要每次返回一张图片,并作一些预处理。

DataLoader是将返回的每张图片形成一个batch,并做一些额外的预处理。

emmm,,写得太乱,主要是时间太短,主要能让我自己看一下随时记忆起来。

 

你可能感兴趣的:(机器学习,python)