Pytorch的torch.nn.functional.cross_entropy的ignore_index参数作用及验证

官方文档

https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.cross_entropy

作用

ignore_index用于忽略ground-truth中某些不需要参与计算的类。假设有两类{0:背景,1:前景},若想在计算交叉熵时忽略背景(0)类,则可令ignore_index=0(同理忽略前景计算可设ignore_index=1)。

代码示例

import torch
import torch.nn.functional as F
pred = torch.Tensor(
    [
        [0.9, 0.1],
        [0.8, 0.2],
        [0.7, 0.3]
    ]
)  # shape=(N,C)=(3,2),N为样本数,C为类数
label = torch.LongTensor([1, 0, 1])  # shape=(N)=(3),3个样本的label分别为1,0,1
out = F.cross_entropy(pred, label, ignore_index=0)  # 忽略0类
print(out)

输出

tensor(1.0421)

验证

pytorch的CrossEntropy使用公式:
在这里插入图片描述
计算:
l o s s = 1 2 × { [ − 0.1 + l n ( e 0.9 + e 0.1 ) ] + [ − 0.3 + l n ( e 0.7 + e 0.3 ) ] } = 1 2 × ( 1.1711 + 0.9130 ) = 1.0421 \begin{aligned} loss&=\frac{1}{2}\times\{[-0.1+ln(e^{0.9}+e^{0.1})]+[-0.3+ln(e^{0.7}+e^{0.3})]\}\\ &=\frac{1}{2}\times(1.1711+0.9130)\\ &=1.0421 \end{aligned} loss=21×{[0.1+ln(e0.9+e0.1)]+[0.3+ln(e0.7+e0.3)]}=21×(1.1711+0.9130)=1.0421

torch.nn.CrossEntropyLoss 同理。

你可能感兴趣的:(python,人工智能,深度学习,算法,pytorch)