Pytorch数据读取(Dataset, DataLoader, DataLoaderIter)

 

Pytorch的数据读取主要包含三个类:

  1. Dataset
  2. DataLoader
  3. DataLoaderIter

这三者大致是一个依次封装的关系: 1.被装进2., 2.被装进3.

一. torch.utils.data.Dataset

是一个抽象类, 自定义的Dataset需要继承它并且实现两个成员方法:

  1. __getitem__()
  2. __len__()

第一个最为重要, 即每次怎么读数据. 以图片为例:

def __getitem__(self, index):
        img_path, label = self.data[index].img_path, self.data[index].label
        img = Image.open(img_path)

        return img, label

值得一提的是, pytorch还提供了很多常用的transform, 在torchvision.transforms 里面, 本文中不多介绍, 常用的有Resize , RandomCrop , Normalize , ToTensor (这个极为重要, 可以把一个PIL或numpy图片转为torch.Tensor, 但是好像对numpy数组的转换比较受限, 所以这里建议在__getitem__()里面用PIL来读图片, 而不是用skimage.io).

 第二个比较简单, 就是返回整个数据集的长度:

def __len__(self):
        return len(self.data)

二. torch.utils.data.DataLoader

类定义为:

class torch.utils.data.DataLoader(
    dataset, 
    batch_size=1, 
    shuffle=False, 
    sampler=None,
    batch_sampler=None, 
    num_workers=0, 
    collate_fn=, 
    pin_memory=False, 
    drop_last=False
)

可以看到, 主要参数有这么几个:

  1. dataset : 即上面自定义的dataset.
  2. collate_fn: 这个函数用来打包batch
  3. num_worker: 非常简单的多线程方法, 只要设置为>=1, 就可以多线程预读数据

这个类其实就是下面将要讲的DataLoaderIter的一个框架, 一共干了两件事:

  1. 定义了一堆成员变量, 到时候赋给DataLoaderIter,
  2. 然后有一个__iter__() 函数, 把自己 "装进" DataLoaderIter 里面.
def __iter__(self):
        return DataLoaderIter(self)

三. torch.utils.data.dataloader.DataLoaderIter

上面提到, DataLoader就是DataLoaderIter的一个框架, 用来传给DataLoaderIter 一堆参数, 并把自己装进DataLoaderIter 里。其实到这里就可以满足大多数训练的需求了, 比如

class CustomDataset(Dataset):
   # 自定义自己的dataset

dataset = CustomDataset()
dataloader = Dataloader(dataset, ...)

for data in dataloader:
   # training...

在for 循环里, 总共有三点操作:

  1. 调用了dataloader 的__iter__() 方法, 产生了一个DataLoaderIter
  2. 反复调用DataLoaderIter 的__next__()来得到batch, 具体操作就是, 多次调用dataset的__getitem__()方法 (如果num_worker>0就多线程调用), 然后用collate_fn来把它们打包成batch. 中间还会涉及到shuffle , 以及sample 的方法等.
  3. 当数据读完后, __next__()抛出一个StopIteration异常, for循环结束, dataloader 失效.

四. 又一层封装

其实上面三个类已经可以搞定了, 仅供参考

class DataProvider:
    def __init__(self, batch_size, is_cuda):
        self.batch_size = batch_size
        self.dataset = Dataset_triple(self.batch_size,
                                      transform_=transforms.Compose(
                                     [transforms.Scale([224, 224]),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                           std=[0.229, 0.224, 0.225])]),
                                      )
        self.is_cuda = is_cuda  # 是否将batch放到gpu上
        self.dataiter = None
        self.iteration = 0  # 当前epoch的batch数
        self.epoch = 0  # 统计训练了多少个epoch

    def build(self):
        dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, drop_last=True)
        self.dataiter = DataLoaderIter(dataloader)

    def next(self):
        if self.dataiter is None:
            self.build()
        try:
            batch = self.dataiter.next()
            self.iteration += 1

            if self.is_cuda:
                batch = [batch[0].cuda(), batch[1].cuda(), batch[2].cuda()]
            return batch

        except StopIteration:  # 一个epoch结束后reload
            self.epoch += 1
            self.build()
            self.iteration = 1  # reset and return the 1st batch

            batch = self.dataiter.next()
            if self.is_cuda:
                batch = [batch[0].cuda(), batch[1].cuda(), batch[2].cuda()]
            return batch

感谢以下链接提供的参考:

https://zhuanlan.zhihu.com/p/30934236

https://blog.csdn.net/u014380165/article/details/79167753

你可能感兴趣的:(pytorch学习笔记,pytorch,学习笔记)