原理:
<点击跳转至原理解释>
实现代码,可直接使用 logits
和 label
作为输入参数
import torch
import torch.nn.functional as F
import torch.nn as nn
class Focal_Loss(nn.Module):
def __init__(self, weight, gamma=2):
super(Focal_Loss,self).__init__()
self.gamma = gamma
self.weight = weight # 是tensor数据格式的列表
def forward(self, preds, labels):
"""
preds:logist输出值
labels:标签
"""
preds = F.softmax(preds,dim=1)
print(preds)
eps = 1e-7
target = self.one_hot(preds.size(1), labels)
print(target)
ce = -1 * torch.log(preds+eps) * target
print(ce)
floss = torch.pow((1-preds), self.gamma) * ce
print(floss)
floss = torch.mul(floss, self.weight)
print(floss)
floss = torch.sum(floss, dim=1)
print(floss)
return torch.mean(floss)
def one_hot(self, num, labels):
one = torch.zeros((labels.size(0),num))
one[range(labels.size(0)),labels] = 1
return one
参数说明
初始化类时,需要传入 a 列表,类型为tensor,表示每个类别的样本占比的反比,比如5分类中,有某一类占比非常多,那么就设置为小于0.2,即相应的权重缩小,占比很小的类,相应的权重就要大于0.2
lf = Focal_Loss(torch.tensor([0.2,0.2,0.2,0.2,0.2]))
使用时,logits
是神经网络的输出,不用计算softmax
,label是torchvision
类自动生成的标签
loss = lf(logits,label)
例子,这里 logits 为(16*5)的tensor,表示批大小为16,5分类;label为每个样本的真实标签类别,对应 logits 的下标,是一个16维的tensor向量
logits = torch.tensor([[-2.7672, 3.6104, -7.4242, -3.2486, -3.1323],
[-2.4270, 3.1833, -5.9394, -2.4592, -3.2292],
[-2.5986, 3.3626, -6.7340, -2.8639, -3.1553],
[-2.6206, 3.4201, -6.8754, -2.9308, -3.1507],
[-2.8307, 3.7070, -7.6975, -3.3924, -3.1318],
[-2.5776, 3.3316, -6.6595, -2.8187, -3.1542],
[-2.8930, 3.7982, -7.9322, -3.5327, -3.1210],
[-2.5489, 3.3580, -6.5229, -2.7590, -3.1912],
[-1.5628, 1.8362, -1.8254, -0.3083, -3.5928],
[ 0.2434, -4.9000, 1.1150, 2.7505, -1.0390],
[-2.6877, 3.5686, -7.1178, -3.0847, -3.1617],
[-2.6847, 3.5191, -6.8264, -3.0083, -3.2041],
[-2.6137, 3.4025, -6.8965, -2.9250, -3.1396],
[-2.7505, 3.5840, -7.3340, -3.2035, -3.1435],
[-2.7030, 3.5163, -7.1549, -3.1002, -3.1424],
[-2.6661, 3.4580, -7.0481, -3.0258, -3.1365]])
label = torch.tensor([1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 1, 1, 3, 1, 1, 1])
lf = Focal_Loss(torch.tensor([0.2,0.2,0.2,0.2,0.2]))
loss = lf(logits,label)
print('loss:', loss)
输出结果
loss: tensor(0.1902)