ArcFace Notes

Framework: Pytorch 0.3.0

关于pytorch 反传计算图一点小心得:

  • 计算图记录了variable的所有操作。若需要此变量反向求导,则一开始参与计算时该以variable出现。

  • 直接改动variable下的tensor(即variable.data),则其改动在反传中无效。

  • 计算时最好进行矩阵整体运算,单独对矩阵中某些数值进行改动,计算图代价很大。
    如,我们的目标是将Mat(cos_\theta)j=label处替换成phi_\theta
    如下代码,如果对phi_theta[i,j]进行单独操作,由于公式1中的cos_theta[i,j]是variable。因此,在计算图中,矩阵中的每个元素都得单独反传。

      for i in range(cos_theta.shape[0]):
          j=target[i].data[0]
          if cos_theta[i,j].data[0] >= -self.cosm:
              phi_theta[i,j]=self.cosm * cos_theta[i,j] - \
                           self.sinm * torch.sqrt(1e-6+1-cos_theta[i,j]*cos_theta[i,j]) #公式1
          else:
              phi_theta[i,j] = cos_theta[i,j]
    

将代码改成

    for i in range(cos_theta.shape[0]):
        j=target[i].data[0]
        if cos_theta[i,j].data[0] >= -self.cosm:
            flagMat[i,j]= 1
        else:
            flagMat[i,j] = 0
    flagMat=Variable(flagMat)
    phi_theta=(self.cosm * cos_theta - self.sinm * torch.sqrt(1e-6+1-cos_theta*cos_theta))*flagMat\
              +(1-flagMat)*cos_theta

新构造指示矩阵flatMat公式1进行运算,且flatMat不参与反传,得到的phi_\theta.grad_fn公式1backward,计算图恢复正常!life will be better!

Reference

  • 『PyTorch』第五弹深入理解autograd上:Variable属性方法
  • 画pytorch模型图,以及参数计算

你可能感兴趣的:(ArcFace Notes)