一文搞懂F.cross_entropy中的weight参数

交叉熵是在分类任务中常用的损失函数,对于样本均衡的分类任务我们可以直接使用。但当我们面对样本类别失衡的情况时,导致训练过程中的损失被数据量最多的类别的主导,从而导致模型不能被有效的训练。我们需要通过为不同的样本损失赋予不同的权重以平衡不同类别间数据量的差异。这时了解一下F.cross_entrpy中的weight参数的底层是如何实现的,是非常有必要的!
关于Pytorch中F.cross_entropy详细的实现过程请看https://blog.csdn.net/code_plus/article/details/115420223这篇博客。
下面分析weight是如何作用的。
1、数据准备

input = torch.tensor([[[[0.5546, 0.1304, 0.9288],
                        [0.6879, 0.3553, 0.9984],
                        [0.1474, 0.6745,0.8948]],
		               [[0.8524, 0.2278, 0.6476],
                        [0.6203, 0.6977, 0.3352],
                        [0.4946, 0.4613, 0.6882]]]])
target = torch.tensor([[[0, 0, 0],
                        [0, 0, 0],
                        [0, 0, 1]]])
weight = torch.tensor([1.0, 9.0])

2、Pytorch中的实现结果
首先我们看一下普通的交叉熵和加权的交叉熵的结果,为了方便结果直接注释在了代码中。

loss = F.cross_entropy(input, target)
print(loss)  # tensor(0.7099)
weight = torch.tensor([1.0, 9.0])
loss = F.cross_entropy(input, target, weight)
print(loss)  # tensor(0.7531)

从交叉熵的结果熵我们发现,加权后的交叉熵在数值上变大了,然后可能就会有人问,那结果是不是一定会变大呢?(当然,我们理想的结果就是让损失值变大)
这个要分析数据的类别失衡问题,我本人目前在做变化检测,对于变化检测来说,变化的像素是极少数的,而未变化的像素是多数的。以上面input的数据为例,假设未变化的类别标签是0,变化的类别标签是1,可以发现0占了大多数,变化和未变化的样本数是失衡的,然后我们给的权重是[1.0,9.0]也就是说我们加大了变化样本(类别标签是1)的权重,然后最终得到的损失是放大的。
总结:当我们增加少数样本的权重时,计算出的损失值应该是放大的。

3、weight参数是如何作用的?
交叉熵的实现过程请看上篇文章,下面加权的交叉熵代码和未加权的交叉熵在实现上只是torch.log前面多加了一个weight[target[b][i][j]]参数,就这么简单。

input = F.softmax(input, dim=1)
loss = 0.0
for b in range(target.shape[0]):
    for i in range(target.shape[1]):
        for j in range(target.shape[2]):
            # loss -= torch.log(input[b][target[b][i][j]][i][j])
            loss -= weight[target[b][i][j]] * torch.log(input[b][target[b][i][j]][i][j])
print(loss/(1*8+9*1))

当我们在计算损失前,当遇到真实标签是0的样本时,我们就乘上我们为它附上的权重,也就是1;当我们遇到真实标签是1的样本时,我们就乘上给该类样本附的权重,也就是9。在对所用样本计算完损失后,再求平均,也就是loss/(1*8+9*1):这个含义是input数据中标签为0样本的个数是8,权重是1,即1个类别为0的样本我们仍然将其视为1个;标签为1的样本的个数是1,权重是9,即1个类别为1的样本我们将其视为9个;这样总的样本数就是1*8+9*1=17个,在对loss求平均,也就是最终的损失。该值也是0.7531,确认我们的分析是没有问题的。

4、总结
weight的作用就是扩大不同类别样本的个数(同时,总的样本个数也跟随着扩大了)。

注:如有错误还请指出!

你可能感兴趣的:(深度学习,深度学习,python)