在本案例研究中,我们将使用来自401(k)分析的真实数据来解释如何使用因果库来估计平均治疗效果(ATE)和条件ATE (CATE)。
Variable Name | Type | Details |
e401 | Treatment | eligibility for the 401(k) plan |
net_tfa | Outcome | net financial assets (in USD) |
age | Covariate | Age |
inc | Covariate | income (in USD) |
fsize | Covariate | family size |
educ | Covariate | education (in years) |
male | Covariate | is a male? |
db | Covariate | defined benefit pension |
marr | Covariate | married? |
twoearn | Covariate | two earners |
pira | Covariate | participation in IRA |
hown | Covariate | home owner? |
hval | Covariate | home value (in USD) |
hequity | Covariate | home equity (in USD) |
hmort | Covariate | home mortgage (in USD) |
nohs | Covariate | no high-school? (one-hot encoded) |
hs | Covariate | high-school? (one-hot encoded) |
smcol | Covariate | some-college? (one-hot encoded) |
该数据集可从’ hdm https://rdrr.io/cran/hdm/man/pension.html ’ __ R包中在线公开获得。为了更方便的做实验,经过一系列的实验,将数据下载至本地进行实验。具体如何搞到的数据,参考整理的博客:hdm数据R语言获取教程
import pandas as pd
df = pd.read_csv("data/pension.csv")
ira | a401 | hval | hmort | hequity | nifa | net_nifa | tfa | net_tfa | tfa_he | tw | age | inc | fsize | educ | db | marr | male | twoearn | dum91 | e401 | p401 | pira | nohs | hs | smcol | col | icat | ecat | zhat | net_n401 | hown | i1 | i2 | i3 | i4 | i5 | i6 | i7 | a1 | a2 | a3 | a4 | a5 | |
0 | 0 | 0 | 69000 | 60150 | 8850 | 100 | -3300 | 100 | -3300 | 5550 | 53550 | 31 | 28146 | 5 | 12 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 3 | 2 | 0.273178 | -3300 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 |
1 | 0 | 0 | 78000 | 20000 | 58000 | 61010 | 61010 | 61010 | 61010 | 119010 | 124635 | 52 | 32634 | 5 | 16 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 4 | 4 | 0.386641 | 61010 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |
2 | 1800 | 0 | 200000 | 15900 | 184100 | 7549 | 7049 | 9349 | 8849 | 192949 | 192949 | 50 | 52206 | 3 | 11 | 0 | 1 | 1 | 1 | 1 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 6 | 1 | 0.533650 | 8849 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 |
3 | 0 | 0 | 0 | 0 | 0 | 2487 | -6013 | 2487 | -6013 | -6013 | -513 | 28 | 45252 | 4 | 15 | 0 | 1 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 5 | 3 | 0.324319 | -6013 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |
4 | 0 | 0 | 300000 | 90000 | 210000 | 10625 | -2375 | 10625 | -2375 | 207625 | 212087 | 42 | 33126 | 3 | 12 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 4 | 2 | 0.602807 | -2375 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
import networkx as nx
import dowhy.gcm as gcm
treatment_var = "e401"
outcome_var = "net_tfa"
covariates = ["age","inc","fsize","educ","male","db",
edges = [(treatment_var, outcome_var)]
edges.extend([(covariate, treatment_var) for covariate in covariates])
edges.extend([(covariate, outcome_var) for covariate in covariates])
causal_graph = nx.DiGraph(edges)
gcm.util.plot(causal_graph, figure_size=[20, 20])
import matplotlib.pyplot as plt
cols = [treatment_var, outcome_var]
for i, col in enumerate(cols):
causal_model = gcm.StructuralCausalModel(causal_graph)
causal_model.set_causal_mechanism(treatment_var, gcm.ClassifierFCM(gcm.ml.create_random_forest_classifier()))
causal_model.set_causal_mechanism(outcome_var, gcm.AdditiveNoiseModel(gcm.ml.create_random_forest_regressor()))
for covariate in covariates:
causal_model.set_causal_mechanism(covariate, gcm.EmpiricalDistribution())
df = df.astype({treatment_var: str})
gcm.fit(causal_model, df)
Fitting causal mechanism of node smcol: 100%|██████████| 18/18 [00:06<00:00, 2.68it/s]
在计算CATE之前,我们首先将家庭划分为收入百分位数的等宽箱(equi-width bins)。这使我们能够研究对不同收入群体的影响。
import numpy as np
percentages = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
bin_edges = [0]
bin_edges.extend(np.quantile(df.inc, percentages[1:]).tolist())
bin_edges[-1] += 1 # adding 1 to the last edge as last edge is excluded by np.digitize
groups = [f'{percentages[i]*100:.0f}%-{percentages[i+1]*100:.0f}%' for i in range(len(percentages)-1)]
group_index_to_group_label = dict(zip(range(1, len(bin_edges)+1), groups))
def estimate_cate():
samples = gcm.interventional_samples(causal_model,
{treatment_var: lambda x: np.random.choice(['0', '1'])},
eligible = samples[treatment_var] == '1'
ate = samples[eligible][outcome_var].mean() - samples[~eligible][outcome_var].mean()
result = dict(ate = ate)
group_indices = np.digitize(samples['inc'], bin_edges)
samples['group_index'] = group_indices
for group_index in group_index_to_group_label:
group_samples = samples[samples['group_index'] == group_index]
eligible_in_group = group_samples[treatment_var] == '1'
cate = group_samples[eligible_in_group][outcome_var].mean() - group_samples[~eligible_in_group][outcome_var].mean()
result[group_index_to_group_label[group_index]] = cate
return result
group_to_median, group_to_ci = gcm.confidence_intervals(estimate_cate, num_bootstrap_resamples=100)
{'ate': 6519.046476486404, '0%-20%': 3985.972442541254, '20%-40%': 3109.9999288096888, '40%-60%': 5731.625707624532, '60%-80%': 7605.467796966453, '80%-100%': 11995.55917989574}
{'ate': array([4982.99412698, 8339.97497725]), '0%-20%': array([2630.16909916, 5676.94495668]), '20%-40%': array([1252.7312225 , 5215.15452742]), '40%-60%': array([3533.43542901, 8243.86661569]), '60%-80%': array([ 4726.56666574, 10603.23313684]), '80%-100%': array([ 4981.36999637, 19280.14639468])}
如置信区间所示[4982.99, 8339.97]
,401(k)资格对净金融资产的平均处理效果为正。 现在,让我们画出不同收入群体的CATEs,以便清楚地了解情况。
fig = plt.figure(figsize=(8,4))
for x, group in enumerate(groups):
ci = group_to_ci[group]
plt.plot((x, x), (ci[0], ci[1]), 'ro-', color='orange')
ax = fig.axes[0]
plt.xticks(range(len(groups)), groups)
plt.xlabel('Income group')
plt.ylabel('ATE of 401(k) eligibility on net financial assets')