在语义分割任务中,根据数据的分布情况可选择不同的损失函数对网络输出和标签进行数值运算,以达到较优的训练效果。特别,在数据样本不均衡以及样本难易程度不同时,选择FocalLoss和DiceLoss往往能起到事半功倍的效果。本博客针对CrossEntropy、FocalLoss和DiceLoss三类损失函数进行了如下分析:
一般来讲,RGB图像数据经过网络过后最终的输出形式为一个四维张量,shape为(batch_size, Num_class, high, width),分别对应batch size,特征图层数(类别数),特征图的高和宽。而标签一般为三维张量,shape为(batch_size, high, width),分别对应batch size,标签的高和宽。由于输出特征图的层数对应类别数,而标签图每个点的值是像素点所属类别,因此可以按类别数将标签图像转化为多层,每一层对应一类,如在二分类任务中,可根据前景和对象将标签图转化为两层,分别对应前景和对象。
假设网络的输出shape为(1, 2, 2, 2),标签shape为(1, 2, 2),表示每次输入网络的图片数量为1,特征图共两层(类别为2),特征图尺寸为2×2。则可如下定义输出和标签:
import numpy as np
import random, torch, torchvision
import torch.nn.functional as F
from torch import nn, optim
def setup_seed(seed):
#固定一个随机因子
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def get_onehot(label, num_class, high, width):
label = label.view(-1)
ones = torch.sparse.torch.eye(num_class)
ones = ones.index_select(0, label)
ones = torch.transpose(ones, 0, 1).int()
ones_high = ones.shape[0]
label_list = []
for i in range(ones_high):
label_list.append(ones[i].reshape(high, width))
return label_list
def main():
setup_seed(20)
output = torch.randn((1, 2, 2, 2))
label = torch.randint(2, (1, 2, 2))
N, C, high, width = output.shape
label_list = get_onehot(label, C, high, width)
label0 = label_list[0].unsqueeze(dim=0)
label1 = label_list[1].unsqueeze(dim=0)
print('output:\n', output)
print('label:\n', label)
print('前景label:\n', label0)
print('对象label:\n', label1)
if __name__ == '__main_':
main()
#打印结果
output:
tensor([[[[-1.2061, 0.0617],
[ 1.1632, -1.5008]],
[[-1.5944, -0.0187],
[-2.1325, -0.5270]]]])
label:
tensor([[[0, 0],
[1, 1]]])
前景label:
tensor([[[1, 1],
[0, 0]]], dtype=torch.int32)
对象label:
tensor([[[0, 0],
[1, 1]]], dtype=torch.int32)
直观形象的表示如下图所示:
2.1 CrossEntropy
交叉熵衡量的是两个分布之间的距离,因此可以被用来刻画预测值和标签值的差异情况,公式如下:
该公式表示损失函数是由各个类别各自的损失函数叠加得到,而各自的损失函数是通过各自标签值和预测值运算得到的。其中,和分别表示标签值和预测值,表示预测值的概率。因此可通过以下代码段得到CrossEntropy损失函数:
pt = F.softmax(output, 1)
pt0 = pt[:, 0, :, :]
pt1 = pt[:, 1, :, :]
log0 = torch.log(pt0)
log1 = torch.log(pt1)
print('pt:\n', pt)
print('log0:\n', log0)
print('log1:\n', log1)
loss = -(label0 * log0 + label1 * log1)
print('ce: ', loss.mean())
#打印结果
pt:
tensor([[[[0.4041, 0.4799],
[0.0357, 0.7259]],
[[0.5959, 0.5201],
[0.9643, 0.2741]]]])
log0:
tensor([[[-0.9060, -0.7342],
[-3.3321, -0.3204]]])
log1:
tensor([[[-0.5177, -0.6538],
[-0.0364, -1.2942]]])
ce: tensor(0.7427)
具体内部操作如下图所示:
利用pytorch内部自带的CrossEntropy函数进行验证:
ce = nn.CrossEntropyLoss()
loss = ce(output, label)
print('ce: ', loss)
#打印结果
ce: tensor(0.7427)
通过验证结果可以发现,对CrossEntropyLoss函数分解并分步计算的结果,与直接使用CrossEntropyLoss函数计算的结果一致。
2.2 FocalLoss
Focal loss是何恺明团队针对训练样本不平衡以及样本难易程度不同提出的,是交叉熵损失函数的变种,公式如下:
的作用是给不同类别的样本loss加权重,正样本少,就加大正样本loss的权重;而的作用是,当样本预测值pt比较大时,也就是易分样本,会很小,这样易分样本的loss会显著减小,模型就会更关注难分样本loss的优化。根据实际的分割任务可对损失函数具体的运算,代码如下:
#假设α=0.25,γ=2
loss = - 0.25 * label0 * ((1 - pt0) ** 2) * torch.log(pt0) - 0.75 * label1 * ((1 - pt1) ** 2) * log1
print('focal: ', loss.mean())
#打印输出
focal: tensor(0.1604)
2.3 DiceLoss
Dice loss是针对前景比例太小的问题提出的,dice系数源于二分类,本质上是衡量两个样本的重叠部分。公式如下:
实现代码如下:
#假设smooth=1
output0 = output[:, 0, :, :]
output1 = output[:, 1, :, :]
intersection0 = output0 * label0
intersection1 = output1 * label1
DSC0 = (2 * torch.abs(torch.sum(intersection0)) + 1) / (torch.abs(torch.sum(output0)) + torch.sum(label0) + 1)
DSC1 = (2 * torch.abs(torch.sum(intersection1)) + 1) / (torch.abs(torch.sum(output1)) + torch.sum(label1) + 1)
loss = 1 - (DSC0 + DSC1) / 2
print('dice:', loss.mean())
#打印输出
dice: tensor(0.5226)
具体内部操作如下图所示: