在构建loss时pytorch常用的包中有最常见的MSE、cross entropy(logsoftmax+NLLLoss)、KL散度Loss、BCE、HingeLoss等等,详见:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-nn/#loss-functions
这里主要讲解一种考虑类间距离的Center Loss:
center loss来自ECCV2016的一篇论文:A Discriminative Feature Learning Approach for Deep Face Recognition。
. In most of the available CNNs, the softmax loss function is used as the supervision signal to train the deep model. In order to enhance the discriminative power of the deeply learned features, this paper proposes a new supervision signal, called center loss
the center loss simultaneously learns a center for deep features of each class and penalizes the distances between the deep features and their corresponding class centers
简单的来说,我们在做分类(无论是image、instance、pixel level)的时候,我们不光需要学得separable的特征,更想要这些特征是discriminative的,这就意味着我们需要在loss上做更多的约束。
Specifically, we learn a center (a vector with the same dimension as a feature) for deep features of each class.
The CNNs are trained under the joint supervision of the softmax loss and center loss, with a hyper parameter to balance the two supervision signals.
融合Softmax Loss 与 Center Loss:
Softmax Loss (保证类之间的feature距离最大)与 Center Loss (保证类内的feature距离最小,更接近于类中心)
m是mini-batch、n是class。在Lc公式中有一个缺陷,就是Cyi是i这个样本对应的类别yi所属于的类中心C∈ Rd,d代表d维。
2、避免错误分类的样本的干扰,使用scalar α 来控制center的学习率
即:当yi = j,也就是mini-batch中某一个sample是对应要更新的那一个类的center的时候就累加起来除以某类的个数+1。
最终loss联立起来如上图,λ用于平衡softmax loss与center loss,越大则区分度 越大,如下图效果:
在三种我们清楚了原理,保证分类情况下的intra-class loss最小。下面讲解如何在代码和结构中实现:
即在特征层输出(classification前最后一层)引入center loss:
class CenterLoss(nn.Module):
def __init__(self, num_classes, feat_dim, size_average=True):
super(CenterLoss, self).__init__()
self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
self.centerlossfunc = CenterlossFunc.apply
self.feat_dim = feat_dim
self.size_average = size_average
def forward(self, label, feat):
batch_size = feat.size(0)
feat = feat.view(batch_size, -1)
# To check the dim of centers and features
if feat.size(1) != self.feat_dim:
raise ValueError("Center's dim: {0} should be equal to input feature's dim: {1}".format(self.feat_dim,feat.size(1)))
loss = self.centerlossfunc(feat, label, self.centers)
loss /= (batch_size if self.size_average else 1)
return loss
class CenterlossFunc(Function):
def forward(ctx, feature, label, centers):
ctx.save_for_backward(feature, label, centers)
centers_batch = centers.index_select(0, label.long())
return (feature - centers_batch).pow(2).sum() / 2.0
def backward(ctx, grad_output):
feature, label, centers = ctx.saved_tensors
centers_batch = centers.index_select(0, label.long())
diff = centers_batch - feature
# init every iteration
counts = centers.new(centers.size(0)).fill_(1)
ones = centers.new(label.size(0)).fill_(1)
grad_centers = centers.new(centers.size()).fill_(0)
counts = counts.scatter_add_(0, label.long(), ones)
grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
grad_centers = grad_centers/counts.view(-1, 1)
return - grad_output * diff, None, grad_centers
def main(test_cuda=False):
device = torch.device("cuda" if test_cuda else "cpu")
ct = CenterLoss(10,2).to(device)
y = torch.Tensor([0,0,2,1]).to(device)
feat = torch.zeros(4,2).to(device).requires_grad_()
print (list(ct.parameters()))
print (ct.centers.grad)
out = ct(y,feat)
center loss 与 constrastive loss 以及 triplet loss的区别在原文中也有给出,center loss相对于contrastive和triplet loss的优点显然省去了复杂并且含糊的样本对构造过程,接下来会对triplet loss做一个梳理。