论文
arxiv.org/pdf/2306.06894.pdfhttps://arxiv.org/pdf/2306.06894.pdf
代码(使用的是华为mindspore框架)
Shuijing2018/GLAC_Mindspore (github.com)https://github.com/Shuijing2018/GLAC_Mindspore
最近的一项研究表明,通过利用未标记数据,可以在类别转移条件下构造LAC在测试分布上的无偏风险估计量(URE)。本研究的动机是,尽管在标记的数据中无法观察到增强类的实例,但它们的分布信息可能包含在未标记的数据中,并通过区分已知类与未标记数据的分布来估计。这种URE对于学习任务是有利的,因为它可以导致基于经验风险最小化的理论基础方法。然而,前人研究[4]导出的LAC的URE仅局限于特定类型的one- vs -rest损失函数,在实际中损失需要随数据集变化时不够灵活。
为了解决上述问题,我们在本文中做出了以下贡献:
1 作者提出了一个广义无偏风险估计量(URE),给定未标记的数据用于增强类的学习,它可以在保持理论保证的同时配备任意损失函数。
2 作者对估计误差界限进行了理论分析,保证了经验风险极小值收敛于真实风险极小值。
3 作者提出了一个新的风险惩罚正则化项,可用于缓解以往研究中经常遇到的负经验风险问题。
在LAC的训练阶段,我们得到一个标记训练集 ,从已知类在的分布 上独立采样,其中是特征空间, 是 个已知类的标签。然而,在测试阶段,我们需要预测从测试分布中采样的未标记数据,其中可能出现训练阶段未观察到的增强类。由于扩充类的具体划分是未知的,因此通常将它们预测为单个类。这样,检验分布Pte的标签空间为,其中ac表示由增广类组成的类。引入类移位条件来描述已知类与增广类分布之间的关系:
其中,为混合比例。此外,从测试分布中采样的一组未标记数据 将用于模型训练。
在LAC的类别转移条件下,测试分布上的预期分类风险可表示为:
L为损失函数
最近的研究[4]只考虑了1 - vs -rest (OVR)损失函数作为Eq.(2)中的分类损失,对于k类分类,其形式如下:
式中表示二值损失函数。通过将OVR损失 代入分类风险,我们得到
先前的研究[4]表明,在LAC设置下,从检验分布中提取未标记的数据,可以得到的等效表达式:
可以验证(数学证明),,因此它的经验版本是一个无偏风险估计量(URE)。我们可以看到,这个URE只局限于多类分类的OVR损失,在实际需要随数据集变化损失时不够灵活。
定理一
在式(1)中的类转移条件下,分类风险 可等价表示为:
证明见原文附录a。
推论一
如果在我们推导出的风险中,使用one- vs -rest (OVR)损失LOVR作为分类损失,那么我们可以准确地恢复之前研究[4]推导出的风险。
这里省略了推论1的证明,因为通过直接将插入来验证它非常简单。推论1表明我们提出的URE是对Zhang等人[4]提出的URE的推广,可以与任意损失函数相容。
给定从已知类的分布中得到的一组标记数据,从检验分布得到的一组未标记数据,我们可以得到如下的URE,即的经验近似:
这样,我们就可以通过直接最小化来学习一个有效的多类分类器。由于对损失函数和模型没有限制,所以对于未标记数据的LAC,我们可以使用任何损失和任何模型 。
风险惩罚正则项略
数据集
我们使用从UCI机器学习知识库下载的6个常规尺度数据集,包括Har、Msplice、Normal、Optdigits、Texture和Usps。由于它们不是大规模的数据集,我们在这些数据集上训练一个线性模型。我们还使用了四个大规模的基准数据集,包括MNIST2[17]、Fashion-MNIST3[18]、Kuzushiji-MNIST4[19]和SVHN5[20]。
对于MINST, Fashion-MNIST和Kuzushiji-MNIST,我们训练了一个三层(d−500−k)的多层感知器(MLP)模型,并使用ReLU激活函数。
对于SVHN数据集,我们训练了一个VGG16模型[21]。
表1列出了所有使用数据集的简要特征。对于每个正则尺度数据集,选择一半的类作为增强类,其余的类被认为是已知类。此外,已标记、未标记和测试样例的数量分别设置为500、1000和1000。对于大规模数据集,我们选择6个类作为已知类,其他类作为增强类。
对于MNIST、Fashion-MNIST和Kuzushiji-MNIST,标记、未标记和测试示例的数量分别设置为24000(每个已知类4000个)、10000(每个类1000个)和1000(每个类100个)。
对于SVHN,标记、未标记和测试示例的数量分别设置为24000(每个已知类4000个)、25000(每个类2500个)和1000(每个类100个)。为了全面评估我们提出的方法的性能,我们报告了10次试验的标准差均值,评估指标包括准确性、Macro-F1和AUC。