【MindSpore】实现CV语义分割不同类别加权计算损失

在复现High-resolution NetWork(HRNet)用于语义分割时,Cityscapes数据集不同类别的物体在计算损失时赋有不同的权重。

weights_list = [0.8373, 0.918, 0.866, 1.0345, 
                1.0166,0.9969, 0.9754, 1.0489,
                0.8786, 1.0023, 0.9539, 0.9843,
                1.1116, 0.9037, 1.0865, 1.0955,
                1.0865, 1.1529, 1.0507]

在PyTorch中提供torch.nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label)可用于实现不同类别的加权计算。

MindSpore的r1.1和r1.2版本并未提供类似功能的损失函数。可以用以下代码实现:

class CrossEntropyLossWithWeights(_Loss):
    def __init__(self, weights, num_classes=19, ignore_label=255):
        super(CrossEntropyLossWithWeights, self).__init__()
        self.weights = weights
        self.resize = F.ResizeBilinear(cfg.train.image_size)
        self.one_hot = P.OneHot(axis=-1)
        self.on_value = Tensor(1.0, mstype.float32)
        self.off_value = Tensor(0.0, mstype.float32)
        self.cast = P.Cast()
        self.ce = nn.SoftmaxCrossEntropyWithLogits()
        self.zeros = F.Zeros()
        self.fill = F.Fill()
        self.equal = F.Equal()
        self.select = F.Select()
        self.num_classes = num_classes
        self.ignore_label = ignore_label
        self.mul = P.Mul()
        self.argmax = P.Argmax(output_type=mstype.int32)
        self.sum = P.ReduceSum(False)
        self.div = P.RealDiv()
        self.transpose = P.Transpose()
        self.reshape = P.Reshape()

    def construct(self, logits, labels):
        logits = self.resize(logits)
        labels_int = self.cast(labels, mstype.int32)
        labels_int = self.reshape(labels_int, (-1,))
        logits_ = self.transpose(logits, (0, 2, 3, 1))  # (12, 1024, 2048, 19)
        logits_ = self.reshape(logits_, (-1, self.num_classes))
        labels_float = self.cast(labels_int, mstype.float32)
        weights = self.zeros(labels_float.shape, mstype.float32)
        for i in range(self.num_classes):
            fill_weight = self.fill(mstype.float32, labels_float.shape, self.weights<i>)
            equal_ = self.equal(labels_float, i)
            weights = self.select(equal_, fill_weight, weights)
        one_hot_labels = self.one_hot(labels_int, self.num_classes, self.on_value, self.off_value)
        loss = self.ce(logits_, one_hot_labels)
        loss = self.mul(weights, loss)
        loss = self.div(self.sum(loss), self.sum(weights))

        return loss

你可能感兴趣的:(深度学习框架经验积累,深度学习,MindSpore,语义分割,损失函数)