一文读懂Dataset, DataLoader及collate_fn, Sampler等参数

数据预处理DataLoader及各参数详解

pytorch关于数据处理的功能模块均在torch.utils.data 中,pytorch输入数据PipeLine一般遵循一个“三步走”的策略,操作顺序是这样的:

继承Dataset类,自定义数据处理类。必须重载实现len()、getitem()这两个方法。
其中__len__返回数据集样本的数量,而__getitem__应该编写支持数据集索引的函数,例如通过dataset[i]可以得到数据集中的第i+1个数据。在实现自定义类时,一般需要对图像数据做增强处理,和标签处理,__getitem__返回图像和对应label,图像增强的方法可以使用pytorch自带的torchvision.transforms内模块,也可以使用自定义或者其他第三方增强库。

② 导入 DataLoader类,传入参数(上面自定义类的对象) 创建一个DataLoader对象。

③ 循环遍历这个 DataLoader 对象。将img, label加载到模型中进行训练

dataset = MyDataset()           # 第一步:构造Dataset对象
dataloader = DataLoader(dataset)# 第二步:通过DataLoader来构造迭代对象

num_epoches = 100
for epoch in range(num_epoches):# 第三步:逐步迭代数据
    for img, label in dataloader:
        # 训练代码

pytorch内部默认的数据处理类有如下:


class Dataset(object):

class IterableDataset(Dataset):

class TensorDataset(Dataset): #  封装成tensor的数据集,每一个样本都通过索引张量来获得。

class ConcatDataset(Dataset): #  连接不同的数据集以构成更大的新数据集

class Subset(Dataset):  # 获取指定一个索引序列对应的子数据集

class ChainDataset(IterableDataset):

一般能用到的是ConcatDataset, Subset,其他不常用。

可迭代对象的创建方式

  • 方法一:在python中凡是具有__iter__的方法的类,都是可迭代的类。可迭代类创建的对象实现了__iter__方法,因此就是可迭代对象。

    from collections import Iterable, Iterator
    
    class Student(object):
        def __init__(self, score):
            self.score = score
    
        def __iter__(self):
            return iter(self.score)  # return 返回的是一个迭代器,
    
    test = Student([80, 90, 95])
    print(isinstance(test,  Iterable))
    print(isinstance(test,  Iterator))
    for i in test:  # test可迭代对象,但不是迭代器,所以不能next(test)
        print(i)
    for i in test:  # 重复遍历试试看,是否有结果
        print(i)
    print("============")
    test=iter(test)  # 对可迭代对象使用内建函数iter(), 使之成为为迭代器,此时可以next(test)
    print(isinstance(test,  Iterable))
    print(isinstance(test,  Iterator))
    for i in test:
        print(i)
    for i in test:  # 对迭代器重复遍历试试看,有结果过没,
        print(i)  # 没有结果
    """
    True
    False
    80
    90
    95
    80
    90
    95
    ============
    True
    True
    80
    90
    95
    """
    

    从本代码可看出Student类创建的对象是可迭代对象,但不是迭代器(因为没有实现__next__方法),且可以实现重复遍历,而迭代器是无法重复遍历的!

  • 方法二:用list、tuple等容器创建的对象,也都是可迭代对象。如:test=[1,2,3], test就是可迭代对象

迭代器的创建方式:迭代器对象必须同时实现__iter__和__next__方法才是迭代器

  • 方法一:自定义类实现__iter__和__next__方法,对于迭代器来说,__iter__ 返回的是它自身 self__next__ 则是返回迭代器中的下一个值。

    class Student(object):
        def __init__(self, score):
            self.score = score
    
        def __iter__(self):
            return self  # 对于迭代器来说,__iter__ 返回的是它自身self,也就是返回迭代器。
    
        def __next__(self):
            if self.score < 100:
                self.score += 1
                return self.score
            else:
                raise StopIteration()
    
    test = Student(95)
    print(isinstance(test,  Iterable))
    print(isinstance(test,  Iterator))
    print(next(test))  # 可用内建函数next(),每次获取下一个值
    for i in test:  
        print(i)
    for i in test:  # 重复遍历试试看,是否有结果
        print(i)
    
    """
    True
    True
    96
    97
    98
    99
    100
    """
  • 方法二:对可迭代对象使用内建函数iter()包装,如:test=[1,2,3],testor=iter(test), testor就是个迭代器了。

python中的for循环其实兼容了两种机制:

1.如果对象有__iter__会返回一个迭代器。

2.如果对象没有__iter__,但是实现了__getitem__,会改用下标迭代的方式,__getitem__可以帮助一个对象进行取数和切片操作。

