©作者 | 小舟
单位 | 电子科技大学
研究方向 | 计算机视觉
本文提出了一个 OOD Detection 的新方法,思想上没有什么创新,结合了现有 OOD 方法的思路,同时引入有监督对比学习的思想,将 OOD 数据视为一个单独的类,与数据集中的长尾样本计算有监督对比学习进行优化。同时,对网络结构中的 BN 和预测层进行了一定的调整。
论文标题:
Partial and Asymmetric Contrastive Learning for Out-of-Distribution Detection in Long-Tailed Recognition
论文链接:
https://arxiv.org/abs/2207.01160
代码链接:
https://github.com/amazon-research/long-tailed-ood-detection
首先介绍一下 OOD Detection 的定义,这一类问题是为了解决网络实际应用过程中,经常会遇到有些目标不属于训练集范围,即不是我们期望检测的目标物体的情况,而现有的通用分类模型,通常会对 OOD 目标生成一个高置信度的预测结果,这不是我们期望看到的景象,尤其是在自动驾驶领域,把 OOD 物体识别成 ID 物体可能会造成严重后果。
因此 OOD Detection 旨在保留判别模型分类能力的同时,执行 OOD 目标的检测功能。
常用的 OOD Detection 策略有如下几种:
Softmax-based: 这类方法利用 pre-trained model 输出的最大 softmax 概率进行统计分析,统计发现 OOD 样本和 ID 样本 softmax 概率的分布情况,试图将二者的分布差距加大,然后选取合适的阈值来判断一个样本属于 OOD 还是 ID。这类方法简单且有效,不用修改分类模型的结构,也不需要训练一个 OOD 样本分类器。
Uncertainty: 由于模型的概率输出并不能直接表示模型的置信度(confidence)。因此这类方法让模型学习一个对输入样本的不确定性属性。面对测试数据,如果模型输入为 ID 样本,则不确定性低,相反,如果模型输入为 OOD 样本,则不确定性高。这类方法需要修改模型的网络结构来学习不确定性属性。
Generative Model: 这类方法主要利用 Variational Autoencoder 的 reconstruction error 或者其他度量方式来判断一个样本是否属于 ID 或 OOD 样本。主要的假设是,Autoencoder 的隐含空间(latent space) 能够学习出 ID 数据的明显特征(silent vector),而对于 OOD 样本则不行,因此OOD样本会产生较高的 reconstruction error. 这类方法只关注 OOD 检测性能,不关注 ID 数据本来的任务。
Classifier: 这类方法比较直接,使用分类器对提取的特征进行分类来判断是否为OOD样本。这类方法简单直接,也取得了不错的效果,有的方法修改网络结构为一个 n+1 类分类器,n 为原本分类任务的类别数,第 n+1 类则是 OOD 类;有的方法直接取提取特征来进行分类,不需要修改网络的结构。
该介绍来自:
https://zhuanlan.zhihu.com/p/102870562
现有的 OOD 方法在均衡的 ID 数据集上取得了较好的效果,而当 ID 数据集为长尾分布数据集时,会导致 OOD Detection 和 ID Classification 精度的双下降。如下表所示,多种 OOD Detection 方法在长尾 ID 数据集上均出现了一定的预测性能的衰减。
不仅如此,将 OOD Detection 方法与长尾分类方法的简单组合,取得的效果也不是特别理想:
不仅如此,将 OOD Detection 方法与长尾分类方法的简单组合,取得的效果也不是特别理想:
为了探究原因,作者对在长尾 ID 数据分布上的 OOD Detection 任务(OE)进行实验,使用 ResNet18 网络作为框架,提取网络的倒数第二个输出层特征,进行可视化:
上图中,类别字母排序靠前者具有更多样本,可视化结果显示,对于 OE 模型而言,其对于 OOD 数据和长尾数据的特征具有很高的相似性,这意味着模型将长尾数据与 OOD 数据混淆了。因此,可以尝试通过增大长尾数据与 OOD 数据特征之间的差异性,尝试改善这一问题。
1. 在原有 OOD Detection 问题之上引入有监督对比损失,增大长尾数据与 OOD 数据特征之间的差异性;
2. 提出网络结构微调策略 APF,对分类器的 BN 和预测层进行微调,提升 ID 分类和 OOD Detection 性能。
OE 方法希望通过约束增大模型对 OOD 预测结果的不确定性,该系列方法认为 OOD 数据只包含身份标签,即标签只用来指示该数据是否属于 OOD 数据,而不关心 OOD 数据内部的类别组成,为了增大预测结果的不确定性,希望模型对 OOD 数据的预测结果服从均匀分布,因此损失函数表示如下:
其中,为 ID 数据使用交叉熵损失进行类别预测:
为 OOD 数据使用 KL 损失,希望预测结果服从均匀分布 u:
3.2 SPL损失的应用
作者通过对前面特征可视化结果的分析,认为 OOD Dection 模型对于长尾 ID 数据分布中的优势类别可以较好地与 OOD 数据区分开来,而劣势类别则难以区分,因此,作者认为,不应该对优势类别添加对比损失,这会导致模型更加注重对优势类别和 OOD 数据之间的区分度,导致劣势类别被忽略,因为 OOD 数据和优势类别数据在总的训练集中占据绝对多数。同时,不应该对 OOD 数据之间施加对比损失,因为 OOD 数据由许多类别组成假如将 OOD 数据视为一个正类使用有监督对比学习进行优化,在理论上存在问题。
因此,作者提出一种有偏不对称的有监督对比学习:
具体来说,对比损失只应用于长尾数据集,对于长尾数据集中的某一个类别,其正对来自与其类别相同的图像,而副对由长尾数据集中的其他类别和 OOD 数据组成,因此 OOD 数据承担的角色只是负对。
这一损失的具体表现形式为:
其中:
因此,模型总损失为:
3.3 针对BN和预测层的Auxiliary Branch Finetuning (ABF) 策略
除了前面提到的有偏非对称有监督对比学习之外,作者还通过借鉴前人的分析,对 BN 和预测层进行改进。
BN:BN 的问题来自于其数据规范化过程本身,BN 在进行推断时,会使用所有 Batch 数据的均值和方差进行规范化,对于 OOD Detection 任务而言,意味着在推断时,使用了 ID 数据和 OOD 数据的混合信息,这对于 OOD 数据的推断是有利的,但是对于 ID 数据的分类而言是不利的,作者认为,BN 在进行 OOD Detection 时应该使用 ID 和 OOD 混合数据信息,而进行·ID 分类时应该只使用 ID 数据信息。
预测层:预测层通常是神经网络最后一个全连接层,目前有一些长尾分布的研究表明,预测层的参数对于长尾数据分类结果具有重要作用,作者认为在进行 ID 数据分类时,最好使用只在 ID 数据上进行训练的预测层。
因此,作者首先使用前述损失函数训练一个深度神经网络,在训练完成后,将该网络记为网络 1,然后固定网络中除 BN 和预测层之外的所有网络参数,并以网络 1 的参数作为 BN 和预测层的初始化,使用 Google 在 2021 年所提出来的 LA 交叉熵损失。
在 ID 数据上进行微调,得到网络 2。在测试时,由网络 1 执行 OOD 检测,由网络 2 执行 ID 分类:
实验
4.1 数据集和网络框架
作者使用的 ID 数据集为 CIFAR10-LT, CIFAR100-LT 和 ImageNet-LT,对于前面两个数据集,OOD 数据集是 TinyImages80M,对于 ImageNet-LT,使用的是自己构建的 ImageNet-Extra 数据集,该数据集包含 577711 张来自 ImageNet-22k 上与 ImageNet1k 不重合的数据。
而对于 OOD 测试集,作者对于不同数据集使用了不同的 OOD 测试集。对于 CIFAR10 和 CIFAR100,使用 ReaNet18,对于 ImageNet,使用 ResNet50。
4.2 评价指标
AUROC:表示 OOD 检测时正例比负例具有更高检测得分的概率,理想结果为100%;
AUPR:表示 OOD 检测的平均召回率;
FPR@TPRn(FPR95):表示 OOD 检测真阳性率为 n 时的假阳性率,n 常取 95%;
ACC@TPRn(ACC95):表示 OOD 检测真阳性率为 n 时,n 常取 95%,ID 数据的分类准确率;
ACC@FPRn:表示当 n 个 ID 样本被错误检测为 OOD 样本时,剩余 ID 样本的分类准确率,当 n=0 时,也记为 ACC。
4.3 实验结果
4.3.1 定量实验结果
作者比较了他们的方法在 CIFAR10-LT 上,在不同 OOD 测试集上的实验结果:
同时,计算了 ACC@FPRn 指标当 n 取不同的值时,两个模型的性能表现:
以及对比了他们的方法和其他一些方法的性能差异:
此外,作者还展示了他们的方法相比于 OE,在 ImageNet-LT 上在 ID 头部和尾部数据分类性能上的提升:
4.3.2 消融实验
作者比较了他们所提出的方法在有偏、不对称、ABF 策略上的消融研究结果:
同时,作者研究了进行对比损失计算的尾部样本比例 k 对于模型性能的影响,发现他们的方法在 k 为 50% 时可以取得最好的效果,并且在较大区间内具有鲁棒性:
此外,作者对于微调策略进行了研究,结果显示,单独微调 BN 或分类层对于性能都有提升,并且微调二者的效果会比微调整个网络的效果更好:
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
投稿通道:
• 投稿邮箱:[email protected]
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
·