分类器ArcFace、ArcLoss在MNIST数据集上的实现和效果

分类器ArcFace、ArcLoss在MNIST数据集上的实现和效果

写在前面:
  前一篇文章(电梯直达)给大家介绍了CenterLoss,本文将带领大家认识一下ArcFace(ArcLoss、Insightface),并在MNIST数据集上实战一下看一下效果。

一、原理

在这里插入图片描述
  CenterLoss是将每个类别的特征缩减到他的中心位置,从而间接使不同特征之间界限分明,而ArcLoss则是在原本两个特征之间的夹角上再加上一个角m,然后优化这个m,使夹角慢慢变大,两个特征慢慢远离,从而使不同特征之间界限分明。他的理想状态是将所有特征之间的夹角都最大化,也就是每个特征都被压迫成一条细线。
  如下图中的α和β,同时加上m,会使得α和β同时增大,从而中间紫色区域将会被压缩,当m优化到最佳,那么紫色区域将会成为一条线。当然,这只是一种理想状态,是一种期望,实际情况下不太可能甚至根本不可能达到这样的效果。
分类器ArcFace、ArcLoss在MNIST数据集上的实现和效果_第1张图片

二、效果

  原理就是那么简简单单,看下效果
分类器ArcFace、ArcLoss在MNIST数据集上的实现和效果_第2张图片

三、训练说明

  如果只使用ArcLoss,loss只会下降一点点,但也能训练出来,要是再结合一个其他的损失函数,模型训练出来的效果将会非常棒,以上效果图就是ArcLoss + CrossEntropyLoss训练的结果。
  用ArcLoss代替原来的Softmax作为输出函数,初始化ArcLoss时指定输入特征维度和分类数,用ArcLoss输出的数据作为分类依据,与target计算一次损失,再用模型的输出与target计算一次损失。

四、实现

""" 网络输出层,不加激活 """
fc_feature = nn.Linear(1024, 2, bias=False)
out = nn.Linear(2, 10, bias=False)
feature_out = fc_feature(conv_out)
out = out(feature_out)
return feature_out, out


""" 初始化ArcLoss,这里是2维特征,10分类 """
arcface = ArcLoss(2, 10)		# 输入特征(N, 2), 输出(N, 10),输出可直接拿来分类


""" 损失函数 """
loss_f = torch.nn.CrossEntropyLoss()

""" 优化器:听说SGD效果更好,感兴趣的自己去捣鼓捣鼓 """
self.optimer = torch.optim.Adam([
    {'params': self.net.parameters()},
     {'params': self.arcface.parameters()}
])


""" 训练器
为了方便阅读理解,部分代码被简化,
如以下第一句标准语法应为:for i, (data, target) in enumerate(dataloader)
"""
for data, target in dataloader:
	feature, out = net(data)
	output = arcface(feature)
	arc_loss = loss_f(output, target)			# 计算ArcFace输出的分类损失
    cls_loss = loss_f(out, target)				# 计算网络直接输出的分类损失
    loss = 0.9 * arc_loss + 0.1 * cls_loss		# 如果只计算arcloss,那么网络的分类能力会很差
    """acc_arc:arcface输出的分类正确率;			acc_cls:网络输出的out的分类正确率"""
    acc_arc = torch.sum(torch.argmax(output, dim=1) == target) / batch_size
    acc_cls = torch.sum(torch.argmax(out, dim=1) == target) / batch_size


"""ArcLoss函数实现"""
class ArcLoss4(nn.Module):
    def __init__(self, feature_num, class_num, s=10, m=0.1):
    	"""
        :param feature_num:     特征数
        :param class_num:       类别数
        :param s: 
        :param m:               加上去的夹角,初始为0.1
        """
        super().__init__()
        self.class_num = class_num
        self.feature_num = feature_num
        self.s = s
        self.m = torch.tensor(m)
        self.w = nn.Parameter(torch.rand(feature_num, class_num), requires_grad=True)  # 2*10
    def forward(self, feature):
        feature = nn.functional.normalize(feature, dim=1)
        w = nn.functional.normalize(self.w, dim=0)
        cos_theat = torch.matmul(feature, w) / 10
        sin_theat = torch.sqrt(1.0 - torch.pow(cos_theat, 2))
        cos_theat_m = cos_theat * torch.cos(self.m) - sin_theat * torch.sin(self.m)
        cos_theat_ = torch.exp(cos_theat * self.s)
        sum_cos_theat = torch.sum(torch.exp(cos_theat * self.s), dim=1, keepdim=True) - cos_theat_
        top = torch.exp(cos_theat_m * self.s)
        div = top / (top + sum_cos_theat)
        return div

# 实现方式2
class ArcLoss2(nn.Module):
    def __init__(self, feature_dim=2, cls_dim=10):
        super().__init__()
        self.W = nn.Parameter(torch.randn(feature_dim, cls_dim), requires_grad=True)

    def forward(self, feature, m=1, s=10):
        x = nn.functional.normalize(feature, dim=1)
        w = nn.functional.normalize(self.W, dim=0)
        cos = torch.matmul(x, w)/10             # 求两个向量夹角的余弦值
        a = torch.acos(cos)                     # 反三角函数求得 α
        top = torch.exp(s*torch.cos(a+m))       # e^(s * cos(a + m))
        down2 = torch.sum(torch.exp(s*torch.cos(a)), dim=1, keepdim=True)-torch.exp(s*torch.cos(a))
        out = torch.log(top/(top+down2))
        return out

五、测试

  测试也很简单,直接把提取器提取的特征放进arcface去得到输出,拿输出做分类,也可以直接拿网络输出的out做分类,通过arc的输出会压缩在0~1之间,而直接的输出没有范围,但满足最大值最可靠,用softmax对arc的输出和net的输出进行比较可以发现net的out比arc的output更精确,这也解释了为什么要加两个loss才能训练得更好。并且我们通常也不会拿他俩的输出来做分类用

feature, out = net(img_data)		# 将处理好的图像数据扔进网络
output = arc(feature)[0]
res1 = torch.argmax(output)			# 通过arc的输出的分类结果
res2 = torch.argmax(out)			# 模型直接的输出的分类结果
print('Arc_out:', torch.nn.Softmax(dim=0)(output))
print('Net_out:', torch.nn.Softmax(dim=0)(out))

六、使用

使用不用多讲,多讲都是废话,就一句话:
  将图片处理成数字信号,丢进网络,拿到特征feature去做余弦相似度对比。
  其实不管CenterLoss还是ArcLoss,在做目标相似度对比识别的时候都用不上他,用到的都只是在他前面输出的那个特征向量。而如果你是做分类,也用不上他,用到的是网络最后输出的分类结果。(当然,如果你非要搞特殊,就要拿arcloss的输出做分类,那也是ok的)
  总结一下:CenterLoss和ArcLoss都只是在训练时提高提取器的特征提取能力,在使用时用不上。

写在最后

print('Thanks! The end!')		
# 有错误之处,欢迎批评指正

你可能感兴趣的:(人工智能,AI)