【pytorch源码赏析】Dataset in pytorch

1. 源码概览

pytorch是众多dl工具中,比较python风格化的一种,另一个完全python化的dl工具是chainer,它的构建语言中只有python,甚至cuda也是从python端调用的。python风格化的好处是,使用了很多python的语言特性,让代码更加简洁,更高效。《python高级编程》的第2、3章,描述了部分python的高级语言特性,比如:列表推导,迭代器和生成器,装饰器等。这些trick让代码更加python化,可读性更强,也更健壮。

pytorch的数据集部分,从源码可以看出,提供了2个主要的类:Dataset,DataLoader。

Dataset为抽象类,定义了两个行为:__getitem__和__len__。也就是任何数据集,都可以len(dataset)获得样本的数量,dataset[i]获得其中第i个样本。派生了两个类:TensorDataset,当x和y是pytorch的tensor时,可以方便地导入;另一个ConcatDataset,用于合并多个数据集(对于实际应用特别有用)。

DataLoader是更核心的类,用户用它来获得每次batch的训练数据。

dataloader.py中有2个类,DataLoader和DataLoaderIter。

DataLoader提供如下功能:
1. 保存了dataset
2. 具有sample行为
3. 提供单线程/多线程来获取数据集中的数据(代码主要实现的功能)

DataLoader有2个行为:__iter__和__len__。而__iter__这个迭代器,代码如下:

def __iter__(self):
    return DataLoaderIter(self)

返回的正是DataLoaderIter。DataLoaderIter的功能是,根据sample指定的方法,获取训练样本。sample方法有SequentialSampler, RandomSampler, BatchSampler这三种,其实是两种:SequentialSampler和RandomSampler。如果指定了shuffle,则是随机采样,否则是序列采样,然后都会使用BatchSample。

DataLoaderIter具有3个行为:__iter__,__len__和__next__。每次使用next(dataLoaderIter)来获得一个batch。

__iter__总是和__next__一起使用,__iter__表明这个类是可以迭代的,__next__表明每次迭代的具体行为,一个例子如下:

class Testing:
    def __init__(self,a,b):
        self.a = a
        self.b = b
    def __iter__ (self):
        print('itering')
        return self
    def next(self):
        print('nexting')
        if self.a <= self.b:
            self.a += 1
            return self.a-1
        else:
            raise StopIteration

myObj = Testing(1,5)           
for i in myObj:
    print i
itering
nexting
1
nexting
2
nexting
3
nexting
4
nexting
5
nexting

2. 使用方法

使用pytorch提供的方法操作数据集,一般分两步:
1. 继承Dataset,实现__getitem____len__方法。
2. 实例化DataLoader,一般需要指定自己的collate_fn方法。

而这正是代码优美的地方,把“读取数据集”这个任务完美地解耦和,用户只需要针对不同的数据集派生Dataset类,实现2个方法。DataLoader负责了如何读取训练样本的行为,只需要实例化即可,还可以通过设置collate_fn定制化自己的具体读取行为。

你可能感兴趣的:(pytorch)