DataLoader有一个参数collate_fn,这个参数接收自定义collate
函数,该函数在数据加载(即通过Dataloader取一个batch数据)之前,定义对每个batch数据的处理行为。
看下面的示例:
import torch
from torch.utils.data import Dataset, DataLoader,\
TensorDataset
def collate(data_):
"""
data_是一个列表,长度和DataLoader中定义的batch_size相等,
每一个列表元素为从Dataset采样一次得到的数据,
比如batch_size为2,从Dataset一次采样的数据为x,y,
那么data_表示为[(x1,y1),(x2,y2)]。而从DataLoader出来的
数据是 X=[x1,x2]^T和Y=[y1,y2]^T,
下面的代码就是将data_变成X和Y的形式。
"""
x, y = zip(*data_) # zip 可以将多个列表(或元组)的对应元素拼在一起,这样x1和x2就在一个列表里,y1和y2在一个列表里
x = torch.stack(x) # 把列表变成张量形式,stack默认在维度0拼接,维度大小等于batch_size大小
y = torch.stack(y)
return x, y
data = torch.rand(100,128) # 生成x数据
label = torch.randint(0,2, (100,)).float() # 生成y标签数据
dataset = TensorDataset(data,label) # 构建数据集
loader = DataLoader(dataset, batch_size=32, collate_fn=collate) # 构建加载器,collate_fn就是在这一步进行处理
X,Y = next(iter(loader)) # 得到一个批次的数据
print(X.shape, Y.shape) # 和预期一致
# Output: (32,128),(32)
除了上面的简单情形,在写训练代码的时候,碰到了下面的情况:
当时想使用 LSTM模型进行深度学习训练,输入数据源为时空数据,提取之后的形状为(node_num, seq_len, n_feat)
, node_num
表示节点数,seq_len
表示序列长度,n_feat
为节点特征维度。
如果直接进行数据加载(DataLoader不通过collate_fn处理),输出的shape将是(batch_size, node_num, seq_len, n_feat)
。
而LSTM模型需要的输入形状是(seq_len, batch_size, n_feat)
, 显然不匹配。刚开始想到通过以下方式解决:
# [batch_size, node_num, seq_len, n_feat]
# -> [batch_size*node_num, seq_len, n_feat]
# -> [seq_len, batch_size*node_num, n_feat]
x.reshape(-1,seq_len,n_feat).permute(1,0,2)
但这样其实不行,因为每次采样的数据 node_num
不相等, 想想之前的第一段代码,如果x1
的shape是[5,12,100]
, x2
是[7,12,100]
, 通过torch.stack
无法拼接, DataLoader默认的行为和上面代码是一样的。如果强行那么做会报错,类似于这种:
invalid argument 0: Sizes of tensors must match except in dimension 0. Got 7 and 5 in dimension 1
。
在这种情况下,需要通过collate_fn
自定义处理行为,看下面的代码:
import torch
from torch.utils.data import Dataset, DataLoader,\
TensorDataset
import numpy as np
class TestDataset(Dataset):
def __init__(self, n=100):
super(TestDataset,self).__init__()
# 定义node_num不同的两组数据
self.data1 = torch.rand(n,5,12,128)
self.data2 = torch.rand(n,7,12,128)
def __len__(self):
return len(self.data1)
def __getitem__(self,indx):
# 随机选择一组数据,得到的x可能出现node_num不同的情况
if torch.rand(1) > 0.5:
return self.data2[indx],torch.tensor(1)
else:
return self.data1[indx],torch.tensor(0)
def collate(data):
x, y = zip(*data)
# 统计每个x的node_num
sz = [len(x) for x in x]
# 计算偏移量
cum_offsets = [0] + np.cumsum(sz).tolist()
start_end = [[start, end] for start, end in zip(cum_offsets, cum_offsets[1:])]
# 在node_num维度拼接,变成 [seq_len, node_num_total, n_feat]
x = torch.cat(x, dim=0).permute(1,0,2)
y = torch.stack(y)
# 为了能从x恢复原来的数据,需要返回偏移量
start_end = torch.LongTensor(start_end)
return x, start_end, y
tdataset = TestDataset()
loader = DataLoader(tdataset, batch_size=32, collate_fn=collate)
X, offsets, Y = next(iter(loader))
print(X.shape, offsets.shape, Y.shape) # 和预期一致
# Output: (12,164,128), (32,2), (32)
通过以上代码的处理,直接就得到了目标shape seq_len, batch_size, n_feat
,不过这个batch_size等于batch_size个node_num
相加,在node_num相等的情况下,这种方式和reshape
+permute
的方式是等效的。