pytorch 语义分割loss_Focal Loss理论及PyTorch实现

一、基本理论

采用soft - gamma: 在训练的过程中阶段性的增大gamma 可能会有更好的性能提升。

alpha 与每个类别在训练数据中的频率有关。

F.nll_loss(torch.log(F.softmax(inputs, dim=1),target)的函数功能与F.cross_entropy相同。

F.nll_loss中实现了对于target的one-hot encoding,将其编码成与input shape相同的tensor,然后与前面那一项(即F.nll_loss输入的第一项)进行 element-wise production。

基于alpha=1采用不同的gamma值进行实验的结果

focal loss解决了什么问题?

(1)不同类别不均衡

(2)难易样本不均衡

在retinanet中,除了使用呢focal loss外,还对初始化做了特殊处理,具体是怎么做的?

在retinanet中,对 classification subnet 的最后一层conv设置它的偏置b为:

b=−log((1−π)/π)

π代表先验概率,就是类别不平衡中个数少的那个类别占总数的百分比,在检测中就是代表object的anchor占所有anchor的比重,论文中设置的为0.01。

二、公式

标准的Cross Entropy 为:[图片上传失败...(image-286df1-1571884440851)]

Focal Loss 为:[图片上传失败...(image-460db1-1571884440851)]

其中,[图片上传失败...(image-d6c655-1571884440851)]

三、代码实现

一、来自Kaggle的实现(基于二分类交叉熵实现)

class FocalLoss(nn.Module):

def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):

super(FocalLoss, self).__init__()

self.alpha = alpha

self.gamma = gamma

self.logits = logits

self.reduce = reduce

def forward(self, inputs, targets):

if self.logits:

BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)

else:

BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)

pt = torch.exp(-BCE_loss)

F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

if self.reduce:

return torch.mean(F_loss)

else:

return F_loss

二、来自知乎大佬的实现:

import torch

import torch.nn as nn

import torch.nn.functional as F

from torch.autograd import Variable

class FocalLoss(nn.Module):

r"""

This criterion is a implemenation of Focal Loss, which is proposed in

Focal Loss for Dense Object Detection.

Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])

The losses are averaged across observations for each minibatch.

Args:

alpha(1D Tensor, Variable) : the scalar factor for this criterion

gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),

putting more focus on hard, misclassified examples

size_average(bool): By default, the losses are averaged over observations for each minibatch.

However, if the field size_average is set to False, the losses are

instead summed for each minibatch.

"""

def __init__(self, class_num, alpha=None, gamma=2, size_average=True):

super(FocalLoss, self).__init__()

if alpha is None:

self.alpha = Variable(torch.ones(class_num, 1))

else:

if isinstance(alpha, Variable):

self.alpha = alpha

else:

self.alpha = Variable(alpha)

self.gamma = gamma

self.class_num = class_num

self.size_average = size_average

def forward(self, inputs, targets):

N = inputs.size(0)

C = inputs.size(1)

P = F.softmax(inputs)

class_mask = inputs.data.new(N, C).fill_(0)

class_mask = Variable(class_mask)

ids = targets.view(-1, 1)

class_mask.scatter_(1, ids.data, 1.)

#print(class_mask)

if inputs.is_cuda and not self.alpha.is_cuda:

self.alpha = self.alpha.cuda()

alpha = self.alpha[ids.data.view(-1)]

probs = (P*class_mask).sum(1).view(-1,1)

log_p = probs.log()

#print('probs size= {}'.format(probs.size()))

#print(probs)

batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p

#print('-----bacth_loss------')

#print(batch_loss)

if self.size_average:

loss = batch_loss.mean()

else:

loss = batch_loss.sum()

return loss

参考

你可能感兴趣的:(pytorch,语义分割loss)