Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MIN mode via acquisition function #340

Merged
merged 4 commits into from
Aug 30, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Implement MIN mode of NumericalTarget via acquisition function
This avoids inverting the computational representation, which gives
inverted predictions back to the user calling the surrogate.
AdrianSosic committed Aug 30, 2024
commit 72a2175283a0f859d3ba4e057fbab476b0f7fcd0
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -25,6 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
in the recommender but in the surrogate
- Fallback models created by `catch_constant_targets` are stored outside the surrogate
- `to_tensor` now also handles `numpy` arrays
- `MIN` mode of `NumericalTarget` is now implemented via the acquisition function
instead of negating the computational representation

### Fixed
- `CategoricalParameter` and `TaskParameter` no longer incorrectly coerce a single
29 changes: 26 additions & 3 deletions baybe/acquisition/base.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,8 @@
from attrs import define

from baybe.objectives.base import Objective
from baybe.objectives.desirability import DesirabilityObjective
from baybe.objectives.single import SingleTargetObjective
from baybe.searchspace.core import SearchSpace
from baybe.serialization.core import (
converter,
@@ -19,6 +21,8 @@
)
from baybe.serialization.mixin import SerialMixin
from baybe.surrogates.base import SurrogateProtocol
from baybe.targets.enum import TargetMode
from baybe.targets.numerical import NumericalTarget
from baybe.utils.basic import classproperty, match_attributes
from baybe.utils.boolean import is_abstract
from baybe.utils.dataframe import to_tensor
@@ -53,14 +57,16 @@ def to_botorch(
The required structure of `measurements` is specified in
:meth:`baybe.recommenders.base.RecommenderProtocol.recommend`.
"""
import botorch.acquisition as botorch_acqf_module
import botorch.acquisition as bacqf
import torch
from botorch.acquisition.objective import LinearMCObjective

# Get computational data representations
train_x = searchspace.transform(measurements, allow_extra=True)
train_y = objective.transform(measurements)

# Retrieve corresponding botorch class
acqf_cls = getattr(botorch_acqf_module, self.__class__.__name__)
acqf_cls = getattr(bacqf, self.__class__.__name__)

# Match relevant attributes
params_dict = match_attributes(
@@ -81,8 +87,25 @@ def to_botorch(
self.get_integration_points(searchspace) # type: ignore[attr-defined]
)

params_dict.update(additional_params)
# Add acquisition objective
match objective:
case SingleTargetObjective(NumericalTarget(mode=TargetMode.MIN)):
if issubclass(acqf_cls, bacqf.AnalyticAcquisitionFunction):
additional_params["maximize"] = False
elif issubclass(acqf_cls, bacqf.MCAcquisitionFunction):
additional_params["objective"] = LinearMCObjective(
torch.tensor([-1.0])
)
else:
raise ValueError(
f"Unsupported acquisition function type: {acqf_cls}."
)
case SingleTargetObjective() | DesirabilityObjective():
pass
case _:
raise ValueError(f"Unsupported objective type: {objective}")

params_dict.update(additional_params)
return acqf_cls(**params_dict)


7 changes: 0 additions & 7 deletions baybe/targets/numerical.py
Original file line number Diff line number Diff line change
@@ -150,13 +150,6 @@ def transform(self, data: pd.DataFrame) -> pd.DataFrame: # noqa: D102
transformed = pd.DataFrame(
func(data, *self.bounds.to_tuple()), index=data.index
)

# Otherwise, simply negate all target values for ``MIN`` mode.
# For ``MAX`` mode, nothing needs to be done.
# For ``MATCH`` mode, the validators avoid a situation without specified bounds.
elif self.mode is TargetMode.MIN:
transformed = -data

else:
transformed = data.copy()