Skip to content

Commit 70b4978

Browse files
authoredFeb 26, 2025
Merge: Minor SHAP improvements (#494)
- added support for the `waterfall` plot type - I've loosened the restriction that the set of `data` and `background_data` columns needs to exactly match. This makes no sense to me as it prohibits e.g. patterns like `insight.explain(campaign.measurements.sample(frac=0.1))` because there are also 3 meta columns present. The return oder is still rearranged to match the one in `data` and additional columns are ignored
2 parents 613bd3a + 59e87a9 commit 70b4978

File tree

3 files changed

+20
-14
lines changed

3 files changed

+20
-14
lines changed
 

‎CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99
- `BCUT2D` encoding for `SubstanceParameter`
1010
- Stored benchmarking results now include the Python environment and version
1111
- `qPSTD` acquisition function
12+
- `SHAPInsight` now supports the `waterfall` plot type
1213

1314
### Changed
1415
- Acquisition function indicator `is_mc` has been removed in favor of new indicators
1516
`supports_batching` and `supports_pending_experiments`
17+
- `SHAPInsight` now allows explanation input that has additional columns compared to
18+
the background data (will be ignored)
1619

1720
### Fixed
1821
- Incorrect optimization direction with `PSTD` with a single minimization target

‎baybe/insights/shap.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
EXPLAINERS = SHAP_EXPLAINERS | NON_SHAP_EXPLAINERS
3838
"""Supported explainer types for :class:`baybe.insights.shap.SHAPInsight`"""
3939

40-
SHAP_PLOTS = {"bar", "beeswarm", "force", "heatmap", "scatter"}
40+
SHAP_PLOTS = {"bar", "beeswarm", "force", "heatmap", "scatter", "waterfall"}
4141
"""Supported plot types for :meth:`baybe.insights.shap.SHAPInsight.plot`"""
4242

4343

@@ -265,15 +265,15 @@ def explain(self, data: pd.DataFrame | None = None, /) -> shap.Explanation:
265265
The computed Shapley explanation.
266266
267267
Raises:
268-
ValueError: If the columns of the given dataframe cannot be aligned with the
269-
columns of the explainer background dataframe.
268+
ValueError: If not all the columns of the explainer background dataframe
269+
are present in the given data.
270270
"""
271271
if data is None:
272272
data = self.background_data
273-
elif set(self.background_data.columns) != set(data.columns):
273+
elif not set(self.background_data.columns).issubset(data.columns):
274274
raise ValueError(
275-
"The provided dataframe must have the same column names as used by "
276-
"the explainer object."
275+
"The provided dataframe must contain all columns that were used for "
276+
"the background data."
277277
)
278278

279279
# Align columns with background data
@@ -302,6 +302,7 @@ def explain(self, data: pd.DataFrame | None = None, /) -> shap.Explanation:
302302
# (`base_values` can be a scalar or vector)
303303
# TODO: https://github.com/shap/shap/issues/3958
304304
idx = self.background_data.columns.get_indexer(data.columns)
305+
idx = idx[idx != -1] # Additional columns in data are ignored.
305306
for attr in ["values", "data", "base_values"]:
306307
try:
307308
setattr(explanations, attr, getattr(explanations, attr)[:, idx])
@@ -327,7 +328,9 @@ def explain(self, data: pd.DataFrame | None = None, /) -> shap.Explanation:
327328

328329
def plot(
329330
self,
330-
plot_type: Literal["bar", "beeswarm", "force", "heatmap", "scatter"],
331+
plot_type: Literal[
332+
"bar", "beeswarm", "force", "heatmap", "scatter", "waterfall"
333+
],
331334
data: pd.DataFrame | None = None,
332335
/,
333336
*,
@@ -367,16 +370,20 @@ def plot(
367370
plot_func = getattr(shap.plots, plot_type)
368371

369372
# Handle plot types that only explain a single data point
370-
if plot_type == "force":
373+
if plot_type in ["force", "waterfall"]:
371374
if explanation_index is None:
372375
warnings.warn(
373376
f"When using plot type '{plot_type}', an 'explanation_index' must "
374377
f"be chosen to identify a single data point that should be "
375378
f"explained. Choosing the first entry at position 0."
376379
)
377380
explanation_index = 0
381+
378382
toplot = self.explain(data.iloc[[explanation_index]])
379-
kwargs["matplotlib"] = True
383+
toplot = toplot[0]
384+
385+
if plot_type == "force":
386+
kwargs["matplotlib"] = True
380387
else:
381388
toplot = self.explain(data)
382389

‎tests/insights/test_shap.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,7 @@ def test_invalid_explained_data(ongoing_campaign, explainer_cls, use_comp_rep):
140140
use_comp_rep=use_comp_rep,
141141
)
142142
df = pd.DataFrame({"Num_disc_1": [0, 2]})
143-
with pytest.raises(
144-
ValueError,
145-
match="The provided dataframe must have the same column names as used by "
146-
"the explainer object.",
147-
):
143+
with pytest.raises(ValueError, match="must contain all columns that were used"):
148144
shap_insight.explain(df)
149145

150146

0 commit comments

Comments
 (0)