BERT中是怎么做到只计算[MASK]token的CrossEntropyLoss的?及torch.nn.CrossEntropyLoss()参数

nn.CrossEntropyLoss()的参数

torch.nn.CrossEntropyLoss(weight=None, size_average=None,
ignore_index=-100, reduce=None, reduction=‘mean’)

  • weight:不必多说,这就是各class的权重。所以它的值必须满足两点:
    1. type = torch.Tensor
    2. weight.shape = tensor(1, class_num)
  • size_averagereduce :都要被弃用了,直接看 reduction就行
  • reduction:结果的规约方式,取值空间为{'mean', 'none', 'sum}。由于你传入 nn.CrossEntropyLoss()的输入是一个batch,那么按理说得到的交叉熵损失应该是 batch个loss。当前默认的处理方式是,对 batch 个损失取平均;也可以选择不做规约;或者将batch个损失取加和;
  • ignore_index :做交叉熵计算时,若输入为ignore_index指定的数值,则该数值会被忽略,不参与交叉熵计算。

BERT中是怎么做到只计算[MASK]token的CrossEntropyLoss的?

nn.CrossEntropyLoss()ignore_index参数在BERT的mask中用到了。由于BERT中其中一个预训练任务是MLM,只有15%的token被[MASK],所以说只有这15%的词会参与交叉熵loss的计算,其他85%不参与loss计算的槽位,就使用-1填充;而参与loss计算的槽位,会使用在 vocab.txt 里提前定义好的原始token对应的index表示,这些index都是大于101([CLS])的,所以计算时不会被ignore

你可能感兴趣的:(pytorch,NLP,bert)