CoTTA:连续的测试时域自适应方法

523b76c7ae917a375cfefe9170d9ff0c.png

文章信息

b7280ccc15363333efade052a3f17836.png

论文题目为《Continual Test-Time Domain Adaptation》,该文于2022年发表于Conference on Computer Vision and Pattern Recognition (CVPR)会议上。文章提出了一种持续的测试时域自适应方法(CoTTA),旨在应对非稳态和不断变化的目标领域环境,通过减少错误累积和防止灾难性遗忘,以实现源模型的长期适应。

fee98a68ac8b47f979224643660625a3.png

摘要

5543dd6ea7966288eb0a230858797a55.png

测试时域自适应旨在将源预训练模型适应到目标领域,而无需使用任何源数据。现有研究主要考虑了目标领域静态的情况。然而,在现实世界中,机器感知系统在非稳态和不断变化的环境中运行,目标领域的分布随时间可能会发生变化。现有方法主要基于自我训练和熵正则化,可能会受到这些非稳态环境的影响。由于目标领域随时间的分布变化,伪标签变得不可靠。噪声伪标签可能进一步导致错误累积和灾难性遗忘。为了解决这些问题,文章提出了一种连续的测试时自适应方法(CoTTA),它包括两个部分。首先,文章提出通过使用加权平均和数据增强平均预测来减少错误累积,这些预测通常更准确。另一方面,为了避免灾难性遗忘,在每次迭代中随机恢复一小部分神经元到源预训练权重,以帮助长期保留源知识。所提出的方法使网络中的所有参数能够进行长期自适应。CoTTA易于实施,并可轻松集成到现成的预训练模型中。作者在四个分类任务和一个分割任务上展示了CoTTA的有效性,用于持续测试时自适应,在这些任务中胜过了现有方法。

1a91f0f428e553b4e6a8a47921039b44.png

引言

223403aebfa473941b92c2b3461773b4.png

测试时域自适应旨在通过在测试时从未标记的目标数据中学习,以使源预训练模型适应目标领域数据,源训练数据与目标测试数据之间存在领域偏移,因此需要适应才能实现良好的性能。例如,一个在晴天条件下训练的语义分割模型在雪夜条件下进行测试时可能会出现显著的性能下降。同样,一个预训练的图像分类模型在测试受传感器降级影响的损坏图像时也可能出现这种现象。在许多情况下,自适应需要以在线方式进行。因此,测试时域自适应对于在领域偏移情况下成功应用于实际机器感知应用至关重要。

现有的测试时域自适应方法通常通过使用伪标签或熵正则化来更新模型参数来处理源领域和固定目标领域之间的分布偏移。这些自训练方法对于来自同一个稳态领域的测试数据是有效的。然而,当目标测试数据源源不断地来自时刻变化的环境时,它们可能会变得不稳定。主要有两个原因:首先,在不断变化的环境下,由于分布偏移,伪标签变得更加嘈杂和不准确。因此,早期的预测错误更有可能导致错误积累。其次,由于模型长时间的适应新的数据分布,对于源领域的知识更难保持,从而导致灾难性遗忘。

为了应对这些问题,文章提出了一种连续的测试时自适应方法(CoTTA),以适应不断变化的环境。如图1所示,目标是从一个现成的源预训练模型开始,不断地使其适应当前的测试数据。假设目标测试数据是从一个不断变化的环境中流式传输的。预测和更新是在线执行的,这意味着模型只能访问当前数据流,而不能访问完整的测试数据或任何源数据。这个设定非常贴近实际的机器感知系统。比如在自动驾驶系统中,周围环境不断变化,从晴天到多云再到雨天等天气变化,车辆驶出隧道时摄像头突然曝光过度。在这种非稳态的环境下,感知模型需要能够实时适应并做出决策。

CoTTA:连续的测试时域自适应方法_第1张图片

图 1 各种适用场景

