©作者 | koncle
大家好,在这里给大家分享一下我们最近被 ICCV2023 接受的工作《DomainAdaptor: A Novel Approach to Test-time Adaptation》。该工作针对领域泛化领域中的测试阶段领域自适应问题进行了研究,旨在提高模型在测试阶段对测试数据的高效自适应。为了充分挖掘测试数据中的信息,我们提出了一个统一的方法,称为 DomainAdaptor,用于测试时的高效自适应。
该方法包括 AdaMixBN 模块(高效融合源域信息)和广义熵最小化(GEM)损失(高效挖掘高置信度样本)。大量实验证明,DomainAdaptor 在多个领域泛化基准测试上优于现有方法。此外,我们的方法在少量数据的未见领域上相较于现有方法能带来更显著地提升。
论文题目:
DomainAdaptor: A Novel Approach to Test-time Adaptation
论文链接:
https://arxiv.org/pdf/2308.10297.pdf
Code链接:
https://github.com/koncle/DomainAdaptor
▲ 测试阶段领域自适应方法流程
领域泛化旨在解决高效的泛化问题,即当存在一个或多个源域数据时,如何用其训练一个具有强泛化能力的模型,以期能够泛化到任意未见领域上。虽然当前领域泛化已经提出了很多方法,但是在应对目标领域时,依旧存在性能下降的问题。原因在于,这些方法主要针对训练策略进行调整,而忽略了测试阶段的模型自适应。传统进行领域自适应的方法主要使用无监督领域自适应的方式,对训练过的模型进行重训练,这会带来巨大的训练开销。因此,本文研究一个更加实际的问题:测试阶段自适应(fully test-time adaptation)[1],在只给定源域模型的情况下,对任意领域进行自适应。
但是当前测试阶段自适应方法自适应效率较低,在遇到较大领域差异时,需要依赖于连续的自适应才能达到较好的性能。因此,本文针对当前方法的两个问题进行研究,以提升方法对测试样本的自适应效率:
1、当前的测试阶段自适应方法依赖于对 BN 统计量的调整,但是这些方法直接将模型中的统计量替换为测试阶段统计量进行归一化,存在测试统计量估计不准确的问题,同时也忽略了源域统计量自身的丰富信息;
2、当前的测试阶段自适应方法的损失主要采用熵最小化(EM)损失,该损失在模型置信度高时,对样本利用率会显著下降,从而导致测试样本利用率低下。
针对这两个问题,我们提出了 DomainAdaptor 方法框架,其中 AdaMixBN 用于动态地对源域和目标域之间的统计量进行融合,从而保证测试阶段统计量的准确估计;GEM loss 保证对高置信度样本的学习,提升样本利用效率。
▲ 方法整体流程
针对测试阶段统计量估计不准确的问题,之前方法通常丢弃源域统计量、直接使用测试 batch 统计量进行归一化,但测试 batch 的统计量的估计不一定准确,导致性能下降。因此我们提出了混合源域和目标的统计量来进行自适应的策略,用丰富的源域信息帮助测试 batch 的统计量的估计:
其中混合系数为 。但由于测试阶段面对的未见领域的数据,因此无法像之前方法 [2] 一样,人为设定固定混合系数。此外,模型高层和低层之间的统计量是不同的,也需要不同的混合系数。因此,本文提出了一种动态的混合策略:
其中, 为源域统计量到测试数据统计之间的距离, 为测试batch统计量与其中单张图片统计量之间的距离, 为源域统计量到测试 batch 内单张图片统计量之间的距离。如果源域统计量到 batch 统计量的距离小于 batch 内图像统计量到 batch 统计量的距离,这意味着源领域和测试领域相近,可以使用较大的 来融合更多的源域统计量,如方法流程图中所示。同时我们也计算了单张图片统计量、测试 batch 统计量和源域统计量之间的距离如下:
▲ 单张图片统计量、测试batch统计量和源域统计量之间的欧式距离
在上表中,源领域统计量到测试 batch 统计量的距离大于 batch 内单张图像的统计量到 batch 统计量的距离,表明 Sketch 领域存在较大的领域差距。为了减小这个领域差距,应该融合更多的测试统计信息(即,较小的 )。
尽管以上方法可以有效提升模型性能,但是直接对它进行微调(finetune)时,会存在性能急剧下降的问题,如下表中,使用不同统计量进行微调的性能:
▲ 使用不同统计量在finetune前后之间的性能差异
该问题是由于源域统计量与微调后的参数之间不匹配所导致的,如下图所示:
▲ 使用源域统计量进行 fuentine 导致统计量与finetune后的参数不匹配
在微调之后,批归一化(BN)的权重(例如,图 3 中的 和 )会发生变化,从而导致输出的特征 分布发生变化。然而,源域统计量(例如,图 3 中的 和 )在微调过程中保持不变。因此,如果我们继续使用仅适用于原始分布的源域统计量来对具有偏移分布的特征图 进行归一化,性能不可避免地会降低。因此,我们提出通过将 AdaMixBN 的归一化方式转换成仅用测试统计量进行归一化,并隐式地融入源域统计量:
2.2 广义熵最小化(GEM)
在数据经过 AdaMixBN 和统计量变换之后,需要对模型进行微调以实现进一步的自适应。熵最小化损失是一种简单有效的手段:
其中 是温度系数。但是由于该损失对高置信度样本输出的损失过小,导致无法有效利用这些高置信度样本,如下图所示:
▲ 置信度越高的样本,loss越小
当一批数据中存在大量高置信度样本时,网络就难以通过单次微调学习到更多的信息。同时,又注意到:高置信度样本的分布通常比较尖锐,如上图所示。因此,为了能够让网络能够对高置信度样本进行高效的学习,我们提出对该损失中的温度系数进行更改:
其中, 和 表示分别用不同的温度系数 , 得到的置信度。
对以上公式中的第 个类别的 logit 进行梯度分析:
发现:1)当我们对公式(1)中的 进行 stop_grad 操作后,其可以看做是软标签,此时公式(2)中第二项会消失,又由于 和 来自同一个输出,因此公式(1)实际上是一个自蒸馏的过程,其中 为 teacher 分布, 为 student 分布。2)当 时,第一项为 0,此时它和原始的熵最小化类似:让分布更加“尖锐”。值得注意的是:对单个样本使用该损失并不能导致模型学到有效信息,因为它并不能改变原始分布的趋势。只有存在大量数据时,该损失才能学习到有效信息。
由于不同温度系数存在不同的作用,我们令 以进行温度的自适应,同时我们设计了几个不同的损失函数:
GEM-T:此时 ,该 loss 让网络继续输出更加 sharp 的分布
GEM-SKD:此时 并且 的梯度固定,让网络学习 “teacher ” 的分布
GEM-Aug:在 GEM-SKD 的基础上,使用了测试阶段增广 ,通过提升 teacher 分布的准确率来提升整体准确率。
实验
与一些 SOTA 方法的性能对比,对比过程中,我们采用单次自适应的设定,即当测试阶段到来一个 batch 时,模型只针对该 batch 进行调整并测试,当下一个 batch 到来时,模型使用原始的权重而非调整过后的权重进行测试。通过该方式可以有效的验证之前的方法是否能够有效利用单个 batch 的信息进行自适应。
沿用 DG 设定,我们在多个 DG 数据集(PACS,VLCS,OfficeHome,MiniDomainNet)上进行了实验,在三个领域上先进行预训练,得到预训练模型(DeepAll),再在其他领域上进行测试阶段自适应的测试。我们的方法对比之前的方法能够取得显著的性能提升。
▲ 与SOTA性能对比
消融实验:
AdaMixBN 的性能和手动设置 之间的性能差距:
模型权重越高层, 值也越大,说明高层的信息相比低层是更具有可迁移性:
将我们的方法用于一些预训练的 DG SOTA 方法,可以看到同样可以显著提升预训练模型的性能:
更多的实验和细节还请阅读原文。
在本工作中,我们研究了如何提升测试阶段自适应方法的样本利用率的问题,针对之前方法无法忽略源域统计量信息及熵最小化损失样本高置信度利用率低的问题,提出 DomainAdaptor 框架,通过融合源域信息(AdaMixBN)和调整熵最小化损失(GEM)来达到针对单个 batch 信息的高效利用。多个数据集上的实验表明,该方法能够有效提升测试样本的利用率,且能够应用于任意预训练好的模型上而无需重新训练。
参考文献
[1] Wang, Dequan, et al. "Tent: Fully Test-Time Adaptation by Entropy Minimization." International Conference on Learning Representations. 2020.
[2] Schneider, Steffen, et al. "Improving robustness against common corruptions by covariate shift adaptation." Advances in neural information processing systems. 2020.
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
投稿通道:
• 投稿邮箱:[email protected]
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
·
·