引用https://blog.csdn.net/qq_44875293/article/details/124060623博客,pytorch中的tensor有如下图的“花式索引”机制。
prediction_scores_right = prediction_scores[torch.arange(batch_size).unsqueeze(1),
torch.arange(seq_length).unsqueeze(0), labels]
prediction_scores是三维tensor,shape为(b, l, v),b是batch_size,l是seq_len,v是单词表大小,其实就是得到每个单词的预测分数。
labels是每个单词对应的标签,shape是(b, l)。
目标是:取prediction_scores中每个单词对应label的分数,即得到shape为(b, l)。
所以取的第一维(b,1)和第二维(1,l)先广播成labels的维度(b,l),然后索引labels上每一个值。