【Loss】Center loss代码详解(pytorch)

**注:**全部代码在最后,此代码不知来自哪位大神。
m:batch size n:class size d:feat dim
x = ( x 0 x 1 . . . x m − 1 ) = ( x 00 x 01 . . . x 0 ( d − 1 ) x 10 x 11 . . . x 1 ( d − 1 ) . . . . . . . . . . . . . x ( m − 1 ) 0 x ( m − 1 ) 1 . . . x ( m − 1 ) ( d − 1 ) ) ∈ R m × d x=\begin{pmatrix}x_0 \\ x_1\\ ...\\x_{m-1}\end{pmatrix} =\begin{pmatrix}x_{00}&x_{01}&...&x_{0(d-1)}\\ x_{10}&x_{11}&...&x_{1(d-1)}\\...&...&....&...\\x_{(m-1)0}&x_{(m-1)1}&...&x_{(m-1)(d-1)}\end{pmatrix}\in\mathbb R^{m×d} x=x0x1...xm1=x00x10...x(m1)0x01x11...x(m1)1.............x0(d1)x1(d1)...x(m1)(d1)Rm×d

torch.pow(x, 2).sum(dim=1, keepdim=True)

= ( x 0 2 x 1 2 . . . x m − 1 2 ) = ( x 00 2 + x 01 2 + . . . + x 0 ( d − 1 ) 2 x 10 2 + x 11 2 + . . . + x 1 ( d − 1 ) 2 . . . x ( m − 1 ) 0 2 + x ( m − 1 ) 1 2 + . . . + x ( m − 1 ) ( d − 1 ) 2 ) ∈ R m × 1 =\begin{pmatrix}x_0^2 \\ x_1^2\\ ...\\x_{m-1}^2\end{pmatrix} =\begin{pmatrix}x_{00}^2+x_{01}^2+...+x_{0(d-1)}^2\\ x_{10}^2+x_{11}^2+...+x_{1(d-1)}^2\\...\\x_{(m-1)0}^2+x_{(m-1)1}^2+...+x_{(m-1)(d-1)}^2\end{pmatrix}\in\mathbb R^{m×1} =x02x12...xm12=x002+x012+...+x0(d1)2x102+x112+...+x1(d1)2...x(m1)02+x(m1)12+...+x(m1)(d1)2Rm×1

.expand(batch_size, self.num_classes)

= ( x 0 2 x 0 2 . . . x 0 2 x 1 2 x 1 2 . . . x 1 2 . . . . . . . . . . . . x m − 1 2 x m − 1 2 . . . x m − 1 2 ) ∈ R m × n =\begin{pmatrix}x_0^2 &x_0^2&...&x_0^2\\ x_1^2 &x_1^2&...&x_1^2\\ ...&...&...&...\\x_{m-1}^2 &x_{m-1}^2&...&x_{m-1}^2\end{pmatrix}\in \mathbb R^{m×n} =x02x12...xm12x02x12...xm12............x02x12...xm12Rm×n

同理self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())

c e n t e r s = ( c 00 c 01 . . . c 0 ( d − 1 ) c 10 c 11 . . . c 1 ( d − 1 ) . . . . . . . . . . . . . c ( n − 1 ) 0 c ( n − 1 ) 1 . . . c ( n − 1 ) ( d − 1 ) ) ∈ R n × d centers=\begin{pmatrix}c_{00}&c_{01}&...&c_{0(d-1)}\\ c_{10}&c_{11}&...&c_{1(d-1)}\\...&...&....&...\\c_{(n-1)0}&c_{(n-1)1}&...&c_{(n-1)(d-1)}\end{pmatrix}\in\mathbb R^{n×d} centers=c00c10...c(n1)0c01c11...c(n1)1.............c0(d1)c1(d1)...c(n1)(d1)Rn×d

torch.pow(self.centers, 2).sum(dim=1, keepdim=True)

= ( c 0 2 c 1 2 . . . c n − 1 2 ) = ( c 00 2 + c 01 2 + . . . + c 0 ( d − 1 ) 2 c 10 2 + c 11 2 + . . . + c 1 ( d − 1 ) 2 . . . c ( n − 1 ) 0 2 + c ( n − 1 ) 1 2 + . . . + c ( n − 1 ) ( d − 1 ) 2 ) ∈ R n × 1 =\begin{pmatrix}c_0^2 \\ c_1^2\\ ...\\c_{n-1}^2\end{pmatrix} =\begin{pmatrix}c_{00}^2+c_{01}^2+...+c_{0(d-1)}^2\\ c_{10}^2+c_{11}^2+...+c_{1(d-1)}^2\\...\\c_{(n-1)0}^2+c_{(n-1)1}^2+...+c_{(n-1)(d-1)}^2\end{pmatrix}\in\mathbb R^{n×1} =c02c12...cn12=c002+c012+...+c0(d1)2c102+c112+...+c1(d1)2...c(n1)02+c(n1)12+...+c(n1)(d1)2Rn×1

