原文:https://zhuanlan.zhihu.com/p/28527749
参考:https://github.com/dinrker/Pytorch-TGS-Salt-Identification-Challenge/blob/87dbce3fdffa5c717a918994da3645b43bf281ea/net/loss.py
import torch
gamma = torch.ones_like(focal_weight).cuda()
gamma[focal_weight > 0.5] = 0.4
gamma[focal_weight < 0.5] = 2.2
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
目标检测不行:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=0.25):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, x, target):
loss = - self.alpha * (1 - x) ** self.gamma * torch.log(x + 1e-8) * target - \
(1 - self.alpha) * x ** self.gamma * torch.log(1 - x + 1e-8) * (1 - target)
return loss.mean()
https://github.com/CVBox/PyTorchCV/blob/14374e57e3d2579fc4c1a7d9871d6157320e3c10/loss/modules/det_modules.py
https://github.com/chicm/ship/blob/d71443646d9756fe756a560c9b0d0ad31c3ee584/dice_losses.py
https://github.com/Simon717/TGS_29th_solution/blob/54ac598ad9af45412136bf79497bc91565dae0f9/code/loss.py
https://github.com/arvention/STDN/blob/57ba7818bdc419617e8a42bafa3c8d7ab346db8b/loss/focal_loss.py
https://github.com/BloodAxe/Kaggle-Salt/blob/b38f73dbf889bf27c20bcc8a7478cb6f82fa9cba/lib/loss.py
https://github.com/artyompal/kaggle_salt/blob/15024489e94bb5ff6c9a1aad60c199b00f73b781/code_gazay/lenin/lenin/metrics/focal.py
torch focalloss
import torch
from torch.autograd import Variable
# class FocalLoss(torch.nn.Module):
# def __init__(self, gamma=2):
# super().__init__()
# self.gamma = gamma
#
# def forward(self, log_pred_prob_onehot, target):
# pred_prob_oh = torch.exp(log_pred_prob_onehot)
# pt = Variable(pred_prob_oh.data.gather(1, target.data.view(-1, 1)), requires_grad=True)
# modulator = (1 - pt) ** self.gamma
# mce = modulator * (-torch.log(pt))
#
# return mce.mean()
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class FocalLoss(nn.Module):
r"""
This criterion is a implemenation of Focal Loss, which is proposed in
Focal Loss for Dense Object Detection.
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
The losses are averaged across observations for each minibatch.
Args:
alpha(1D Tensor, Variable) : the scalar factor for this criterion
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
putting more focus on hard, misclassified examples
size_average(bool): By default, the losses are averaged over observations for each minibatch.
However, if the field size_average is set to False, the losses are
instead summed for each minibatch.
"""
def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
super(FocalLoss, self).__init__()
if alpha is None:
self.alpha = Variable(torch.ones(class_num, 1))
else:
if isinstance(alpha, Variable):
self.alpha = alpha
else:
self.alpha = Variable(alpha)
self.gamma = gamma
self.class_num = class_num
self.size_average = size_average
def forward(self, inputs, targets):
N = inputs.size(0)
C = inputs.size(1)
P = F.softmax(inputs)
class_mask = inputs.data.new(N, C).fill_(0)
class_mask = Variable(class_mask)
ids = targets.view(-1, 1)
class_mask.scatter_(1, ids.data, 1.)
#print(class_mask)
if inputs.is_cuda and not self.alpha.is_cuda:
self.alpha = self.alpha.cuda()
alpha = self.alpha[ids.data.view(-1)]
probs = (P*class_mask).sum(1).view(-1,1)
log_p = probs.log()
#print('probs size= {}'.format(probs.size()))
#print(probs)
batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p
#print('-----bacth_loss------')
#print(batch_loss)
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
if __name__ == '__main__':
loss=FocalLoss()
conf_mask = torch.FloatTensor([0.0, 1.0, 0.0, 1.0, 1.0])-1
conf_data = torch.FloatTensor([-0.1, -0.9, 0.0, -0.2, -0.2])
print(loss(conf_mask,conf_data))