最近在看sentences-transformers
的源码,在有一个模块发现了dataloader.collate_fn
,当时没搞懂是什么意思,后来查了一下,感觉还是很有意思的,因此来分享一下。
dataloader肯定都是知道的,就是为数据提供一个迭代器。
在dataloader按照batch进行取数据的时候, 是取出大小等同于batch size的index列表,然后将列表列表中的index输入到dataset的getitem()函数中,取出该index对应的数据,最后, 对每个index对应的数据进行堆叠,就形成了一个batch的数据。
DataLoader完整的参数表如下:
class torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=<function default_collate>,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None)
在最后一步堆叠的时候可能会出现问题: 如果一条数据中所含有的每个数据元的长度不同, 那么将无法进行堆叠. 如: multi-hot类型的数据, 序列数据。在使用这些数据时, 通常需要先进行长度上的补齐, 再进行堆叠. 以现在的流程, 是没有办法加入该操作的。此外, 某些优化方法是要对一个batch的数据进行操作。
collate_fn函数就是手动将抽取出的样本堆叠起来的函数。
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
test = np.arange(11)
input = torch.tensor(np.array([test[i:(i + 3)] for i in range(10 - 1)]))
target = torch.tensor(np.array([test[i:(i + 1)] for i in range(10 - 1)]))
torch_dataset = TensorDataset(input, target)
batch = 3
#> input data shape: torch.Size([9, 3])
#> target data shape: torch.Size([9, 1])
需要注意的是上面的input数据shape为(9, 3);target数据shape为(9,1)。我们设置每一次的batch为3
my_dataloader = DataLoader(
dataset=torch_dataset,
batch_size=batch
)
for (i, j) in my_dataloader:
print('*' * 30)
print(i)
print(j)
查看上面的结果就可以看到每一批都返回两个结果,一个是input的样本,一个是target的样本。
input样本、target样本的维度和原始保持一致,但是大小尺寸全部为batch。
my_dataloader = DataLoader(
dataset=torch_dataset,
batch_size=4,
collate_fn=lambda x: x
)
for i in my_dataloader:
print('*' * 30)
print(i)
这个时候每一批都是返回了一个列表,这个列表的大小为3,列表里面的每一个对象就是一个成对的input和target。
如果我们继续想把上面的列表解析成第一个的情况,我们可以这么做:
a = i
list((torch.cat([a[i][j].unsqueeze(0) for i in range(len(a))]).unsqueeze(0) for j in range(len(a[0]))))
上面其实是很哇塞的,他其实是什么意思,就是把输出的长度为batch的列表转换为一个矩阵了。看着是挺复杂的,其实就是对list做了数据抽取和合并。非常简单。大概可以有这么个拆解路线:
my_dataloader = DataLoader(
dataset=torch_dataset,
batch_size=4,
collate_fn=lambda x: x,
drop_last=True
)
for i in my_dataloader:
print('*' * 30)
print(i)
a = i
a
然后,查看视频:
视频演示
现在结合上面的步骤,我们自定义自己的参数,然后实现默认的效果。大概代码如下:
my_dataloader = DataLoader(
dataset=torch_dataset,
batch_size=batch,
collate_fn=lambda x:(
torch.cat([x[i][j].unsqueeze(0) for i in range(len(x))],dim=0) for j in range(len(x[0]))
)
)
for i,j in my_dataloader:
print('*' * 30)
print(i)
print(j)