2019-12-30 13:04:12
人工智能顶会 ICLR 2020 将于明年 4 月 26 日于埃塞俄比亚首都亚的斯亚贝巴举行,不久之前,大会官方公布论文接收结果:在最终提交的 2594 篇论文中,有 687 篇被接收,接收率为 26.5%。本文介绍了华为诺亚方舟实验室被 ICLR 2020 接收的一篇满分论文。
论文地址:https://arxiv.org/pdf/1906.04477.pdf
因果研究作为下一个潜在的热点,已经吸引了机器学习/深度学习领域的的广泛关注,例如 Youshua Bengio 和 Fei-Fei Li 近期都有相关的工作。因果研究中一个经典的问题是「因果发现」问题——从被动可观测的数据中发现潜在的因果图结构。
在此论文中,华为诺亚方舟实验室因果研究团队将强化学习应用到打分法的因果发现算法中,通过基于自注意力机制的 encoder-decoder 神经网络模型探索数据之间的关系,结合因果结构的条件,并使用策略梯度的强化学习算法对神经网络参数进行训练,最终得到因果图结构。在学术界常用的一些数据模型中,该方法在中等规模的图上的表现优于其他方法,包括传统的因果发现算法和近期的基于梯度的算法。同时该方法非常灵活,可以和任意的打分函数结合使用。
模型定义和问题
我们假设以下常用的数据生成模型:给定一个有向无环图(DAG),每个节点对应一个随机变量,每个变量的观测值是图中父亲变量的函数加上一个独立的噪声,即
这里噪声 n_i 是联合独立的。如果所有的函数都是线性的且噪声是高斯的,则上述模型为标准的线性高斯模型。当函数为线性但噪声为非高斯函数时,上述模型为线性非高斯加性模型(LiNGAM),在一定的条件下是可以识别出真实的 DAG。
我们目前考虑所有的变量都是一维的实变量;给定一个合适的打分函数则可以直接扩展到多维变量的情形。在固定的函数和噪声分布下,我们的观测数据是根据上述模型在某个未知的 DAG 上独立采样得到。因果发现的目的就是使用这些观测的数据来推断真实的因果 DAG。
背景介绍
打分法是因果发现算法中一类常用的方法:给每个有向图打分(通常基于观测数据计算得到),然后在所有的 DAG 中进行搜索取得最好分数的 DAG:
尽管有很多已经深入研究的打分函数,例如基于线性高斯模型的 BIC/MDL 和 BGe 分数,但上述问题通常是 NP-hard 的,因为 DAG 条件是一个组合问题,并且可能的 DAG 数量的随着图节点的个数增加而超指数增加。为了解决这个问题,大多数已有方法都依赖于局部启发式算法。
例如,贪婪等价搜索(GES)在添加一条边时显式检查 DAG 约束是否满足。GES 在适当的假设和极限数据量的情况下可以找到具全局最优值,但在有限样本的情况下无法得到保证。
最近,也有工作在线性数据模型上对上述的无环条件提出了一个等价的可微分函数,再选择适当的损失函数(例如最小二乘损失),上述问题可以转换为关于带权值的邻接矩阵的连续优化问题。后续的工作也采用 ELBO 和 negative log-likelihood 作为损失函数,并使用神经网络对因果关系进行建模。但是很多已有的得分函数没有显式的表示或者是非常复杂的等价损失函数,这样和上述连续的方法结合会比较困难。
基于强化学习的因果发现算法
我们提出一种基于 RL 的方法来搜索 DAG,整体框架图如下所示。基于随机策略的 RL 可以在给定策略的不确定性信息的情况下自动确定要搜索的位置,同时可以通过奖励信号来及时更新。在合成数据集和真实数据集上的实验表明,基于强化学习的方法大大提高了搜索能力,并且不会影响打分函数的选择。
基于自注意力机制的 Encoder-Decoder 模型
如上图所示,我们采用 Transfomer 中基于自注意机制的 encoder,而 decoder 则是通过建立成对的 encoder 输出之间的关系来生成图的邻接矩阵。为了得到 0-1 的邻接矩阵,我们将每个 decoder 的输出通过 logistic-sigmoid 函数,然后使用 Bernoulli 分布进行采样。
我们也尝试了其他的 decoder,例如 bilinear model 以及 Transformer 中的 decoder。我们实验发现上图中 decoder 的效果最好,可能是因为它的参数量比较少、更容易训练来找到更好的 DAG,而基于自注意力机制的 encoder 已经提供了足够的交互来探索数据之间的因果关系。
Reward
传统的 GES 会在每次添加一条边时显式的检查图是否有环,我们使用打分函数和基于有环性质的惩罚项来设计 reward,并允许生成的图在每次迭代中变化多条边。具体的形式如下:
其中第一项是得分函数,用于衡量给定有向图和观测数据的匹配程度,其他两个正项则衡量某些「DAGness」(给定的有向图距无环的某种度量,例如所有环上的长度之和),lambda_1 和 lamba_2 是惩罚项的权重。通过选择适当的惩罚权重,最大化 reward 等价于之前打分法的问题的形式。但是两个问题等价并不意味着使用 RL 来最大化 reward 就可以直接取得很好的结果:实际中,我们发现较大的惩罚权重可能会妨碍 RL 的探索,得到的因果图的得分通常比较差,而较小的惩罚值将导致有环的图。同时,不同的打分函数可能具有非常不同的范围,而两个惩罚项的值与打分函数是没有关系的。因此,我们将所有的打分函数调整到一定范围,并为惩罚权重设计一种在线更新策略。详细内容可以参见论文的第 5 章。
Actor-Critic 优化参数
我们采用策略梯度和随机优化的方法来优化以下目标:
其中 A 中有向图对应的 0-1 邻接矩阵。我们使用 Actor-Critic 来进行训练,同时还加了熵正则项来鼓励探索。尽管策略梯度方法仅在一定条件下能保证局部收敛,但是通过惩罚项系数的设计,在我们的实验中 RL 算法得到的图都是无环的。
最终输出
由于我们关注的是寻找得分最好的 DAG,而不是 policy,因此我们记录了训练过程中生成的所有的有向图,并选择具有最佳 reward 的图作为输出结果。实际上由于有限的数据,图中会包含一些真图里边不存在的边,因此需要进一步的减枝处理。
我们可以根据损失函数或者打分函数,使用贪婪方法来进行减枝操作。我们删除一个父亲变量并计算相应的结果,如果损失函数或者打分函数效果没有变差或者是在预先设定的范围内,就接受减枝的操作并继续下去。对于线性模型,可以通过和阈值比较的方法来进行减枝。
实验结果
在此工作中,我们使用 BIC 打分函数,并假设附加性的高斯噪声(实际中噪声可能是非高斯的)。考虑两种情况:不同的噪声方差,等价于 negative log-likelihood 加上一个对边的个数的惩罚项作为打分函数;以及相等的噪声方差,将得到最小平方损失加上边的个数的惩罚项。它们分别表示为 RL-BIC 和 RL-BIC2。
我们的方法与传统方法(PC,GES,ICA-LiNGAM 和 CAM)以及最近基于梯度的方法(NOTEARS,DAG-GNN 和 GraN-DAG)在学术界常用的一些数据集上进行了比较。我们使用三个指标评估学到的图结构:错误发现率(FDR),正确率(TPR)和结构汉明距离(SHD)。SHD 是将得到的图转换为真实 DAG 的边添加,删除和反转操作的最少个数。
高斯和非高斯噪声的线性数据模型
我们首先考虑 12 个节点的有向图。图 2 显示了在一个线性高斯数据集上 RL-BIC2 的训练过程。我们采用 NOTEARS 和 DAG-GNN 在同样的数据集上使用的阈值来做减枝。在这个例子中,RL-BIC2 在训练过程中生成 683,784 个不同的图,远低于 12 个节点 DAG 的总数(约 5.22 * 10^26)。经过减枝的 DAG 和真实的图结构完全相同。
图 2:在线性高斯数据集上 RL-BIC2 的学习过程。
表 1 是我们在 LiNGAM 和线性高斯数据模型的实验结果。在该实验中,RL-BIC2 在两个数据模型上恢复了所有真实的因果图,而 RL-BIC 的表现稍差。尽管如此,在相同的 BIC 分数下,RL-BIC 在两个数据集上的表现均远好于 GES。
具有高斯过程的非线性模型
我们考虑一种非线性的数据模型,每个因果关系函数是从高斯过程中采样的一个函数。该问题被证明是可识别的,即可以从联合概率分布中识别出真实的图。我们使用和 GraN-DAG 一样的实验条件:10 个节点,40 条边的 DAG,并考虑 1000 个观测样本。实验结果如下表 3 所示。对于我们的方法,我们将高斯过程回归(GPR)与 RBF 核一起使用来建立因果关系模型。虽然观察到的数据是来自于高斯过程采样得到的函数,但这并不能保证具有相同核的 GPR 可以达到很好的结果。实际上,使用固定的核参数将导致严重的过度拟合,从而导致许多错误的边,这样训练结束最好 reward 对应的有向图通常不是 DAG。为此我们将数据归一化处理,并使用 median heuristics 来选择核参数。我们两种方法的表现都不错,其中 RL-BIC 的结果优于其他所有方法。
真实数据集
我们最后考虑 Sachs 数据集,通过蛋白质和磷脂的表达程度来发现蛋白质信号网络。我们将带有 RBF 内核的 GPR 应用于因果关系建模,对数据做归一化并使用基于 median heuristics 的核参数。我们使用和 CAM 及 Gran-DAG 中同样的减枝方法。实验结果见下表。与其他方法相比,RL-BIC 和 RL-BIC2 均取得了不错的结果。
结语
我们使用强化学习来搜索具有最佳分数的 DAG,其中 actor 是基于自注意力机制的 encoder-decoder 模型,而 reward 结合了预先给定的得分函数和两个惩罚项来得到无环图。在合成和真实数据集上,该方法均取得了很好的结果。在论文里,我们还展示了该方法在 30 节点的图上的效果,但是处理大规模的图(超过 50 个节点)仍然具有挑战性。尽管如此,许多实际的应用(例如 Sachs 数据集)的变量数都相对较少。此外,有可能将大的因果发现问题分解为较小的问题分别处理,基于先验知识或基于约束的方法也可以用来减少搜索空间。
当前的工作有几个未来改进的方向。在目前的实现中,打分函数的计算比训练神经网络会花费更多的时间,一个更有效率的打分函数将会大大提升目前算法的表现。其他 RL 算法也可以用来加速训练,例如 A3C。此外,我们观察到实验中使用的总迭代次数通常超过了需要的次数,我们也会研究如何进行 early stopping。