因果推断深度学习工具箱 - Reducing Selection Bias in Counterfactual Reasoning for Individual Treatment Effect...

(代码实现的坑待填...,日更太难了...)

文章名称

Reducing Selection Bias in Counterfactual Reasoning for Individual Treatment Effects Estimation

核心要点

文章仍然关注binary treatment情境下的CATE估计。作者通过AE结合利用Pearson Correlation Coefficient的正则化,鼓励模型对covariates进行分解,从而学习两组不同的变量,一组和outcome的treatment assignment相关(group A),另一组与selection bias和outcome prediction都相关(group BC),最终用group BC来同时平衡selection bias并预测outcome。

方法细节

问题引入

文章来自NeurIPS 2019 CausalML Workshop。相比于通过balancing with representation learning,其实很多时候,我们把一些只影响potential outcome估计的covariates也当做是confounder来做adjustment,导致在学习样本平衡的时候存在噪声,因果效应的估计能力变差。从因果图的角度我们可以把confounder分为3类,第一类是只影响treatment assignment的;第二类是confounder,不但印象treatment assignment,也影响outcome;第三类则只影响outcome,具体因果图,如下图所示。作者期望把第一类和第二、三类covariates区分开,从而减少第一类covariates对potential outcome预测带来的噪声(因为我们不关心是不是哪些虽然影响策略分配,但完全不影响outcome的特征,他们不会带来偏差)。这个因果分解的思路最先出现在引用文章[1]里(后面会讲,其实这个思路还不完整,后续会介绍更完善的covariates分解的相关文章),不同的是这篇文章把BC合并在了一起,并且使用了不同的正则化方法 -- Pearson Correlation Coefficient。


covariates de-correlation

具体做法

实际的网络结构如图所示。首先,通过一个AE,学习样本表示,样本表示由两部分向量组成。随后,利用学到的传递给outcome预测网络,进行不同counterfactual的预测。不知道有没有同学有似曾相识的感觉。大概还是自监督学习还没有兴起的时候(约2018-19年),曾经流行用AE在大量的无标签样本上进行重构损失的训练,然后利用训练的得到的隐向量,也就是这里的,来辅助做downstream的无监督学习(表示学习)。这种类型无监督结合有监督的方法在NLP,CV都有使用,比如做文本分类。后来还延伸出了很多方法,诸如先做无监督主题模型,学到的主题向量做文本分类(扯远了,回到正题...)。本质是通过引入covariates de-correlation的辅助任务,来消除selection bias,只是这个辅助任务比其他的任务要聪明,因为不但纠正了偏差,同时减少了噪声,同时符合因果图的理念(后边会看到更精妙的,比如去掉无意间引入的collidor)。

Network Architecture,RSB-Net

然而,仅凭这样的网络,是不可能达到很好估计causal effect的效果的,不然不就没有causal什么事儿了... 回想,causal inference的两个主要问题,1)missing counterfactual;2)selection bias。这两个问题还是需要通过loss function来解决。方法的整体loss如下,其中, 是无监督表示学习的重构损失; 没啥好说的,是factual的估计损失(也就是观测数据预测的准不准); 是分布距离损失,用来度量不同treatment下covariates分布的差异性,这个在之前介绍BNN的那篇完章里有些(理论证明的坑还没有填上...,容证明再飞一会儿...);而 就是文章的核心要点Pearson Correlation Coefficient。
loss function

重构损失 ,是标准的 损失,度量covariates的重构能力,保证AE能够充分学习(这里也许可以采用其他的AE,当然已经有用VAE做的了)。
预估损失 ,是BNN中提到的加权 损失。
prediction loss

prediction weight

分布差异损失,也是BNN中的Integral Probability Metric Loss。

IPM loss

de-correlation损失函数,是利用两个不同向量组(A和BC)的皮尔逊相关系数作为损失函数,当这个损失达到最小的时候,两个向量组线性无关。其中,指的是向量中的第个元素。是指第个样本的隐向量表示,是所有样本的平均,其他同理。

PCC loss

代码实现

文章伪代码参见下图(实际代码的坑后续再填...)。


pseudo code

心得体会

unsupervised assassinated supervised learning

文章用到的类似无监督辅助有监督学习的思路,来帮助更准确的估计potential outcome。本质是寻找了更多的内在信息或结构,来引导potential outcome不要走偏(消除selection bias)。这个和自监督中寻找相关性的思路很吻合,也许自监督与causal inference结合的方法已经在路上了。

linear independent

文章虽然通过PCC让两个向量组A和BC线性无关,但是在现实世界里covariates之间的非线性关系是存在的,也是神经网络的优势之一。所以,这种损失的de-correlation性能可能比较有限。

文章引用

[1] Negar Hassanpour and Russell Greiner. Counterfactual regression with importance sampling weights. In Proceedings of the Twenty-Eighth International Joint Conference on Artificial Intelligence, IJCAI-19, pages 5880–5887, 7 2019.
[2] Shalit, U., Johansson, F.D., & Sontag, D. (2017). Estimating individual treatment effect: generalization bounds and algorithms. ICML.

你可能感兴趣的:(因果推断深度学习工具箱 - Reducing Selection Bias in Counterfactual Reasoning for Individual Treatment Effect...)