【PyTorch】pytorch实现focalLoss

focalLoss焦点损失函数,主要是为了解决one-stage目标检测中正负样本比例严重失衡的问题。该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘。

FocalLoss是在交叉熵损失函数的基础上修改的得来的

                                              

其中y表示真实样本;p表示预测得到的概率;平衡因子alpha,用来平衡正负样本本身的比例不均;gamma调节简单样本权重降低的速率,当gamma为0时即为交叉熵损失函数,当gamma增加时,调整因子的影响也在增加。实验发现gamma为2是最优;alpha取0.25,即正样本要比负样本占比小,这是因为负例易分。

def __init__(self, class_num, alpha=None, gamma=1.5, size_average=True):
        super(FocalLoss, self).__init__()
        if alpha is None:
            #self.alpha = Variable(torch.ones(class_num, 1))
            #self.alpha[0] = 0.3
            self.alpha = Variable(torch.tensor([0.3,1,1,1,1,1,1,1]))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                self.alpha = Variable(alpha)
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average
def forward(self, inputs, targets):
        N = inputs.size(0)
        C = inputs.size(1)
        # P = F.softmax(inputs)
        P = inputs.softmax(dim=1)

        class_mask = inputs.data.new(N, C).fill_(0)
        class_mask = Variable(class_mask)
        ids = targets.view(-1, 1)
        class_mask.scatter_(1, ids.data, 1.)

        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()
        alpha = self.alpha[ids.data.view(-1)]

        probs = (P*class_mask).sum(1).view(-1,1)

        log_p = probs.log()
        #print('probs size= {}'.format(probs.size()))
        #print(probs)

        batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p 
        #print('-----bacth_loss------')
        #print(batch_loss)


        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        return loss

 

你可能感兴趣的:(PyTorch,pytorch)