因果推断深度学习工具箱 - Perfect Match: A Simple Method for Learning Representations For Counterfactual Infe...

文章名称

Perfect Match: A Simple Method for Learning Representations For Counterfactual Inference With Neural Networks

核心要点

现有的深度学习的overly complex,作者通过propensity matching的方法,用目标样本其他treatment下的最近邻样本,构造训练的mini-batch,通过数据增广的方式来解决观测数据下因果推断的2个基本问题,1)缺失的反事实;2)混淆偏差。
比较大的优势是,这种方法不同于介绍过的文章,可以应用于multiple treatment。

方法细节

问题引入

文章建立在potential outcome框架下,并且需要满足unconfoundness的假设,即。为了需要估计因果效应,比较老的方法采用直接建模的方式,比如,也就是我们常说的single learner(如果两个带有下表就是T-learner)。这种建模方式的弊端是高维的特征会淹没低纬度的干预。
如果采用T-learner,不会存在干预被淹没的问题,也比较灵活,却引入了模型误差带来的因果效应估计的偏差,并且牺牲了统计效率,不能够充分利用样本。

具体做法

首先,作者扩展了TARNET,把two heads扩展成为multiple heads,模仿TARNET解决treatment在维度较高的时候,被淹没的情况。但是这个改进非常subtile[汗]。
其次,作者利用propensity score做balancing,构造虚拟的随机实验mini-batch。其实是利用最近邻matching的方法,做数据增广,期望在梯度回传的时候减少overfit,来解决由于混淆变量引起的训练样本分布不均,以及预测时分布迁移的问题。
同时,作者定义(拓展)了一些评价指标,首先,利用真实值和估计值,拓展了PEHE到,其中,在multiple treatment的时候,采用的是pairwise的平均值。这种指标需要我们知道真实的各种counterfactual,除非模拟数据,不然是不现实的。因此,模型选择的部分,作者也提出了基于NN的,

$\epsilon_{PEHE}$
metrics for multiple treatments

NN-PEHE

最后作者也证明了为什么这样的训练数据下,利用SGD能够得到causal effect的一致性估计。证明的核心逻辑是,利用各种因果效应可以被识别的假设,推导出我们是在做条件期望的极限。当N趋于无穷大时,极大概率会有一个样本是和当前样本特征一模一样,但treatment不一样的。我们可以利用这样的样本估计因果效应。个人觉得,建立在positive的假设下,这个证明应该是没问题的。。

proof of consistency

代码实现

文章中的伪代码,思路上还是比较直接的,每个mini-batch,利用propensity score寻找最近的样本,返回mini-batch。后续直接用改进的TARNET进行训练。


pseudo code

To be continued...

心得体会

model selection criteria

文章另外比较大的贡献是提供了一些模型评价指标,可以用来做模型选择,并且公开了可以用来验证multiple treatment下模型性能的基准数据集。虽然个人觉得,其实就是作者训练的思路,有点作弊的嫌疑。但是,还是对观测数据下的模型筛选,提供了一个思路(虽然这个思路,很在就有了,参见reference[1],但是作者详细定义了指标,也与非nn的指标进行了统一)。

nearest neighbor matching

构造mini-batch的时候,可以采用多种matching的方法,包括最近邻,k近邻等等,甚至不用propensity score作为balancing score,这些方法都可以从传统的balancing里借鉴,甚至结合一些其他的balancing weighting学习的方法(后续会介绍,比如利用adversarial training)。这种trick也许在工业界,能有不错的效果。
同时,这种方法和另外一些新兴的imputing的方法有异曲同工之妙。

matching in minibatch&efficient in heavy overlap region

个人理解PM是matching的一种minibatch版本。在样本特征分布重合度较高的地方,会被加强。因为特征分布重合度较高意味着对每一个样本,有充足的其他treatment下样本可以用来学习反事实。最极端的情况是,正好有特征完全重合的样本,可以用来估计该样本的causal effect。之前介绍的propensity dropout也是希望充分利用overlap度较高的样本训练模型,从这个角度说,两偏文章分别利用了两种深度学习技巧augmentation和dropout来解决因果推断的基本问题,简单直接好理解,角度也比较新颖。
另外,考虑到神经网络需要大量的样本进行训练,propensity dropout确实也可能存在作者所说的样本利用率欠缺的问题,考虑到神经网络需要大量的样本进行训练。其实也就是深度神经网络的训练技巧,数据增广方法的各种花样也许都可以用来结合一下构造样本。
作者也提到mini-batch的方法类似于minibatch sampling strategy,只不过是用在了causal inference的场景。这种mini-batch的方法优于整体做augmentation,因为,整体augmentation之后,还需要再采样mini-batch,相同covariates的样本可能并不会被分到同一个mini-batch,反而没有起到虚拟随机实验的模拟效果。

simple to use

PM方法确实非常简单直接,因为不需要改变网络结构、损失函数,并且没有添加任何额外的计算,所以理论上是可以和任何神经网络相关的causal inference方法组合的。但是,由于训练是改变了样本周边的分布,相当于加权了和当前样本相关的周边的别的treatment的样本,如果和其他调整样本分布的方法,比如re-weighting的方法一起使用时,需要考虑re-weighting的学习过程是否收到影响。

文章引用

[1] Kapelner, A., Bleich, J., Levine, A., Cohen, Z., DeRubeis, R., & Berk, R. (2021). Evaluating the Effectiveness of Personalized Medicine With Software. Frontiers in Big Data, 4.
[2] Shalit, U., Johansson, F.D., & Sontag, D. (2017). Estimating individual treatment effect: generalization bounds and algorithms. ICML.
[3] https://github.com/d909b/perfect_match/tree/master/perfect_match/models

你可能感兴趣的:(因果推断深度学习工具箱 - Perfect Match: A Simple Method for Learning Representations For Counterfactual Infe...)