【pytorch】交叉熵损失

前言

交叉熵损失本质是衡量模型预测的概率分布与实际概率分布的差异程度,其值越小,表明模型的预测结果与实际结果越接近,模型效果越好。熵的概念来自与信息论,参考资料1对交叉熵的概念做了简明的介绍,很好理解。需要注意:
【pytorch】交叉熵损失_第1张图片
Pytorch中的CrossEntropyLoss是LogSoftMax与NLLLoss的结合,下面以实例逐步拆解CrossEntropyLoss的计算过程。

LogSoftMax

当网络最后不加softmax层是,模型输出的是“logits”,对类别的logits进行softmax操作,可以将其每个元素都统一到“(0,1)”,其本质就是转换为“似然值”——因为参数不确定,此种情况下,不是称为概率,而是似然。对似然取对数后,在保持单调性上,便于计算。参考下列过程,可以发现LogSoftMax就是对logits先softmax然后log“注意y的数据类型需要设置为int64,否则会报错。”
【pytorch】交叉熵损失_第2张图片

NLLLoss

NLLLoss是Negtivate Log Likelihood Loss的缩写,pytroch中计算负的对数似然损失,是根据sigma(- y * log(y^hat)),y是样本的实际概率分布,假设有C类,转换成0-1向量后,只有一个位置是1,其它位置全为0。 -log(y^hat)就是似然值取对数,然后在取反——损失不能为负。将两部分乘积然后加和,就是取特定位置的值进行加和,如下所示:
【pytorch】交叉熵损失_第3张图片

特殊参数

  • reduction: 指定是否进行归并计算,默认计算均值,可取值为“none” “mean” “sum”,如果设置为“none”,会返回一组值,表示每一个样本的损失,一般这组数的数量就是batch_size的数量,因为网络训练是以batch为单位的
  • weight:指定类别权重,该参数的取值是一个1维的张量,数据类型必须为浮点数,默认情况每个类别权重值为1.0,元素个数为C(类别的总数量),如下所示进行类别加权后的损失:
    【pytorch】交叉熵损失_第4张图片
  • ignore_index(重要):指定忽略的索引号,默认为-100, 该选项是自定义计算损失的关键,通过各种方法,将需要忽略计算的样本的真实标签索引号更新为-100,则计算损失时就不会计算该条样本。
    例如在序列标注任务中,当对输入文本进行补齐到max_length后,相应也应该补齐便签tag,但这些补齐的tag不应该计算其损失,这时候可以通过:
activate_attention = attention_mask.view(-1) == 1 # 只计算真实标签
# tags为真实标签,where更加attention的值,更新真实标签的索引值
activate_tags = torch.where(activate_attention, tags.view(-), -100) 

这样attention位置是0对应的标签就会被更新为“-100”,不参与损失的计算。
 

参考资料

Pytorch常用的交叉熵损失函数详解

nn.CrossEntropyLoss

你可能感兴趣的:(pytorch,pytorch)