Skip to content

Commit fdc1e81

Browse files
committed
Adjust acqf objective logic
1 parent bcfc0b0 commit fdc1e81

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

baybe/acquisition/base.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -115,22 +115,29 @@ def to_botorch(
115115
# Add acquisition objective / best observed value
116116
match objective:
117117
case SingleTargetObjective(NumericalTarget(mode=TargetMode.MIN)):
118+
# Adjust best_f
118119
if "best_f" in signature_params:
119120
additional_params["best_f"] = (
120121
bo_surrogate.posterior(train_x).mean.min().item()
121122
)
122123
if issubclass(acqf_cls, bo_acqf.MCAcquisitionFunction):
123124
additional_params["best_f"] *= -1.0
124125

126+
# Adjust objective
125127
if issubclass(
126-
acqf_cls, bo_acqf.AnalyticAcquisitionFunction
127-
) and not issubclass(acqf_cls, bo_acqf.PosteriorStandardDeviation):
128-
# Minimize acqfs in case the target should be minimized. PSTD is
129-
# exempt as the direction does not depend on the target type.
130-
additional_params["maximize"] = False
131-
elif issubclass(acqf_cls, bo_acqf.qNegIntegratedPosteriorVariance):
132-
# qNIPV is valid but does not require any adjusted params
128+
acqf_cls,
129+
(
130+
bo_acqf.qNegIntegratedPosteriorVariance,
131+
bo_acqf.PosteriorStandardDeviation,
132+
bo_acqf.qPosteriorStandardDeviation,
133+
),
134+
):
135+
# The active learning acqfs are valid but no changes based on the
136+
# target direction are required.
133137
pass
138+
elif issubclass(acqf_cls, bo_acqf.AnalyticAcquisitionFunction):
139+
# Minimize acqfs in case the target should be minimized.
140+
additional_params["maximize"] = False
134141
elif issubclass(acqf_cls, bo_acqf.MCAcquisitionFunction):
135142
additional_params["objective"] = LinearMCObjective(
136143
torch.tensor([-1.0])

0 commit comments

Comments
 (0)