centerloss中心损失它仅仅用来减少类内的差异,而不能有效增大类间的差异性。下图中,图(a)表示softmax loss学习到的特征描述 。图(b)表示softmax loss + center loss 学习到的特征描述,他能把同一类的样本之间的距离拉近一些,使其相似性变大,尽量的往样本中心靠拢,但可以看出他没有把不同类样本之间的样本距离拉大。
centerloss的主要思路为:让每一类特征尽可能的在输出特征空间内聚集在一起。更直白的描述就是每一类的特征在特征空间中尽可能的聚集在某一个中心点附近。正常情况下,如果我们先验的知道了所有样本的GT中心点,那这个任务就好解决了,然而事实是我们无法预先获取类中心特征空间的分布。因此我们只能从训练的过程中动态的获取类中心特征,并对整体的训练过程产生约束。需要注意的是在训练的过程中,受限于GPU的显存等问题,我们不可能直接获取所有样本的特征中心,因此整个过程是基于batch进行的,而且当网络还未收敛的情况下,网络得到的特征中心也是不正确的。基于这两点,特征中心的确定势必是一个基于batch的动态过程。
接下来就详细讲一下这个动态过程,首先提出一个问题:中心点明明是不确定的,那如何让特征去聚集在这个不确定的特征中心点呢?
这要从centerloss的更新机制说起,从下面的两组公式可以看出,center中心点的更新方向是特征值和中心点的二范数,简单来说最终通过这种更新方式会使得某一类特征值对应的中心点被更新成与所有该类样本特征值的二范数和最小的位置,而这个位置我们可以广义的理解为所以特征的中心点位置。因此整体的centerloss是在边学习边找中心点的,最终中心点的确定和整体分类任务的收敛是同步进行的。
用知乎上比较概括性的话来讲就是:
center loss的原理主要是在softmax loss的基础上,通过对训练集的每个类别在特征空间分别维护一个类中心,在训练过程,增加样本经过网络映射后在特征空间与类中心的距离约束,从而兼顾了类内聚合与类间分离。
最终通过将centerloss和softmaxloss进行加权求和,实现整体的分类任务的学习。
centerloss的计算代码:
def forward(self, output_features, y_truth):
"""
损失计算
:param output_features: conv层输出的特征, [b,c,h,w]
:param y_truth: 标签值 [b,]
:return:
"""
batch_size = y_truth.size(0)
output_features = output_features.view(batch_size, -1)
assert output_features.size(-1) == self.feat_dim
factor = self.scale / batch_size
# return self.lamda * factor * self.lossfunc(output_features, y_truth, self.feature_centers))
centers_batch = self.feature_centers.index_select(0, y_truth.long()) # [b,features_dim]
diff = output_features - centers_batch
loss = self.lamda * 0.5 * factor * (diff.pow(2).sum())
#########
return loss
center的更新代码:
# 改段代码需要注意的是backward返回值需要与对应的forward的输入参数一一对应。
class CenterlossFunc(Function):
@staticmethod
def forward(ctx, feature, label, centers, batch_size):
ctx.save_for_backward(feature, label, centers, batch_size)
centers_batch = centers.index_select(0, label.long())
return (feature - centers_batch).pow(2).sum() / 2.0 / batch_size
@staticmethod
def backward(ctx, grad_output):
feature, label, centers, batch_size = ctx.saved_tensors
centers_batch = centers.index_select(0, label.long())
diff = centers_batch - feature
# init every iteration
counts = centers.new_ones(centers.size(0))
ones = centers.new_ones(label.size(0))
grad_centers = centers.new_zeros(centers.size())
counts = counts.scatter_add_(0, label.long(), ones)
grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
grad_centers = grad_centers/counts.view(-1, 1)
return - grad_output * diff / batch_size, None, grad_centers / batch_size, None
pytorch代码
https://www.cnblogs.com/dxscode/p/12059548.html
https://github.com/jxgu1016/MNIST_center_loss_pytorch/blob/master/CenterLoss.py