pytorch 绘制多个算法loss_Pytorch - FocalLoss的几种实现

FocalLoss 用于 one-stage 目标检测算法(Retinanet),提升检测效果.

也可以被用于分类任务中,解决数据不平衡问题.

1. Github - DeepLabV3Plus-Pytorchimport torch

import torch.nn as nn

import torch.nn.functional as F

class FocalLoss(nn.Module):

def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=255):

super(FocalLoss, self).__init__()

self.alpha = alpha

self.gamma = gamma

self.ignore_index = ignore_index

self.size_average = size_average

def forward(self, inputs, targets):

ce_loss = F.cross_entropy(inputs, targets,

reduction='none', ignore_index=self.ignore_index)

pt = torch.exp(-ce_loss)

focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss

if self.size_average:

return focal_loss.mean()

else:

return focal_loss.sum()

2. Github - Focal-Loss-Pytorch

基于 Pytorch 的实现如:#!-*- coding: utf-8 -*-

# @Author : LG

from torch import nn

import torch

from torch.nn import functional as F

class focal_loss(nn.Module):

def __init__(self, alpha=0.25, gamma=2, num_classes = 3, size_average=True):

"""

focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi)

步骤详细的实现了 focal_loss损失函数.

:param alpha: 阿尔法α,类别权重.

当α是列表时,为各类别权重;

当α为常数时,类别权重为[α, 1-α, 1-α, ....],

常用于目标检测算法中抑制背景类,

retainnet中设置为0.25

:param gamma: 伽马γ,难易样本调节参数. retainnet中设置为2

:param num_classes: 类别数量

:param size_average: 损失计算方式,默认取均值

"""

super(focal_loss,self).__init__()

self.size_average = size_average

if isinstance(alpha,list):

assert len(alpha)==num_classes

# α可以以list方式输入,

# size:[num_classes] 用于对不同类别精细地赋予权重

print("Focal_loss alpha = {}, 将对每一类权重进行精细化赋值".format(alpha))

self.alpha = torch.Tensor(alpha)

else:

assert alpha<1 #如果α为一个常数,则降低第一类的影响,在目标检测中为第一类

print("Focal_loss alpha = {} ,将对背景类进行衰减,请在目标检测任务中使用.".format(alpha))

self.alpha = torch.zeros(num_classes)

self.alpha[0] += alpha

self.alpha[1:] += (1-alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]

self.gamma = gamma

def forward(self, preds, labels):

"""

focal_loss损失计算

:param preds: 预测类别. size:[B,N,C] or [B,C] 分

别对应与检测与分类任务, B 批次, N检测框数, C类别数

:param labels: 实际类别. size:[B,N] or [B]

:return:

"""

# assert preds.dim()==2 and labels.dim()==1

preds = preds.view(-1,preds.size(-1))

self.alpha = self.alpha.to(preds.device)

# 这里并没有直接使用log_softmax, 因为后面会用到softmax的结果(当然也可以使用log_softmax,然后进行exp操作)

preds_softmax = F.softmax(preds, dim=1)

preds_logsoft = torch.log(preds_softmax)

# 这部分实现nll_loss ( crossempty = log_softmax + nll )

preds_softmax = preds_softmax.gather(1,labels.view(-1,1))

preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))

self.alpha = self.alpha.gather(0,labels.view(-1))

loss = -torch.mul(torch.pow((1-preds_softmax), self.gamma), preds_logsoft)

# torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ

loss = torch.mul(self.alpha, loss.t())

if self.size_average:

loss = loss.mean()

else:

loss = loss.sum()

return loss

3. 相关

你可能感兴趣的:(pytorch,绘制多个算法loss)