论文地址:https://openreview.net/forum?id=rJlnOhVYPS
代码地址:https://github.com/yxgeee/MMT
文章的主要核心思想顾名思义就是解决伪标签噪声问题。
看到网上有很好的解析文章:
https://zhuanlan.zhihu.com/p/116074945/
1.最先进的无监督领域自适应人再识别方法通过在目标域上使用聚类算法生成的伪标签进行优化,将学习到的知识从源域转移到目标域,但忽略了不可避免的标签噪声的聚类过程。
2.我们提出了一个无监督的框架,Mutual Mean-Teaching (MMT)来柔化目标领域的伪标签(就是降噪),以替代训练方式,通过离线提炼hard伪标签和在线提炼soft伪标签来学习目标领域更好的特征。
(在这里,"硬"标签指代置信度为100%的标签,如常用的one-hot标签[0,1,0,0],而"软"标签指代置信度<100%的标签,如[0.1,0.6,0.2,0.1]。)
3.此外,常用的做法是同时采用分类损失和三元组损失,以实现行人重id模型的最优性能。然而,传统的三元组损失不能与软精炼的标签一起工作。为了解决这一问题,提出了一种新的支持软三元组伪标签学习的名为softmax-triplet损失算法,以实现最优的域自适应性能。
1.最先进的UDA方法(Song et al., 2018;Zhang et al., 2019b;Yang et al., 2019)对行人重id组未标注的图像进行聚类算法,并用聚类生成的伪标签训练网络。虽然伪标签生成和伪标签特征学习交替进行,在一定程度上细化了伪标签,但不可避免的标签噪声仍然在很大程度上阻碍了神经网络的训练。噪声的产生是由于源域特征的可转移性有限、目标域恒等式的个数未知、目标域恒等式的个数未知等因素造成的。噪声伪标签的提取对最终性能有至关重要的影响,但基于聚类的UDA算法大都忽略了伪标签的提取。
2.我们提出了一个无监督的MMT(共同均值学习)的框架网络,
在离线提炼硬伪标签和在线提炼软伪标签的联合监督下,对神经网络进行优化,有效地实现伪标签的提炼,提供了更健壮的软伪标签。
3.为了避免训练误差的放大,提出了每个网络的时间平均模型(the temporally average model of each network)
,以产生可靠的软标签,在协同训练策略中监督其他网络。
4.通过在目标域上使用这种在线软伪标签训练的对等网络,可以迭代改进学习到的特征表示,提供更准确的软伪标签,从而进一步提高学习到的特征表示的可辩别性。
5.需要注意的是,这两个网络上的协同训练策略仅在训练过程中采用。只有一个网络保持在推断阶段,而不需要任何额外的计算或内存成本。
近年来无监督学习的发展与不足(我会单独写一篇文章进行概括总结)
说了一堆其实就是为了引出3.2MMT
的方法使用。
讲的是预训练模型的使用过程:
利用神经网络预先训练模型,同时在训练阶段使用常用的id损失和三元组损失。
总的损失函数:
我们的MMT框架通过协作训练两个具有不同初始化的相同网络来生成软伪标签。
总体框架如图所示:
1.伪类仍然是通过现有的基于聚类的UDA方法生成的,其中每个聚类代表一个类。
2.除了硬伪标签和噪声伪标签外,我们的两个协作网络还通过网络预测生成在线软伪标签来互相训练
3.通常来说,即使使用硬伪标签对网络进行训练后,它们也可以大致捕获训练数据的分布,因此它们的类预测可以作为训练的软类标签。然而,由于训练误差和有噪声的硬伪标签,这种软标签通常并不完美。
4.为了避免两个网络之间的协作偏差,使用每个网络过去的时间平均模型而不是当前模型来生成另一个网络的软伪标签。
5.联合使用离线硬伪标签和在线软伪标签对两个协作网络进行训练。在训练之后,只有一个经过验证性能更好的过去的平均模型
来用于推理(如下图c)。
我们将两个相互合作的网络分别表示为特征变化函数F(·|θ1) and F(·|θ2),将它们的伪标签分类器分别表示为 Ct1 and Ct2。
对于这两个网络,每个目标域图片都可以用xti,xti1表示,
并且它们的伪标签的可信度被分别预测为 Ct1(F(xti|θ1)) and Ct2(F(xti1|θ2)). (Ct1与Ct2就是伪标签置信向量)
训练协作网络的一种方法是直接利用上述伪标签置信向量
作为软伪标签来训练另一个网络。然而,在这种情况下,两个网络的预测可能收敛到相等,两个网络失去了它们的输出独立性,同时分类误差和伪标签误差可能在训练过程中被放大。
为了避免错误的不断扩大,我们不选择直接利用生成的伪标签置信向量作为软标签,而是提出了一个利用每个网络的时间平均模型来生成可靠的软伪标签来监督另一个网络。
在当前迭代T处,两个网络的时间平均模型的参数分别表示为 E(T)[θ1] and E(T)[θ2] ,E(T −1)[θ1],与E(T −1)[θ2]分别表示前T-1代中的平均时间模型的参数。
the initial temporal average parameters are E(0)[θ1] = θ1, E(0)[θ2] = θ2,and α is the ensembling momentum to be within the range [0, 1).
The robust soft pseudo label supervisions are then generated by the two temporal average models as Ct1(F(xti|E(T)[θ1])) and Ct2(F(xti1|E(T)[θ2])) respectively
更加健壮的软伪标签监督由两个平均时间模型产生,并分别表示为
Ct1(F(xti|E(T)[θ1])) 与 Ct2(F(xti1|E(T)[θ2]))
注:这个Ct1与Ct2 与 上述第二条中的Ct1与Ct2形成对比,此处引入了E作为相互监督
因此,用其他网络生成的软伪标签来优化θ1和θ2时的软分类损失可表示为上述公式。
两个网络的伪标签预测通过使用其他网络的过去平均模型产生监督,具有更好的非相关性,因此可以更好地避免误差放大。
我们提出了专门针对软标签的三元组损失,其具体描述如下:
公式参数解释:
考虑到两个网络的联合性,我们利用一个网络的过去平均时间模型为另一个网络生成软三元组标签并提出软三元组损失
Ti(E(T)[θ1]) and Ti(E(T)[θ2]) 是由两个网络的过去平均时间模型产生的三元组软标签。这两个软三元组标签随着训练监督不断优化,稳固。通过采用软三元组损失,我们的MMT克服了使用传统三元组损失硬监督的限制。它成功的训练软三元组标签,它在实验过程中体现出了性能改善。