一、首先回顾下“交叉熵loss Cross Entropy Loss”
CE(Pi)=-log(Pi)
二、一般地说,我们数据集会存在类别不平衡问题,很多人会在loss上对应不同类别设置不同系数
loss就变成了上面的样子
三、Focal loss其实就是通过数学公式上的改变,扩大了不平衡因素在loss上的影响
这里的at 依然是和“二里面的”那个手动的一个意思!,这是他们文中的意思!我们来分析下!首先,这个的存在,无论为“预测器”分对分错,都会减小loss的值!!!
从上面一段话,我们可以看出,当某一个类别分类预测概率越来越高(分类准确的情况下),那么loss会因为的存在而被大大降低!这也是我们所希望的!!!当一个类别预测概率小于0.5,也就是基本上被判断到了别的类别去了!分错了!loss相对于“分对了”的情况而言最多只能减少4倍!!!loss减少量是“分对情况”下的“扩大”几十倍,百倍,千倍!
四、代码
这些,github上也有,那么怎么用呢?,顺边给大家lovasz_softmax loss的代码
使用小例子,图像分割为例子
#正常ce loss
loss = F.bceloss(F.softmax(pred), target)
#Focal loss
loss = FocalLoss2d(F.softmax(pred), target)
#lovasz_softmax
loss = lovasz_softmax(F.softmax(pred1,dim=1), target)
class FocalLoss2d(torch.nn.Module):
def __init__(self, weight=None, size_average=True, ignore_index=255):
super(FocalLoss2d, self).__init__()
self.gamma = 2
self.nll_loss = torch.nn.NLLLoss2d(weight, size_average)
def forward(self, inputs, targets):
return self.nll_loss((1 - F.softmax(inputs,1)) ** self.gamma * F.log_softmax(inputs,1), targets)
def softmax_focalloss(input, target, weight, size_average, ignore_index, reduce=True):
return F.nll_loss((1 - F.softmax(inputs,1)) ** gamma * F.log_softmax(inputs,1), targets)
"""
Lovasz-Softmax and Jaccard hinge loss in PyTorch
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
"""
from __future__ import print_function, division
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
try:
from itertools import ifilterfalse
except ImportError: # py3k
from itertools import filterfalse
def lovasz_grad(gt_sorted):
"""
Computes gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
p = len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts - gt_sorted.float().cumsum(0)
union = gts + (1 - gt_sorted).float().cumsum(0)
jaccard = 1. - intersection / union
if p > 1: # cover 1-pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard
def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
"""
IoU for foreground class
binary: 1 foreground, 0 background
"""
if not per_image:
preds, labels = (preds,), (labels,)
ious = []
for pred, label in zip(preds, labels):
intersection = ((label == 1) & (pred == 1)).sum()
union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
if not union:
iou = EMPTY
else:
iou = float(intersection) / union
ious.append(iou)
iou = mean(ious) # mean accross images if per_image
return 100 * iou
def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
"""
Array of IoU for each (non ignored) class
"""
if not per_image:
preds, labels = (preds,), (labels,)
ious = []
for pred, label in zip(preds, labels):
iou = []
for i in range(C):
if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
intersection = ((label == i) & (pred == i)).sum()
union = ((label == i) | ((pred == i) & (label != ignore))).sum()
if not union:
iou.append(EMPTY)
else:
iou.append(float(intersection) / union)
ious.append(iou)
ious = map(mean, zip(*ious)) # mean accross images if per_image
return 100 * np.array(ious)
# --------------------------- BINARY LOSSES ---------------------------
def lovasz_hinge(logits, labels, per_image=True, ignore=None):
"""
Binary Lovasz hinge loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
per_image: compute the loss per image instead of per batch
ignore: void class id
"""
if per_image:
loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
for log, lab in zip(logits, labels))
else:
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
return loss
def lovasz_hinge_flat(logits, labels):
"""
Binary Lovasz hinge loss
logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
labels: [P] Tensor, binary ground truth labels (0 or 1)
ignore: label to ignore
"""
if len(labels) == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.
signs = 2. * labels.float() - 1.
errors = (1. - logits * Variable(signs))
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
perm = perm.data
gt_sorted = labels[perm]
grad = lovasz_grad(gt_sorted)
loss = torch.dot(F.relu(errors_sorted), Variable(grad))
return loss
def flatten_binary_scores(scores, labels, ignore=None):
"""
Flattens predictions in the batch (binary case)
Remove labels equal to 'ignore'
"""
scores = scores.view(-1)
labels = labels.view(-1)
if ignore is None:
return scores, labels
valid = (labels != ignore)
vscores = scores[valid]
vlabels = labels[valid]
return vscores, vlabels
class StableBCELoss(torch.nn.modules.Module):
def __init__(self):
super(StableBCELoss, self).__init__()
def forward(self, input, target):
neg_abs = - input.abs()
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
return loss.mean()
def binary_xloss(logits, labels, ignore=None):
"""
Binary Cross entropy loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
ignore: void class id
"""
logits, labels = flatten_binary_scores(logits, labels, ignore)
loss = StableBCELoss()(logits, Variable(labels.float()))
return loss
# --------------------------- MULTICLASS LOSSES ---------------------------
def lovasz_softmax(probas, labels, only_present=False, per_image=False, ignore=None):
"""
Multi-class Lovasz-Softmax loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1)
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
only_present: average only on classes present in ground truth
per_image: compute the loss per image instead of per batch
ignore: void class labels
"""
if per_image:
loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), only_present=only_present)
for prob, lab in zip(probas, labels))
else:
loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), only_present=only_present)
return loss
def lovasz_softmax_flat(probas, labels, only_present=False):
"""
Multi-class Lovasz-Softmax loss
probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
labels: [P] Tensor, ground truth labels (between 0 and C - 1)
only_present: average only on classes present in ground truth
"""
C = probas.size(1)
losses = []
for c in range(C):
fg = (labels == c).float() # foreground for class c
if only_present and fg.sum() == 0:
continue
errors = (Variable(fg) - probas[:, c]).abs()
errors_sorted, perm = torch.sort(errors, 0, descending=True)
perm = perm.data
fg_sorted = fg[perm]
losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
return mean(losses)
def flatten_probas(probas, labels, ignore=None):
"""
Flattens predictions in the batch
"""
B, C, H, W = probas.size()
probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
labels = labels.view(-1)
if ignore is None:
return probas, labels
valid = (labels != ignore)
vprobas = probas[valid.nonzero().squeeze()]
vlabels = labels[valid]
return vprobas, vlabels
def xloss(logits, labels, ignore=None):
"""
Cross entropy loss
"""
return F.cross_entropy(logits, Variable(labels), ignore_index=255)
# --------------------------- HELPER FUNCTIONS ---------------------------
def mean(l, ignore_nan=False, empty=0):
"""
nanmean compatible with generators.
"""
l = iter(l)
if ignore_nan:
l = ifilterfalse(np.isnan, l)
try:
n = 1
acc = next(l)
except StopIteration:
if empty == 'raise':
raise ValueError('Empty mean')
return empty
for n, v in enumerate(l, 2):
acc += v
if n == 1:
return acc
return acc / n