因果推断深度学习工具箱 - CounterFactual Regression with Importance Sampling Weights

文章名称

CounterFactual Regression with Importance Sampling Weights

核心要点

文章主要针对binary treatment的场景,能够用来估计CATE(当然也可以估计ATE)。作者基于CFR[1],提出利用上下文感知的重要性采样来取代CFR的固定权重,来平衡selection bias。相比于BNN和CFR利用频率统计得到的样本权重,文章提出的方法能够实现selection bias的平衡,弥补IPM loss较小平衡能力不足的问题。CFR-IS采用两阶段交替学习。首先,利用给定权重,训练类似BNN和CFR的loss。随后,通过最小化NLL得到更优的权重。

方法细节

问题引入

BNN和CFR主要利用IPM来平衡不同treatment下的分布差异,具体loss如下图所示。但是由于这种平衡是建立在的联合分布上的,的影响可能会被忽略,而且高维特征会导致有treatment引起的分布距离比较小,不能够提供足够的loss,来进行selection bias的平衡。

CFR loss

同时,BNN和CFR在构建factual loss(估计样本实际输出)的时候,采用了频率统计得到的权重,即图中的 ,其计算方法如下图所示。可以看出这个weight是一个频率统计值,本质是一个propensity score的倒数。
CFR weight

CFR weight(2)

而经过loss的改写,发现这部分权重的目标是平衡样本不均(参见引用[1]),并不能起到balancing当中的re-weigthing的作用。因此,总体作者认为对selection bias的矫正是不充分的。所以,提出利用重要性采样的方法来学习样本权重实现不同treatment下的covariates均衡(大家都是这条路,做法不同而已)。
CFR loss reformation

具体做法

因此,作者把两个不同的treatment下的分布,看做是两个不同分布的采样。为了对齐两个分布的学习效果,我们把counterfactual的covariates分布当做是目标分布,把实际观测到的样本分布当做采样分布。例如,当我们处理的数据是,的covariates分布就是采样分布,而是目标分布。

importance sampling

当控制住 之后,下图中因果图的后门被阻断(后门准则),那么 与 是独立的。
belif net

因此,得到不同treatment下 和 的联合分布的比值等于不同treatment下 的比值。这样我们构造了一个有covariates得到的隐向量 决定的重要性采样权重。
counterfactual IS

为了能够在观测数据上也表现得好(也就是预测好factual),作者在权重上加1,表示采样分布和目标分布是同一个。
weight

但是,我们发现直接估计这个weight不现实,因为是要估计一个隐向量在不同treatment下出现的概率的比值。无论是直接估计概率密度函数,还是用高斯建模概率的密度函数要么计算量大,要么假设太强,不准确。所以作者采用贝叶斯法则转化了weight的估计方式,如下图所示。其中, 表示propensity score,可以用LR或者神经网络得到。
weight reformation

propensity \pi

学习propensity的loss就是简单的NLL。作者采用交替优化CFR loss和propensity loss的方法进行学(也许可以一起学,类似Dragnnet)。
propensity loss

具体的网络结构如图所示,
network structure

代码实现

pseudo code

(留坑待填...)

心得体会

why IS work?

个人理解,IS就是把眼分布的数据用来换到目标分布来估计目标结果。这里weight是用在factual loss的那个部分,也就是说,我们假设样本可能来自counterfactual分布,在这种情况下还用观测结果作为事实来代表counterfactual的值,就需要用IS。并且IS之后,就可以把估计factual loss当做是在估计counterfactual loss。

add 1 to weight

在权重上+1,就把一个样本分成了两个。因为,。本质是表示如果这个样本实际就是从观测分布来的,那么就不需要加权,但需要被用来估计factual。

文章引用

[1] Shalit, U., Johansson, F.D., & Sontag, D. (2017). Estimating individual treatment effect: generalization bounds and algorithms. ICML.

你可能感兴趣的:(因果推断深度学习工具箱 - CounterFactual Regression with Importance Sampling Weights)