OHEM,Focal loss,GHM loss二分类pytorch代码实现(减轻难易样本不均衡问题)

https://mp.weixin.qq.com/s/iOAICJege2b0pCVxPkvNiA
综述:解决目标检测中的样本不均衡问题
该综述主要介绍了OHEM,Focal loss,GHM loss;由于我这的二分类数据集不存在正负样本不均衡的问题,所以着重看了处理难易样本不均衡(正常情况下,容易的样本较多,困难的样本较少);由于我只是分类问题,所以写了各种分类的loss,且网络的最后一层为softmax,所以网络输出的pred是softmax层前的logits经过softmax后的结果,普通的交叉熵损失即为sum(-gt*log(pred)),但torch.nn.CrossEntropyLoss()中会对于输入的pred再进行一次softmax,所以这里使用torch.nn.NLLLoss代替,当然经测试,即使网络最后一层使用softmax损失函数还是使用torch.nn.CrossEntropyLoss(),效果和使用torch.nn.NLLLoss差不多。。。

OHEM:
代码参考:https://www.codeleading.com/article/7442852142/

def ohem_loss(pred, target, keep_num):
    loss = torch.nn.NLLLoss(reduce=False)(torch.log(pred), target)
    print(loss)
    loss_sorted, idx = torch.sort(loss, descending=True)
    loss_keep = loss_sorted[:keep_num]
    return loss_keep.sum() / keep_num

Focal loss:
详解:原论文Focal Loss for Dense Object Detection
代码参考:https://zhuanlan.zhihu.com/p/80594704

def focal_loss(pred,target,gamma=0.5):
    pred_temp=pred.detach().cpu()
    target_temp=target.detach().cpu()
    pt = torch.tensor([pred_temp[i,target_temp[i]] for i in range(target_temp.shape[0])])
    focal_weight = (1-pt).pow(gamma)
    return torch.mean((torch.nn.NLLLoss(reduce=False)(torch.log(pred), target)).mul(focal_weight.to(device).detach()))

GHM loss:
详解:https://zhuanlan.zhihu.com/p/80594704
代码参考:https://github.com/DHPO/GHM_Loss.pytorch/blob/master/GHM_loss.py

class GHM_Loss(nn.Module):
    def __init__(self, bins, alpha):
        super(GHM_Loss, self).__init__()
        self._bins = bins
        self._alpha = alpha
        self._last_bin_count = None

    def _g2bin(self, g):
        return torch.floor(g * (self._bins - 0.0001)).long()

    def _custom_loss(self, x, target, weight):
        raise NotImplementedError

    def _custom_loss_grad(self, x, target):
        raise NotImplementedError

    def forward(self, x, target):
        g = torch.abs(self._custom_loss_grad(x, target))
        bin_idx = self._g2bin(g)
        bin_count = torch.zeros((self._bins))
        for i in range(self._bins):
            bin_count[i] = (bin_idx == i).sum().item()

        N = x.size(0)

        nonempty_bins = (bin_count > 0).sum().item()
        gd = bin_count * nonempty_bins
        gd = torch.clamp(gd, min=0.0001)
        beta = N / gd
        return self._custom_loss(x, target, beta[bin_idx])
        
class GHMC_Loss(GHM_Loss):
    def __init__(self, bins, alpha):
        super(GHMC_Loss, self).__init__(bins, alpha)

    def _custom_loss(self, x, target, weight):
        return torch.sum((torch.nn.NLLLoss(reduce=False)(torch.log(x),target)).mul(weight.to(device).detach()))/torch.sum(weight.to(device).detach())

    def _custom_loss_grad(self, x, target):
        x=x.cpu().detach()
        target=target.cpu()
        return torch.tensor([x[i,target[i]] for i in range(target.shape[0])])-target

你可能感兴趣的:(pytorch)