polyloss是Cross-entropy loss和Focal loss的优化版本,PolyLoss在二维图像分类、实例分割、目标检测和三维目标检测任务上都明显优于Cross-entropy loss和Focal loss。
作者认为可以将常用的分类损失函数,如Cross-entropy loss和Focal loss,分解为一系列加权多项式基。
它们可以被分解为 ∑ j = 1 n α j ( 1 − P t ) j \sum_{j=1}^n\alpha_j(1-P_t)^j ∑j=1nαj(1−Pt)j的形式,其中 α j ∈ R + \alpha_j∈R^+ αj∈R+为多项式系数, P t P_t Pt为目标类标签的预测概率。每个多项式基 ( 1 − P t ) j (1-P_t)^j (1−Pt)j由相应的多项式系数 α j ∈ R + \alpha_j∈R^+ αj∈R+进行加权,这使PolyLoss能够很容易地调整不同的多项式基。
import tensorflow as tf
def poly1_cross_entropy(epsilon=1.0):
def _poly1_cross_entropy(y_true, y_pred):
# pt, CE, and Poly1 have shape [batch].
pt = tf.reduce_sum(y_true * tf.nn.softmax(y_pred), axis=-1)
CE = tf.nn.softmax_cross_entropy_with_logits(y_true, y_pred)
Poly1 = CE + epsilon * (1 - pt)
loss = tf.reduce_mean(Poly1)
return loss
return _poly1_cross_entropy
def poly1_focal_loss(gamma=2.0, epsilon=1.0, alpha=0.25):
def _poly1_focal_loss(y_true, y_pred):
p = tf.math.sigmoid(y_pred)
ce_loss = tf.nn.sigmoid_cross_entropy_with_logits(y_true, y_pred)
pt = y_true * p + (1 - y_true) * (1 - p)
FL = ce_loss * ((1 - pt) ** gamma)
if alpha >= 0:
alpha_t = alpha * y_true + (1 - alpha) * (1 - y_true)
FL = alpha_t * FL
Poly1 = FL + epsilon * tf.math.pow(1 - pt, gamma + 1)
loss = tf.reduce_mean(Poly1)
return loss
return _poly1_focal_loss
import torch
import torch.nn as nn
import torch.nn.functional as F
class Poly1CrossEntropyLoss(nn.Module):
def __init__(self,
num_classes: int,
epsilon: float = 1.0,
reduction: str = "none"):
"""
Create instance of Poly1CrossEntropyLoss
:param num_classes:
:param epsilon:
:param reduction: one of none|sum|mean, apply reduction to final loss tensor
"""
super(Poly1CrossEntropyLoss, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.reduction = reduction
return
def forward(self, logits, labels):
"""
Forward pass
:param logits: tensor of shape [N, num_classes]
:param labels: tensor of shape [N]
:return: poly cross-entropy loss
"""
labels_onehot = F.one_hot(labels, num_classes=self.num_classes).to(device=logits.device,
dtype=logits.dtype)
pt = torch.sum(labels_onehot * F.softmax(logits, dim=-1), dim=-1)
CE = F.cross_entropy(input=logits, target=labels, reduction='none')
poly1 = CE + self.epsilon * (1 - pt)
if self.reduction == "mean":
poly1 = poly1.mean()
elif self.reduction == "sum":
poly1 = poly1.sum()
return poly1
class Poly1FocalLoss(nn.Module):
def __init__(self,
num_classes: int,
epsilon: float = 1.0,
alpha: float = 0.25,
gamma: float = 2.0,
reduction: str = "none"):
"""
Create instance of Poly1FocalLoss
:param num_classes: number of classes
:param epsilon: poly loss epsilon
:param alpha: focal loss alpha
:param gamma: focal loss gamma
:param reduction: one of none|sum|mean, apply reduction to final loss tensor
"""
super(Poly1FocalLoss, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
return
def forward(self, logits, labels):
"""
Forward pass
:param logits: output of neural netwrok of shape [N, num_classes] or [N, num_classes, ...]
:param labels: ground truth of shape [N] or [N, ...], NOT one-hot encoded
:return: poly focal loss
"""
# focal loss implementation taken from
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/focal_loss.py
p = torch.sigmoid(logits)
# if labels are of shape [N]
# convert to one-hot tensor of shape [N, num_classes]
if labels.ndim == 1:
labels = F.one_hot(labels, num_classes=self.num_classes)
# if labels are of shape [N, ...] e.g. segmentation task
# convert to one-hot tensor of shape [N, num_classes, ...]
else:
labels = F.one_hot(labels.unsqueeze(1), self.num_classes).transpose(1, -1).squeeze_(-1)
labels = labels.to(device=logits.device,
dtype=logits.dtype)
ce_loss = F.binary_cross_entropy_with_logits(logits, labels, reduction="none")
pt = labels * p + (1 - labels) * (1 - p)
FL = ce_loss * ((1 - pt) ** self.gamma)
if self.alpha >= 0:
alpha_t = self.alpha * labels + (1 - self.alpha) * (1 - labels)
FL = alpha_t * FL
poly1 = FL + self.epsilon * torch.pow(1 - pt, self.gamma + 1)
if self.reduction == "mean":
poly1 = poly1.mean()
elif self.reduction == "sum":
poly1 = poly1.sum()
return poly1
以resnet18为例,训练过程loss图被删了,所以只能在花朵识别在验证集中识别一下,结果正确率上升了6%左右,数据集如下:
链接:https://pan.baidu.com/s/1zs9U76OmGAIwbYr91KQxgg
提取码:bhjx
有兴趣的小伙伴可以自己尝试一下。