Skip to content

Commit 821d7b1

Browse files
authored
Fix wrong desirabilities in predictive strategies (#506)
* fix bug * fix tests * now it runs ;) * rename it
1 parent 9a5f76c commit 821d7b1

File tree

3 files changed

+48
-4
lines changed

3 files changed

+48
-4
lines changed

bofire/data_models/domain/features.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -783,24 +783,39 @@ def get_keys_by_objective(
783783
def __call__(
784784
self,
785785
experiments: pd.DataFrame,
786+
experiments_adapt: Optional[pd.DataFrame] = None,
786787
predictions: bool = False,
787788
) -> pd.DataFrame:
788789
"""Evaluate the objective for every feature.
789790
790791
Args:
791792
experiments (pd.DataFrame): Experiments for which the objectives should be evaluated.
793+
experiments_adapt (pd.DataFrame, optional): Experimental values which are used to update the objective
794+
parameters on the fly. This is for example needed when a `MovingMaximizeSigmoidObjective` is used
795+
as this depends on the best experimental value achieved so far. For this reason `experiments_adapt`
796+
has to be provided if `predictions=True` ie. that the objectives of candidates are evaluated.
797+
Defaults to None.
792798
predictions (bool, optional): If True use the prediction columns in the dataframe to calc the
793-
desirabilities `f"{feat.key}_pred`.
799+
desirabilities `f"{feat.key}_pred`, furthermore `experiments_adapt` has to be provided.
794800
795801
Returns:
796802
pd.DataFrame: Objective values for the experiments of interest.
797803
798804
"""
805+
if predictions and experiments_adapt is None:
806+
raise ValueError(
807+
"If predictions are used, `experiments_adapt` has to be provided.",
808+
)
809+
else:
810+
experiments_adapt = (
811+
experiments if experiments_adapt is None else experiments_adapt
812+
)
813+
799814
desis = pd.concat(
800815
[
801816
feat(
802817
experiments[f"{feat.key}_pred" if predictions else feat.key],
803-
experiments[f"{feat.key}_pred" if predictions else feat.key],
818+
experiments_adapt[feat.key].dropna(),
804819
)
805820
for feat in self.features
806821
if feat.objective is not None

bofire/strategies/predictives/predictive.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,10 @@ def predict(self, experiments: pd.DataFrame) -> pd.DataFrame:
156156
predictions=predictions,
157157
outputs=self.domain.outputs,
158158
)
159-
desis = self.domain.outputs(predictions, predictions=True)
160-
predictions = pd.concat((predictions, desis), axis=1)
159+
objectives = self.domain.outputs(
160+
predictions, experiments_adapt=self.experiments, predictions=True
161+
)
162+
predictions = pd.concat((predictions, objectives), axis=1)
161163
predictions.index = experiments.index
162164
return predictions
163165

tests/bofire/data_models/domain/test_outputs.py

+27
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
ConstrainedObjective,
1313
MaximizeObjective,
1414
MaximizeSigmoidObjective,
15+
MovingMaximizeSigmoidObjective,
1516
Objective,
1617
TargetObjective,
1718
)
@@ -300,6 +301,32 @@ def test_outputs_call(features, samples):
300301
]
301302

302303

304+
def test_outputs_call_adapt_experiment():
305+
outputs = Outputs(
306+
features=[
307+
ContinuousOutput(key="of1", objective=MaximizeObjective()),
308+
ContinuousOutput(
309+
key="of2",
310+
objective=MovingMaximizeSigmoidObjective(tp=0, steepness=10, w=1.0),
311+
),
312+
],
313+
)
314+
candidates = pd.DataFrame(
315+
columns=["of1_pred", "of2_pred"], data=[[1.0, 5.0], [2.0, 5.0]]
316+
)
317+
318+
experiments = pd.DataFrame(columns=["of1", "of2"], data=[[1.0, 5.0], [2.0, 6.0]])
319+
320+
with pytest.raises(
321+
ValueError,
322+
match="If predictions are used, `experiments_adapt` has to be provided.",
323+
):
324+
outputs(candidates, predictions=True)
325+
326+
outputs(experiments)
327+
outputs(candidates, experiments_adapt=experiments, predictions=True)
328+
329+
303330
def test_categorical_objective_methods():
304331
obj = ConstrainedCategoricalObjective(
305332
categories=["a", "b"],

0 commit comments

Comments
 (0)