Skip to content

Commit 6281868

Browse files
authoredMar 5, 2025··
Merge: Vectorize fuzzy_row_match (#489)
- use vectorized operations instead of the for loop - fixed column validations - I tested that the result of the new version is always exactly equal to the old version - added some basic pytests for the utility - related to #344 Here a resulting test looking at the speedup: <img width="815" alt="image" src="https://github.com/user-attachments/assets/094bd96c-1e0f-4c4b-a10e-fdd5d680eb16" /> - speedup for the most realistic cases (`left_df` large versus `right_df`) approaches 4x from above - for less relevant cases (`left_df` and `right_df` comparable in size or overall very small) the speedup can even be 40x
2 parents 4ba501d + 1d0d922 commit 6281868

File tree

4 files changed

+229
-43
lines changed

4 files changed

+229
-43
lines changed
 

‎CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616
`supports_batching` and `supports_pending_experiments`
1717
- `SHAPInsight` now allows explanation input that has additional columns compared to
1818
the background data (will be ignored)
19+
- `fuzzy_row_match` now uses vectorized operations, resulting in a speedup of matching
20+
measurements to the search space between 4x and 40x
1921

2022
### Fixed
2123
- Incorrect optimization direction with `PSTD` with a single minimization target

‎baybe/exceptions.py

+22
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
"""Custom exceptions and warnings."""
22

3+
import pandas as pd
4+
from attr.validators import instance_of
5+
from attrs import define, field
6+
from typing_extensions import override
37

48
##### Warnings #####
59

@@ -11,6 +15,24 @@ class UnusedObjectWarning(UserWarning):
1115
"""
1216

1317

18+
@define
19+
class SearchSpaceMatchWarning(UserWarning):
20+
"""
21+
When trying to match data to entries in the search space, something unexpected
22+
happened.
23+
"""
24+
25+
message: str = field(validator=instance_of(str))
26+
data: pd.DataFrame = field(validator=instance_of(pd.DataFrame))
27+
28+
def __attrs_pre_init(self):
29+
super().__init__(self.message)
30+
31+
@override
32+
def __str__(self):
33+
return self.message
34+
35+
1436
##### Exceptions #####
1537

1638

‎baybe/utils/dataframe.py

+75-42
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111
import pandas as pd
1212

13+
from baybe.exceptions import SearchSpaceMatchWarning
1314
from baybe.targets.base import Target
1415
from baybe.targets.binary import BinaryTarget
1516
from baybe.targets.enum import TargetMode
@@ -462,7 +463,7 @@ def fuzzy_row_match(
462463
right_df: pd.DataFrame,
463464
parameters: Sequence[Parameter],
464465
) -> pd.Index:
465-
"""Match row of the right dataframe to the rows of the left dataframe.
466+
"""Match rows of the right dataframe to rows of the left dataframe.
466467
467468
This is useful for matching measurements to entries in the search space, e.g. to
468469
detect which ones have been measured. For categorical parameters, there needs to be
@@ -476,57 +477,89 @@ def fuzzy_row_match(
476477
477478
Args:
478479
left_df: The data that serves as lookup reference.
479-
right_df: The data that should be checked for matching rows in the left
480-
dataframe.
481-
parameters: List of baybe parameter objects that are needed to identify
482-
potential tolerances.
480+
right_df: The data that is checked for matching rows in the left dataframe.
481+
parameters: Parameter objects that identify the relevant column names and how
482+
matching is performed.
483483
484484
Returns:
485485
The index of the matching rows in ``left_df``.
486486
487487
Raises:
488-
ValueError: If some rows are present in the right but not in the left dataframe.
488+
ValueError: If either ``left_df`` or ``right_df`` does not contain columns for
489+
each entry in parameters.
489490
"""
490-
# Assert that all parameters appear in the given dataframe
491-
if not all(col in right_df.columns for col in left_df.columns):
491+
# Separate columns types
492+
cat_cols = {p.name for p in parameters if (not p.is_numerical and p.is_discrete)}
493+
num_cols = {p.name for p in parameters if (p.is_numerical and p.is_discrete)}
494+
non_discrete_cols = {p.name for p in parameters if not p.is_discrete}
495+
496+
# Assert that all parameters appear in the given dataframes
497+
if diff := (cat_cols | num_cols).difference(left_df.columns):
498+
raise ValueError(
499+
f"For fuzzy row matching, all discrete parameters need to have a "
500+
f"corresponding column in the left dataframe. Parameters not found: {diff})"
501+
)
502+
if diff := (cat_cols | num_cols).difference(right_df.columns):
492503
raise ValueError(
493-
"For fuzzy row matching all rows of the right dataframe need to be present"
494-
" in the left dataframe."
504+
f"For fuzzy row matching, all discrete parameters need to have a "
505+
f"corresponding column in the right dataframe. Parameters not found: "
506+
f"{diff})"
495507
)
496508

497-
# Iterate over all input rows
498-
inds_matched = []
499-
for ind, row in right_df.iterrows():
500-
# Differentiate category-like and discrete numerical parameters
501-
cat_cols = [p.name for p in parameters if not p.is_numerical]
502-
num_cols = [p.name for p in parameters if (p.is_numerical and p.is_discrete)]
503-
504-
# Discrete parameters must match exactly
505-
match = left_df[cat_cols].eq(row[cat_cols]).all(axis=1, skipna=False)
506-
507-
# For numeric parameters, match the entry with the smallest deviation
508-
for col in num_cols:
509-
abs_diff = (left_df[col] - row[col]).abs()
510-
match &= abs_diff == abs_diff.min()
511-
512-
# We expect exactly one match. If that's not the case, print a warning.
513-
inds_found = left_df.index[match].to_list()
514-
if len(inds_found) == 0 and len(num_cols) > 0:
515-
warnings.warn(
516-
f"Input row with index {ind} could not be matched to the search space. "
517-
f"This could indicate that something went wrong."
518-
)
519-
elif len(inds_found) > 1:
520-
warnings.warn(
521-
f"Input row with index {ind} has multiple matches with the search "
522-
f"space. This could indicate that something went wrong. Matching only "
523-
f"first occurrence."
524-
)
525-
inds_matched.append(inds_found[0])
526-
else:
527-
inds_matched.extend(inds_found)
509+
provided_cols = {p.name for p in parameters}
510+
allowed_cols = cat_cols | num_cols | non_discrete_cols
511+
assert allowed_cols == provided_cols, (
512+
f"There are parameter types that would be silently ignored: "
513+
f"{provided_cols.difference(allowed_cols)}"
514+
)
515+
516+
# Initialize the match matrix. We will later filter it down using other
517+
# matrices (representing the matches for individual parameters) via logical 'and'.
518+
match_matrix = pd.DataFrame(
519+
True, index=right_df.index, columns=left_df.index, dtype=bool
520+
)
521+
522+
# Match categorical parameters
523+
for col in cat_cols:
524+
# Per categorical parameter, this identifies matches between all elements of
525+
# left and right and stores them in a matrix.
526+
match_matrix &= right_df[col].values[:, None] == left_df[col].values[None, :]
527+
528+
# Match numerical parameters
529+
for col in num_cols:
530+
# Per numerical parameter, this identifies the rows with the smallest absolute
531+
# difference and records them in a matrix.
532+
abs_diff = np.abs(right_df[col].values[:, None] - left_df[col].values[None, :])
533+
min_diff = abs_diff.min(axis=1, keepdims=True)
534+
match_matrix &= abs_diff == min_diff
535+
536+
# Find the matching indices. If a right row is not matched to any of the rows in
537+
# left, idxmax would return the first index of left_df. Hence, we remember these
538+
# cases and drop them explicitly.
539+
matched_indices = pd.Index(match_matrix.idxmax(axis=1).values)
540+
mask_no_match = ~match_matrix.any(axis=1)
541+
matched_indices = matched_indices[~mask_no_match]
542+
543+
# Warn if there are multiple or no matches
544+
if no_match_indices := right_df.index[mask_no_match].tolist():
545+
w = SearchSpaceMatchWarning(
546+
f"Some input rows could not be matched to the search space. Indices with "
547+
f"no matches: {no_match_indices}",
548+
right_df.loc[no_match_indices],
549+
)
550+
warnings.warn(w)
551+
552+
mask_multiple_matches = match_matrix.sum(axis=1) > 1
553+
if multiple_match_indices := right_df.index[mask_multiple_matches].tolist():
554+
w = SearchSpaceMatchWarning(
555+
f"Some input rows have multiple matches with the search space. "
556+
f"Matching only first occurrence for these rows. Indices with multiple "
557+
f"matches: {multiple_match_indices}",
558+
right_df.loc[multiple_match_indices],
559+
)
560+
warnings.warn(w)
528561

529-
return pd.Index(inds_matched)
562+
return matched_indices
530563

531564

532565
def pretty_print_df(

‎tests/utils/test_dataframe.py

+130-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
11
"""Tests for dataframe utilities."""
22

3+
from contextlib import nullcontext
4+
35
import numpy as np
46
import pandas as pd
57
import pytest
8+
from pandas.testing import assert_frame_equal
9+
from pytest import param
10+
11+
from baybe.exceptions import SearchSpaceMatchWarning
12+
from baybe.utils.dataframe import (
13+
add_noise_to_perturb_degenerate_rows,
14+
add_parameter_noise,
15+
fuzzy_row_match,
16+
)
617

7-
from baybe.utils.dataframe import add_noise_to_perturb_degenerate_rows
18+
19+
@pytest.fixture()
20+
def n_grid_points():
21+
return 5
822

923

1024
def test_degenerate_rows():
@@ -41,3 +55,118 @@ def test_degenerate_rows_invalid_input():
4155
# Add noise
4256
with pytest.raises(TypeError):
4357
add_noise_to_perturb_degenerate_rows(df)
58+
59+
60+
@pytest.mark.parametrize(
61+
("parameter_names", "noise", "duplicated"),
62+
[
63+
param(
64+
["Categorical_1", "Num_disc_1", "Some_Setting"],
65+
False,
66+
True,
67+
id="discrete_num_noiseless_duplicated",
68+
),
69+
param(
70+
["Categorical_1", "Num_disc_1", "Some_Setting"],
71+
False,
72+
False,
73+
id="discrete_num_noiseless_unique",
74+
),
75+
param(
76+
["Categorical_1", "Num_disc_1", "Some_Setting"],
77+
True,
78+
False,
79+
id="discrete_num_noisy_unique",
80+
),
81+
param(
82+
["Categorical_1", "Switch_1", "Some_Setting"],
83+
False,
84+
False,
85+
id="discrete_cat",
86+
),
87+
param(
88+
["Categorical_1", "Switch_1", "Conti_finite_1"],
89+
False,
90+
False,
91+
id="hybrid_cat",
92+
),
93+
param(
94+
["Categorical_1", "Num_disc_1", "Conti_finite_1"],
95+
False,
96+
False,
97+
id="hybrid_num_noiseless_unique",
98+
),
99+
param(
100+
["Categorical_1", "Num_disc_1", "Conti_finite_1"],
101+
True,
102+
False,
103+
id="hybrid_num_noisy_unique",
104+
),
105+
param(
106+
["Categorical_1", "Num_disc_1", "Conti_finite_1"],
107+
False,
108+
True,
109+
id="hybrid_num_noiseless_duplicated",
110+
),
111+
],
112+
)
113+
def test_fuzzy_row_match(searchspace, noise, duplicated):
114+
"""Fuzzy row matching returns expected indices."""
115+
left_df = searchspace.discrete.exp_rep.copy()
116+
selected = np.random.choice(left_df.index, 4, replace=False)
117+
right_df = left_df.loc[selected].reset_index(drop=True)
118+
119+
context = nullcontext()
120+
if duplicated:
121+
# Set one of the input values to exactly the midpoint between two values to
122+
# cause a degenerate match
123+
vals = searchspace.get_parameters_by_name(["Num_disc_1"])[0].values
124+
right_df.loc[0, "Num_disc_1"] = vals[0] + (vals[1] - vals[0]) / 2.0
125+
context = pytest.warns(SearchSpaceMatchWarning, match="multiple matches")
126+
127+
if noise:
128+
add_parameter_noise(
129+
right_df,
130+
searchspace.discrete.parameters,
131+
noise_type="relative_percent",
132+
noise_level=0.1,
133+
)
134+
135+
with context as c:
136+
matched = fuzzy_row_match(left_df, right_df, searchspace.parameters)
137+
138+
if duplicated:
139+
# Assert correct identification of problematic df parts
140+
w = next(x for x in c if isinstance(x.message, SearchSpaceMatchWarning)).message
141+
assert_frame_equal(right_df.loc[[0]], w.data)
142+
143+
# Ignore problematic indices for subsequent equality check
144+
selected = selected[1:]
145+
matched = matched[1:]
146+
147+
assert set(selected) == set(matched), (selected, matched)
148+
149+
150+
@pytest.mark.parametrize(
151+
"parameter_names",
152+
[
153+
param(["Categorical_1", "Categorical_2", "Switch_1"], id="discrete"),
154+
param(["Categorical_1", "Num_disc_1", "Conti_finite1"], id="hybrid"),
155+
],
156+
)
157+
@pytest.mark.parametrize("invalid", ["left", "right"])
158+
def test_invalid_fuzzy_row_match(searchspace, invalid):
159+
"""Returns expected errors when dataframes don't contain all expected columns."""
160+
left_df = searchspace.discrete.exp_rep.copy()
161+
selected = np.random.choice(left_df.index, 4, replace=False)
162+
right_df = left_df.loc[selected].copy()
163+
164+
# Drop first column
165+
if invalid == "left":
166+
left_df = left_df.iloc[:, 1:]
167+
else:
168+
right_df = right_df.iloc[:, 1:]
169+
170+
match = f"corresponding column in the {invalid} dataframe."
171+
with pytest.raises(ValueError, match=match):
172+
fuzzy_row_match(left_df, right_df, searchspace.parameters)

0 commit comments

Comments
 (0)
Please sign in to comment.