在现实世界的场景中,半监督学习应用的一个基本限制是假设未标记的测试数据仅包含在已标记的训练数据中先前遇到的类别。然而,在真实场景中,这种假设很少成立,因为在测试时可能会出现属于新类别的实例。因此,我们引入了一种新颖的开放世界半监督学习设置,明确了未标记的测试数据中可能出现新类别的概念。在这种新颖的设置下,目标是解决标记和未标记数据之间的类别分布不匹配,其中在测试时每个输入实例要么需要被分类到现有类别之一,要么需要被初始化为一个新的未知类别。为了解决这个具有挑战性的问题,我们提出了ORCA,一种端到端的深度学习方法,引入了不确定性自适应边界机制,以克服学习已见类别的判别特征比学习新类别的判别特征更快造成的偏见。通过这种方式,ORCA减少了现有类别内部方差与新类别之间的差距。在图像分类数据集和单细胞注释数据集上的实验证明,ORCA始终优于替代基线模型,在ImageNet数据集的已见类别上实现了25%的改进,对于新类别实现了96%的改进。
随着深度学习的发展,取得了显著的突破,当前的机器学习系统在具有大量标记数据的任务上表现出色(LeCun等,2015年; Silver等,2016年; Esteva等,2017年)。尽管有这些优势,但绝大多数模型都是为封闭世界的设置而设计的,这基于一个假设:训练和测试数据来自同一组预定义的类别(Bendale&Boult,2015年; Boult等,2019年)。然而,在现实世界中,这个假设很少成立,因为标记数据取决于对给定领域的完整了解。例如,生物学家可能预先标记已知的细胞类型(已见类别),然后想将该模型应用于新的组织,以识别已知的细胞类型,同时发现以前未知的细胞类型(未见类别)。同样,在社交网络中,人们可能希望将用户分类到预定义的兴趣组中,同时还可以发现用户的新的未知/未标记兴趣。因此,与通常假设的封闭世界相反,许多现实世界中的问题本质上是开放世界的 - 测试数据中可能出现从未在训练过程中看到(和标记)的新类别。
在这里,我们介绍开放世界半监督学习(open-world SSL)设置,它是半监督学习和新类别发现的一般化。在开放世界SSL下,我们有一个带有标签的训练数据集和一个无标签的数据集。带标签的数据集包含属于一组已见类别的实例,而无标签/测试数据集中的实例既属于已见类别,也属于未知数量的未见类别(图1)。在这种设置下,模型需要将实例分类到以前已见的类别中的一个,或者发现新的类别并将实例分配给它们。换句话说,开放世界SSL是在类分布不匹配的情况下进行的转导式学习设置,其中无标签测试集可能包含在训练过程中从未标记过的类别,即不属于带有标签的训练集的一部分。
图1:在开放世界SSL中,未标记的数据集可能包含在标记集中从未遇到过的类。给定未标记的测试集,模型需要将实例分配给先前在标记集中看到的类之一,或者形成一个新的类并将实例分配给它。
开放世界半监督学习(Open-world SSL)与最近两种工作基本不同,但又密切相关:鲁棒半监督学习(Robust SSL)和新类别发现(Novel Class Discovery)。鲁棒半监督学习假设标签和无标签数据之间存在类分布不匹配,但在这种情况下,模型只需要能够识别(拒绝)来自无标签数据中的新类别的实例,将其识别为“不属于任何已知类别”。相比之下,开放世界半监督学习旨在发现个别的新类别,并将实例分配给它们。新类别发现是一个聚类问题,其中假设无标签数据仅由新类别组成。而开放世界半监督学习更为一般化,因为无标签数据中的实例既可以来自已见类别,也可以来自新类别。将鲁棒半监督学习和新类别发现方法应用于开放世界半监督学习,可以采用多步骤的方法,首先使用鲁棒半监督学习来拒绝来自新类别的实例,然后对被拒绝的实例应用新类别发现方法来发现新类别。另一种方法是将所有类别都视为“新类别”,应用新类别发现方法,然后将一些类别匹配回标签数据集中的已见类别。然而,我们的实验表明,这样的临时方法在实践中效果不好。因此,有必要设计一种可以在端到端框架中解决这个实际问题的方法。
本文中,我们提出了ORCA(Open-world with unCertainty based Adaptive margin),它在新的开放世界半监督学习设置下运行。ORCA有效地将来自无标签数据的示例分配到先前见过的类别中,或通过组合相似实例来形成新的类别。ORCA是一个端到端的深度学习框架,其中我们的方法的关键是一种新颖的不确定性自适应边界机制,在训练过程中逐渐减少模型的可塑性并增加其可辨识性。这种机制有效地减少了由于学习先前见过的类别比新类别更快而引起的类内方差与新类别之间的不良差异,我们表明这是这种情况下的关键困难。然后,我们开发了一种特殊的模型训练过程,学习将数据点分类到一组先前见过的类别,并学习为每个新发现的类别使用额外的分类头。用于先前见过的类别的分类头用于将无标签示例分配给来自标记集的类别,而激活额外的分类头允许ORCA形成新的类别。ORCA不需要提前知道新类别的数量,可以在部署时自动发现它们。
我们在三个用于开放世界半监督学习设置的基准图像分类数据集以及生物学领域的单细胞注释数据集上评估了ORCA。由于没有现有的方法可以在开放世界半监督学习设置下运行,我们将现有的最先进的半监督学习、开放集识别和新类别发现方法扩展到开放世界半监督学习,并将它们与ORCA进行比较。实验结果表明,ORCA有效地解决了开放世界半监督学习设置的挑战,并且在所有基准测试中始终显著优于所有基线方法。具体而言,ORCA在ImageNet数据集的已见类别和新类别上分别实现了25%和96%的改进。此外,我们证明ORCA对未知数量的新类别、不同分布的已见和新类别、不平衡的数据分布、预训练策略和少量标记示例都具有鲁棒性。
我们在表1中总结了开放世界SSL和相关设置之间的相似性和差异。其他相关工作见附录A。
新类发现
在新类别发现任务中(Hsu等,2018;Han等,2020;Brbic等,2020;Zhong等,2021),任务是对未标记的数据集进行聚类,该数据集包含与标记数据集中完全不重叠的相似类别,用于学习更好的聚类表示。这些方法假设在测试时所有类别都是新类别。虽然这些方法能够发现新类别,但它们不识别已知类别。相反,我们的开放世界半监督学习更加通用,因为未标记的测试集包含了新类别,但也包含了之前在标记数据中见过的需要被识别的类别。原则上,我们可以通过在测试时将所有类别都视为“新类别”,然后将其中的一些类别与标记数据集中的已知类别进行匹配来扩展新类别发现方法。我们采用这样的方法作为我们的基线,但实验结果表明,它们在实践中表现不好。
半监督学习
SSL方法(Chapelle等,2009;Kingma等,2014;Laine和Aila,2017;Zhai等,2019;Lee,2013;Xie等,2020;Berthelot等,2019;2020;Sohn等,2020)假设封闭世界设置,即标记和未标记数据来自同一组类别。Robust SSL方法(Oliver等,2018;Chen等,2020b;Guo等,2020;Yu等,2020)放宽了SSL假设,假设来自新类别的实例可能出现在未标记的测试集中。Robust SSL的目标是拒绝来自新类别的实例,这些实例被视为分布之外的实例。而在开放世界半监督学习中,目标是发现个别的新类别,然后将实例分配给它们。为了将Robust SSL扩展到开放世界半监督学习,可以将聚类/新类别发现方法应用于被丢弃的实例。早期的研究(Miller和Browning,2003)考虑了用EM算法的扩展来解决这个问题。然而,我们的实验表明,通过丢弃实例,这些方法学习到的嵌入不允许准确地发现新类别。
开放集和开放世界识别
开放集识别(Scheirer等,2012;Geng等,2020;Bendale和Boult,2016;Ge等,2017;Sun等,2020a)考虑了归纳设置,在测试期间可能出现新类别,模型需要拒绝来自新类别的实例。为了将这些方法扩展到开放世界设置,我们包括了一个基线,该基线在被拒绝的实例上发现类别。然而,结果表明这样的方法不能有效地解决开放世界半监督学习的挑战。类似地,开放世界识别方法(Bendale和Boult,2015;Rudd等,2017;Boult等,2019)要求系统通过逐步学习和扩展已知类别集合来增量标记新类别。这些方法通过人机交互逐步标记新类别。相比之下,开放世界半监督学习在学习阶段利用未标记数据,不需要人机交互。
广义零样本学习
与开放世界半监督学习一样,广义零样本学习(GZSL)(Xian等,2017;Liu等,2018;Chao等,2016)假设在测试时存在标签集和新类别。然而,GZSL对于提供给辅助属性的先验知识有额外的假设,这些属性唯一地描述了每个个体类别,包括新类别。这种限制性假设严重限制了GZSL方法在实际应用中的应用。相比之下,开放世界半监督学习更加通用,因为它不对类别的任何先验信息进行假设。
在本节中,我们首先定义了开放世界半监督学习的设置。然后,我们概述ORCA框架,并详细介绍了我们框架的每个组件。
图2:ORCA框架概述。ORCA使用额外的分类头用于新的类别。ORCA的目标函数包括:(i) 带有不确定性自适应边界的有监督目标,(ii) 生成伪标签的成对目标,以及(iii) 正则化项。
在开放世界半监督学习中,我们假设处于传导学习设置中,即给定输入的数据集有标签部分 Dl = {(xi, yi)}n
i=1 和无标签部分 Du = {(xi)}m
i=1。我们将在标记数据中看到的类别集合称为 Cl,并将在无标签测试数据中看到的类别集合称为 Cu。我们假设存在类别转移,即 Cl ∩ Cu = ∅ 并且 Cl ≠ Cu。我们将 Cs = Cl ∩ Cu 视为已知类别的集合,并将 Cn = Cu\Cl 视为新类别的集合。
定义1(开放世界半监督学习)。在开放世界半监督学习中,模型需要将无标签数据集 Du 中的实例分配到先前已知的类别 Cs 中,或者形成一个新类别 c ∈ Cn,并将实例分配给它。
值得注意的是,开放世界半监督学习泛化了新类别发现和传统的(封闭世界的)半监督学习。新类别发现假设标记数据和无标签数据中的类别是不相交的,即 Cl ∩ Cu = ∅,而(封闭世界的)半监督学习则假设标记数据和无标签数据中的类别是相同的,即 Cl = Cu。
解决开放世界半监督学习的关键挑战是同时学习来自已知/标记类别和未知/无标记类别。这是一项具有挑战性的任务,因为模型在已知类别上学习鉴别性的表示比在新类别上学习得更快。这导致已知类别的内部方差较小,而新类别的内部方差较大。为了解决这个问题,我们提出了ORCA,一种通过使用不确定性自适应边界来减小已知类别和新类别之间内部方差的差距的方法。ORCA的关键洞察是使用无标记数据上的不确定性来控制已知类别的内部方差:如果无标记数据上的不确定性较高,我们将强制增加已知类别的内部方差,以减小已知类别和新类别之间方差的差距;而如果不确定性较低,则会降低已知类别的内部方差,以鼓励模型充分利用标记数据。通过使用不确定性自适应边界,我们控制了已知类别的内部方差,并确保已知类别的鉴别性表示不会相对于新类别学习得太快。
ORCA首先将具有标签的实例Xl = {xi ∈ RN}n i=1和无标签的实例Xu = {xi ∈ RN}m i=1输入到嵌入函数fθ : RN → RD中,以获得标记和未标记数据的特征表示Zl = {zi ∈ RD}n i=1和Zu = {zi ∈ RD}m i=1。其中,zi = fθ(xi)是每个实例xi ∈ Xl ∪Xu的特征表示。在主干网络的顶部,我们添加了一个分类头,由一个线性层组成,其参数是一个权重矩阵W : RD → R|Cl∪Cu|,并跟随一个softmax层。注意,分类头的数量设置为先前已知类别和预期的未知类别的数量。因此,前|Cl|个头用于将实例分类为先前已知的类别之一,而剩余的头用于将实例分配给新的未知类别。最终的类别/聚类预测通过计算ci = argmax(WT · zi) ∈ R得到。如果ci /∈ Cl,则xi属于新的未知类别。未知类别的数量|Cu|可以在算法中作为输入给出,这是聚类和新类别发现方法的典型假设。然而,如果未知类别的数量事先不知道,我们可以初始化ORCA时使用大量的预测头/新类别。然后,ORCA的目标函数通过不将任何实例分配给不需要的预测头,从而推断类别的数量,因此这些头永远不会激活。
ORCA的目标函数结合了三个组件(如图2所示):(i) 带有不确定性自适应边界的有监督目标函数,(ii) 成对目标函数,以及(iii) 正则化项。
其中,LS表示有监督目标函数,LP表示成对目标函数,R表示正则化项。η1和η2是设置为1的正则化参数,在我们的所有实验中都是如此。算法的伪代码在附录B中总结。我们在附录C中报告了对正则化参数的敏感性分析,并接下来讨论每个目标项的细节。
首先,具有不确定性自适应边界的有监督目标强制网络正确地将实例分配给先前观察到的类别,但控制学习此任务的速度,以便同时学习形成新类别。我们利用标记数据的分类注释{yi}n i=1,并优化权重W和主干网络θ。可以通过使用标准的交叉熵(CE)损失作为有监督的目标来利用分类注释:
然而,在标记数据上使用标准的交叉熵损失会导致已知类别(Cs)和新类别(Cn)之间出现不平衡的问题,即梯度只会对已知类别进行更新,而不会对新类别进行更新。这可能导致学习到的分类器对已知类别产生更大的梯度幅值(Kang等人,2019),从而导致整个模型对已知类别产生偏见。为了解决这个问题,我们引入了一个不确定性自适应边界机制,并提议对逻辑进行归一化,接下来我们将进行说明。
一个关键挑战是由于有监督的目标,已知类别的学习速度更快,因此它们往往与新类别相比具有较小的类内差异(Liu等人,2020)。通过在特征空间中对距离进行排序,成对目标生成未标记数据的伪标签,因此在类别之间的类内差异不平衡的情况下,会产生容易出错的伪标签。换句话说,来自新类别的实例将被分配给已知类别。为了减轻这种偏差,我们提出使用自适应边界机制来减小已知类别和新类别之间的类内差异。直观地说,在训练开始时,我们希望强制实施较大的负边界,以鼓励已知类别相对于新类别具有相同大的类内差异。接近训练结束时,当新类别的聚类已经形成时,我们将调整边界项接近于0,以便模型可以充分利用标记数据,即目标简化为Eq. (2)中定义的标准交叉熵。我们建议使用不确定性来捕捉类内差异。因此,我们使用不确定性估计来调整边界,以实现期望的行为 - 在训练的早期阶段,不确定性较大,从而导致较大的边界,而随着训练的进行,不确定性变小,导致较小的边界。
具体而言,具有不确定性自适应边界机制的有监督目标定义如下:
其中,¯u是不确定性,λ是定义其强度的正则化器。参数s是一个额外的缩放参数,用于控制交叉熵损失的温度(Wang等,2018)。为了估计不确定性¯u,我们依赖于从softmax函数的输出计算出的未标记实例的置信度。在二进制设置中,¯u可以通过以下近似得到:
¯u = 1/|Du| * ∑(x∈Du) Pr(Y=1|X=x) * Pr(Y=0|X=x)
这里,k是所有类别的索引。我们在多类别设置中使用与(Cao等,2020b)类似的公式来近似计算组不确定性。为了正确调整边界,我们需要限制分类器的幅度,因为不受限制的分类器幅度可能会对边界的调整产生负面影响。为了避免这个问题,我们对线性分类器的输入和权重进行归一化处理,即zi = zi / |zi|和Wj = Wj / |Wj|。
对于pairwise objective,其目标是学习预测实例对之间的相似性,以便将来自同一类别的实例归为一组。这个部分的目标是生成用于引导训练的未标记数据的伪标签。通过使用不确定性自适应边界来控制已知和新颖类别的内部差异,ORCA改进了伪标签的质量。
具体来说,通过控制相似性,我们可以将相似的实例聚类在一起,从而形成新的类别。这种方法在训练过程中,可以自动将未标记的实例分配给先前已知的类别或形成新的类别,从而在开放世界半监督学习中有效地解决了分类问题。通过利用相似性信息,模型可以更好地理解数据的分布,从而提高对新类别的识别能力。这样,我们可以在训练过程中逐渐学习到不同类别之间的相似性,从而更好地进行聚类和分类。
具体而言,我们将聚类学习问题转化为一种成对相似性预测任务。给定带标签的数据集 Xl 和未标记的数据集 Xu,我们的目标是微调我们的主干网络 fθ,并学习一个相似性预测函数,其参数为线性分类器 W,以便将属于同一类别的实例聚集在一起。为了实现这一点,我们依赖于标记集合上的真实标签和在未标记集合上生成的伪标签。具体而言,对于标记集合,我们已经知道哪些实例应该属于同一类别,因此我们可以使用真实标签。对于未标记集合,我们通过计算小批量中所有特征表示 zi 对之间的余弦距离来生成伪标签。然后,我们对计算出的距离进行排序,并为每个实例生成与其最相似邻居的伪标签。因此,我们仅为每个小批量内的每个实例生成来自最有信心的正对样本的伪标签。对于小批量中的特征表示 Zl ∪ Zu,我们将其最接近的集合表示为 Z‘l ∪ Z’u。需要注意的是,Z‘l 总是正确的,因为它是使用真实标签生成的。ORCA 中的pairwise objective 定义为二进制交叉熵损失 (BCE) 的修改形式:
在这里,σ表示softmax函数,用于将实例分配到先前观察到的类别或新的类别中。对于带标签的实例,我们使用真实标签来计算目标函数。对于未标记的实例,我们根据生成的伪标签来计算目标函数。我们只考虑最可信的对来生成伪标签,因为我们发现伪标签中的噪声增加会对聚类学习产生不利影响。与(Hsu et al., 2018; Han et al., 2020; Chang et al., 2017)不同的是,我们只考虑正对,并发现将负对包含在我们的目标函数中对学习没有益处(负对可以很容易地被识别)。我们的成对目标函数只考虑正对,与(Van Gansbeke et al., 2020)相关。然而,我们以在线方式更新距离和正对,因此在训练过程中受益于改进的特征表示。
最后,正则化项避免将所有实例都分配到同一类。在训练的早期阶段,网络可能会退化为一个平凡的解决方案,即将所有实例都分配到单个类别,即|Cu| = 1。为了避免这种情况,我们引入了一个Kullback-Leibler(KL)散度项来使Pr(y|x ∈ Dl ∪ Du)接近于先验概率分布P的标签y:
在这里,σ表示softmax函数。由于在大多数应用中,知道先验分布是一个很强的假设,所以我们在所有实验中使用最大熵正则化来对模型进行正则化。最大熵正则化在基于伪标签的半监督学习 (Arazo et al., 2020)、深度聚类方法 (Van Gansbeke et al., 2020) 和噪声标签训练 (Tanaka et al., 2018) 等任务中被用来防止类别分布过于平坦。在实验中,我们展示了这个项对ORCA的性能没有负面影响,即使在数据分布不均衡的情况下也是如此。
我们考虑在ORCA以及所有对比方法中使用自监督预训练,特别是在图像数据集上,我们使用自监督学习对ORCA和对比方法进行预训练。自监督学习形式化了一个预先任务(pretext/auxiliary task),该任务不需要任何手动标注,并且可以直接应用于带标签和无标签的数据。这个预先任务指导模型以完全无监督的方式学习有意义的表示。具体而言,我们采用SimCLR方法进行自监督预训练(引用来源为Chen等人,2020)。我们用一个预先目标对模型的主干部分fθ进行整个数据集Dl ∪ Du的预训练。在训练期间,我们冻结主干部分fθ的前几层,只更新最后几层和分类器W。我们对所有对比方法采用相同的SimCLR预训练协议。此外,我们还考虑了没有预训练的情况,即对于细胞类型注释任务,我们不使用任何预先任务,ORCA从随机初始化的权重开始。此外,我们在附录C中报告了使用不同预训练策略的结果,包括只对标记子集Dl进行预训练,并将SimCLR替换为RotationNet方法(引用来源为Kanezaki等人,2018)。
数据集:我们在四个不同的数据集上评估ORCA,包括三个标准的基准图像分类数据集CIFAR-10、CIFAR-100(Krizhevsky,2009)和ImageNet(Russakovsky等人,2015),以及来自生物学领域的高度不平衡的单细胞老化细胞图谱数据集(Consortium等人,2020)。对于单细胞数据集,我们考虑一个现实的跨组织细胞类型注释任务,其中未标记的数据来自不同的组织(Cao等人,2020a)(详见附录B)。对于ImageNet数据集,我们按照(Van Gansbeke等人,2020)的方法对100个类别进行了子采样。在所有数据集上,我们使用可控的未标记数据和新类别的比例。首先,我们将类别分为50%的已见类别和50%的新类别。然后,我们将50%的已见类别选择为标记数据集,其余为未标记数据集。我们在附录C中展示了不同已见类别和新类别比例以及10%标记样本的结果。
基准方法:鉴于开放世界半监督学习是一个新的设置,目前没有现成的基准方法可供使用。因此,我们将新类别发现、半监督学习和开放集识别方法扩展到开放世界半监督学习设置。新类别发现方法不能识别已见类别,即无法将未标记数据集中的类别匹配到已标记数据集中的已见类别。我们报告了它们在新类别上的性能,并通过以下方式将这些方法扩展为适用于已见类别:我们将已见类别视为新类别(这些方法有效地对未标记数据进行聚类),并使用匈牙利算法将一些发现的类别与已标记数据中的类别进行匹配,然后报告在已见类别上的性能。我们考虑了两种方法:DTC(Han等人,2019)和RankStats(Han等人,2020)。
另一方面,传统的半监督学习和开放集识别方法不能发现新类别。因此,我们将半监督学习和开放集识别方法扩展为适用于新类别,具体做法如下:我们使用半监督学习/开放集识别方法将数据点分类到已知类别,并估计出分布(OOD)样本。我们报告它们在已见类别上的性能,然后我们应用K-means聚类(Lloyd,1982)对OOD样本进行聚类,从而得到聚类结果(新类别)。通过这种方式,我们将两种半监督学习方法调整为适用于开放世界半监督学习设置:Deep Safe SSL(DS3L)(Guo等人,2020)和FixMatch(Sohn等人,2020),以及最近的深度学习开放集识别方法CGDL(Sun等人,2020a)。CGDL自动拒绝OOD样本。DS3L通过将低权重赋予OOD样本来考虑未标记数据中的新类别。为了扩展该方法,我们对具有最低权重的样本进行聚类。对于FixMatch,我们根据softmax置信度得分来估计OOD样本。对于这两种半监督学习方法,我们使用已知类别和新类别分区的真实信息来确定OOD样本的阈值。
在图像数据集上,我们对所有新类别发现和半监督学习基线进行SimCLR预训练,以确保ORCA的优势不是由于预训练造成的。唯一的例外是DTC,它有自己的专门预训练过程(Han等人,2019)。作为额外的对比基线,我们对SimCLR预训练后的表示进行了K-means聚类(Chen等人,2020a)。我们还进行了大量的消融实验,评估ORCA方法的优势。具体来说,我们将自适应边缘交叉熵损失替换为标准交叉熵损失的基线称为ORCA-ZM,用于评估自适应边缘的效果,我们将ORCA与固定负边缘(FNM)进行对比。我们发现边缘值为0.5可以获得最佳性能(见附录C),并在实验中使用该值。其他实现和实验细节可以在附录中找到。
表2:在三次运行中计算的平均准确率。星号(∗)表示原始方法无法识别已知类别(我们不得不进行扩展)。Dagger(†)表示原始方法无法检测到新类别(我们不得不进行扩展)。对于单细胞数据集,SimCLR和FixMatch不适用(NA)。改进是相对于最佳基线的相对改进。
备注:我们意识到对视觉数据集进行扰动的对比学习可以大大提升无监督学习的效果(Van Gansbeke等,2020)。我们故意避免使用这些技巧,因为它们可能不容易转移到其他领域。对于感兴趣的读者,可以自由地添加这些技巧并重新评估我们的模型。
在基准数据集上的评估。我们报告了在已知类和新类上的准确率,以及总体准确率。表2中的结果显示,ORCA始终以较大的优势优于所有基准线。例如,在CIFAR-100和ImageNet数据集的已知类上,ORCA分别比最佳基准线提高了21%和25%。在新类上,ORCA比基准线分别提高了51%(CIFAR-100),96%(ImageNet)和104%(单细胞数据集)。此外,将ORCA与ORCA-ZM和ORCA-FNM基准线进行比较清楚地展示了引入不确定性自适应边界对解决开放世界SSL的重要性。总的来说,我们的结果表明:(i)开放世界SSL设置很难,现有的方法不能很好地解决它,以及(ii)ORCA有效地解决了开放世界SSL的挑战,取得了显著的性能提升。
不确定性自适应边界的好处。我们进一步系统地评估引入不确定性自适应边界机制的效果。在CIFAR-100数据集上,我们在训练过程中将ORCA与ORCA-ZM和ORCA-FNM基准线进行比较(见图3)。我们报告准确率和不确定性,不确定性捕捉了类内方差,如公式(4)所定义。在第140个epoch,我们衰减学习率。结果显示,ORCA-ZM在训练过程中无法减小新类的类内方差,导致新类上的性能不尽人意。在已知类上,ORCA-ZM非常快速地达到了高性能,但在训练接近结束时,准确率开始下降。学习率触发性能下降的原因是小学习率可能会导致过拟合问题(Li等人,2019),这在ORCA-ZM中会引发问题,因为已知类和新类之间的方差差异以及随小学习率恶化的噪声伪标签问题(Song等人,2020)。这表明,没有不确定性自适应边界,模型会非常快地学习到已知类,但无法在新类上实现令人满意的性能。相比之下,ORCA有效地减小了已知类和新类的类内方差,并持续提高准确率。这个结果完全符合我们的主要想法,即在训练过程中逐渐增加已知类的区分度,以确保已知类和新类之间的类内方差相似。与ORCA-FNM相比,自适应边界在已知类上显示出明显的好处,实现了更低的类内方差和整个训练过程中更好的性能。总的来说,负边界确保了已知类的较大类内方差,使模型能够学习形成新类,而自适应边界则确保模型在训练过程中能充分利用标记数据。
与表2中的其他基准线相比,ORCA仅在12个epoch后就超过了它们的最终性能。此外,在附录C中,我们展示了不确定性自适应边界如何改善伪标签的质量,并展示了对不确定性强度参数λ的鲁棒性。综合起来,我们的结果强烈支持不确定性自适应边界的重要性。
图3显示了在CIFAR-100数据集上使用不确定性自适应边界对估计的不确定性(左图)和准确性(右图)的影响。在第140个epoch时,我们衰减学习率。
在未知新类数量的情况下进行评估。ORCA和其他基准线假设新类的数量是已知的。然而,在实际应用中,我们通常事先不知道类的数量。在这种情况下,我们可以先使用ORCA估计类的数量。为了评估在具有100个类的CIFAR-100数据集上的性能,我们首先使用(Han等人,2019)中提出的技术估计类的数量为124。然后,我们使用估计的类数量重新测试所有算法。ORCA通过不使用所有初始化的分类头自动修剪类的数量,并找到了114个新的类别。表3中的结果显示,ORCA在新类发现基准线上表现出色,与RankStats相比,性能提高了97%。此外,使用估计的类数量,ORCA的性能与事先已知类数量的情况相比,仅略有下降。我们还分析了ORCA未分配的14个额外头,并发现它们与小的聚类相关,即未分配头的平均样本数仅为16个。这表明额外的聚类包含正确类的更小子类,并属于有意义的聚类。在附录C中,我们进行了额外的消融研究,对大量类的情况进行了评估。
消融研究中的目标函数。ORCA的目标函数包括带有不确定性自适应边界的有监督目标,成对目标和正则化项。为了研究每个部分的重要性,我们进行了消融研究,通过去除以下部分修改ORCA:(i) 有监督目标(即w/o LS),(ii) 正则化项(即w/o R)。在第一种情况下,我们仅依赖于经过正则化的成对目标来解决问题,而在第二种情况下,我们使用未经正则化的有监督和成对目标。我们注意到,成对目标是必要的,以能够发现新类。在CIFAR-100数据集上的结果如表4所示,表明有监督目标LS和正则化R都是目标函数的必要部分。附录C中报告了带有不平衡数据分布的额外实验结果。
表3:在CIFAR-100数据集上进行了三次运行的平均准确率和归一化互信息(NMI),其中未知类别数。
表4:在CIFAR-100数据集上对目标函数组成部分进行了消融研究。我们报告了三次运行的平均准确率和归一化互信息(NMI)。
我们引入了开放世界半监督学习(open-world SSL)设置,其中在未标记的测试数据中可能出现新类别,模型需要将实例分配给在标记数据中出现过的类别,或者形成新类别并将实例分配给它们。为了解决这个问题,我们提出了ORCA,一种基于不确定性自适应边界机制的方法,在训练过程中控制已知类别和新类别的类内方差。我们进行了广泛的实验,结果表明ORCA有效地解决了开放世界半监督学习问题,并且比其他替代方法表现出更好的性能。我们的工作主张将传统的封闭世界设置转变为更现实的开放世界机器学习模型评估。