Pytorch损失函数cross_entropy、binary_cross_entropy和binary_cross_entropy_with_logits的区别

在做分类问题时我们经常会遇到这几个交叉熵函数:
cross_entropy、binary_cross_entropy和binary_cross_entropy_with_logits。
那么他们有什么区别呢?下面我们就来探讨一下:

1.torch.nn.functional.cross_entropy

def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100,
                  reduce=None, reduction='mean'):
    # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor

    if size_average is not None or reduce is not None:
        reduction = _Reduction.legacy_get_string(size_average, reduce)
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)

看上面代码也能知道input和target是必选项,并且是Tensor类型的。
最后一行说明functional.cross_entropy实际计算过程就是先计算Tensor的log_softmax,然后再计算nll_loss。

Tensor的log_softmax函数和functional的函数作用一样,都是先对数据进行softmax,然后进行log函数,这里的log以e为底,即ln。log_softmax和softmax中的数字表示按照什么维度计算。0代表按列计算,softmax函数计算后的数据按列加起来为1;1代表按行计算,softmax函数计算后的数据按行加起来为1。

2.binary_cross_entropy和binary_cross_entropy_with_logits

binary_cross_entropy和binary_cross_entropy_with_logits都是来自torch.nn.functional的函数,首先对比官方文档对它们的区别:
Pytorch损失函数cross_entropy、binary_cross_entropy和binary_cross_entropy_with_logits的区别_第1张图片区别只在于这个logits,那么这个logits是什么意思呢?以下是从网络上找到的一个答案:

有一个(类)损失函数名字中带了with_logits. 而这里的logits指的是,该损失函数已经内部自带了计算logit的操作,无需在传入给这个loss函数之前手动使用sigmoid/softmax将之前网络的输入映射到[0,1]之间

再看看官方给的示例代码:
binary_cross_entropy:

input = torch.randn((3, 2), requires_grad=True)
target = torch.rand((3, 2), requires_grad=False)
loss = F.binary_cross_entropy(F.sigmoid(input), target)
loss.backward()
# input is  tensor([[-0.5474,  0.2197],
#         [-0.1033, -1.3856],
#         [-0.2582, -0.1918]], requires_grad=True)
# target is  tensor([[0.7867, 0.5643],
#         [0.2240, 0.8263],
#         [0.3244, 0.2778]])
# loss is  tensor(0.8196, grad_fn=)

binary_cross_entropy_with_logits:

input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
loss = F.binary_cross_entropy_with_logits(input, target)
loss.backward()
# input is  tensor([ 1.3210, -0.0636,  0.8165], requires_grad=True)
# target is  tensor([0., 1., 1.])
# loss is  tensor(0.8830, grad_fn=)

的确binary_cross_entropy_with_logits不需要sigmoid函数了。

你可能感兴趣的:(Pytorch,Python,pytorch,深度学习)