Skip to content

Commit 27d5903

Browse files
authored
Refactor adapter model (#260)
This PR moves the `AdapterModel` (i.e., the connecting layer between baybe surrogates and botorch models) to the surrogate package, where it actually belongs. Additionally, the surrogates are equipped with a `to_botorch` method that simplifies the model translation and can be customized per subclass. In particular, GP surrogates are no longer wrapped using the adapter but now expose their internal botorch model instance directly.
2 parents 914363f + dcaad6e commit 27d5903

File tree

6 files changed

+21
-6
lines changed

6 files changed

+21
-6
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

77
## [Unreleased]
8+
### Added
9+
- `Surrogate` base class now exposes a `to_botorch` method
10+
811
### Changed
912
- Passing an `Objective` to `Campaign` is now optional
13+
- `GaussianProcessSurrogate` models are no longer wrapped when cast to BoTorch
1014

1115
### Removed
1216
- Support for Python 3.9 removed due to new [BoTorch requirements](https://github.com/pytorch/botorch/pull/2293)

baybe/acquisition/base.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,13 @@ def to_botorch(
4343
"""Create the botorch-ready representation of the function."""
4444
import botorch.acquisition as botorch_analytical_acqf
4545

46-
from baybe.acquisition._adapter import AdapterModel
47-
4846
acqf_cls = getattr(botorch_analytical_acqf, self.__class__.__name__)
4947
params_dict = filter_attributes(object=self, callable_=acqf_cls.__init__)
5048

5149
additional_params = {
5250
p: v
5351
for p, v in {
54-
"model": AdapterModel(surrogate),
52+
"model": surrogate.to_botorch(),
5553
"best_f": train_y.max().item(),
5654
"X_baseline": to_tensor(train_x),
5755
}.items()

baybe/acquisition/_adapter.py baybe/surrogates/_adapter.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Adapter for making BoTorch's acquisition functions work with BayBE models."""
1+
"""Adapter functionality for making BayBE surrogates BoTorch-ready."""
22

33
from collections.abc import Callable
44
from typing import Any
@@ -19,7 +19,7 @@ class AdapterModel(Model):
1919
surrogate model usable in conjunction with BoTorch acquisition functions.
2020
2121
Args:
22-
surrogate: The internal surrogate model
22+
surrogate: The internal surrogate model.
2323
"""
2424

2525
def __init__(self, surrogate: Surrogate):

baybe/surrogates/base.py

+7
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from baybe.surrogates.utils import _prepare_inputs, _prepare_targets
2626

2727
if TYPE_CHECKING:
28+
from botorch.models.model import Model
2829
from torch import Tensor
2930

3031
# Define constants
@@ -55,6 +56,12 @@ class Surrogate(ABC, SerialMixin):
5556
"""Class variable encoding whether or not the surrogate supports transfer
5657
learning."""
5758

59+
def to_botorch(self) -> Model:
60+
"""Create the botorch-ready representation of the model."""
61+
from baybe.surrogates._adapter import AdapterModel
62+
63+
return AdapterModel(self)
64+
5865
def posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]:
5966
"""Evaluate the surrogate model at the given candidate points.
6067

baybe/surrogates/gaussian_process/core.py

+6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323

2424
if TYPE_CHECKING:
25+
from botorch.models.model import Model
2526
from torch import Tensor
2627

2728

@@ -59,6 +60,11 @@ def from_preset(preset: GaussianProcessPreset) -> GaussianProcessSurrogate:
5960
"""Create a Gaussian process surrogate from one of the defined presets."""
6061
return make_gp_from_preset(preset)
6162

63+
def to_botorch(self) -> Model: # noqa: D102
64+
# See base class.
65+
66+
return self._model
67+
6268
def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]:
6369
# See base class.
6470
posterior = self._model.posterior(candidates)

tests/test_imports.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def test_imports(module: str):
5858

5959
WHITELISTS = {
6060
"torch": [
61-
"baybe.acquisition._adapter",
6261
"baybe.acquisition.partial",
62+
"baybe.surrogates._adapter",
6363
"baybe.utils.botorch_wrapper",
6464
"baybe.utils.torch",
6565
],

0 commit comments

Comments
 (0)