centerloss损失函数的理解与实现

import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as TF
import torchvision.utils as vutils
import torch.nn.functional as F
from torch.autograd import Function

class CenterLoss(nn.Module):
    """
    paper: http://ydwen.github.io/papers/WenECCV16.pdf
    code:  https://github.com/pangyupo/mxnet_center_loss
    pytorch code: https://blog.csdn.net/sinat_37787331/article/details/80296964
    """

    def __init__(self, features_dim, num_class=10, alpha=0.01, scale=1.0, batch_size=64):
        """
        初始化
        :param features_dim: 特征维度 = c*h*w
        :param num_class: 类别数量
        :param alpha:   centerloss的权重系数 [0,1]
        """
        assert 0 <= alpha <= 1
        super(CenterLoss, self).__init__()
        self.alpha = alpha
        self.num_class = num_class
        self.scale = scale
        self.batch_size = batch_size
        self.feat_dim = features_dim
        # store the center of each class , should be ( num_class, features_dim)
        self.feature_centers = nn.Parameter(torch.randn([num_class, features_dim]))

        self.lossfunc = CenterLossFunc.apply
        init_weight(self, 'normal')

    def forward(self, output_features, y_truth):
        """
        损失计算
        :param output_features: conv层输出的特征,  [b,c,h,w]
        :param y_truth:  标签值  [b,]
        :return:
        """
        batch_size = y_truth.size(0)
        output_features = output_features.view(batch_size, -1)
        assert output_features.size(-1) == self.feat_dim
        loss = self.lossfunc(output_features, y_truth, self.feature_centers)
        loss /= batch_size

        # centers_pred = self.feature_centers.index_select(0, y_truth.long())  # [b,features_dim]
        # diff = output_features - centers_pred
        # loss = self.alpha * 1 / 2.0 * (diff.pow(2).sum()) / self.batch_size
        return loss


class CenterLossFunc(Function):
    # https://blog.csdn.net/xiewenbo/article/details/89286462
    @staticmethod
    def forward(ctx, feat, labels, centers):
        ctx.save_for_backward(feat, labels, centers)
        centers_batch = centers.index_select(0, labels.long())
        return (feat - centers_batch).pow(2).sum() / 2.0

    @staticmethod
    def backward(ctx, grad_output):
        feature, label, centers = ctx.saved_tensors
        centers_batch = centers.index_select(0, label.long())
        diff = centers_batch - feature
        # init every iteration
        counts = centers.new(centers.size(0)).fill_(1)
        ones = centers.new(label.size(0)).fill_(1)
        grad_centers = centers.new(centers.size()).fill_(0)

        counts = counts.scatter_add_(0, label.long(), ones)
        grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
        grad_centers = grad_centers / counts.view(-1, 1)
        return - grad_output * diff, None, grad_centers


if __name__ == '__main__':
    ct = CenterLoss(2, 10, 0.1).cuda()
    y = torch.Tensor([0, 0, 2, 1]).cuda()
    feat = torch.zeros(4, 2).cuda().requires_grad_()
    print(list(ct.parameters()))
    print(ct.feature_centers.grad)
    out = ct(feat, y)
    print(out.item())
    out.backward()
    print(ct.feature_centers.grad)
    print(feat.grad)

  

你可能感兴趣的:(centerloss损失函数的理解与实现)