SHAP模型:可解释机器学习模型

SHAP模型:可解释机器学习模型_第1张图片 小白进阶选手,如果写的内容有什么问题大家一起讨论学习呀 :)

模型介绍

首先个人理解SHAP模型是对机器学习模型进行解释的一个模型

SHAP模型:可解释机器学习模型_第2张图片

 上面这个图就是一个比较直观的解释

机器学习模型一般都是一个黑盒。比如某个模型要进行一些预测任务,首先对模型输入一些已知条件(Age=65,Sex=F,BP=180,BMI=40),然后模型根据输入进行训练,最终训练完的模型可以对该条件输出预测结果(Output=0.4)

所以这样模型只能得到最终的结果,至于模型内部是怎么计算的,输入的已知条件(Age=65,Sex=F,BP=180,BMI=40)是怎么对预测结果(Output=0.4)影响的,我们都没法知道

而SHAP模型就可以让我们知道这些已知条件到底对最终预测结果起到哪些影响(是对结果起到正向影响还是对结果起到了负向影响),除了SHAP模型,其实也有其他方法可以进行特征重要性的计算,比如下面这个表格里提到的,我们可以根据各种方法的优点选择适合的进行特征重要性计算

SHAP模型:可解释机器学习模型_第3张图片

而本文主要介绍的SHAP 属于模型事后解释的方法,它的核心思想是计算特征对模型输出的边际贡献,再从全局和局部两个层面对“黑盒模型”进行解释。SHAP构建一个加性的解释模型,所有的特征都视为“贡献者”

SHAP的全称是SHapley Additive exPlanation,SHAP是由Shapley value启发的可加性解释模型。而Shapley value起源于合作博弈论,那什么是合作博弈呢。比如说甲乙丙丁四个工人一起打工,甲和乙完成了价值100元的工件,甲、乙、丙完成了价值120元的工件,乙、丙、丁完成了价值150元的工件,甲、丁完成了价值90元的工件,那么该如何公平、合理地分配这四个人的工钱呢?Shapley提出了一个合理的计算方法,我们称每个参与者分配到的数额为Shapley value

结合文章一开始提到的预测任务,我认为就是已知条件(Age=65,Sex=F,BP=180,BMI=40)一起完成了预测结果(Output=0.4),那么该如何公平、合理地分配这四个已知条件对预测结果的贡献呢?此时SHAP模型就会给这四个已知条件都分配一个Shapley value值,根据这个值我们就可以很好的进行理解

SHAP可以具体解决的任务

  • 调试模型用
  • 指导工程师做特征工程
  • 指导数据采集的方向
  • 指导人们做决策
  • 建立模型和人之间的信任

这一部分在https://yyqing.me/post/2018/2018-09-25-kaggle-model-insights这个网站里讲的很详细

SHAP库可用的explainers

在SHAP中进行模型解释需要先创建一个explainer,SHAP支持很多类型的explainer

  • deep:用于计算深度学习模型,基于DeepLIFT算法,支持TensorFlow / Keras。
  • gradient:用于深度学习模型,综合了SHAP、集成梯度、和SmoothGrad等思想,形成单一期望值方程,但速度比DeepExplainer慢,并且做出了不同的假设。 此方法基于Integrated Gradient归因方法,并支持TensorFlow / Keras / PyTorch。
  • kernel:模型无关,适用于任何模型
  • linear:适用于特征独立不相关的线性模型
  • tree:适用于树模型和基于树模型的集成算法,如XGBoost,LightGBM或CatBoost

实验

在网上找了几个相关的实验跑一下加深印象,SHAP模型输出的可视化图真的是挺美观的

SHAP模型:可解释机器学习模型_第4张图片

 SHAP模型:可解释机器学习模型_第5张图片

感觉要把SHAP运用理解透,首先对于机器学习的一些模型需要运用的比较熟练