2, DataLoader类详解,数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代。

torch.utils.data.DataLoader类实现__iter__方法(且方法return的是个迭代器),所以创建的对象是一个可迭代对象,且该对象可以被for循环遍历(因为DataLoader类的__iter__方法return的是个迭代器),可以重复for循环遍历,这点上面讲到原因了,但是DataLoader类创建的对象并不能使用next()访问,因为这个对象它并不是迭代器。若使用iter(这个对象)使之返回一个迭代器,然后就可以使用next访问了。

class DataLoader(object):
    Arguments:
        dataset (Dataset): 是一个DataSet对象,表示需要加载的数据集.三步走第一步创建的对象
        batch_size (int, optional): 每一个batch加载多少个样本,即指定batch_size,默认是 1 
        shuffle (bool, optional): 布尔值True或者是False ,表示每一个epoch之后是否对样本进行随机打乱,默认是False
------------------------------------------------------------------------------------
        sampler (Sampler, optional): 自定义从数据集中抽取样本的策略,如果指定这个参数,那么shuffle必须为False
        batch_sampler (Sampler, optional): 此参数很少使用,与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥)
------------------------------------------------------------------------------------
        num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
        collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数
        pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.默认是False
------------------------------------------------------------------------------------
        drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了,如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
------------------------------------------------------------------------------------

以上是部分参数,一般不需要自己去实现DataLoader了,只需要在构造对象时中指定相应的参数即可,其中常常使用的是dataset, batch_size, shuffe, sampler, num_workers, collate_fn, pin_memory,这几个参数。

sampler参数

sampler参数其实就是一个“采样器”,表示从样本中究竟如何取样,pytorch采样器有如下几个:

class Sampler(object):

class SequentialSampler(Sampler):# 顺序采样样本,始终按照同一个顺序。

class RandomSampler(Sampler): # 无放回地随机采样样本元素。

class SubsetRandomSampler(Sampler): # 无放回地按照给定的索引列表采样样本元素

class WeightedRandomSampler(Sampler): # 按照给定的概率来采样样本。

class BatchSampler(Sampler):  # 在一个batch中封装一个其他的采样器。

# torch.utils.data.distributed.DistributedSampler
class DistributedSampler(Sampler): # 采样器可以约束数据加载进数据集的子集。

Sampler类是所有的采样器的基类,每一个继承自Sampler的子类都必须实现它的__iter__方法和__len__方法。前者实现如何迭代样本,后者实现一共有多少个样本。Sampler这个基类是不能用来创建对象后用来采样的,是转为作为自定义采样器的父类。

pytorch默认是采用的采样器如下:

if batch_sampler is None:  # 没有手动传入batch_sampler参数时
    if sampler is None:  # 没有手动传入sampler参数时
        if shuffle:
            sampler = RandomSampler(dataset)  # 随机采样
        else:
            sampler = SequentialSampler(dataset)  # 顺序采样
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
self.__initialized = True

当默认采样器不满足我们的使用时,需要继承基类Sampler自定义采样器,如下是一种视频处理时的采样器:

from torch.utils.data.sampler import Sampler
import numpy as np

class RandomSequenceSampler(Sampler):
	# 作用与BatchSampler有点类似,每seq_len个视频shuffle
    def __init__(self, n_sample, seq_len):
        self.n_sample = n_sample  # 视频的数量
        self.seq_len = seq_len  # 视频序列长度

    def _pad_ind(self, ind):
        zeros = np.zeros(self.seq_len - self.n_sample % self.seq_len)
        ind = np.concatenate((ind, zeros))
        return ind

    def __iter__(self):
        idx = np.arange(self.n_sample)
        if self.n_sample % self.seq_len != 0:
            idx = self._pad_ind(idx)
        idx = np.reshape(idx, (-1, self.seq_len))
        np.random.shuffle(idx)
        idx = np.reshape(idx, (-1))
        return iter(idx.astype(int))

    def __len__(self):
        return self.n_sample + (self.seq_len - self.n_sample % self.seq_len)

再看个例子,在训练文字识别时,一种采样方式

class randomSequentialSampler(Sampler):
    def __init__(self, data_source, batch_size):
        self.num_samples = len(data_source)
        self.batch_size = batch_size

    def __iter__(self):
        n_batch = len(self) // self.batch_size
        tail = len(self) % self.batch_size
        index = torch.LongTensor(len(self)).fill_(0)
        for i in range(n_batch):
            random_start = random.randint(0, len(self) - self.batch_size)
            batch_index = random_start + torch.range(0, self.batch_size - 1)
            index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index
        # deal with tail
        if tail:
            random_start = random.randint(0, len(self) - self.batch_size)
            tail_index = random_start + torch.range(0, tail - 1)
            index[(i + 1) * self.batch_size:] = tail_index

        return iter(index)

    def __len__(self):
        return self.num_samples

