论文地址:https://arxiv.org/pdf/1612.02295.pdf
L-softmax的主要思想是通过一个超参m对softmax+cross entropy的损失函数进行改进。一般我们把y = Wx + b, output = softmax(y), cross_entropy(output, label)这个过程统称为softmax loss.
从softmax到L-softmax的改进在论文中已经解释的非常清楚了。
损失函数可以写成下面的形式
关于角度的问题,我们需要设计一个单调递减的函数。我的理解是由于cos函数是一个周期函数,当m*theta > pi之后,cos(theta)会进入上升阶段。而很明显,在(4)这个式子中,theta(yi)越大,我们需要对这个限制的越厉害,因此也就需要一个更小的phi值。所以作者设计了如下一个phi函数。
在实现的过程中,我们利用cos(theta)的定义和多倍角公式,得到下面的式子:
L-softmax有多种框架的实现版本,作者使用的是caffe,本文不详细介绍反传过程,因此选择pytorch实现版本进行解读。pytorch版本实现地址:https://github.com/jihunchoi/lsoftmax-pytorch/blob/master/lsoftmax.py。
实现代码
import math
import torch
from torch import nn
from torch.autograd import Variable
from scipy.special import binom
class LSoftmaxLinear(nn.Module):
def __init__(self, input_dim, output_dim, margin):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.margin = margin
self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim))
self.divisor = math.pi / self.margin
'''
#这个是系数,对应(7)式中前面的系数。使用一个二项分布的函数,产生系数,并且每隔1个取一个数字。假设m=3,则有[1, 3, 3, 1], 隔开取之后是[1, 3].m = 4则有[1, 4, 6, 4, 1],隔开取之后是[1, 6, 1],与7中的系数是对应的关系
'''
self.coeffs = binom(margin, range(0, margin + 1, 2))
'''
这个是cos项的指数,从m到0,每次减去2
'''
self.cos_exps = range(self.margin, -1, -2)
'''
这个是sin平方项的指数,从1到n,与cos项的指数对应
'''
self.sin_sq_exps = range(len(self.cos_exps))
'''
这个是符号 1 -1 1 -1 1 -1....
'''
self.signs = [1]
for i in range(1, len(self.sin_sq_exps)):
self.signs.append(self.signs[-1] * -1)
def reset_parameters(self):
nn.init.kaiming_normal(self.weight.data.t())
def find_k(self, cos):
acos = cos.acos()
k = (acos / self.divisor).floor().detach()
return k
def forward(self, input, target=None):
'''
input: N,D 其中N是batch size
target是(N,)的label
'''
if self.training:
assert target is not None
'''
y = Wx这样得到输出逻辑, logit的维度应该是(N, C), C是输出类别数
'''
logit = input.matmul(self.weight)
batch_size = logit.size(0)# N
'''
通过这个操作把N, C矩阵中的yi全部都取出来,形成一个一维向量 (N , )
'''
logit_target = logit[range(batch_size), target]
'''
求L2范数
'''
weight_target_norm = self.weight[:, target].norm(p=2, dim=0)
input_norm = input.norm(p=2, dim=1)
# norm_target_prod: (batch_size,)
norm_target_prod = weight_target_norm * input_norm
# cos_target: (batch_size,)
'''
这里就得到了cos(theta)和sin(theta)的二次方
'''
cos_target = logit_target / (norm_target_prod + 1e-10)
sin_sq_target = 1 - cos_target**2
num_ns = self.margin//2 + 1# m = 4时为3, m = 3时为2 ,让这个数为n
# coeffs, cos_powers, sin_sq_powers, signs: (num_ns,)注意这里的shape
coeffs = Variable(input.data.new(self.coeffs))
cos_exps = Variable(input.data.new(self.cos_exps))
sin_sq_exps = Variable(input.data.new(self.sin_sq_exps))
signs = Variable(input.data.new(self.signs))
'''
(N, 1) ** (1, n)-->(N, n)这个矩阵是batch中每个example的cos
'''
cos_terms = cos_target.unsqueeze(1) ** cos_exps.unsqueeze(0)
'''
同上,这个矩阵是batch中每个example的sin平方
'''
sin_sq_terms = (sin_sq_target.unsqueeze(1)
** sin_sq_exps.unsqueeze(0))
'''
符号*系数*cos*sin平方 (1, n) * (1, n) * (N, n) * (N , n)
'''
cosm_terms = (signs.unsqueeze(0) * coeffs.unsqueeze(0)
* cos_terms * sin_sq_terms)
# 各个example各自求和,得到 cos (m * theta)
cosm = cosm_terms.sum(1)
# 寻找k值
k = self.find_k(cos_target)
# 根据k值计算||W||*||x||*cos (phi)
ls_target = norm_target_prod * (((-1)**k * cosm) - 2*k)
# 把计算出来的值代替原来的softmax中yi的位置,返回之后通过cross entropy计算就得到了Lsoftmax
logit[range(batch_size), target] = ls_target
return logit
else:
assert target is None
return input.matmul(self.weight)