今天阅读ICLR 2023 ——Towards Stable Test-time Adaptation in Dynamic Wild World
Keywords:Test-time adaptation (TTA);
TTA通过在测试样本上在线调整模型,在下面Wild test setting下稳定性差:
本文对 TTA 在Wild test setting下失败原因进行分析,如何在这种Wild test setting下解决TTA稳定性问题?
原因发现BN层是关键影响因素
本文发现:
本文解决思路:
为了解决GN/LN模型测试时熵最小化的模型崩溃问题,主要思想概括如下:
该问题由于具有较大梯度的噪声样本导致,所以我们根据他们的熵来过滤具有较大和噪声梯度的部分样本;
对于剩下样本,我们提出了sharpness-aware and reliable entropy minimization method(SAR),使可靠测试样本的熵和熵的锐度最小化,确保模型权重优化到平坦的最小值,从而对较大和嘈杂梯度/更新具有鲁棒性
现实世界中,由于光线,拍摄等问题造成测试域上distribution shift
Prior TTA通过在测试阶段在线更新模型来解决上述问题并已经取得不错的表现,这些优异性能在特定测试下获得,如在一段时间内的样本均来自于同一种分布偏移类型、the ground-truth test label distribution是均匀且随机的,以及每次需要有一个 mini-batch 的样本后才可以进行适应。然而实际情况的测试数据是任意的。
wild TTA是测试场景可能为以下三种情况并且导致TTA不稳定:
i) mixture of multiple distribution shifts
ii) small test batch sizes (even single sample)
iii) the ground-truth test label distribution may be imbalanced at each time-step
之前的工作有:
MEMO 、episodic Tent 、DDA以及the diffusion model in DDA
how to stabilize online TTA under wild test settings is still an open question.
解决distribution shift问题,之前提出了许多方法:DG、增加训练及大小,数据增强,这些方法目的是在训练时提前预测测试集的偏移,使训练分布能够覆盖测试数据可能的偏移。然而这是不现实的并且计算大,因此我们打算在测试集上直接学习来解决测试集偏移问题
传统的UDA联合优化带标签源域和无标签目标域来解决分布漂移,例如设计一个domain discriminator 在特征级对齐源域和目标域或者通过对比学习方式对齐源域和目标域
无源域**(source-free)UDA**方法用来解决源域不存在时的自适应问题,例如基于生成的方法,从模型生成原图像。这些方法在整个测试集上离线适应模型,通常需要多个训练阶段,因此很难部署在在线测试场景中。
TTA主要分为两类——Test-time Training (TTT)联合优化具有监督损失和自监督损失的源域模型,并在测试时进行自监督学习。自监督损失可以是TTT中的旋转预测(Gidaris等人,2018),也可以是TTT++和MT3中基于对比的目标
Fully Test-time Adaptation (TTA) 不改变训练过程,可以应用于任何预训练的模型,包括适应BN层,无监督熵最小化,预测一致性最大化
SAM 同时优化监督目标(例如交叉熵)和损失曲面的锐度,旨在找到具有良好泛化能力的平坦最小值
在本工作中,我们分析了在test-time熵最小化的失败原因,我们发现一些产生gradients with large norms 的噪声样本会损害模型自适应,从而导致模型崩溃。为了缓解这个问题,我们提出了sharpness-aware minimization使在线模型更新对那些有噪声/大梯度具有鲁棒性,算法如下:
文中写了BN层的具体计算流程:
BN是计算a batch of test samples 的 mean E[x]和variance Var[X]
在Wild test setting中:
mixed distribution shifts
BN数据代表了一个分布,理想情况下每个分布都对应自己的数据
从 mini-batc test samples 估计多个分布的共享BN数据会限制性能
small batch sizes
统计数据的质量依赖于batch size ,因此small batch size 很难精准估计
online imbalanced label distribution shifts
不平衡的标签偏移也会导致BN统计数据偏向数据集中的某些特定类
因此我们提出BN无关更适合wild TTA,如group norm(GN) and layer norm(LN)
在线熵最小化倾向于导致崩溃的平凡解,即预测所有样本到同一个类
我们发现:当distribution shift严重时,在GN模型上的熵最小化趋于崩溃
为什么会发生模型崩溃?
以下实验都是在ResNet50 (GN)高斯噪声的随机ImageNet-C上进行的,(严重性)level越大表示分布偏移越严重。根据图 (a) 和 (b) 显示,在分布偏移程度严重(level 5)时,在线自适应过程中突然出现了模型退化崩溃现象,即所有样本(真实类别不同)被预测到同一类;同时,GN在模型崩溃前后快速增大而后降至几乎为 0,见图 6(c)
说明可能是某些大尺度 / 噪声梯度破坏了模型参数,进而导致模型崩溃。
基于以上分析,避免模型崩溃的两种最直接的解决方案是根据样本梯度过滤测试样本或进行梯度裁剪。然而,由于不同模型和分布移位类型的梯度范数尺度不同,这些方法并不十分可行
对于模型崩溃,我们的解决方案:本文提出了锐度敏感且可靠的测试时熵最小化方法 (Sharpness-aware and Reliable Entropy Minimization Method, SAR)。其从两个方面缓解这一问题:
1)可靠熵最小化从模型自适应更新中移除部分产生较大 / 噪声梯度的样本;
2)模型锐度优化使得模型对剩余样本中所产生的某些噪声梯度不敏感
具体细节阐述如下:
目的:从模型自适应更新中移除部分产生较大 / 噪声梯度的样本
选择损失值较小的样本,可以剔除部分梯度较大的样本(Area1),由于Area2的置信度低不可靠因此也移除。
目的:使初筛后梯度较大样本对模型更新不敏感,模型最终优化至flat minimum
通过Reliable Entropy Minimization methods过滤后的样本Area1 2,理想情况我们只要Area3的样本优化模型。Area4的样本仍然有很大梯度可能损害自适应无法避免。我们希望模型往优化至一个 flat minimum,使其能够对噪声梯度带来的模型更新不敏感,即不影响其原始模型性能,优化目标为:
上述目标的最终梯度更新形式如下:
至此,本文的总体优化目标为:
(1) Norm Layer Effects in TTA Under Small Test Batch Sizes
batch sizes (BS) selected from {1, 2, 4, 8, 16, 32, 64}
不同方法和模型(不同归一化层)在不同 batch size 下性能表现。图中阴影区域表示该模型性能的标准差,ResNet50-BN 和 ResNet50-GN 的标准差过小导致在图中不显著
(2) Norm Layer Effects in TTA Under Mixed Distribution Shifts
不同方法和模型(不同归一化层)在混合分布偏移下性能表现
(3) Norm Layer Effects in TTA Under Online Imbalanced Label Shifts
不同方法和模型(不同归一化层)在在线不平衡标签分布偏移下性能表现,图中横轴 Imbalance Ratio 越大代表的标签不平衡程度越严重
SAR 基于上述三种动态开放场景,即 a)混合分布偏移、b)单样本适应和 c)在线不平衡类别分布偏移,在 ImageNet-C 数据集上进行实验验证,结果如表 1, 2, 3 所示。SAR 在三种场景中均取得显著效果,特别是在场景 b)和 c)中,SAR 以 VitBase 作为基础模型,准确率超过当前 SOTA 方法 EATA 接近 10%。
首先在介绍这些问题之前我们要知道一个点
深度学习的成功主要归功于假设大量的标注数据和训练集与测试集独立且来自同一概率分布, 然后设计相应的模型和判别准则对待测试的样例的输出进行预测。然而实际场景中训练和测试样本的概率分布是不一样的
源域和目标域:
可以简单理解为源域就是训练集,目标域就是测试集
专业术语:源域(Source Domain)是已有的知识领域;目标域(Target Domain)是要进行学习的领域
源域与目标域区别主要体现在数据分布上,这个问题又分为三大类:
Domain shift:
如果训练资料和测试资料是来自于不同的分布,这样就会让模型在测试集上的效果很差,这种问题称为Domain shift
DA研究问题:
当源域和目标域并不是独立同分布时,经典机器学习会出现过拟合问题,DA需要解决源域与目标域概率分布不一致,但是任务相同的问题
DA目标:
如何减少source和target不同分布之间的差异
举例:比如训练集是各种英短蓝猫(源域),而想训练得到可以区分田园猫的模型(目标域),该模型相比于英短蓝猫识别情况性能会下降。当训练数据集和测试数据集分布不一致的情况下,通过在训练数据集上按经验误差最小准则训练得到的模型在测试数据集上性能不佳,因此,我们引入了DA来解决训练集与测试集概率分布不一致但都是同一任务的问题。
DA主要思想:
将源域与目标域(如两个不同的数据集)的数据特征映射到同一个特征空间,这样可利用其它领域数据来增强目标领域训练。
举例:比如下图源域是黑白手写数字,目标域是彩色数字,两个分布明显不同,我们需要训练一个特征提取器,然后对这些样本的关键特征进行提取来缩小不同分布之间的差异(下图就是去除颜色的影响提取数字作为最关键的特征)
DA三种方法:
样本自适应Instance adaptation:将源域中样本重采样,使其分布趋近于目标域分布;从源域中找出那些长的最像目标域的样本,让他们带着高权重加入目标域的数据学习。
特征自适应 Feature adaptation:将源域和目标域投影到公共特征子空间,这样两者的分布相匹配,通过学习公共的特征表示,这样在公共特征空间,源域和目标域的分布就会相同。
模型自适应 Model adaptation:考虑目标域的误差,对源域误差函数进行修改。假设利用上千万的数据来训练好一个模型,当我们遇到一个新的数据领域问题的时候,就不用再重新去找几千万个数据来训练,只需把原来训练好的模型迁移到新的领域,在新的领域往往只需相对较少的数据就同样可以得到很高的精度。实现的原理则是利用模型之间存在的相似性。
DA中又分别可以根据目标域数据的打标签情况分为监督的、半监督的、无监督的DA。学术界研究最多的是无监督的DA,这个比较困难而且价值比较高。
如果目标域数据没有标签,就没法用Fine-Tune把目标域数据扔进去训练,这时候无监督的自适应方法就是基于特征的自适应。因为有很多能衡量源域和目标域数据的距离的数学公式,那么就能把距离计算出来嵌入到网络中作为Loss来训练,这样就能优化让这个距离逐渐变小,最终训练出来的模型就将源域和目标域就被放在一个足够近的特征空间里了。
具体用于无监督DA的DDC,MADA,RevGrad等算法后期需要再进行阅读
DG是DA的进一步推广,DG与DA的区别:
DA在训练时可以拿到少量目标域数据,这些目标域数据可能是有标签的(有监督DA),也可能是无标签的(无监督DA),但是DG在训练时看不到目标域数据
DG研究问题:
通过带标签的源域学习一个通用的特征表示,并希望该表示也能应用于未见过的目标域
DG目标:
学习域无关的特征表示
DA和DG优点:
简单说DA由于要使用目标域中的数据,因此DA性能高,而DG去学习一个通用特征表示,因此DG泛化性更强
毫无疑问,DG是比DA更具有挑战性和实用性的场景:毕竟我们都喜欢“一次训练、到处应用”的足够泛化的机器学习模型。
DG分类:
DG主要分为单源域DG和多源域DG
TTA研究问题:
在测试样本上在线对模型进行调整,在拿到样本后模型需要立刻给出决策并更新。
TTA目标:
最终使得调整后的模型可以拟合目标域数据分布或者将目标域特征映射到源域特征分布。
TTA、DA、DG区别:
DG需要对目标域进行预先假设,在源域 finetune 预训练模型,然后部署时不经过任何调整。
DA在源域上训练,根据无标签的目标域在训练时调整模型
TTA不需要像DG一样对目标域进行预先假设,也不需要像DA一样依赖源域,而需要在测试时进行 adaptation
TTA与DG不同的是,TTA在于在线调整模型需要及时做出判断,DG在于离线学习一种通用的特征表示,DA在训练时调整模型