该方法主要有两个作用,首先能够减少误差积累。在自我训练框架下,通过两种不同的方式来提高伪标签质量。一方面,由于平均教师预测通常比标准模型具有更高的质量,使用加权平均教师模型来提供更准确的预测。另一方面,对于具有较大域间隙的测试数据,使用增强平均预测来进一步提高伪标签的质量。其次,该方法能够保存源知识并避免遗忘。作者将网络中的一小部分神经元随机恢复到预训练的源模型。通过减少误差积累和保留知识,CoTTA能够在不断变化的环境中进行长期适应,并能够训练网络的所有参数,而以前的方法只能训练batchnorm参数。CoTTA可以轻松地集成到任何现成的预训练模型中,而无需它们的源数据。为了验证其有效性,作者在四个分类任务和一个分割任务上进行实验,实验结果表明,在这些任务中使用了持续的测试时自适应方法(CoTTA)的预训练模型,性能显著提高并超过了现有的方法。

文章的贡献总结为:作者提出了一种持续测试时自适应方法,可以有效地将现成的源预训练模型适应到不断变化的目标数据中。通过使用更准确的加权平均和数据增强平均伪标签来减少错误积累。通过明确保留源模型的知识来缓解了长期遗忘效应。所提出的方法显著提高了在分类和分割基准测试中的持续测试时自适应性能。

c22c34a81214642f95ebce4d2cb7a34a.png

问题定义

4b97c26b56763db84bf4ba7fd2ee798d.png

我们的目标是在推理时,针对一个持续变化的目标领域,以在线方式不断提高现有预训练模型4c54d1fb5e351bdcfe26c165a04a4186.png(其中θ是参数,已经在源数据9d3ef8fda6a99476bbb9cb11fc2e4a6f.png上训练过)的性能,而不需要访问任何源数据。目标领域的未标记数据XT是按顺序提供的,模型只能访问当前时间步的数据。在时间步t,目标数据45bc8de1b814ff4321686d182e92af5e.png被提供为输入,模型8cd53880f28213eeee8a5ae0b969d033.png需要进行预测4649c7a2f0df50ed9b3b22e4b17527d7.png,并相应地进行自适应以适应未来的输入a270489fee56c66afbe1215680d60dda.png1090cfcba4e79157ff48099caeb10acb.png的数据分布不断变化,模型的性能是基于在线预测进行评估的。

这种设置很大程度上是由不断变化的环境中机器感知应用的需求所驱动的。例如,由于位置、天气和时间等因素,自动驾驶汽车的周围环境会不断变化。感知决策需要在线做出,模式需要调整。在线连续测试时间适应设置与现有适应设置之间的主要区别如下表所示,与以往专注于固定目标域的设置相比,作者考虑了对不断变化的目标环境的长期适应。

CoTTA:连续的测试时域自适应方法_第2张图片

618afedc4644d8df2b3f14264e4ea1cc.png

方法

80d144cb07a321b3eb4a141ee5c7388d.png

在线的连续测试时自适应方法采用现成的源预训练模型,并以在线方式使其适应不断变化的目标数据。通过使用加权平均和权重增强伪标签来减少错误积累。此外,为了帮助减少持续适应中的遗忘,该方法还保留了源模型中的信息,如图2所示。

CoTTA:连续的测试时域自适应方法_第3张图片

图 2 连续的测试时域自适应方法流程

方法1:加权平均伪标签

给定目标数据和模型bcec00f4ce268a3e2195997937153268.png,目标是通过模型的预测结果7b31bac657133bc4c0ac27ffc530f1dc.png与伪标签之间的交叉熵一致性来进行优化。这里的伪标签是一种用于训练的标签,通常由模型的预测结果生成。作者提到,如果直接使用模型的预测结果作为伪标签,这将导致在目标领域保持不变时有效,对于不断变化的目标数据,由于分布的变化,伪标签的质量可能会显著下降。