网上找的的几个代码都是回归类的问题,所以了解的也比较浅显,以后遇到其他问题也可以尝试用SHAP对模型进行解释看看(网上对SHAP解释深度学习模型的例子不算很多,比如有看到一个CV方向的例子,通过SHAP来解释深度学习模型每一层网络对最终检查结果的影响情况,以后可以尝试一下这个方面),然后不断补充这里的实验部分

问题1:

足球运动员身价估计

每个足球运动员在转会市场都有各自的价码,这个问题的目的是根据球员的各项信息和能力值来预测该球员的市场价值。

问题2:

波士顿房价估计

通过数据挖掘对影响波士顿房价的因素进行分析

具体代码:

https://colab.research.google.com/drive/1V6XUWCbR7cPKXfdCJXKC1LjNHlg0GpFE?usp=sharing

SHAP导出各种格式的图

注意画图前需要加:

shap.initjs()

不加就会报错:

Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

如果想要在论文中添加SHAP模型输出的图,只是在代码最后加上plt.savefig,保存下来的图像为空白图。python的shap库,底层仍然使用matplotlib,根据源码可以发现在调用shap.xxx_plot时可以选择传递一个参数就可以正常使用plt.savefig了,几种不同的shap.xxx_plot需要传递的参数可能是不一样的,下面就举几个例子,一般情况下加上matplotlib=True,show = False这两个参数的比较多

除了下面几个例子,瀑布图直接调用plt.savefig即可,shap.decision_plot则需要传入参数return_objects=True

shap.force_plot(explainer.expected_value, shap_values[j], data[cols].iloc[j],matplotlib=True,show = False)
plt.savefig('./result_300dpi.jpg', bbox_inches='tight', dpi=300)
plt.savefig('./result_150dpi.jpg', bbox_inches='tight', dpi=150)
plt.savefig('./result_300dpi.png', bbox_inches='tight', dpi=300)
plt.savefig('./result_150dpi.png', bbox_inches='tight', dpi=150)
plt.savefig('./result_300dpi.tiff', bbox_inches='tight', dpi=300)
plt.savefig('./result_150dpi.tiff', bbox_inches='tight', dpi=150)
plt.savefig('./result_300dpi.svg', bbox_inches='tight', dpi=300)
plt.savefig('./result_150dpi.svg', bbox_inches='tight', dpi=150)
shap.summary_plot(shap_values, data[cols],show = False)
plt.savefig('./result_300dpi.jpg', bbox_inches='tight', dpi=300)
plt.savefig('./result_150dpi.jpg', bbox_inches='tight', dpi=150)
plt.savefig('./result_300dpi.png', bbox_inches='tight', dpi=300)
plt.savefig('./result_150dpi.png', bbox_inches='tight', dpi=150)
plt.savefig('./result_300dpi.tiff', bbox_inches='tight', dpi=300)
plt.savefig('./result_150dpi.tiff', bbox_inches='tight', dpi=150)
plt.savefig('./result_300dpi.svg', bbox_inches='tight', dpi=300)
plt.savefig('./result_150dpi.svg', bbox_inches='tight', dpi=150)
shap.dependence_plot('age', shap_values, data[cols], interaction_index=None, show=False)
plt.savefig('./result_300dpi.jpg', bbox_inches='tight', dpi=300)
plt.savefig('./result_150dpi.jpg', bbox_inches='tight', dpi=150)
plt.savefig('./result_300dpi.png', bbox_inches='tight', dpi=300)
plt.savefig('./result_150dpi.png', bbox_inches='tight', dpi=150)
plt.savefig('./result_300dpi.tiff', bbox_inches='tight', dpi=300)
plt.savefig('./result_150dpi.tiff', bbox_inches='tight', dpi=150)
plt.savefig('./result_300dpi.svg', bbox_inches='tight', dpi=300)
plt.savefig('./result_150dpi.svg', bbox_inches='tight', dpi=150)

References 

https://www.jianshu.com/p/324a7c982034

https://zhuanlan.zhihu.com/p/64799119

https://zhuanlan.zhihu.com/p/83412330

你可能感兴趣的:(机器学习,python,机器学习)