PyTorch自定义数据加载:深究Dataset与DataLoader类

PyTorch自定义数据加载:深究Dataset与DataLoader类_第1张图片

PyTorch自定义数据加载:深究Dataset与DataLoader类

  • 写在文章开头
  • 数据加载步骤
    • 创建Dataset对象
    • 创建DataLoader对象
    • 循环获取数据用以训练

写在文章开头

本篇文章是我个人的学习笔记。在我看来,可以说PyTorch几乎占据了深度学习、强化学习科研领域,无论我查看什么样的文献人手皆Torch,尽管TensorFlow目前用起来比较称心,但也不得不迫使我转型PyTorch。希望自己能尽快掌握这个框架吧。

今天要深入学习的是torch.utils.data.DataLoadertorch.utils.data.Dataset类。

数据加载步骤

在PyTorch中,要加载自己的数据集需要执行以下的三个步骤:

  • 创建一个Dataset对象;
  • 创建一个DataLoader对象;
  • DataLoader对象进行循环,将数据和标签取出,用以训练模型;

创建Dataset对象

对于Dataset类的创建,我们首先需要继承torch.utils.data.Dataset,然后重写三个函数:

  • __init__:加载初始化数据,可以通过self为我们的数据类添加多个属性。比如:
  • __len__:该函数返回我们的数据总量,方便我们通过索引去访问数据集中的某一数据项,该函数可以配合__getitem__函数使用,并且通过Python内置函数len调用。
  • __getitem__:该函数返回一条训练数据,并将其转换成torch.Tensor类型数据。

接下来我们就创建一个Dataset对象:

class MyDataset(Dataset):

    def __init__(self, transform = None):
        self.transform = transform

    def __len__(self):
        return len(os.listdir('E:/Pycharm Project/ImgSegBase/CCNet/KITTI/dataset/Train/base'))

    def __getitem__(self, idx):

        image_name = os.listdir('E:/Pycharm Project/ImgSegBase/CCNet/KITTI/dataset/Train/base')[idx]
        image = cv2.imread('E:/Pycharm Project/ImgSegBase/CCNet/KITTI/dataset/Train/base/' + image_name, 1)
        mask = cv2.imread('E:/Pycharm Project/ImgSegBase/CCNet/KITTI/dataset/Train/semantic_RGB/'+ image_name, 0)
        label = torch.tensor(mask, dtype = torch.long)

        if self.transform:
            image = self.transform(image)

        return image, label

创建DataLoader对象

PyTorch模型的训练基本都会使用DataLoader对象。对于torch.utils.data.DataLoader,其对应的参数如下:

torch.utils.data.DataLoader(dataset,
						    batch_size=1,
						    shuffle=False,
						    sampler=None,
						    batch_sampler=None,
						    num_workers=0,
						    collate_fn=None,
						    pin_memory=False,
						    drop_last=False,
						    timeout=0,
						    worker_init_fn=None,
						    prefetch_factor=2,
						    persistent_workers=False)

几个重要参数的具体含义如下:

  • dataset>> PyTorch已有的数据读取接口 1 ^1 1(比如torchvision.datasets.ImageFolder)或者自定义数据接口的输出 2 ^2 2并且该输出要么是torch.utils.data.Dataset类的对象、要么是继承自torch.utils.data.Dataset类的自定义类的对象。
  • batch_size>> 训练数据的小批量大小。
  • shuffle>> 是否要打乱数据,一般在训练数据中会采用。
  • collate_fn>> 对于这一函数,官网的解释是:merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset。该参数需要传入一个函数映射 f f f。我认为参数的大意是,我们首先从DataLoader中获取了一个批量的数据,随后我们用lamda函数或def函数来对该批量进行再处理,从而能够定制我们的训练数据。
  • drop_last>> set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
  • sampler>> defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified.该参数的输入是一个迭代器,用以定义从数据集中采样的方式。具体的迭代器可参考下图
    PyTorch自定义数据加载:深究Dataset与DataLoader类_第2张图片
    根据源码,由于shuffle默认值为False,所以此时的samplerSequentialSampler,也就是按顺序取样;shuffleTrue时,samplerRandomSampler, 也就是随机取样。
  • batch_sampler>> like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last。如果传入了batch_size参数,该参数则不用初始化;反之,则需要传入一个torch.utils.data.BatchSampler类的实例。相关的源码如下:
if batch_size is not None and batch_sampler is None:
    # auto_collation without custom batch_sampler
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)
    
self.sampler = sampler
self.batch_sampler = batch_sampler
# batch generation process of CLASS BatchSampler
def __iter__(self):
    batch = []
    for idx in self.sampler:
        batch.append(idx) # return indices
        if len(batch) == self.batch_size:
            yield batch
            batch = []
    if len(batch) > 0 and not self.drop_last:
        yield batch
  • num_workers>> 数据导入的线程数量。如果设置为0,那么数据导入将只使用当前的主线程。
  • pin_memory>> If True, the data loader will copy Tensors into CUDA pinned memory(锁页内存) before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.如果设置为True,意味着生成的torch.Tensor数据将被拷贝到内存中的锁页内存,这样将内存中的数据转义到GPU的显存时速度会更快。具体可参考这篇文章。
  • timeout>> 用来设置数据读取时间的最大限度,超过该限度则会报错。

循环获取数据用以训练

第三个循环就是我们对DataLoader进行循环产出数据。

    for epo in range(epoch):
    
        # Each epoch records 5 cross validation 
        # trains_loss and val_loss, divide by 5
        
        train_loss = 0
        val_loss = 0
        val_acc = 0
        val_miou = 0
        
        for i in range(5):
            # train
            CCNet.train()
            for index, (image, label) in enumerate(train_dataloader[i]):
                # training

你可能感兴趣的:(Torch)