在深度学习训练中,通过对训练过程中的多个时间步的模型进行权重平均,通常可以得到比最终模型更准确的模型。这是因为权重平均可以减轻训练过程中的噪声和波动,从而提高了模型的鲁棒性和泛化性能。于是引入了一个称为教师模型的概念,在时间步t=0时,教师模型初始化为与源预训练模型相同。在时间步长t处,伪标签首先由教师0cc15a7a0e5458195b5e671099e5c4cf.png生成。然后,学生2a2914ced4caaa099cccae9c45215752.png通过学生和教师预测之间的交叉熵损失来更新:

68c1910211015559220baafb15347a7e.png

其中,715e4b1fddbab0478a720154707dd72e.png为教师模型软伪标签预测中c类的概率,0b955a7ab321c4251c3f8d63ca0c0537.png为主模型(学生模型)的预测。这种损失加强了教师和学生预测之间的一致性。

在学生模型权重273f4853e2b2e3ae494d1f29af5b9c0e.png通过上述公式更新后,便使用学生模型的权重来更新教师模型的权重,采用指数移动平均的方法进行更新:

5da07d01e14e06dff1ec998ede17a753.png

其中α是平滑因子,控制了新权重与旧权重的混合程度。对输入数据db56a8ccf0c7b122162dfe04530f102b.png的最终预测值是9c2ca3d2f471f831bf83c03868550d2f.png中概率最高的类。

加权平均一致性有两个好处。一是通过使用更准确的加权平均预测作为伪标签目标,模型通过对高质量伪标签的训练在持续自适应过程中受到的误差积累较少。第二是平均教师预测编码了过去迭代中模型的信息,因此,在长期持续适应中不太可能遭受灾难性遗忘,并提高了对新的未知领域的泛化能力。

方法2:权重增强伪标签

在训练阶段数据增强已被广泛应用于提高模型性能。不同的数据增强策略可以手动设计或者通过搜索算法(如自动增强搜索)来确定,以适应不同的数据集和任务。在测试模型性能时,有时也会应用测试时增强(test-time augmentation)。这是一种在测试样本上应用数据增强变换的方法,它已被证明可以提高模型的鲁棒性,然而,通常情况下,测试时增强策略会在训练期间固定下来,而不考虑推理时领域(数据分布)发生变化的情况。在不断变化的环境下,测试分布可能发生巨大变化,这可能使增强策略无效。因此,作者考虑测试时的领域变化,并通过预测置信度来近似领域之间的差异。当领域之间的差异较大时,才会应用数据增强,以减少误差积累。

CoTTA:连续的测试时域自适应方法_第4张图片

CoTTA:连续的测试时域自适应方法_第5张图片

其中e63b52f7fdb607706b259a96943a50f4.png表示教师模型的增强平均预测,ac8b2578de35d876f33552d8242ca1e7.png为教师模型的直接预测,fcfa8a8bc24052273345e5029aa52a9d.png为源预训练模型对当前输入9cf551ca6d376bca41f275b7d712610c.png的预测置信度,f2396060ae030e82f6a3aa8152fc2b5e.png为置信度阈值。通过计算源预训练模型的预测置信度来估计源域和当前域之间的差异。低置信度可能表示领域差异较大,而相对较高的置信度可能表示领域差异较小。因此,当预测置信度高于设定的阈值时,就会直接使用教师模型的原始预测作为伪标签,而不进行额外的数据增强。但是,当置信度较低时,会额外应用 N 次随机数据增强来提高伪标签的质量。这种方法的目的是根据模型对于当前输入的预测置信度来决定何时以及如何应用数据增强,以提高模型性能并适应领域差异。学生模型的预测与改进的伪标签之间的损失函数为:

c50dbec08fcd45021195d61bd08ec08d.png

760a1baa134f88f53e4f2b36a6d7d96b.png表示学生模型对输入样本的预测中的类别 c 的概率,是从教师模型获得的改进的伪标签中的类别 c 的概率。通过最小化这个损失函数,模型试图使学生模型的预测尽可能接近改进的伪标签,从而更好地适应目标任务或领域。