collate_fn参数

当继承Dataset类自定义类时,__getitem__方法一般返回一组类似于(image,label)的一个样本,在创建DataLoader类的对象时,collate_fn函数会将batch_size个样本整理成一个batch样本,便于批量训练。

default_collate(batch)中的参数就是这里的 [self.dataset[i] for i in indices],indices是从所有样本的索引中选取的batch_size个索引,表示本次批量获取这些样本进行训练。self.dataset[i]就是自定义Dataset子类中__getitem__返回的结果。默认的函数default_collate(batch) 只能对大小相同image的batch_size个image整理,如[(img0, label0), (img1, label1),(img2, label2), ] 整理成([img0,img1,img2,], [label0,label1,label2,]), 这里要求多个img的size相同。所以在我们的图像大小不同时,需要自定义函数callate_fn来将batch个图像整理成统一大小的,若读取的数据有(img, box, label)这种你也需要自定义,因为默认只能处理(img,label)。当然你可以提前将数据集全部整理成统一大小的。

以下是文字识别时,文本行图像长度不一,需要自定义整理。

class AlignCollate(object):
    """将数据整理成batch"""
    def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False):
        self.imgH = imgH
        self.imgW = imgW
        self.keep_ratio_with_pad = keep_ratio_with_pad

    def __call__(self, batch):# 有可能__getitem__返回的图像是None, 所以需要过滤掉
        batch = filter(lambda x: x is not None, batch)
        images, labels = zip(*batch)

        if self.keep_ratio_with_pad:  # same concept with 'Rosetta' paper
            resized_max_w = self.imgW
            input_channel = 3 if images[0].mode == 'RGB' else 1
            transform = NormalizePAD((input_channel, self.imgH, resized_max_w))

            resized_images = []
            for image in images:
                w, h = image.size
                ratio = w / float(h)
                # 图片的宽度大于设定的输入ingW
                if math.ceil(self.imgH * ratio) > self.imgW:
                    resized_w = self.imgW
                else:
                    resized_w = math.ceil(self.imgH * ratio)

                resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
                resized_images.append(transform(resized_image))
                # resized_image.save('./image_test/%d_test.jpg' % w)

            image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)

        else:
            transform = ResizeNormalize((self.imgW, self.imgH))
            image_tensors = [transform(image) for image in images]
            image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0)

        return image_tensors, labels

再看个在做目标检测时自定义collate_fn函数,给每个图像添加索引

def collate_fn(self, batch):
        paths, imgs, targets = list(zip(*batch))
        # Remove empty placeholder targets  
        # 有可能__getitem__返回的图像是None, 所以需要过滤掉
        targets = [boxes for boxes in targets if boxes is not None]
        # Add sample index to targets
        # boxes是每张图像上的目标框,但是每个图片上目标框数量不一样呢,所以需要给这些框添加上索引,对应到是哪个图像上的框。
        for i, boxes in enumerate(targets):
            boxes[:, 0] = i
        targets = torch.cat(targets, 0)
        # Selects new image size every tenth batch
        if self.multiscale and self.batch_count % 10 == 0:
            self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))
        # Resize images to input shape
        # 每个图像大小不同呢,所以resize到统一大小
        imgs = torch.stack([resize(img, self.img_size) for img in imgs])
        self.batch_count += 1
        return paths, imgs, targets

其实也可以自定义collate_fn同时,结合使用默认的default_collate

from torch.utils.data.dataloader import default_collate  # 导入这个函数


def collate_fn(batch):
    """
    params:
        batch :是一个列表,列表的长度是 batch_size
               列表的每一个元素是 (x,y) 这样的元组tuple,元祖的两个元素分别是x,y
               大致的格式如下 [(x1,y1),(x2,y2),(x3,y3)...(xn,yn)]
    returns:
        整理之后的新的batch

    """

    # 这一部分是对 batch 进行重新 “校对、整理”的代码

    return default_collate(batch) #返回校对之后的batch,一般就直接推荐使用default_collate进行包装,因为它里面有很多功能,比如将numpy转化成tensor等操作,这是必须的。

在使用pytorch时,当加载数据训练for i, batch in enumerate(train_loader):时,可能会出现TypeError: ‘NoneType’ object is not callable这个错误,若遇到更换pytorch版本即可

你可能感兴趣的:(计算机视觉)