【Loss系列】结合dice loss 和 bce loss

def bce_dice(pred, mask):
    ce_loss   = F.binary_cross_entropy_with_logits(pred, mask)
    pred      = torch.sigmoid(pred)
    inter     = (pred*mask).sum(dim=(1,2))
    union     = pred.sum(dim=(1,2))+mask.sum(dim=(1,2))
    dice_loss = 1-(2*inter/(union+1)).mean()
    return ce_loss, dice_loss

https://github.com/weijun88/SANet/blob/main/src/train.py

你可能感兴趣的:(损失函数,机器学习,人工智能)