交叉熵损失函数公式_保证数值稳定的二元交叉熵损失函数

交叉熵损失函数公式_保证数值稳定的二元交叉熵损失函数_第1张图片

1. 先来回顾一下普通的交叉熵损失

其中,

2. 看看BCEWithLogitsLoss

就是在

外边复合一层sigmoid函数,即
,损失函数变为:

3. 示例

import torch
target = torch.ones([10, 64], dtype=torch.float32)  # 64 classes, batch size = 10
output = torch.full([10, 64], 1.5)  # A prediction (logit)
pos_weight = torch.ones([64])  # All weights are equal to 1
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target)  # -log(sigmoid(1.5))

输出:

tensor(0.2014)

验证一下:

import numpy as np
def sigmoid(x):
    return 1/(1+np.exp(-x))
loss = -target*np.log(sigmoid(output)) - (1-target)*np.log(1-sigmoid(output))
torch.mean(loss)

输出:

tensor(0.2014)

你可能感兴趣的:(交叉熵损失函数公式,交叉熵损失函数的理解)