37
37
EXPLAINERS = SHAP_EXPLAINERS | NON_SHAP_EXPLAINERS
38
38
"""Supported explainer types for :class:`baybe.insights.shap.SHAPInsight`"""
39
39
40
- SHAP_PLOTS = {"bar" , "beeswarm" , "force" , "heatmap" , "scatter" }
40
+ SHAP_PLOTS = {"bar" , "beeswarm" , "force" , "heatmap" , "scatter" , "waterfall" }
41
41
"""Supported plot types for :meth:`baybe.insights.shap.SHAPInsight.plot`"""
42
42
43
43
@@ -265,15 +265,15 @@ def explain(self, data: pd.DataFrame | None = None, /) -> shap.Explanation:
265
265
The computed Shapley explanation.
266
266
267
267
Raises:
268
- ValueError: If the columns of the given dataframe cannot be aligned with the
269
- columns of the explainer background dataframe .
268
+ ValueError: If not all the columns of the explainer background dataframe
269
+ are present in the given data .
270
270
"""
271
271
if data is None :
272
272
data = self .background_data
273
- elif set (self .background_data .columns ) != set (data .columns ):
273
+ elif not set (self .background_data .columns ). issubset (data .columns ):
274
274
raise ValueError (
275
- "The provided dataframe must have the same column names as used by "
276
- "the explainer object ."
275
+ "The provided dataframe must contain all columns that were used for "
276
+ "the background data ."
277
277
)
278
278
279
279
# Align columns with background data
@@ -302,6 +302,7 @@ def explain(self, data: pd.DataFrame | None = None, /) -> shap.Explanation:
302
302
# (`base_values` can be a scalar or vector)
303
303
# TODO: https://github.com/shap/shap/issues/3958
304
304
idx = self .background_data .columns .get_indexer (data .columns )
305
+ idx = idx [idx != - 1 ] # Additional columns in data are ignored.
305
306
for attr in ["values" , "data" , "base_values" ]:
306
307
try :
307
308
setattr (explanations , attr , getattr (explanations , attr )[:, idx ])
@@ -327,7 +328,9 @@ def explain(self, data: pd.DataFrame | None = None, /) -> shap.Explanation:
327
328
328
329
def plot (
329
330
self ,
330
- plot_type : Literal ["bar" , "beeswarm" , "force" , "heatmap" , "scatter" ],
331
+ plot_type : Literal [
332
+ "bar" , "beeswarm" , "force" , "heatmap" , "scatter" , "waterfall"
333
+ ],
331
334
data : pd .DataFrame | None = None ,
332
335
/ ,
333
336
* ,
@@ -367,16 +370,20 @@ def plot(
367
370
plot_func = getattr (shap .plots , plot_type )
368
371
369
372
# Handle plot types that only explain a single data point
370
- if plot_type == "force" :
373
+ if plot_type in [ "force" , "waterfall" ] :
371
374
if explanation_index is None :
372
375
warnings .warn (
373
376
f"When using plot type '{ plot_type } ', an 'explanation_index' must "
374
377
f"be chosen to identify a single data point that should be "
375
378
f"explained. Choosing the first entry at position 0."
376
379
)
377
380
explanation_index = 0
381
+
378
382
toplot = self .explain (data .iloc [[explanation_index ]])
379
- kwargs ["matplotlib" ] = True
383
+ toplot = toplot [0 ]
384
+
385
+ if plot_type == "force" :
386
+ kwargs ["matplotlib" ] = True
380
387
else :
381
388
toplot = self .explain (data )
382
389
0 commit comments