方法3:随机恢复

虽然更准确的伪标签可以减少错误的积累,但长期自训练的持续适应会不可避免地引入错误并导致遗忘。特别是在数据序列中遇到强烈的域移位,因为强烈的分布移位会导致校准错误甚至错误的预测。在这种情况下,自训练可能只会强化错误的预测。并且在遇到困难的例子后,即使新数据没有严重偏移,模型也可能因为不断的适应而无法恢复。

为了进一步解决灾难性遗忘问题,作者提出了一种随机恢复方法,该方法明确地从源预训练模型中恢复知识。考虑在时间步长为t时,基于方程1的梯度更新后的学生模型0dc925260e42ee2dd67669bbcc1f3a9c.png内的卷积层为:

3371546e875e31a1f889befdc262acbf.png

其中*表示卷积运算,5e242332fbff5198eee7c77c9666e41a.pngc8ea8a1e6a32eeb94d94deed2aba4dad.png表示该层的输入和输出,2c46addaebbaaef03e64539b91589fc0.png表示卷积核。本文提出的随机恢复方法通过以下方式对权值进行更新:

79f910ac5b81034b400aaedf374bb695.png

30156705ede847f3ffa86816f5d530f4.png

其中表示d82a392cf8cc995afb854672b552513a.png元素的乘法。p是一个小的恢复概率,M是与d39f9c1264848e36e59c326090139c87.png形状相同的掩模张量,M 按照 Bernoulli 分布进行随机采样,M 的元素取值为0或1,取0的概率为 p,取1的概率为 1-p。掩码张量c23a86575e8d4f984a9edb21f228f0bb.png决定中的哪个元素要恢复到源权重c55f5b06e5f333e9c628395c8aa961d7.png

随机恢复也可以看作是Dropout的一种特殊形式。通过随机地将可训练权值中的少量张量元素恢复到初始权值,避免了网络偏离初始源模型太远,从而避免了灾难性遗忘。

将精细伪标签与随机恢复相结合,形成了的在线连续测试时间适应(CoTTA)方法,如下所示。

CoTTA:连续的测试时域自适应方法_第6张图片

有一个预训练模型(学生模型)13b963aa07f6a7009779e068c582f6ca.png和一个教师模型6ac2f6401cae3c98b6f1b9450b60abf5.png,在初始化阶段时间步t=0时,教师模型初始化为与源预训练模型相同。输入时间步长t时的数据流e86dd2d8705b45cc133a993f48dd6186.png。首先,从教师模型中生成加权和增强平均伪标签(pseudo-labels),作为学生模型的训练目标。接着,将学生模型的预测结果和该伪标签,根据交叉熵损失更新学生模型。然后,使用学生权重通过指数移动平均更新教师模型的权重。其次,随机恢复一部分学生模型参数。最后,得到预测结果14345e798f2bad9107d042d8254bcc08.png,更新的学生模型ea082b12c60546391ccdce9bb17a6ed4.png和更新的教师模型5f06f627c91e284b995ac14ecdd567e9.png。重复以上步骤,直到模型在目标领域上达到满意的性能。这个过程可以在不断变化的目标数据上持续进行,以适应领域分布的变化。

a5375d264019348c07d97b12b677e6ad.png

实验

40c16e05d7cafa89b58fdf3a6e204686.png

作者在五个不同的连续测试时间适应基准任务上评估了他们提出的方法,这些任务包括:CIFAR10-to-CIFAR10C(标准和渐变),CIFAR100-to-CIFAR100C,ImageNet-to-ImageNet-C,Cityscapses-to-ACDC。这些任务代表了不同类型的应用场景,包括图像分类和语义分割。CIFAR10C、CIFAR100C和ImageNet-C是包含15种严重程度为5级的损坏类型。

CIFAR10 to CIFAR10C的实验:

对于在线连续测试时间自适应任务,使用在CIFAR10或CIFAR100数据集上训练好的预训练网络。在测试期间,损坏的图像以在线方式提供给网络。在最大损坏严重等级5下评估各种基准模型。评估是基于遇到数据后立即的在线预测结果。CIFAR10和CIFAR100实验均采用在线连续测试时间自适应方案。

我们首先评估了所提出的模型在CIFAR10到CIFAR10C任务上的有效性。将我们的方法与纯源基准和四种流行的方法进行比较。结果如下表所示。CoTTA:连续的测试时域自适应方法_第7张图片

CoTTA利用加权和增强平均的一致性,可以优于上述所有方法。错误率显著降低到16.2%。并且通过随机恢复方法,模型在长期适应的过程中性能不会下降。

 消融实验:此外作者还做了消融实验,如上表所示,通过使用教师模型的加权平均伪标签,错误率从20.7%降低到18.3%。这表明加权预测确实比直接预测更准确。通过使用多个增强来进一步细化权重平均预测,我们能够进一步将性能提高到17.4%。最后,通过随机恢复保留源知识,可以大大提高长期的预测。将错误率降低到16.2%。

  鲁棒性实验:通过逐渐改变15种不同程度的损坏类型图片设计10种不同的随机打乱的序列,使用这10个不同序列的平均错误率来评估这些方法,结果如下表所示。CoTTA优于其他方法,错误率只有10.4%。

CoTTA:连续的测试时域自适应方法_第8张图片

CIFAR100 to CIFAR100C实验:

在难度更高的cifar100 - cifar100c任务上对其进行了评估。实验结果如下表所示。

CoTTA:连续的测试时域自适应方法_第9张图片

ImageNet to ImageNet-C的实验:

在实验中作者使用了标准的预训练resnet50模型。在十种不同的损坏顺序下对ImageNet-C实验进行了评估。在严重等级为5的10种不同的损坏类型序列上进行了ImageNet-to-ImageNet-C实验。如下表所示,CoTTA优于其他方法。

CoTTA:连续的测试时域自适应方法_第10张图片

Cityscapes to ACDC的实验:

Cityscapes to ACDC是一个连续的语义分割任务,用它来模拟现实世界中的连续分布变化。为了尽可能重新访问类似环境的场景,并评估该方法的遗忘效果,作者将相同的序列组(四种条件)重复10次(即总共40次:雾- !夜- !雨- !雪- !雾…)。这也为长期适应性能的评估提供了依据。

在更复杂的连续测试时间语义分割Cityscapes to ACDC任务上评估了CoTTA。实验结果如下表所示。结果表明,该方法对语义分割任务也很有效,并且对不同的结构选择具有鲁棒性。与基线相比,我们提出的方法绝对提高了1.9%的mIoU,达到了58.6%的mIoU。

CoTTA:连续的测试时域自适应方法_第11张图片

b35c28b3f4fba321a4033918e6724bc0.png

结论

4edd6cc9cd350fb74a8942e3886e3cb1.png

对于在目标域分布随时间不断变化的非平稳环境中持续的测试时自适应产生的错误积累和灾难性遗忘问题,作者提出了一种新的CoTTA方法,该方法由两部分组成。首先,通过使用权重平均和增广平均预测来减少误差积累,这两种预测通常更准确。其次,为了保留来自源模型的知识,随机地将一小部分权重恢复到源预训练的权重。所提出的方法可以集成到现成的预训练模型中,而不需要访问源数据,作者在4个分类任务和1个分割任务上验证了其有效性。

d7389988b2b47b0f3a552d4be9f4415b.png

Attention

8ce94d13663975ce6cdbfe672037053b.png

如果你和我一样是轨道交通、道路交通、城市规划相关领域的,可以加微信:Dr_JinleiZhang,备注“进群”,加入交通大数据交流群!希望我们共同进步!

你可能感兴趣的:(深度学习,人工智能,机器学习,计算机视觉,算法)