pytorch中nn.CrossEntropyLoss使用注意事项

Loss的数学表达公式:

使用代码样例:

# 这样展开就相当于每个词正确的类别和预测的整个词表概率分布进行对应
# ignore_index是指忽略真实标签中的类别
criterion = nn.CrossEntropyLoss(ignore_index=2).to(device) 
vocab_size = pre.shape[-1]
trg = trg[:,1:]
trg_tag = trg.reshape(-1).to(device) # view函数要求在同一个连续地址里,而reshape不用
pre_tag = pre[1:].view(-1,vocab_size).to(device)
loss = criterion(pre_tag,trg_tag)

注意事项:

  1. CrossEntropyLoss实例化之后,其两个输入分别是预测标签和真实标签,顺序不要搞错。预测标签的大小为[N,classnum],真实样本的大小为[N],因为该函数会把真实标签进行one-hot表示。N不一定是batchsize大小,可以对向量进行展开,从而可以逐个样本进行计算loss。

  2. 从公式可以看出,pytorch中的交叉熵loss其本身已使用的一个softmax约束了预测标签输入控制在了0-1之间,所以loss的输入即用模型的输出即可不需要通过softmax后再输入loss中,否则两个softmax可能会导致模型在训练的过程中loss保持不变。

 

你可能感兴趣的:(代码试错)