10
10
import numpy as np
11
11
import pandas as pd
12
12
13
+ from baybe .exceptions import SearchSpaceMatchWarning
13
14
from baybe .targets .base import Target
14
15
from baybe .targets .binary import BinaryTarget
15
16
from baybe .targets .enum import TargetMode
@@ -462,7 +463,7 @@ def fuzzy_row_match(
462
463
right_df : pd .DataFrame ,
463
464
parameters : Sequence [Parameter ],
464
465
) -> pd .Index :
465
- """Match row of the right dataframe to the rows of the left dataframe.
466
+ """Match rows of the right dataframe to rows of the left dataframe.
466
467
467
468
This is useful for matching measurements to entries in the search space, e.g. to
468
469
detect which ones have been measured. For categorical parameters, there needs to be
@@ -476,57 +477,89 @@ def fuzzy_row_match(
476
477
477
478
Args:
478
479
left_df: The data that serves as lookup reference.
479
- right_df: The data that should be checked for matching rows in the left
480
- dataframe.
481
- parameters: List of baybe parameter objects that are needed to identify
482
- potential tolerances.
480
+ right_df: The data that is checked for matching rows in the left dataframe.
481
+ parameters: Parameter objects that identify the relevant column names and how
482
+ matching is performed.
483
483
484
484
Returns:
485
485
The index of the matching rows in ``left_df``.
486
486
487
487
Raises:
488
- ValueError: If some rows are present in the right but not in the left dataframe.
488
+ ValueError: If either ``left_df`` or ``right_df`` does not contain columns for
489
+ each entry in parameters.
489
490
"""
490
- # Assert that all parameters appear in the given dataframe
491
- if not all (col in right_df .columns for col in left_df .columns ):
491
+ # Separate columns types
492
+ cat_cols = {p .name for p in parameters if (not p .is_numerical and p .is_discrete )}
493
+ num_cols = {p .name for p in parameters if (p .is_numerical and p .is_discrete )}
494
+ non_discrete_cols = {p .name for p in parameters if not p .is_discrete }
495
+
496
+ # Assert that all parameters appear in the given dataframes
497
+ if diff := (cat_cols | num_cols ).difference (left_df .columns ):
498
+ raise ValueError (
499
+ f"For fuzzy row matching, all discrete parameters need to have a "
500
+ f"corresponding column in the left dataframe. Parameters not found: { diff } )"
501
+ )
502
+ if diff := (cat_cols | num_cols ).difference (right_df .columns ):
492
503
raise ValueError (
493
- "For fuzzy row matching all rows of the right dataframe need to be present"
494
- " in the left dataframe."
504
+ f"For fuzzy row matching, all discrete parameters need to have a "
505
+ f"corresponding column in the right dataframe. Parameters not found: "
506
+ f"{ diff } )"
495
507
)
496
508
497
- # Iterate over all input rows
498
- inds_matched = []
499
- for ind , row in right_df .iterrows ():
500
- # Differentiate category-like and discrete numerical parameters
501
- cat_cols = [p .name for p in parameters if not p .is_numerical ]
502
- num_cols = [p .name for p in parameters if (p .is_numerical and p .is_discrete )]
503
-
504
- # Discrete parameters must match exactly
505
- match = left_df [cat_cols ].eq (row [cat_cols ]).all (axis = 1 , skipna = False )
506
-
507
- # For numeric parameters, match the entry with the smallest deviation
508
- for col in num_cols :
509
- abs_diff = (left_df [col ] - row [col ]).abs ()
510
- match &= abs_diff == abs_diff .min ()
511
-
512
- # We expect exactly one match. If that's not the case, print a warning.
513
- inds_found = left_df .index [match ].to_list ()
514
- if len (inds_found ) == 0 and len (num_cols ) > 0 :
515
- warnings .warn (
516
- f"Input row with index { ind } could not be matched to the search space. "
517
- f"This could indicate that something went wrong."
518
- )
519
- elif len (inds_found ) > 1 :
520
- warnings .warn (
521
- f"Input row with index { ind } has multiple matches with the search "
522
- f"space. This could indicate that something went wrong. Matching only "
523
- f"first occurrence."
524
- )
525
- inds_matched .append (inds_found [0 ])
526
- else :
527
- inds_matched .extend (inds_found )
509
+ provided_cols = {p .name for p in parameters }
510
+ allowed_cols = cat_cols | num_cols | non_discrete_cols
511
+ assert allowed_cols == provided_cols , (
512
+ f"There are parameter types that would be silently ignored: "
513
+ f"{ provided_cols .difference (allowed_cols )} "
514
+ )
515
+
516
+ # Initialize the match matrix. We will later filter it down using other
517
+ # matrices (representing the matches for individual parameters) via logical 'and'.
518
+ match_matrix = pd .DataFrame (
519
+ True , index = right_df .index , columns = left_df .index , dtype = bool
520
+ )
521
+
522
+ # Match categorical parameters
523
+ for col in cat_cols :
524
+ # Per categorical parameter, this identifies matches between all elements of
525
+ # left and right and stores them in a matrix.
526
+ match_matrix &= right_df [col ].values [:, None ] == left_df [col ].values [None , :]
527
+
528
+ # Match numerical parameters
529
+ for col in num_cols :
530
+ # Per numerical parameter, this identifies the rows with the smallest absolute
531
+ # difference and records them in a matrix.
532
+ abs_diff = np .abs (right_df [col ].values [:, None ] - left_df [col ].values [None , :])
533
+ min_diff = abs_diff .min (axis = 1 , keepdims = True )
534
+ match_matrix &= abs_diff == min_diff
535
+
536
+ # Find the matching indices. If a right row is not matched to any of the rows in
537
+ # left, idxmax would return the first index of left_df. Hence, we remember these
538
+ # cases and drop them explicitly.
539
+ matched_indices = pd .Index (match_matrix .idxmax (axis = 1 ).values )
540
+ mask_no_match = ~ match_matrix .any (axis = 1 )
541
+ matched_indices = matched_indices [~ mask_no_match ]
542
+
543
+ # Warn if there are multiple or no matches
544
+ if no_match_indices := right_df .index [mask_no_match ].tolist ():
545
+ w = SearchSpaceMatchWarning (
546
+ f"Some input rows could not be matched to the search space. Indices with "
547
+ f"no matches: { no_match_indices } " ,
548
+ right_df .loc [no_match_indices ],
549
+ )
550
+ warnings .warn (w )
551
+
552
+ mask_multiple_matches = match_matrix .sum (axis = 1 ) > 1
553
+ if multiple_match_indices := right_df .index [mask_multiple_matches ].tolist ():
554
+ w = SearchSpaceMatchWarning (
555
+ f"Some input rows have multiple matches with the search space. "
556
+ f"Matching only first occurrence for these rows. Indices with multiple "
557
+ f"matches: { multiple_match_indices } " ,
558
+ right_df .loc [multiple_match_indices ],
559
+ )
560
+ warnings .warn (w )
528
561
529
- return pd . Index ( inds_matched )
562
+ return matched_indices
530
563
531
564
532
565
def pretty_print_df (
0 commit comments