
0x01. 背景





0x02. 模拟数据

我们描述SCM框架如下, f p 1 , p 2 f_{p1,p2} fp1,p2是对模型增加的噪声,表示为: V i s i o n = V n + f p 1 , p 2 ( T r e a t m e n t , C o n d i t i o n ) Vision=V_n+f_{p1,p2}(Treatment, Condition) Vision=Vn+fp1,p2(Treatment,Condition)。我们对三个观测变量的原始特征 N T , N C , N V N_T,N_C,N_V NT,NC,NV加噪声进行采样,目标变量的Vision则是 N V N_V NV加上其输入节点的噪声。

T r e a t m e n t = N T Treatment=N_T Treatment=NT~0,1或2的概率分别为33%:33%的用户什么都不做,33%的用户选择选项1,33%的用户选择选项2。这与患者是否患有罕见疾病无关。

C o n d i t i o n = N C Condition=N_C Condition=NC~伯努利(0.01):患者是否有罕见病。

V i s i o n = N V + f p 1 , p 2 ( T r e a t m e n t , C o n d i t i o n ) = N V − P 1 ( 1 − C o n d i t i o n ) ( 1 − T r e a t m e n t ) ( 2 − T r e a t m e n t ) + 2 P 2 ( 1 − C o n d i t o n ) T r e a t m e n t ( 2 − T r e a t m e n t ) + P 2 ( 1 − C o n d i t i o n ) ( 3 − T r e a t m e n t ) ( 1 − T r e a t m e n t ) T r e a t m e n t − 2 P 2 C o n d i t o n T r e a t m e n t ( 2 − T r e a t m e n t ) − P 2 C o n d i t o n ( 3 − T r e a t m e n t ) ( 1 − T r e a t m e n t ) T r e a t m e n t Vision=N_V+f_{p1,p2}(Treatment, Condition)=N_V-P_1(1-Condition)(1-Treatment)(2-Treatment)+2P_2(1-Conditon)Treatment(2-Treatment)+P_2(1-Condition)(3-Treatment)(1-Treatment)Treatment-2P_2ConditonTreatment(2-Treatment)-P_2Conditon(3-Treatment)(1- Treatment)Treatment Vision=NV+fp1,p2(Treatment,Condition)=NVP1(1Condition)(1Treatment)(2Treatment)+2P2(1Conditon)Treatment(2Treatment)+P2(1Condition)(3Treatment)(1Treatment)Treatment2P2ConditonTreatment(2Treatment)P2Conditon(3Treatment)(1Treatment)Treatment

P 1 P_1 P1是一个常数,在患者没有罕见的情况下,原始视力会下降,他没有服用任何药物。

P 2 P_2 P2是一个常数,根据患者是否患有这种疾病以及他们将使用的滴剂类型,原始视力将相应地增加或减少。更具体地说:

If Condition = 0 and Treatment = 1 then Vision = N_V + P_2

elIf Condition = 0 and Treatment = 2 then Vision = N_V - P_2

elIf Condition = 1 and Treatment = 1 then Vision = N_V - P_2

elIf Condition = 1 and Treatment = 2 then Vision = N_V + P_2

elIf Condition = 0 and Treatment = 0 then Vision = N_V - P_1

elif Condition = 1 and Treatment = 0 then Vision = N_V - P3

对于这样的罕见事件,比如有条件(condition =1,有1%的低概率),需要有大量的样本来训练模型,以便准确地反映这些罕见事件。这就是为什么我们在这里使用10000个样本来生成患者数据库。


from scipy.stats import bernoulli, norm, uniform
import numpy as np
from random import randint
import pandas as pd

n_unobserved = 10000
unobserved_data = {
   'N_T': np.array([randint(0, 2) for p in range(n_unobserved)]),
   'N_vision': np.random.uniform(0.4, 0.6, size=(n_unobserved,)),
   'N_C': bernoulli.rvs(0.01, size=n_unobserved)
P_1 = 0.2
P_2 = 0.15

def create_observed_medical_data(unobserved_data, name):
    observed_medical_data = {}
    observed_medical_data['Condition'] = unobserved_data['N_C']
    observed_medical_data['Treatment'] = unobserved_data['N_T']
    observed_medical_data['Vision'] = unobserved_data['N_vision'] + (-P_1)*(1 - observed_medical_data['Condition'])*(1 - observed_medical_data['Treatment'])*(2 - observed_medical_data['Treatment']) + (2*P_2)*(1 - observed_medical_data['Condition'])*(observed_medical_data['Treatment'])*(2 - observed_medical_data['Treatment']) + (P_2)*(1 - observed_medical_data['Condition'])*(observed_medical_data['Treatment'])*(1 - observed_medical_data['Treatment'])*(3 - observed_medical_data['Treatment']) + 0*(observed_medical_data['Condition'])*(1 - observed_medical_data['Treatment'])*(2 - observed_medical_data['Treatment']) + (-2*P_2)*(unobserved_data['N_C'])*(observed_medical_data['Treatment'])*(2 - observed_medical_data['Treatment']) + (-P_2)*(observed_medical_data['Condition'])*(observed_medical_data['Treatment'])*(1 - observed_medical_data['Treatment'])*(3 - observed_medical_data['Treatment'])
    dfs = pd.DataFrame(observed_medical_data)
    dfs.to_csv(name, index=False)
    return pd.DataFrame(observed_medical_data)

medical_data = create_observed_medical_data(unobserved_data, 'patients_database.csv')


num_samples = 1
original_vision = np.random.uniform(0.4, 0.6, size=num_samples)
def generate_specific_patient_data(num_samples):
    return create_observed_medical_data({
    'N_T': np.full((num_samples,), 2),
    'N_C': bernoulli.rvs(1, size=num_samples),
    'N_vision': original_vision,

specific_patient_data = generate_specific_patient_data(num_samples, "newly_come_patients")

0x03. 读取正常数据


import pandas as pd

medical_data = pd.read_csv('patients_database.csv')


Condition Treatment Vision
0 0 2 0.223475
1 0 2 0.197306
2 0 0 0.101252
3 0 1 0.703056
4 0 0 0.020249
medical_data.iloc[0:100].plot(figsize=(15, 10))


0x04. 建模


import networkx as nx
import dowhy.gcm as gcm

causal_model = gcm.InvertibleStructuralCausalModel(nx.DiGraph([('Treatment', 'Vision'), ('Condition', 'Vision')]))
gcm.auto.assign_causal_mechanisms(causal_model, medical_data)


gcm.fit(causal_model, medical_data)


0x05. 读取异常数据

specific_patient_data = pd.read_csv('newly_come_patients.csv')


Condition Treatment Vision
0 1 2 0.857103

0x06. 回答爱丽丝的反事实问题



counterfactual_data1 = gcm.counterfactual_samples(causal_model,
                                                  {'Treatment': lambda x: 1},
                                                  observed_data = specific_patient_data)

counterfactual_data2 = gcm.counterfactual_samples(causal_model,
                                                  {'Treatment': lambda x: 0},
                                                  observed_data = specific_patient_data)

import matplotlib.pyplot as plt

df_plot2 = pd.DataFrame()
df_plot2['Vision after option 2'] = specific_patient_data['Vision']
df_plot2['Counterfactual vision (option 1)'] = counterfactual_data1['Vision']
df_plot2['Counterfactual vision (No treatment)'] = counterfactual_data2['Vision']

df_plot2.plot.bar(title="Counterfactual outputs")
plt.ylabel('Eyesight quality')

