基于平衡因子的聚焦损失函数(Focal loss function)的python(torch)类实现

一、聚焦损失函数和平衡因子基础知识

  • 从分类错误代价和样本困难程序两个方面思考,有时间再写。

二、基于平衡因子的聚焦损失函数的python(torch)类实现

class FocalLoss(nn.Module):
    def __init__(self, gamma=4.5, alpha=0.05):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if gamma==0:
            self.x = 0
        else:
            self.x = 1
    def forward(self, preY, tureY):
        preY = torch.sigmoid(preY)
        lossList = ((- (1-self.alpha) ** self.x) *(1 - preY) ** self.gamma * torch.log(preY) * tureY \
                         - (self.alpha ** self.x) * preY ** self.gamma * torch.log(1 - preY) * (1 - tureY)) / len(tureY)
        loss = torch.sum(lossList)
        # 统计loss的分布特征
        loss_class = torch.zeros((2))
        loss_class[0] = (lossList.data * tureY).sum()
        loss_class[1] = (lossList.data * (1 - tureY)).sum()
        # cv_loss_class = loss_class[0] / loss_class[1]
        cv_loss_class = torch.std(loss_class)/torch.mean(loss_class)
        #  用来调试loss中可能出现的nan
        if torch.isnan(loss):
            print("The loss has been nan!!!")
            print(torch.sum(self.alpha * (1 - preY) ** self.gamma * torch.log(preY) * tureY))
            print(torch.sum(preY ** self.gamma * torch.log(1 - preY) * (1 - tureY)))
            print(torch.sum(torch.log(1 - preY)))
            print(torch.sum(preY))
            print(torch.min(preY))
            print(torch.min(1 - preY))
            print(torch.min(torch.log(preY)))
        # 返回值:loss值、loss分布的差异系数
        return loss, cv_loss_class

三、使用聚焦损失函数进行分类(故障诊断)的完整工程

  • 其中数据由于保密原因不公开
  • 完整python实现如以下github中所示:
    使用基于平衡因子的聚焦损失函数进行分类(故障诊断)

你可能感兴趣的:(python,Pytorch,python,机器学习,深度学习)