(原)CosFace/AM-Softmax及其mxnet代码

转载请注明出处:

http://www.cnblogs.com/darkknightzh/p/8525241.html

论文:

CosFace: Large Margin Cosine Loss for Deep Face Recognition

https://arxiv.org/abs/1801.09414

Additive Margin Softmax for Face Verification

https://arxiv.org/abs/1801.05599

第一篇论文目前无代码

第二篇论文官方代码:

https://github.com/happynear/AMSoftmax

这两篇论文第三方mxnet代码:

https://github.com/deepinsight/insightface

 

说明:没用过mxnet,下面的代码注释只是纯粹从代码的角度来分析并进行注释,如有错误之处,敬请谅解,并欢迎指出。

 

先查看sphereface,查看$\psi (\theta )$的介绍:http://www.cnblogs.com/darkknightzh/p/8524937.html

论文AM中定义$\psi (\theta )$为:

$\psi (\theta )=\cos (\theta )-m$

sphereface中只对w进行归一化,AM中对w及x均进行了归一化,不过为了使得训练能收敛,增加了一个参数s=30,最终AM如下:

${ {L}_{AMS}}=-\frac{1}{n}\sum\limits_{i=1}^{n}{\log \frac{ { {e}^{s\centerdot (\cos { {\theta }_{yi}}-m)}}}{ { {e}^{s\centerdot (\cos { {\theta }_{yi}}-m)}}+\sum\nolimits_{j=1,j\ne yi}^{c}{ { {e}^{s\centerdot \cos { {\theta }_{j}}}}}}}=-\frac{1}{n}\sum\limits_{i=1}^{n}{\log \frac{ { {e}^{s\centerdot (W_{yi}^{T}{ {f}_{i}}-m)}}}{ { {e}^{s\centerdot (W_{yi}^{T}{ {f}_{i}}-m)}}+\sum\nolimits_{j=1,j\ne yi}^{c}{ { {e}^{sW_{j}^{T}{ {f}_{i}}}}}}}$

程序中计算时,$s\centerdot (\cos { {\theta }_{yi}}-m)=s\centerdot \cos { {\theta }_{yi}}-sm$,分别计算$s\centerdot \cos { {\theta }_{yi}}$,sm。而后将yi处的减去sm,之后通过log softmax,得到概率,在计算损失。

具体的代码如下(完整代码请见参考网址中mxnet的代码):

 1     s = args.margin_s  # 参数s
 2     m = args.margin_m  # 参数m
 3     _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) # (C,F)
 4     _weight = mx.symbol.L2Normalization(_weight, mode='instance')  # 对w进行归一化
 5     
 6     nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s # 对x进行归一化,并得到s*x,(B,F)
 7     fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') # Y=XW'+b,(B,F)*(C,F)'=(B,C), '为转置
 8        
 9     s_m = s*m  # 计算s*m
10     gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = s_m, off_value = 0.0) # 得到one-hot矩阵,每行对应i处值为s_m
11     fc7 = fc7-gt_one_hot  # 将对应i处的减去s_m

 

你可能感兴趣的:(人工智能)