Dataloader可以帮我们返回封装好的batch内容,采用迭代读取的方式完成训练,占用更少内存。
在构建Dataloader之前要构建好一个继承了Dataset类的数据集类,在Dataset类中完成语料的预处理(主要是tokenizer和embeeding过程),之后再用Dataloader包装Dataset,设置batch size等参数完成batch sets的构建,产生的可迭代对象可用于后续的模型训练。可以说是很优雅、规范地完成了数据预处理任务了。
import torch.utils.data as tud
class MyDataset(tud.Dataset):
def __init__(self, ...):
super(MyDataset, self).__init__()
# 此处对传入参数进行数据预处理
def __len__(self):
# 这个数据集有多少个item
def __getitem__(self, idx):
# 根据给定的index返回一个item
dataset = WordEmbeddingDataset(text, word_to_idx, idx_to_word, word_freqs, word_counts)
# 传入一些参数,返回一个可迭代的dataset
dataloader = tud.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
# 直接使用Dataloader罩上
for e in range(NUM_EPOCHS):
for i, (返回的变量可以用元组的方式来接收) in enumerate(dataloader):
# 直接按照batch获取数据
...
可以看出,使用的时候可以更加方便。每次循环直接获取了batch size的数据,而不是像之前一样通过range方法的步长设置来获取一个batch,因为dataloader读取是迭代读取的,因此这时的dataloader中的数据没有全部的存在内存中,只有当使用的时候会使用。
torch.utils.data.DataLoader
:torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, \
batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, \
drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)
- 其中包含的参数有:
dataset:定义的dataset类返回的结果。
batchsize:每个bacth要加载的样本数,默认为1。
shuffle:在每个epoch中对整个数据集data进行shuffle重排,默认为False。
sample:定义从数据集中加载数据所采用的策略,如果指定的话,shuffle必须为False;batch_sample类似,表示一次返回一个batch的index。
num_workers:表示开启多少个线程数去加载你的数据,默认为0,代表只使用主进程。
collate_fn:改造并合并样本列表以形成可用于不同任务的小批量的Tensor对象。如在把数据送入lstm模型前要做的序列padding操作。
pin_memory:表示要将load进来的数据是否要拷贝到pin_memory区中,其表示生成的Tensor数据是属于内存中的锁页内存区,这样将Tensor数据转义到GPU中速度就会快一些,默认为False。通常情况下,数据在内存中要么以锁页的方式存在,要么保存在虚拟内存(磁盘)中,设置为True后,数据直接保存在锁页内存中,后续直接传入cuda;否则需要先从虚拟内存中传入锁页内存中,再传入cuda,这样就比较耗时了,但是对于内存的大小要求比较高。
drop_last:当你的整个数据长度不能够整除你的batchsize,选择是否要丢弃最后一个不完整的batch,默认为False。
[本节参考:https://zhuanlan.zhihu.com/p/117270644]
(假设dataset类返回的是:data, label,label是个二分类,存在样本不均衡问题,类别2远多于类别1)
from torch.utils.data.sampler import WeightedRandomSampler
## 如果label为1,那么对应的该类别被取出来的概率是另外一个类别的2倍
weights = [2 if label == 1 else 1 for data, label in dataset]
sampler = WeightedRandomSampler(weights,num_samples=10, replacement=True)
dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)
PyTorch中提供的这个sampler模块,用来对数据进行采样。默认采用SequentialSampler,它会按顺序一个一个进行采样。常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据。这里使用另外一个很有用的采样方法: WeightedRandomSampler,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。replacement
用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据。
[本节参考:https://zhuanlan.zhihu.com/p/117270644]
找了几篇17年左右的帖子,说dataloader.py中调用collate_fn源码如下:
indices = next(self.sample_iter)
batch = self.collate_fn([dataset[i] for i in indices])
但是在我的pytorch1.4版本里这个源码片段消失了,估计是只有在旧版本中能找到了(对应旧版本dataloader.py 180行)。所以此处我们还是直接去官网看一下吧。官网给了两个例子,来说明在collate_fn函数中发生了什么:
给了两个等同样例,看起来确实和消失的那部分代码是一样的,用了一个yield用于迭代。迭代读取数据的方式特别适用于数据量巨大的情况,所以这里可以看出使用Dataloader相比于直接使用循环做预处理的好处了。