交叉熵损失的代码理解

本文参考了:添加链接描述
1:首先看一下交叉熵的公式:
交叉熵损失的代码理解_第1张图片
2:这里我们考虑图片交叉熵损失:
2.1:以语义分割为例:模型的output的维度为(B,C,H,W),假如有四十个类别,那么(1,40,256,256)。图片的标签即label即真实值即groundtruth维度为(1,1,256,256)。
2.2:在代码中:输入和标签直接输入到nn.CrossEntropyLoss中。那么两个维度都不一样是如何计算的呢?

class CrossEntropyLoss2d(nn.Module):
    def __init__(self, device, weight):
        super(CrossEntropyLoss2d, self).__init__()
        self.weight = torch.tensor(weight).to(device)
        self.num_classes = len(self.weight) + 1  # +1 for void
        if self.num_classes < 2**8:
            self.dtype = torch.uint8
        else:
            self.dtype = torch.int16
        self.ce_loss = nn.CrossEntropyLoss(
            torch.from_numpy(np.array(weight)).float(),
            reduction='none',
            ignore_index=-1
        )
        self.ce_loss.to(device)

    def forward(self, inputs_scales, targets_scales):
        losses = []
        for inputs, targets in zip(inputs_scales, targets_scales):
            # mask = targets > 0
            targets_m_1 = targets.clone() #深拷贝
            targets_m = targets_m_1-1 #减去
            # targets_to_one_hot = torch.nn.functional.one_hot(targets_m.to(torch.int64)) #值为-1的样本不参与计算
            loss_all = self.ce_loss(inputs, targets_m.long())

            number_of_pixels_per_class = \
                torch.bincount(targets.flatten().type(self.dtype),
                               minlength=self.num_classes)
            divisor_weighted_pixel_sum = \
                torch.sum(number_of_pixels_per_class[1:] * self.weight)   # without void
            losses.append(torch.sum(loss_all) / divisor_weighted_pixel_sum)
            # losses.append(torch.sum(loss_all) / torch.sum(mask.float()))

        return losses

经过查看发现:交叉熵会将输入的通道维度进行压缩。且会对label进行one-hot编码。添加链接描述
3:那么是如何进行one-hot编码的呢?
标签:(1,5,5)
交叉熵损失的代码理解_第2张图片
假如我们的pred:有四个通道即四个类别。(1,4,5,5)
交叉熵损失的代码理解_第3张图片
那我们标签进行编码也有四个通道:(1,5,5,4)。他就是将每一行的类别进行one-hot编码。
交叉熵损失的代码理解_第4张图片
我们进行reshape到与pred一样大小。(1,4,5,5)
交叉熵损失的代码理解_第5张图片
经过one-hot编码之后的标签:(1,4,5,5)
交叉熵损失的代码理解_第6张图片
那么我们的损失就可以计算为:添加链接描述
交叉熵损失的代码理解_第7张图片

你可能感兴趣的:(pytorch函数,深度学习,人工智能,python,pytorch,机器学习)