import torch import torch.nn as nn import torch.utils.data as data import torchvision.transforms as TF import torchvision.utils as vutils import torch.nn.functional as F from torch.autograd import Function class CenterLoss(nn.Module): """ paper: http://ydwen.github.io/papers/WenECCV16.pdf code: https://github.com/pangyupo/mxnet_center_loss pytorch code: https://blog.csdn.net/sinat_37787331/article/details/80296964 """ def __init__(self, features_dim, num_class=10, alpha=0.01, scale=1.0, batch_size=64): """ 初始化 :param features_dim: 特征维度 = c*h*w :param num_class: 类别数量 :param alpha: centerloss的权重系数 [0,1] """ assert 0 <= alpha <= 1 super(CenterLoss, self).__init__() self.alpha = alpha self.num_class = num_class self.scale = scale self.batch_size = batch_size self.feat_dim = features_dim # store the center of each class , should be ( num_class, features_dim) self.feature_centers = nn.Parameter(torch.randn([num_class, features_dim])) self.lossfunc = CenterLossFunc.apply init_weight(self, 'normal') def forward(self, output_features, y_truth): """ 损失计算 :param output_features: conv层输出的特征, [b,c,h,w] :param y_truth: 标签值 [b,] :return: """ batch_size = y_truth.size(0) output_features = output_features.view(batch_size, -1) assert output_features.size(-1) == self.feat_dim loss = self.lossfunc(output_features, y_truth, self.feature_centers) loss /= batch_size # centers_pred = self.feature_centers.index_select(0, y_truth.long()) # [b,features_dim] # diff = output_features - centers_pred # loss = self.alpha * 1 / 2.0 * (diff.pow(2).sum()) / self.batch_size return loss class CenterLossFunc(Function): # https://blog.csdn.net/xiewenbo/article/details/89286462 @staticmethod def forward(ctx, feat, labels, centers): ctx.save_for_backward(feat, labels, centers) centers_batch = centers.index_select(0, labels.long()) return (feat - centers_batch).pow(2).sum() / 2.0 @staticmethod def backward(ctx, grad_output): feature, label, centers = ctx.saved_tensors centers_batch = centers.index_select(0, label.long()) diff = centers_batch - feature # init every iteration counts = centers.new(centers.size(0)).fill_(1) ones = centers.new(label.size(0)).fill_(1) grad_centers = centers.new(centers.size()).fill_(0) counts = counts.scatter_add_(0, label.long(), ones) grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff) grad_centers = grad_centers / counts.view(-1, 1) return - grad_output * diff, None, grad_centers if __name__ == '__main__': ct = CenterLoss(2, 10, 0.1).cuda() y = torch.Tensor([0, 0, 2, 1]).cuda() feat = torch.zeros(4, 2).cuda().requires_grad_() print(list(ct.parameters())) print(ct.feature_centers.grad) out = ct(feat, y) print(out.item()) out.backward() print(ct.feature_centers.grad) print(feat.grad)