.expand(self.num_classes, batch_size)

= ( c 0 2 c 0 2 . . . c 0 2 c 1 2 c 1 2 . . . c 1 2 . . . . . . . . . . . . c n − 1 2 c n − 1 2 . . . c n − 1 2 ) ∈ R n × m =\begin{pmatrix}c_0^2 &c_0^2&...&c_0^2\\ c_1^2 &c_1^2&...&c_1^2\\ ...&...&...&...\\c_{n-1}^2 &c_{n-1}^2 &...&c_{n-1}^2 \end{pmatrix}\in \mathbb R^{n×m} =c02c12...cn12c02c12...cn12............c02c12...cn12Rn×m

.t()

= ( c 0 2 c 1 2 . . . c n − 1 2 c 0 2 c 1 2 . . . c n − 1 2 . . . . . . . . . . . . c 0 2 c 1 2 . . . c n − 1 2 ) ∈ R m × n =\begin{pmatrix}c_0^2 &c_1^2&...&c_{n-1}^2\\ c_0^2 &c_1^2&...&c_{n-1}^2\\ ...&...&...&...\\c_0^2 &c_1^2&...&c_{n-1}^2\end{pmatrix}\in \mathbb R^{m×n} =c02c02...c02c12c12...c12............cn12cn12...cn12Rm×n

distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()

d i s m a t = ( x 0 2 x 0 2 . . . x 0 2 x 1 2 x 1 2 . . . x 1 2 . . . . . . . . . . . . x m − 1 2 x m − 1 2 . . . x m − 1 2 ) + ( c 0 2 c 1 2 . . . c n − 1 2 c 0 2 c 1 2 . . . c n − 1 2 . . . . . . . . . . . . c 0 2 c 1 2 . . . c n − 1 2 ) ∈ R m × n dismat=\begin{pmatrix}x_0^2 &x_0^2&...&x_0^2\\ x_1^2 &x_1^2&...&x_1^2\\ ...&...&...&...\\x_{m-1}^2 &x_{m-1}^2&...&x_{m-1}^2\end{pmatrix}+\begin{pmatrix}c_0^2 &c_1^2&...&c_{n-1}^2\\ c_0^2 &c_1^2&...&c_{n-1}^2\\ ...&...&...&...\\c_0^2 &c_1^2&...&c_{n-1}^2\end{pmatrix}\in \mathbb R^{m×n} dismat=x02x12...xm12x02x12...xm12............x02x12...xm12+c02c02...c02c12c12...c12............cn12cn12...cn12Rm×n

distmat.addmm_(1, -2, x, self.centers.t())

Pytorch里addmm()和addmm_()的用法详解_悲恋花丶无心之人的博客-CSDN博客_addmm

d i s m a t = d i s t m a t − 2 × x × c e n t e r s T dismat=distmat - 2 × x ×centers^T dismat=distmat2×x×centersT
= ( x 0 2 x 0 2 . . . x 0 2 x 1 2 x 1 2 . . . x 1 2 . . . . . . . . . . . . x m − 1 2 x m − 1 2 . . . x m − 1 2 ) + ( c 0 2 c 1 2 . . . c n − 1 2 c 0 2 c 1 2 . . . c n − 1 2 . . . . . . . . . . . . c 0 2 c 1 2 . . . c n − 1 2 ) − 2 ( x 00 x 01 . . . x 0 ( d − 1 ) x 10 x 11 . . . x 1 ( d − 1 ) . . . . . . . . . . . . . x ( m − 1 ) 0 x ( m − 1 ) 1 . . . x ( m − 1 ) ( d − 1 ) ) × ( c 00 c 01 . . . c 0 ( d − 1 ) c 10 c 11 . . . c 1 ( d − 1 ) . . . . . . . . . . . . . c ( n − 1 ) 0 c ( n − 1 ) 1 . . . c ( n − 1 ) ( d − 1 ) ) T ∈ R m × n =\begin{pmatrix}x_0^2 &x_0^2&...&x_0^2\\ x_1^2 &x_1^2&...&x_1^2\\ ...&...&...&...\\x_{m-1}^2 &x_{m-1}^2&...&x_{m-1}^2\end{pmatrix}+\begin{pmatrix}c_0^2 &c_1^2&...&c_{n-1}^2\\ c_0^2 &c_1^2&...&c_{n-1}^2\\ ...&...&...&...\\c_0^2 &c_1^2&...&c_{n-1}^2\end{pmatrix}-2\begin{pmatrix}x_{00}&x_{01}&...&x_{0(d-1)}\\ x_{10}&x_{11}&...&x_{1(d-1)}\\...&...&....&...\\x_{(m-1)0}&x_{(m-1)1}&...&x_{(m-1)(d-1)}\end{pmatrix}×\begin{pmatrix}c_{00}&c_{01}&...&c_{0(d-1)}\\ c_{10}&c_{11}&...&c_{1(d-1)}\\...&...&....&...\\c_{(n-1)0}&c_{(n-1)1}&...&c_{(n-1)(d-1)}\end{pmatrix}^T\in \mathbb R^{m×n} =x02x12...xm12x02x12...xm12............x02x12...xm12+c02c02...c02c12c12...c12............cn12cn12...cn122x00x10...x(m1)0x01x11...x(m1)1.............x0(d1)x1(d1)...x(m1)(d1)×c00c10...c(n1)0c01c11...c(n1)1.............c0(d1)c1(d1)...c(n1)(d1)TRm×n

