SHAP有两个核心,分别是shap values和shap interaction values,在官方的应用中,主要有三种,分别是force plot、summary plot和dependence plot,这三种应用都是对shap values和shap interaction values进行处理后得到的。
def waterfall(shap_values, max_display=10, show=True):
""" Plots an explantion of a single prediction as a waterfall plot.
The SHAP value of a feature represents the impact of the evidence provided by that feature on the model's
output. The waterfall plot is designed to visually display how the SHAP values (evidence) of each feature
move the model output from our prior expectation under the background data distribution, to the final model
prediction given the evidence of all the features. Features are sorted by the magnitude of their SHAP values
with the smallest magnitude features grouped together at the bottom of the plot when the number of features
in the models exceeds the max_display parameter.
shap_values : Explanation
A one-dimensional Explanation object that contains the feature values and SHAP values to plot.
max_display : str
The maximum number of features to plot.
show : bool
Whether is called before returning. Setting this to False allows the plot
to be customized further after it has been created.
shap.summary_plot(shap_values, features=None, feature_names=None, max_display=None, plot_type=None, color=None, axis_color='#333333', title=None, alpha=1, show=True, sort=True, color_bar=True, plot_size='auto', layered_violin_max_num_bins=20, class_names=None, class_inds=None, color_bar_label='Feature value', cmap=<matplotlib.colors.LinearSegmentedColormap object>, auto_size_plot=None, use_log_scale=False)
Create a SHAP beeswarm plot, colored by feature values when they are provided.
For single output explanations this is a matrix of SHAP values (# samples x # features). For multi-output explanations this is a list of such matrices of SHAP values.
featuresnumpy.array or pandas.DataFrame or list
Matrix of feature values (# samples x # features) or a feature_names list as shorthand
Names of the features (length # features)
How many top features to include in the plot (default is 20, or 7 for interaction plots)
plot_type“dot” (default for single output), “bar” (default formulti-output), “violin”,
or “compact_dot”. What type of summary plot to produce. Note that “compact_dot” is only used for SHAP interaction values.
plot_size“auto” (default), float, (float, float), or None
What size to make the plot. By default the size is auto-scaled based on the number of features that are being displayed. Passing a single float will cause each row to be that many inches high. Passing a pair of floats will scale the plot by that number of inches. If None is passed then the size of the current figure will be left unchanged.
shap.force_plot(base_value, shap_values=None, features=None, feature_names=None, out_names=None, link='identity', plot_cmap='RdBu', matplotlib=False, show=True, figsize=20, 3, ordering_keys=None, ordering_keys_time_format=None, text_rotation=0)
Visualize the given SHAP values with an additive force layout.
This is the reference value that the feature contributions start from. For SHAP values it should be the value of explainer.expected_value.
Matrix of SHAP values (# features) or (# samples x # features). If this is a 1D array then a single force plot will be drawn, if it is a 2D array then a stacked force plot will be drawn.
Matrix of feature values (# features) or (# samples x # features). This provides the values of all the features, and should be the same shape as the shap_values argument.
List of feature names (# features).
The name of the output of the model (plural to support multi-output plotting in the future).
link“identity” or “logit”
The transformation used when drawing the tick mark labels. Using logit will change log-odds numbers into probabilities.
Whether to use the default Javascript output, or the (less developed) matplotlib output. Using matplotlib can be helpful in scenarios where rendering Javascript/HTML is inconvenient.
def shap_values(self, X, ranked_outputs=None, output_rank_order='max', check_additivity=True):
Return approximate SHAP values for the model applied to the data given by X.
X : list,
if framework == 'tensorflow': numpy.array, or pandas.DataFrame
if framework == 'pytorch': torch.tensor
A tensor (or list of tensors) of samples (where X.shape[0] == # samples) on which to
explain the model's output.
ranked_outputs : None or int
If ranked_outputs is None then we explain all the outputs in a multi-output model. If
ranked_outputs is a positive integer then we only explain that many of the top model
outputs (where "top" is determined by output_rank_order). Note that this causes a pair
of values to be returned (shap_values, indexes), where shap_values is a list of numpy
arrays for each of the output ranks, and indexes is a matrix that indicates for each sample
which output indexes were choses as "top".
output_rank_order : "max", "min", or "max_abs"
How to order the model outputs when using ranked_outputs, either by maximum, minimum, or
maximum absolute value.
array or list
For a models with a single output this returns a tensor of SHAP values with the same shape
as X. For a model with multiple outputs this returns a list of SHAP value tensors, each of
which are the same shape as X. If ranked_outputs is None then this list of tensors matches
the number of model outputs. If ranked_outputs is a positive integer a pair is returned
(shap_values, indexes), where shap_values is a list of tensors with a length of
ranked_outputs, and indexes is a matrix that indicates for each sample which output indexes
were chosen as "top".
return self.explainer.shap_values(X, ranked_outputs, output_rank_order, check_additivity=check_additivity)
shap.summary_plot(shap_values, X, plot_type=“bar”)
summary plot 为每个样本绘制其每个特征的SHAP值,这可以更好地理解整体模式,并允许发现预测异常值。每一行代表一个特征,横坐标为SHAP值。一个点代表一个样本,颜色表示特征值(红色高,蓝色低)。比如,这张图表明LSTAT特征较高的取值会降低预测的房价
下图显示了服用激素避孕药的年数(Hormonal.Contraceptives…years.)的SHAP依赖图,该图包含所有的点 :
SHAP依赖图可以替代部分依赖图(Partial Dependence Plot)和累积局部效应图(Accumulated Local Effects Plot)。SHAP依赖图也显示y轴上的方差,这是因为有其他特征的相互作用,所以依赖图在y轴上会分散。通过显示这些特性交互,可以改进依赖图。
shap.dependence_plot('age', shap_values, data[cols], interaction_index=None, show=False)
shap_interaction_values = shap.TreeExplainer(model).shap_interaction_values(data[cols])
shap.summary_plot(shap_interaction_values, data[cols], max_display=4)
shap.dependence_plot('potential', shap_values, data[cols], interaction_index='international_reputation', show=False)