nlp中常用DataLoader中的collate_fn,快速对batch进行整理使其符合bert的输入

创建dataloder时,如果使用默认的collate_fn,输出的batch中,input_ids,和token_type_ids,attention_mask都是长度为 sequence_length的列表,列表的每个元素是大小为[batch_size]的tensor

train_loader = DataLoader(dataset, batch_size=3, shuffle=True, collate_fn=default_collate)
for batch in train_loader:
    print(batch)
    break

{'input_ids': [tensor([101, 101, 101]), tensor([100, 100, 100]), tensor([102, 100, 100]), tensor([100, 102, 102]), tensor([1074,  976,  100]), tensor([ 100, 1035,  100]), tensor([100, 100, 100]), tensor([1099,  100,  100]), tensor([ 100, 1099,  100]), tensor([100, 100, 100]), tensor([100, 100, 100]), tensor([ 100,  100, 1099]), tensor([100, 100, 100]), tensor([100, 100, 100]), tensor([102, 100, 100]), tensor([  0, 100, 100]), tensor([  0, 100, 100]), tensor([  0, 976, 100]), tensor([  0, 100, 100]), tensor([   0,  100, 1099]), tensor([  0, 100, 100]), tensor([   0, 1099,  100]), tensor([  0, 100, 100]), tensor([  0, 100, 100]), tensor([  0, 100, 100]), tensor([  0, 100, 100]), tensor([  0, 100, 100]), tensor([  0, 100, 976]), tensor([   0, 1099,  100]), tensor([   0,  100, 1099]), tensor([  0, 100, 979]), tensor([  0, 100, 100]), tensor([  0, 100, 100]), tensor([  0, 100, 100]), tensor([  0, 100, 100]), tensor([  0, 100, 100]), tensor([   0, 1074,  100]), tensor([  0, 100, 100]), tensor([   0, 1074, 1099]), tensor([  0, 100, 100]), tensor([  0, 100, 100]), tensor([  0, 100, 100]), tensor([  0, 886, 100]), tensor([  0, 102, 102]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0])], 'token_type_ids': [tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([1, 0, 0]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0])], 'attention_mask': [tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([1, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 1, 1]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0]), tensor([0, 0, 0])], 'start_positions': tensor([ 7, 18,  4]), 'end_positions': tensor([11, 24,  7])}

这样是无法直接把这个batch输入bert的,必须把input_ids,和token_type_ids,attention_mask都转化为大小为[batch_size,sequence_length]的tensor才能输入bert。
所以,需要定义自己的collate_fn函数,对batch进行整理:

#定义collate_fn,把input_ids,和token_type_ids,attention_mask列转化为tensor
def collate_fn(examples):
    batch = default_collate(examples)
    batch['input_ids'] = torch.stack(batch['input_ids'], dim=1)
    batch['token_type_ids'] = torch.stack(batch['token_type_ids'], dim=1)
    batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=1)
    return batch

train_loader = DataLoader(dataset, batch_size=3, shuffle=True, collate_fn=collate_fn)
#输出第一个batch
for batch in train_loader:
    print(batch)
    break

结果如下,可以直接输入bert模型

output=model(**batch)
batch
{'input_ids': tensor([[ 101,  100,  100,  100,  100,  102,  100, 1070,  100,  100,  100, 1099,
          100,  100,  100,  100, 1099,  100,  100,  100,  100,  100,  100, 1099,
          100,  100,  100,  100,  100,  100, 1045,  886,  100,  100,  100,  100,
         1086,  886,  102,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [ 101,  100,  100,  102,  100,  100,  100,  100,  100, 1099,  100,  100,
          100,  100,  100,  100,  100, 1009,  100,  100,  100,  100,  100,  102,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [ 101,  984,  100,  100,  100,  102,  984,  100,  100,  100,  976,  100,
          100, 1033,  100, 1009,  102,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'start_positions': tensor([15, 12,  4]), 'end_positions': tensor([17, 19,  6])}

你可能感兴趣的:(自然语言处理,batch,bert)