创建dataloder时,如果使用默认的collate_fn,输出的batch中,input_ids,和token_type_ids,attention_mask都是长度为 sequence_length的列表,列表的每个元素是大小为[batch_size]的tensor
train_loader = DataLoader(dataset, batch_size=3, shuffle=True, collate_fn=default_collate)
for batch in train_loader:
print(batch)
break
{'input_ids': [tensor([101, 101, 101]), tensor([100, 100, 100]), tensor([102, 100, 100]), tensor([100, 102, 102]), tensor([1074, 976, 100]), tensor([ 100, 1035, 100]), tensor([100, 100, 100]), tensor([1099, 100, 100]), tensor([ 100, 1099, 100]), tensor([100, 100, 100]), tensor([100, 100, 100]), tensor([ 100, 100, 1099]), tensor([100, 100, 100]), tensor([100, 100, 100]), tensor([102, 100, 100]), tensor([ 0, 100, 100]), tensor([ 0, 100, 100]), tensor([ 0, 976, 100]), tensor([ 0, 100, 100]), tensor([ 0, 100, 1099]), tensor([ 0, 100, 100]), tensor([ 0, 1099, 100]), tensor([ 0, 100, 100]), tensor([ 0, 100, 100]), tensor([ 0, 100, 100]), tensor([ 0, 100, 100]), tensor([ 0, 100, 100]), tensor([ 0, 100, 976]), tensor([ 0, 1099, 100]), tensor([ 0, 100, 1099]), tensor([ 0, 100, 979]), tensor([ 0, 100, 100]), tensor([ 0, 100, 100]), tensor([ 0, 100, 100]), tensor([ 0, 100, 100]), tensor([ 0, 100, 100]), tensor([ 0, 1074, 100]), tensor([ 0, 100, 100]), tensor([ 0, 1074, 1099]), tensor([ 0, 100, 100]), tensor([ 0, 100, 100]), tensor([ 0, 100, 100]), tensor([ 0, 886, 100]), tensor([ 0, 102, 102]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0])], 'token_type_ids': [tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([1, 0, 0]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0])], 'attention_mask': [tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0])], 'start_positions': tensor([ 7, 18, 4]), 'end_positions': tensor([11, 24, 7])}
这样是无法直接把这个batch输入bert的,必须把input_ids,和token_type_ids,attention_mask都转化为大小为[batch_size,sequence_length]的tensor才能输入bert。
所以,需要定义自己的collate_fn函数,对batch进行整理:
#定义collate_fn,把input_ids,和token_type_ids,attention_mask列转化为tensor
def collate_fn(examples):
batch = default_collate(examples)
batch['input_ids'] = torch.stack(batch['input_ids'], dim=1)
batch['token_type_ids'] = torch.stack(batch['token_type_ids'], dim=1)
batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=1)
return batch
train_loader = DataLoader(dataset, batch_size=3, shuffle=True, collate_fn=collate_fn)
#输出第一个batch
for batch in train_loader:
print(batch)
break
结果如下,可以直接输入bert模型
output=model(**batch)
batch
{'input_ids': tensor([[ 101, 100, 100, 100, 100, 102, 100, 1070, 100, 100, 100, 1099,
100, 100, 100, 100, 1099, 100, 100, 100, 100, 100, 100, 1099,
100, 100, 100, 100, 100, 100, 1045, 886, 100, 100, 100, 100,
1086, 886, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[ 101, 100, 100, 102, 100, 100, 100, 100, 100, 1099, 100, 100,
100, 100, 100, 100, 100, 1009, 100, 100, 100, 100, 100, 102,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[ 101, 984, 100, 100, 100, 102, 984, 100, 100, 100, 976, 100,
100, 1033, 100, 1009, 102, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'start_positions': tensor([15, 12, 4]), 'end_positions': tensor([17, 19, 6])}