导语
感谢大家关注图智决策,本公众号为纯学术交流平台,主要为大家分享有关图神经网络、图上组合优化与决策、图上算法推理等前沿研究话题相关的论文及其解读。
本期话题为图上的多任务学习,笔者将带领大家一起解读来自ICLR2022的一篇Spotlight论文,论文题为《关系型多任务学习:建模数据和任务之间的关系》(RELATIONAL MULTI-TASK LEARNING: MODELING RELATIONS BETWEEN DATA AND TASKS),作者来自斯坦福大学计算机科学系Jure Leskovec团队,包括共同第一作者Kaidi Cao和Jiaxuan You,以及Jure Leskovec教授。用一句话总结这篇文章的工作的亮点就是:文章提出了一个构建在数据和任务上的知识图模型MetaLink可以用于多种场景下的多任务学习。
研究领域:多任务学习,图神经网络
曾利 | 作者
Frank | 审核
文章字数:9000字 建议阅读时长:20分钟
原文链接:https://openreview.net/forum?id=8Py-W8lSUgy
代码链接:https://github.com/snap-stanford/GraphGym/tree/meta_link
汇报PPT:请在关注图智决策公众号后,发送文字消息 metalink 获得下载链接
多任务学习中的一个关键假设是,在推理时,多任务模型只能访问给定的数据,而不能访问来自其他任务的数据标签。本文提供了一个扩展多任务学习的方法,以利用来自其他辅助任务的数据点标签,并且这种方式提高了新任务的性能。在这里,我们介绍了一种新的关系多任务学习场景,其中我们利用辅助任务的数据点标签来对新任务进行更准确的预测。我们开发了MetaLink,其中我们的关键创新是构建一个连接数据点和任务的知识图,从而允许我们利用辅助任务中的标签。知识图由两种类型的节点组成:(1)数据节点,其中节点特征是由神经网络计算的数据嵌入;(2)任务节点,每个任务的最后一层权重作为节点特征。此知识图中的边捕获数据任务关系,边标签捕获特定任务上的数据点的标签。在MetaLink下,我们将新任务重新定义为数据节点和任务节点之间的链接标签预测问题。MetaLink框架为从辅助任务标签到感兴趣任务的知识转移建模提供了灵活性。我们在生化和视觉领域的6个基准数据集上评估MetaLink。实验表明,MetaLink可以成功地利用不同任务之间的关系,在所提出的关系型多任务学习场景下优于最先进的方法,ROC AUC提高了27%。
为了对多任务学习的概念有个更加深入的理解,我们首先对多任务学习及其相关概念进行一下辨析。
图1 单任务学习与多任务学习的区别
来源:https://github.com/mbs0221/Multitask-Learning/blob/master/pdf/Multi-Task%20Learning-Theory,%20Algorithms,%20and%20Applications.pdf
如图1所示,传统的单任务学习场景下(图1a),在针对不同学习任务,需要建立各自的模型,逐一学习,且各个模型之间无法共享参数。而多任务学习(图1b)不同于单任务学习,它使用共享表示(模型)将一个问题与其他相关问题同时进行学习。
图2 不同学习方法的区别与联系
图2显示了与多任务学习的相关概念及其联系。从图中可以看出迁移学习的概念最广泛、其次为多任务学习、多标签学习、多类别学习等。
其中迁移学习的特点为:
(1)需要定义来源域和目标域
(2)模型在来源域上学习
(3)模型需要在目标域上进行泛化
多任务学习的特点为:
(1)对任务之间的关联关系建模
(2)同步完成多个任务上的学习
(3)不同任务的数据/特征不尽相同
多标签学习的特点为:
(1)对标签之间的关联关系建模
(2)同步完成多个标签上的学习
(3)不同标签具有相同的的数据/特征
多分类学习的特点为:
(1)对每个类别进行独立学习
(2)不同类之间是互斥的
关于如何从多个任务中进行学习这一话题,研究者从不同视角进行了探索,提出了包括多任务学习(Caruana, 1997)、元学习(Finn et al., 2017)和少样本学习(Vinyals et al., 2016; Cao et al., 2020a;b)等多种学习模式。虽然这些学习模式促进了多种考虑任务之间关系的学习模型(Chen et al., 2019; Zamir et al., 2018; Sener & Koltun, 2018; Lin et al.,2019; Ma et al., 2020),但这些模型却无法捕捉真实世界机器学习应用程序的全部复杂性。具体地说,当从多个任务中学习时,这些方法假设测试数据在对一个新任务进行预测时,无法使用来自其他任务的标签。然而,在许多实际应用中,这种假设过于简化了,忽略了对潜在有用知识的利用。
例如,多任务学习研究同时学习多个预测任务,以利用任务之间的共同特性。测试时,多任务模型可以在感兴趣的任务(即目标任务)上预测给定数据点所对应的标签,例如,预测化合物是否无毒。同时,人们还可以访问其他辅助任务上数据点的标签信息,例如,化合物在两次毒理学测试中有阳性结果。这样的辅助任务标签可以极大地帮助改进目标任务上的预测结果。
然而,目前的深度学习体系结构无法模拟辅助任务/标签与目标任务之间的这种知识迁移。仅仅将已知的标签拼接到输入特性有其局限性,特别是因为这样的标签通常比较难以获取,且尚不清楚如何将这些标签用于新任务和不可见任务。另一种思路是通过生成模型来建模这种灵活且有条件的推理任务 (Dempster et al., 1977; Koller & Friedman, 2009)。虽然生成模型功能强大,但它们是出了名的数据饥渴,因此为高维数据构建和训练生成模型非常困难(Turhan & Bilge, 2018)。
图3:MetaLink架构示意图
在关系型多任务场景下,模型学会在预测时加入辅助知识,以提升数据效率。具体而言,给定任务子集 上的观察值 及其标签,目标是建立一个模型,利用辅助任务标签 并对新任务进行预测(由于文章主要以二分类任务为例,因此标签为0或者1)。通常采用的方法是建立一个多头深度神经网络,为每个单独的任务 设置一个预测头。然而,这种方法不能使用辅助标签。相比之下,我们提出的MetaLink将每个任务的最后一层权重重新解释为任务节点,并创建一个知识图,其中途中的节点包括数据节点和任务节点,边上的标签则表示该任务与数据是否有关联(任务是否可以使用数据点上的标签)。在预测给定任务的数据点标签时,MetaLink可以使用来自其他任务上的标签,从而提高了在目标任务上的预测性能。
在这里,我们提出了一种新的多任务学习场景,称为关系型多任务学习。在我们的场景中,我们区分目标任务(例如,预测分子毒性)和辅助任务,前者是我们旨在预测的任务,后者是在推断时数据标签可用的任务。注意,在我们的场景下,每个数据点可能带有用于辅助任务的不同子集的标签。因此,模型的目标是通过在辅助任务的某些子集上利用给定数据点的标签来实现强大的预测性能。
为了解决关系型多任务学习问题,我们提出了一个名为MetaLink的通用判别模型,它可以显式地融合辅助任务中的知识。我们的核心创新是构建一个连接不同任务 和数据点 的知识图(图4)。 我们方法的第一步是获取取输入数据点 和特征提取器(即神经网络) ,得到其嵌入。然后我们构建了由两种节点组成的知识图:数据节点和任务节点。如果数据点 参与了任务 ,并且在任务上用到了的标签,则数据节点 与任务节点建立一条连边。我们将特征神经网络的最后一层的嵌入向量作为数据节点的初始输入特征,将的最后一层的权重作为任务节点的初始输入特征。
根据我们的知识图,我们将多任务学习问题重新表述为数据节点和任务节点之间的链路标签预测问题。这意味着在推断时,MetaLink能够使用给定数据点的所有信息(包括辅助任务上的标签)来预测新任务上的标签。我们通过图神经网络(GNN)解决了这个链路标签预测学习任务(Hamilton et al., 2017; He et al., 2019; Xue et al., 2021)。与之前的工作,如ML-GCN (Chen et al., 2019) 仅对任务之间的关系建模所不同的是,MetaLink允许对数据-任务、数据-数据和任务-任务三种关系进行灵活地自动建模。
我们在生物化学和视觉领域的6个基准数据集上在不同的场景下对MetaLink进行了评估。实验表明,MetaLink可以成功地利用不同任务之间的关系,且在我们新提出的关系型多任务学习场景下优于目前最先进的方法,ROC AUC提高高达27%。
图4:我们的MetaLink框架可以建模四种不同的多任务学习场景:其中表示数据节点而表示任务节点。蓝色表示在训练阶段可以使用到的数据/任务,白色表示仅在测试阶段可以使用到的数据/任务。在模型推理阶段(训练阶段和测试阶段),实线表示数据-任务对的标签已知,而虚线则表示数据-任务对的标签需要预测。
在这里,我们首先正式介绍关系型多任务学习的场景。假设我们有个机器学习任务,其中是介于1和之间的整数。我们提出从以下两个维度对不同的多任务学习场景进行分类:
(1)任务是否是关系型任务,即在推理时是否可以使用辅助任务标签;
(2)任务是否为元学习任务,即测试时的任务是否在训练时被看到。
因此可以根据以上两个标准划将多任务学习分为四种可能的场景(如图4所示),具体如下:
设表示输入, 表示与任务相关的标签,即~ 。标准的有监督多任务学习可以表示为:
训练集:
测试集:
其中表示输入和输出之间的连接。训练集和测试集是没有交集的数据点。为了简便起见,后续将用表示
在关系学习场景中,除了输入之外,我们假设在进行预测时还可以访问辅助任务标签。 和表示是与任务子集相关的整数的划分。具体来说,是指输入可以访问的任务索引, 是我们希望预测的任务索引,这两个集合是不重叠的,即 ,并且它们根据输入的不同而不同,即任务子集和可以不同。那么此时输入-输出对的形式为:
训练集:
训练集:
在元学习场景中,我们想要学习如何在测试时预测未知任务。形式上,设 , 表示可见任务(在训练时使用)和不可见任务(在测试时使用)的分区,其中;我们可以访问一批带有标签的样本作为支持集,而学习的目标是可以对查询集中的样本进行正确预测。
训练集:给定
预测
测试集:给定
预测
关系型元学习场景结合了关系型场景和元学习场景的特点。与元学习场景类似,我们的目标是预测测试时看不见的任务 ;与此同时,与关系型学习场景类似,我们还假设此时存在数量有限个带有标签数的辅助任务 。形式上,此时具有包括有一个支持集和查询集:
训练集:
给定
预测
测试集:
给定
预测
四、MetaLink模型框架
接下来,我们将描述MetaLink框架,该框架允许我们利用单个框架设置上述四种多任务学习场景。特别地,MetaLink将它们表述为异构知识图上的链路标签预测任务,这样,MetaLink就可以有效利用数据和任务的关系信息。
我们首先回顾一下神经网络的一般公式。给定数据点及其标签 ,一个深度学习模型可以表述为参数化的嵌入函数(可以是深度神经网络)和任务头,任务头只包含一个权重矩阵,而 将数据点映射到向量嵌入空间,而任务头负责将嵌入映射到预测值(即本文主要研究一维回归或者二元分类问题), 则,若任务头 包含多层转换时,我们有,其中可以是一个任意可微函数。在多任务学习环境中,研究者通常将一个神经网络分配多个任务头。假设我们有个任务,那么就会有 个任务头,使得 。
在这里我们观察到任务头的权重和数据点的特征嵌入在多任务预测任务中是对称的(由于点积的存在)。因此,与以往将权重 作为神经网络中的参数所不同,我们提出将权重作为另一种类型的输入来支持预测任务。从本质上讲,我们重新构造了一个任务头,将其从变为了 。在这个新的视角中,任务权重 和数据的特征嵌入 都被视为输入,这使我们能够构建一个更复杂的预测模型 ,其中包含两个主要步骤,即和 。一般来说, 具有与相似的模型复杂性, 而则为模型提供了额外的表达能力。
在MetaLink框架中,我们提出了一个基于任务权重 和特征嵌入的知识图。通过构建这个知识图,我们可以简洁地表示数据点和任务之间的关系,以及不同的多任务学习场景。具体地说,知识图帮助我们轻松地表达任何数据-任务关系(例如,一个数据点在给定的任务上有一个标签)、数据-数据关系(例如,两个数据点相似与否)或任务-任务关系(例如,不同任务的层次结构)。此外,知识图极大地简化了我们在第3节中概述的所有多任务学习场景;事实上,所有的场景都可以在不同知识图构建方法上的链路标签预测任务,具体如图4所示。
我们定义知识图为,其中为节点集, 为边集。我们定义了两种类型的节点:数据节点,任务节点 。然后,我们可以定义数据和任务节点之间的边为,在数据节点之间的连边为,任务节点之间的连边为。MetaLink框架可以处理所有三种类型的边缘;但是由于大多数基准数据集没有关于数据-数据或任务-任务之间的关系信息,因此我们将在重点讨论数据-任务关系 。具体来说,我们基于任务标签定义 ,即如果标签存在,我们将数据节点连接到任务节点。
给出构建的知识图的方法后,接下来我们来讨论MetaLink如何从构建的知识图中进行学习。
初始化节点/连边的特征
首先,我们初始化知识图的特征。具体来说,我们将数据节点特征初始化为由特征提取器计算出来的数据嵌入向量 ,即 。同时将已知的任务节点的特征初始化为任务头的权重 ,即 。在元学习场景下,未知的任务节点将会在测试时出现,因此我们将用常向量来初始化这些节点;由于节点的节点特征的构建过程是归纳式的,因此这样的初始化方法可以有效保证MetaLink可以泛化到未知任务。
通过异构图神经网络来进行预测
我们利用图神经网络(GNN)来实现构建在数据节点和任务节点上预测模型。GNN的目标是学习基于局部网络邻域迭代聚合的节点嵌入向量 而的第 次迭代,或者说第 层,可以写成:
其中为迭代第次后的节点嵌入, 为按照初始化步骤所描述的节点的初始化特征, 表示节点 的直接邻居,为聚合函数的缩写, 为消息函数的缩写。我们在已经构建的知识图之上执行个GNN层。在更新数据和任务节点嵌入后,我们可以通过对给定任务进行预测,其形式为:
一般来说,MetaLink可以使用任何符合公式1定义的GNN架构。我们在MetaLink中使用GraphSAGE层(Hamilton et al., 2017) (其中, 和为可以训练的参数):
接下来,我们讨论了在MetaLink中所使用的特殊GNN设计,该设计已被证明是成功的。
MetaLink中的特殊GNN设计细节
我们对方程1中的公式作了三个扩展。首先,由于在我们的知识图中有两种类型的节点,我们为不同的消息类型定义了不同的消息传递函数,即从数据节点到任务节点的消息,以及从任务节点到数据节点的消息。其次,我们在消息计算中加入边的特征。这些设计对我们的表述尤其重要,因为任务标签值被包括在边得特征中,并且应该在GNN消息传递过程中被考虑。具体地,我们将式2扩展为:
其中表示消息类型(即从任务到数据,还是从数据到任务),是一个额外的可训练权值,允许任务标签参与消息传递。最后,我们让每个GNN层进行预测,并将其相加作为最终预测;这种方法从节点邻居不同跳的混合信息中得到最终的预测结果。我们观察到这种多层集成技术可以帮助MetaLink做出稳健的预测。
在这里,我们详细描述了如何将MetaLink应用于算法1中的关系元学习场景。在训练时,由于大多数现有的多标签数据集不是为元学习场景设计的,我们通过采样一个支持集和查询集。我们确保抽样的元任务和辅助任务没有交集。首先利用特征提取器进行数据嵌入,并利用该嵌入对数据节点进行初始化。要初始化任务节点,我们可以:(1)如果任务是元学习任务,则使用常向量来初始化;若不是元学习任务,则使用训练过的权重。最后我们基于支持集上的标签值和查询集上的标签值来构建数据-任务对之间的连边,基于以上步骤现在我们构建了知识图的所有组成部分,就可以应用预测图模型来学习表达数据和任务节点嵌入并进行预测。
在测试时,我们使用与训练时相同的流程来构建知识图并运行推理。有关关系型学习或元学习场景的流程,请参考附录A。
图5:MetaLink在关系型元学习下的训练流程
在这里,我们通过实验证明了我们提出的MetaLink可以灵活地处理不同的学习场景,并且可以有效利用辅助任务中的知识。我们首先评估了我们在Tox21 (Huang et al., 2016)、Sider(Kuhn et al., 2016)、ToxCast(Richard et al., 2016)和MS-COCO (Lin et al. 2014)数据集上的算法,这些数据集对关系多任务学习具有各种可控设置。为了进一步证明MetaLink的优势,我们还包括一个经过充分研究的任务:少样本学习的实验。我们的核心算法是使用PyTorch开发的(Paszke et al., 2017)。我们每个实验使用一个NVIDIA RTX 8000 GPU,最耗时的一个(MS-COCO)需要不到24小时。
数据集描述
我们使用四个广泛使用的多标签数据集模拟四个关系多任务学习场景(图2),比如Tox21(Huang et al., 2016)包含12个不同的毒理学实验,每个样本都有二元标签(活性/非活性)。Sider(Kuhn et al., 2016)是一个上市药物和药物不良反应(ADR)的数据库,分为27个任务。ToxCast(Richard et al., 2016)包含约8K对分子图和对应的617维二进制向量,代表不同的实验结果。Microsoft COCO (Common Objects in Context) (Lin et al.,2014)原本是一个大规模的对象检测、分割数据集。通过计算每个类型的对象是否存在于一个场景中作为一个单一的任务,它也作为默认的大规模数据集,用于在视觉中对标多标签分类。有80个二分类任务,平均每张图像有2.9个正标签。
实验设置
我们在图2中描述的所有四个学习场景上评估MetaLink,并将详细的场景总结如下。
有监督多任务学习场景:为了进行公平的比较,我们在关系设置中对带有未知标签的同一组任务进行评估。
关系型场景:我们假设每个样本可以访问每个数据集中20%的所有任务的标签。我们评估剩下的带有未知标签的任务。
元学习场景:训练时我们使用20%的任务,只在测试时评估hold out任务。我们使用256-shot设置,这意味着在测试时,我们使用256个数据点作为支持集来初始化不可见任务的原型。之所以shot的数量要比通常使用的少样本学习(1-shot,5-shot)多得多,是因为在某些任务中,积极的标签有时是稀疏的。
关系元学习场景: 训练时我们使用20%的任务,与元学习场景相同。在测试时,我们假设每个不可见任务可以访问可见任务的20%的标签。
基准线
尽管我们工作的主要动机是利用辅助任务标签,但我们仍然在标准有监督学习场景时包含了一些基线,以便对最新的结果进行基准测试。最简单的方法是
(1)经验风险最小化(Empirical risk minimization, ERM):在标准监督设置下训练具有交叉熵损失的网络;
(2)为分子设计的各种常用的图神经网络架构,MPNN (Gilmer et al.,2017)、DMPNN (Yang et al.,2019)、MGCN (Lu et al.,2019)、AttentiveFP (Xiong et al.,2019);
(3) GROVER (Rong et al.,2020)将消息传递网络集成到变压器式架构中。通过利用预训练,它在上述分子数据集上获得了最先进的结果;
(4) Baseline++ (Chen et al., 2018):由于之前没有针对多任务学习的寻址元设置的工作,我们将Baseline++从少镜头学习调整到该设置。我们首先在训练集中训练一个特征提取器。在测试时,我们使用支持集训练一个线性分类层。
生化数据集上的实验结果
MS-COCO数据集上的实验结果
MetaLink真的学会利用任务之间的相关性了吗?
图6: Sider数据集上27个任务之间的Pearon相关热力图
为了更好地理解我们算法的改进,我们首先在Sider上绘制皮尔逊相关热图(图6)。我们可以找到与其他任务相比平均相关性最高和最低的前3个任务。我们分别报告了MetaLink在这两个任务子集上的性能(表3)。我们观察到,MetaLink在相关性较高的任务上表现出更大的改进。这个实验验证了MetaLink能够像预期的那样学习利用任务之间的相关性。
MetaLink在使用不同比例的辅助任务标签时表现如何?
我们改变每个测试点附加标签的比例,并在下图中报告结果。当我们逐渐增加每个数据集中辅助标签的数量时,我们观察到一致的改进。实验表明,只要在数据集中添加辅助任务标签,MetaLink就能成功地利用边缘信息。对于最初添加的几个标签,改进通常是显著的。
由于关系多任务场景很新颖,所以我们可以比较的基线非常少。本节的主要目的是展示MetaLink在一个经过充分研究的问题中的优势:少样本学习。注意,上面的元学习场景和这里的少样本学习场景之间有细微的区别。对于少样本学习,对于所有与一个输入相关的链接标签预测任务,将只有一个正链接。通过使用交叉熵损失对所有链接标签预测进行建模,可以很容易地合并这种归纳偏差。
实验场景设置
我们在两个标准基准上评估性能:miniImageNet (Vinyals et al., 2016)和分级tiered-ImageNet (Ren et al., 2018)。我们与MatchNet (Vinyals et al., 2016)、Baseline++ (Chen et al., 2018)、MetaOptNet(Lee et al., 2019)和Meta-Baseline (Chen et al., 2020b)进行比较,它们只假设输入是一个矢量。
结果
下表显示了我们的MetaLink优于标准的少镜头学习基准测试。注意,如果我们设置KG Layer = 0,建议的MetaLink退化为元学习场景。实验清楚地证明了在最后一层构建知识图的好处。此外,作为一项消融研究,我们操纵了KG层的数量,发现在少样本图像识别场景中,叠加2 KG层比1 KG层有改进,这意味着非线性是有用的。我们没有观察到超过3层的进一步改进。
多任务学习是一种共同优化一组具有共享参数的任务的学习范式。人们普遍认为,不同任务之间的关系可以提高整体表现。一些工作将其视为一个多目标优化问题,并引入了多种基于梯度的方法来减少任务之间的负迁移(Fliege & Vaz, 2016; Lin et al., 2019)。其他工作使用特定的启发式为不同的任务分配(自适应)权重(Kendall et al., 2018; Chen et al., 2018)。我们的实证研究也与多标签学习密切相关,在多标签学习中,问题通常被分解为多个二元分类任务(Tsoumakas & Katakis, 2007)。有一个学习利用任务之间关系的工作线(Haller et al., 2021; Zamir et al., 2020). Wang et al. (2016)利用递归神经网络将标签转换为嵌入的标签向量,学习标签之间的相关性。最新的工作是ML-GCN (Chen et al., 2019),它使用GCN将标签图映射为一组相互依赖的分类器。虽然我们工作的主要动机也是关于利用任务之间的相关性,但我们的问题表述是新的,因此产生了新的算法。
之前有研究研究数据点或任务的图结构,图结构已被证明在某些任务中是有效的。Satorras & Estrach (2018)探索了仅用于少数镜头学习的数据点上的图神经表示。此外,一些作品研究了如何通过仅在任务节点/分类器上构建图来在任务之间传递知识(Liu et al., 2019; Chen et al., 2020a)。沿着这个方向,最近的工作不是构建一个完全连通的图,而是利用辅助任务结构/知识图来构建图(Chen et al., 2019; Lee et al., 2018)。相比之下,并且与上面的论文正交,我们将重点放在对最后一层进行重新解释的数据-任务关系建模上。此外,仅使用数据-任务关系,我们仍然能够通过高阶消息传递隐式地捕获数据-数据、任务-任务关系。
我们介绍了关系多任务场景,在这种场景中,方法需要学会利用辅助任务上的标签来预测新任务。这些场景在生物医学领域很有影响,因为不同任务的标签通常很少可用。为了解决这些场景下的学习任务,我们建议使用MetaLink,它足够通用,允许我们在单个框架中制定上述各种场景。我们证明了MetaLink可以成功地利用任务之间的关系,在提出的关系多任务学习场景下优于最先进的方法,ROC AUC提高了27%。我们将重点限制在建模数据-任务关系上,因为大多数基准数据集不包含数据-数据或任务-任务关系的信息,尽管MetaLink具有足够的表达能力来建模这种关系。我们将MetaLink的扩展留给以后更复杂的关系或任务。