假设 d i s m a t = [ d i j ] , 0 ≤ i ≤ m − 1 , 0 ≤ j ≤ n − 1 dismat=[d_{ij}],0\le i\le m-1,0\le j\le n-1 dismat=[dij],0im1,0jn1,可以看作样本 x i x_i xi与类中心 c j c_j cj之间的距离,再结合后面的mask

classes = torch.arange(self.num_classes).long()

classes =[0,1,…,n-1]

输入的 l a b e l s ∈ R m labels\in \mathbb R^{m} labelsRm 对应每个样本的类别,

labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) unsqueeze(1): l a b e l s ∈ R m × 1 labels\in \mathbb R^{m×1} labelsRm×1
expand(batch_size, self.num_classes): l a b e l s ∈ R m × n labels\in \mathbb R^{m×n} labelsRm×n 每行元素相同

mask = labels.eq(classes.expand(batch_size, self.num_classes)) 将labels转换成one-hot

example:
batch size = 3 num class = 4
l a b e l s = ( 2 2 2 2 1 1 1 1 0 0 0 0 ) labels=\begin{pmatrix}2&2&2&2\\ 1&1&1&1\\0&0&0&0\end{pmatrix} labels=210210210210
c l a s s e s . e x p a n d = ( 0 1 2 3 0 1 2 3 0 1 2 3 ) classes.expand=\begin{pmatrix}0&1&2&3\\ 0&1&2&3\\0&1&2&3\end{pmatrix} classes.expand=000111222333
m a s k = ( F a l s e F a l s e T r u e F a l s e F a l s e T r u e F a l s e F a l s e T r u e F a l s e F a l s e F a l s e ) mask=\begin{pmatrix}False&False&True&False\\ False&True&False&False\\True&False&False&False\end{pmatrix} mask=FalseFalseTrueFalseTrueFalseTrueFalseFalseFalseFalseFalse

下面这部分自己理解的是通过mask 找到中心点,并且不断减小样本与其对应类别中心之间的距离。

dist = []
for i in range(batch_size):
    value = distmat[i][mask[i]]
    value = value.clamp(min=1e-12, max=1e+12)# for numerical stability
    dist.append(value)
dist = torch.cat(dist)
loss = dist.mean()
return loss

全部代码

class CenterLoss(nn.Module):
    """Center loss.

    Reference:
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.

    Args:
        num_classes (int): number of classes.
        feat_dim (int): feature dimension.
    """

    def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.use_gpu = use_gpu

        if self.use_gpu:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
        else:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))

    def forward(self, x, labels):
        """
        Args:
            x: feature matrix with shape (batch_size, feat_dim).
            labels: ground truth labels with shape (num_classes).  应该是batch_size
        """
        assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)"

        batch_size = x.size(0)
        #
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        # 1 * distmat - 2 * x * centers.t()
        distmat.addmm_(1, -2, x, self.centers.t())

        classes = torch.arange(self.num_classes).long()
        if self.use_gpu: classes = classes.cuda()
        
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, self.num_classes))

        dist = []
        for i in range(batch_size):
            value = distmat[i][mask[i]]
            value = value.clamp(min=1e-12, max=1e+12)  # for numerical stability
            dist.append(value)
        dist = torch.cat(dist)
        loss = dist.mean()
        return loss

你可能感兴趣的:(DL基础,pytorch,pytorch,深度学习,机器学习)