在复现High-resolution NetWork(HRNet)用于语义分割时,Cityscapes数据集不同类别的物体在计算损失时赋有不同的权重。
weights_list = [0.8373, 0.918, 0.866, 1.0345,
1.0166,0.9969, 0.9754, 1.0489,
0.8786, 1.0023, 0.9539, 0.9843,
1.1116, 0.9037, 1.0865, 1.0955,
1.0865, 1.1529, 1.0507]
在PyTorch中提供torch.nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label)
可用于实现不同类别的加权计算。
MindSpore的r1.1和r1.2版本并未提供类似功能的损失函数。可以用以下代码实现:
class CrossEntropyLossWithWeights(_Loss):
def __init__(self, weights, num_classes=19, ignore_label=255):
super(CrossEntropyLossWithWeights, self).__init__()
self.weights = weights
self.resize = F.ResizeBilinear(cfg.train.image_size)
self.one_hot = P.OneHot(axis=-1)
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.cast = P.Cast()
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.zeros = F.Zeros()
self.fill = F.Fill()
self.equal = F.Equal()
self.select = F.Select()
self.num_classes = num_classes
self.ignore_label = ignore_label
self.mul = P.Mul()
self.argmax = P.Argmax(output_type=mstype.int32)
self.sum = P.ReduceSum(False)
self.div = P.RealDiv()
self.transpose = P.Transpose()
self.reshape = P.Reshape()
def construct(self, logits, labels):
logits = self.resize(logits)
labels_int = self.cast(labels, mstype.int32)
labels_int = self.reshape(labels_int, (-1,))
logits_ = self.transpose(logits, (0, 2, 3, 1)) # (12, 1024, 2048, 19)
logits_ = self.reshape(logits_, (-1, self.num_classes))
labels_float = self.cast(labels_int, mstype.float32)
weights = self.zeros(labels_float.shape, mstype.float32)
for i in range(self.num_classes):
fill_weight = self.fill(mstype.float32, labels_float.shape, self.weights<i>)
equal_ = self.equal(labels_float, i)
weights = self.select(equal_, fill_weight, weights)
one_hot_labels = self.one_hot(labels_int, self.num_classes, self.on_value, self.off_value)
loss = self.ce(logits_, one_hot_labels)
loss = self.mul(weights, loss)
loss = self.div(self.sum(loss), self.sum(weights))
return loss