Skip to content

Commit ea011ea

Browse files
authoredFeb 10, 2025··
Merge: Refactor ACQF indicators (#479)
Fixes #467 and #458 BayBE acquisition functions have been refactored - Removed `is_mc` - Added `supports_batching` and `supports_pending_experiments` indicators - Changed tests using these properties + some changes to have more readable output - Fixed bug with `PSTD` where the direction was negated depending on target type. However, for `PSTD` the direction is set by the user and means to change whether points with large or small std are returned - ie it is fully independent of the target type. So it needs to be exempt from the overwriting of the `maximize` setting that is done for other analytical acqfs - Enabled the new `qPSTD` function which requires a higher botorch version, conveniently also including the upgrade to support higher scipy versions, version pins have been adjusted accordingly.
2 parents ae64ca1 + fdc1e81 commit ea011ea

File tree

12 files changed

+169
-68
lines changed

12 files changed

+169
-68
lines changed
 

‎.lockfiles/py310-dev.lock

+21-11
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ botocore==1.35.68
5353
# via
5454
# boto3
5555
# s3transfer
56-
botorch==0.11.3
56+
botorch==0.13.0
5757
# via baybe (pyproject.toml)
5858
cachecontrol==0.14.0
5959
# via pip-audit
@@ -196,7 +196,7 @@ googleapis-common-protos==1.63.2
196196
# via
197197
# opentelemetry-exporter-otlp-proto-grpc
198198
# opentelemetry-exporter-otlp-proto-http
199-
gpytorch==1.12
199+
gpytorch==1.14
200200
# via
201201
# baybe (pyproject.toml)
202202
# botorch
@@ -259,7 +259,9 @@ ipywidgets==8.1.3
259259
isoduration==20.11.0
260260
# via jsonschema
261261
jaxtyping==0.2.33
262-
# via linear-operator
262+
# via
263+
# gpytorch
264+
# linear-operator
263265
jedi==0.19.1
264266
# via ipython
265267
jinja2==3.1.4
@@ -354,7 +356,7 @@ lifelines==0.29.0
354356
# via ngboost
355357
lime==0.2.0.1
356358
# via shap
357-
linear-operator==0.5.2
359+
linear-operator==0.6
358360
# via
359361
# botorch
360362
# gpytorch
@@ -401,8 +403,8 @@ mordredcommunity==2.0.6
401403
# via scikit-fingerprints
402404
mpmath==1.3.0
403405
# via
404-
# botorch
405406
# gpytorch
407+
# linear-operator
406408
# sympy
407409
msgpack==1.0.8
408410
# via cachecontrol
@@ -411,7 +413,9 @@ multipledispatch==1.0.0
411413
mypy==1.11.0
412414
# via baybe (pyproject.toml)
413415
mypy-extensions==1.0.0
414-
# via mypy
416+
# via
417+
# mypy
418+
# typing-inspect
415419
myst-parser==4.0.0
416420
# via baybe (pyproject.toml)
417421
nbclient==0.10.0
@@ -452,7 +456,6 @@ numpy==1.26.4
452456
# baybe (pyproject.toml)
453457
# altair
454458
# autograd
455-
# botorch
456459
# contourpy
457460
# descriptastorus
458461
# e3fp
@@ -716,6 +719,8 @@ pyparsing==3.1.2
716719
# pip-requirements-parser
717720
pyproject-api==1.7.1
718721
# via tox
722+
pyre-extensions==0.0.32
723+
# via botorch
719724
pyreadline3==3.4.1 ; sys_platform == 'win32'
720725
# via humanfriendly
721726
pyro-api==0.1.2
@@ -934,7 +939,9 @@ terminado==0.18.1
934939
# jupyter-server
935940
# jupyter-server-terminals
936941
threadpoolctl==3.5.0
937-
# via scikit-learn
942+
# via
943+
# botorch
944+
# scikit-learn
938945
tifffile==2024.12.12
939946
# via scikit-image
940947
tinycss2==1.3.0
@@ -1010,9 +1017,7 @@ traitlets==5.14.3
10101017
triton==2.3.1 ; python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform == 'linux'
10111018
# via torch
10121019
typeguard==2.13.3
1013-
# via
1014-
# jaxtyping
1015-
# linear-operator
1020+
# via jaxtyping
10161021
types-python-dateutil==2.9.0.20240316
10171022
# via arrow
10181023
types-pytz==2024.1.0.20240417
@@ -1025,15 +1030,20 @@ typing-extensions==4.12.2
10251030
# altair
10261031
# anyio
10271032
# async-lru
1033+
# botorch
10281034
# cattrs
10291035
# formulaic
10301036
# funcy-stubs
10311037
# huggingface-hub
10321038
# ipython
10331039
# mypy
10341040
# opentelemetry-sdk
1041+
# pyre-extensions
10351042
# streamlit
10361043
# torch
1044+
# typing-inspect
1045+
typing-inspect==0.9.0
1046+
# via pyre-extensions
10371047
tzdata==2024.1
10381048
# via pandas
10391049
uri-template==1.3.0

‎CHANGELOG.md

+8
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
### Added
99
- `BCUT2D` encoding for `SubstanceParameter`
1010
- Stored benchmarking results now include the Python environment and version
11+
- `qPSTD` acquisition function
12+
13+
### Changed
14+
- Acquisition function indicator `is_mc` has been removed in favor of new indicators
15+
`supports_batching` and `supports_pending_experiments`
16+
17+
### Fixed
18+
- Incorrect optimization direction with `PSTD` with a single minimization target
1119

1220
## [0.12.2] - 2025-01-31
1321
### Changed

‎baybe/acquisition/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
qLogNoisyExpectedImprovement,
1414
qNegIntegratedPosteriorVariance,
1515
qNoisyExpectedImprovement,
16+
qPosteriorStandardDeviation,
1617
qProbabilityOfImprovement,
1718
qSimpleRegret,
1819
qThompsonSampling,
@@ -21,6 +22,7 @@
2122

2223
PM = PosteriorMean
2324
PSTD = PosteriorStandardDeviation
25+
qPSTD = qPosteriorStandardDeviation
2426
qSR = qSimpleRegret
2527
EI = ExpectedImprovement
2628
qEI = qExpectedImprovement
@@ -43,6 +45,7 @@
4345
# Posterior Statistics
4446
"PosteriorMean",
4547
"PosteriorStandardDeviation",
48+
"qPosteriorStandardDeviation",
4649
# Simple Regret
4750
"qSimpleRegret",
4851
# Expected Improvement
@@ -67,6 +70,7 @@
6770
# Posterior Statistics
6871
"PM",
6972
"PSTD",
73+
"qPSTD",
7074
# Simple Regret
7175
"qSR",
7276
# Expected Improvement

‎baybe/acquisition/acqfs.py

+12
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,13 @@ class PosteriorStandardDeviation(AcquisitionFunction):
174174
with minimal posterior standard deviation."""
175175

176176

177+
@define(frozen=True)
178+
class qPosteriorStandardDeviation(AcquisitionFunction):
179+
"""Monte Carlo based posterior standard deviation."""
180+
181+
abbreviation: ClassVar[str] = "qPSTD"
182+
183+
177184
########################################################################################
178185
### Simple Regret
179186
@define(frozen=True)
@@ -307,6 +314,11 @@ def _non_botorch_attrs(cls) -> tuple[str, ...]:
307314
flds = fields(qThompsonSampling)
308315
return (flds.n_mc_samples.name,)
309316

317+
@override
318+
@classproperty
319+
def supports_batching(cls) -> bool:
320+
return False
321+
310322

311323
# Collect leftover original slotted classes processed by `attrs.define`
312324
gc.collect()

‎baybe/acquisition/base.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,18 @@ class AcquisitionFunction(ABC, SerialMixin):
4444
"""An alternative name for type resolution."""
4545

4646
@classproperty
47-
def is_mc(cls) -> bool:
48-
"""Flag indicating whether this is a Monte-Carlo acquisition function."""
47+
def supports_batching(cls) -> bool:
48+
"""Flag indicating whether batch recommendation is supported."""
4949
return cls.abbreviation.startswith("q")
5050

51+
@classproperty
52+
def supports_pending_experiments(cls) -> bool:
53+
"""Flag indicating whether pending experiments are supported.
54+
55+
This is based on the same mechanism underlying batched recommendations.
56+
"""
57+
return cls.supports_batching
58+
5159
@classproperty
5260
def _non_botorch_attrs(cls) -> tuple[str, ...]:
5361
"""Names of attributes that are not passed to the BoTorch constructor."""
@@ -95,7 +103,7 @@ def to_botorch(
95103
self.get_integration_points(searchspace) # type: ignore[attr-defined]
96104
)
97105
if pending_experiments is not None:
98-
if self.is_mc:
106+
if self.supports_pending_experiments:
99107
pending_x = searchspace.transform(pending_experiments, allow_extra=True)
100108
additional_params["X_pending"] = to_tensor(pending_x)
101109
else:
@@ -107,18 +115,29 @@ def to_botorch(
107115
# Add acquisition objective / best observed value
108116
match objective:
109117
case SingleTargetObjective(NumericalTarget(mode=TargetMode.MIN)):
118+
# Adjust best_f
110119
if "best_f" in signature_params:
111120
additional_params["best_f"] = (
112121
bo_surrogate.posterior(train_x).mean.min().item()
113122
)
114123
if issubclass(acqf_cls, bo_acqf.MCAcquisitionFunction):
115124
additional_params["best_f"] *= -1.0
116125

117-
if issubclass(acqf_cls, bo_acqf.AnalyticAcquisitionFunction):
118-
additional_params["maximize"] = False
119-
elif issubclass(acqf_cls, bo_acqf.qNegIntegratedPosteriorVariance):
120-
# qNIPV is valid but does not require any adjusted params
126+
# Adjust objective
127+
if issubclass(
128+
acqf_cls,
129+
(
130+
bo_acqf.qNegIntegratedPosteriorVariance,
131+
bo_acqf.PosteriorStandardDeviation,
132+
bo_acqf.qPosteriorStandardDeviation,
133+
),
134+
):
135+
# The active learning acqfs are valid but no changes based on the
136+
# target direction are required.
121137
pass
138+
elif issubclass(acqf_cls, bo_acqf.AnalyticAcquisitionFunction):
139+
# Minimize acqfs in case the target should be minimized.
140+
additional_params["maximize"] = False
122141
elif issubclass(acqf_cls, bo_acqf.MCAcquisitionFunction):
123142
additional_params["objective"] = LinearMCObjective(
124143
torch.tensor([-1.0])

‎baybe/recommenders/pure/bayesian/botorch.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _recommend_discrete(
114114
The dataframe indices of the recommended points in the provided
115115
experimental representation.
116116
"""
117-
if batch_size > 1 and not self.acquisition_function.is_mc:
117+
if batch_size > 1 and not self.acquisition_function.supports_batching:
118118
raise IncompatibleAcquisitionFunctionError(
119119
f"The '{self.__class__.__name__}' only works with Monte Carlo "
120120
f"acquisition functions for batch sizes > 1."
@@ -168,8 +168,7 @@ def _recommend_continuous(
168168
Returns:
169169
A dataframe containing the recommendations as individual rows.
170170
"""
171-
# For batch size > 1, this optimizer needs a MC acquisition function
172-
if batch_size > 1 and not self.acquisition_function.is_mc:
171+
if batch_size > 1 and not self.acquisition_function.supports_batching:
173172
raise IncompatibleAcquisitionFunctionError(
174173
f"The '{self.__class__.__name__}' only works with Monte Carlo "
175174
f"acquisition functions for batch sizes > 1."
@@ -234,8 +233,8 @@ def _recommend_hybrid(
234233
Returns:
235234
The recommended points.
236235
"""
237-
# For batch size > 1, this optimizer needs a MC acquisition function
238-
if batch_size > 1 and not self.acquisition_function.is_mc:
236+
# For batch size > 1, the acqf needs to support batching
237+
if batch_size > 1 and not self.acquisition_function.supports_batching:
239238
raise IncompatibleAcquisitionFunctionError(
240239
f"The '{self.__class__.__name__}' only works with Monte Carlo "
241240
f"acquisition functions for batch sizes > 1."

‎docs/conf.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@
126126
(r"py:class", "baybe.utils.basic._U"),
127127
(r"ref:obj", "baybe.surrogates.base.ModelContext"),
128128
# Ignore custom class properties
129-
(r"py:obj", "baybe.acquisition.acqfs.*.is_mc"),
129+
(r"py:obj", "baybe.acquisition.acqfs.*.supports_batching"),
130+
(r"py:obj", "baybe.acquisition.acqfs.*.supports_pending_experiments"),
130131
# Other
131132
(r"py:obj", "baybe.utils.basic.UnspecifiedType.UNSPECIFIED"),
132133
]
@@ -301,8 +302,6 @@ def autodoc_process_docstring(app, what, name, obj, options, lines):
301302

302303

303304
def autodoc_skip_member(app, what, name, obj, skip, options):
304-
"""Skip the docstring for the is_mc classproperty."""
305-
# Note that we cannot do `return name == "is_mc"` since this messes up other members
306-
# that need to be skipped.
307-
if name == "is_mc":
305+
"""Skip the docstring for the acqf classproperties."""
306+
if name in ["supports_batching", "supports_pending_experiments"]:
308307
return True

‎docs/userguide/active_learning.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,16 @@ including a few guidelines.
1919
## Local Uncertainty Reduction
2020
In BayBE, there are two types of acquisition function that can be chosen to search for
2121
the points with the highest predicted model uncertainty:
22-
- [`PosteriorStandardDeviation`](baybe.acquisition.acqfs.PosteriorStandardDeviation) (`PSTD`)
22+
- [`PosteriorStandardDeviation`](baybe.acquisition.acqfs.PosteriorStandardDeviation) (`PSTD`)
23+
/ [`qPosteriorStandardDeviation`](baybe.acquisition.acqfs.qPosteriorStandardDeviation) (`qPSTD`)
2324
- [`UpperConfidenceBound`](baybe.acquisition.acqfs.UpperConfidenceBound) (`UCB`) /
2425
[`qUpperConfidenceBound`](baybe.acquisition.acqfs.qUpperConfidenceBound) (`qUCB`)
2526
with high `beta`:
2627
Increasing values of `beta` effectively eliminate the effect of the posterior mean on
2728
the acquisition value, yielding a selection of points driven primarily by the
2829
posterior variance. However, we generally recommend to use this acquisition function
2930
only if a small exploratory component is desired – otherwise, the
30-
[`PosteriorStandardDeviation`](baybe.acquisition.acqfs.PosteriorStandardDeviation)
31+
[`qPosteriorStandardDeviation`](baybe.acquisition.acqfs.qPosteriorStandardDeviation)
3132
acquisition function is what you are looking for.
3233

3334
## Global Uncertainty Reduction

‎pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ keywords = [
2929
dynamic = ['version']
3030
dependencies = [
3131
"attrs>=24.1.0",
32-
"botorch>=0.9.3,<1",
32+
"botorch>=0.13.0,<1",
3333
"cattrs>=24.1.0",
3434
"exceptiongroup",
3535
"funcy>=1.17,<2",
@@ -40,7 +40,7 @@ dependencies = [
4040
"pandas>=1.4.2,<3",
4141
"scikit-learn>=1.1.1,<2",
4242
"scikit-learn-extra>=0.3.0,<1",
43-
"scipy>=1.10.1,<1.15", # See https://github.com/pytorch/botorch/commit/37f04d11193704f4ece222b029df103edb3e6642
43+
"scipy>=1.10.1",
4444
"setuptools-scm>=7.1.0",
4545
"torch>=1.13.1,<3",
4646
"typing_extensions>=4.7.0",

‎tests/hypothesis_strategies/acquisition.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
qLogNoisyExpectedImprovement,
1616
qNegIntegratedPosteriorVariance,
1717
qNoisyExpectedImprovement,
18+
qPosteriorStandardDeviation,
1819
qProbabilityOfImprovement,
1920
qSimpleRegret,
2021
qUpperConfidenceBound,
@@ -51,6 +52,7 @@ def _qNIPV_strategy(draw: st.DrawFn):
5152
st.builds(UpperConfidenceBound, beta=finite_floats(min_value=0.0)),
5253
st.builds(PosteriorMean),
5354
st.builds(PosteriorStandardDeviation, maximize=st.sampled_from([True, False])),
55+
st.builds(qPosteriorStandardDeviation),
5456
st.builds(LogExpectedImprovement),
5557
st.builds(qExpectedImprovement),
5658
st.builds(qProbabilityOfImprovement),

‎tests/test_iterations.py

+76-33
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
# TODO: This file needs to be refactored.
22
"""Tests various configurations for a small number of iterations."""
33

4+
from contextlib import nullcontext
5+
46
import pytest
7+
from botorch.exceptions import UnsupportedError
58
from pytest import param
69

710
from baybe.acquisition import qKG, qNIPV, qTS, qUCB
811
from baybe.acquisition.base import AcquisitionFunction
9-
from baybe.exceptions import UnusedObjectWarning
12+
from baybe.exceptions import InvalidSurrogateModelError, UnusedObjectWarning
1013
from baybe.kernels.base import Kernel
1114
from baybe.kernels.basic import (
1215
LinearKernel,
@@ -77,17 +80,19 @@
7780
in [SearchSpaceType.CONTINUOUS, SearchSpaceType.HYBRID, SearchSpaceType.EITHER]
7881
]
7982

80-
valid_active_learning_acqfs = [
83+
acqfs_extra = [ # Additionally tested acqfs with extra configurations
8184
qNIPV(sampling_fraction=0.2, sampling_method="Random"),
8285
qNIPV(sampling_fraction=0.2, sampling_method="FPS"),
8386
qNIPV(sampling_fraction=1.0, sampling_method="FPS"),
8487
qNIPV(sampling_n_points=1, sampling_method="Random"),
8588
qNIPV(sampling_n_points=1, sampling_method="FPS"),
8689
]
87-
valid_mc_acqfs = [
88-
a() for a in get_subclasses(AcquisitionFunction) if a.is_mc
89-
] + valid_active_learning_acqfs
90-
valid_nonmc_acqfs = [a() for a in get_subclasses(AcquisitionFunction) if not a.is_mc]
90+
acqfs_batching = [
91+
a() for a in get_subclasses(AcquisitionFunction) if a.supports_batching
92+
] + acqfs_extra
93+
acqfs_non_batching = [
94+
a() for a in get_subclasses(AcquisitionFunction) if not a.supports_batching
95+
]
9196

9297
# List of all hybrid recommenders with default attributes. Is extended with other lists
9398
# of hybrid recommenders like naive ones or recommenders not using default arguments
@@ -202,35 +207,40 @@
202207
]
203208

204209
test_targets = [
205-
["Target_max"],
206-
["Target_min"],
207-
["Target_match_bell"],
208-
["Target_match_triangular"],
209-
["Target_max_bounded", "Target_min_bounded"],
210+
param(["Target_max"], id="Tmax"),
211+
param(["Target_min"], id="Tmin"),
212+
param(["Target_match_bell"], id="Tmatch_bell"),
213+
param(["Target_match_triangular"], id="Tmatch_triang"),
214+
param(["Target_max_bounded", "Target_min_bounded"], id="Tmax_bounded_Tmin_bounded"),
210215
]
211216

212217

213218
@pytest.mark.slow
214219
@pytest.mark.parametrize(
215-
"acqf", valid_mc_acqfs, ids=[a.abbreviation for a in valid_mc_acqfs]
220+
"acqf", acqfs_batching, ids=[a.abbreviation for a in acqfs_batching]
216221
)
217222
@pytest.mark.parametrize("n_iterations", [3], ids=["i3"])
218-
def test_mc_acqfs(campaign, n_iterations, batch_size, acqf):
219-
if isinstance(acqf, qKG):
220-
pytest.skip(f"{acqf.__class__.__name__} only works with continuous spaces.")
221-
if isinstance(acqf, qTS) and batch_size > 1:
222-
pytest.skip(f"{acqf.__class__.__name__} only works with batch size 1.")
223-
224-
run_iterations(campaign, n_iterations, batch_size)
223+
@pytest.mark.parametrize("n_grid_points", [5], ids=["g5"])
224+
def test_batching_acqfs(campaign, n_iterations, batch_size, acqf):
225+
context = nullcontext()
226+
if campaign.searchspace.type not in [
227+
SearchSpaceType.CONTINUOUS,
228+
SearchSpaceType.HYBRID,
229+
] and isinstance(acqf, qKG):
230+
# qKG does not work with purely discrete spaces
231+
context = pytest.raises(UnsupportedError)
232+
233+
with context:
234+
run_iterations(campaign, n_iterations, batch_size)
225235

226236

227237
@pytest.mark.slow
228238
@pytest.mark.parametrize(
229-
"acqf", valid_nonmc_acqfs, ids=[a.abbreviation for a in valid_nonmc_acqfs]
239+
"acqf", acqfs_non_batching, ids=[a.abbreviation for a in acqfs_non_batching]
230240
)
231241
@pytest.mark.parametrize("n_iterations", [3], ids=["i3"])
232242
@pytest.mark.parametrize("batch_size", [1], ids=["b1"])
233-
def test_nonmc_acqfs(campaign, n_iterations, batch_size):
243+
def test_non_batching_acqfs(campaign, n_iterations, batch_size):
234244
run_iterations(campaign, n_iterations, batch_size)
235245

236246

@@ -256,13 +266,20 @@ def test_kernel_factories(campaign, n_iterations, batch_size):
256266
ids=[c.__class__ for c in valid_surrogate_models],
257267
)
258268
def test_surrogate_models(campaign, n_iterations, batch_size, surrogate_model):
269+
context = nullcontext()
259270
if batch_size > 1 and isinstance(surrogate_model, IndependentGaussianSurrogate):
260-
pytest.skip("Batch recommendation is not supported.")
261-
run_iterations(campaign, n_iterations, batch_size)
271+
context = pytest.raises(InvalidSurrogateModelError)
272+
273+
with context:
274+
run_iterations(campaign, n_iterations, batch_size)
262275

263276

264277
@pytest.mark.slow
265-
@pytest.mark.parametrize("recommender", valid_initial_recommenders)
278+
@pytest.mark.parametrize(
279+
"recommender",
280+
valid_initial_recommenders,
281+
ids=[c.__class__ for c in valid_initial_recommenders],
282+
)
266283
def test_initial_recommenders(campaign, n_iterations, batch_size):
267284
with pytest.warns(UnusedObjectWarning):
268285
run_iterations(campaign, n_iterations, batch_size)
@@ -275,35 +292,61 @@ def test_targets(campaign, n_iterations, batch_size):
275292

276293

277294
@pytest.mark.slow
278-
@pytest.mark.parametrize("recommender", valid_discrete_recommenders)
295+
@pytest.mark.parametrize(
296+
"recommender",
297+
valid_discrete_recommenders,
298+
ids=[c.__class__ for c in valid_discrete_recommenders],
299+
)
279300
def test_recommenders_discrete(campaign, n_iterations, batch_size):
280301
run_iterations(campaign, n_iterations, batch_size)
281302

282303

283304
@pytest.mark.slow
284-
@pytest.mark.parametrize("recommender", valid_continuous_recommenders)
285-
@pytest.mark.parametrize("parameter_names", [["Conti_finite1", "Conti_finite2"]])
305+
@pytest.mark.parametrize(
306+
"recommender",
307+
valid_continuous_recommenders,
308+
ids=[c.__class__ for c in valid_continuous_recommenders],
309+
)
310+
@pytest.mark.parametrize(
311+
"parameter_names", [["Conti_finite1", "Conti_finite2"]], ids=["conti_params"]
312+
)
286313
def test_recommenders_continuous(campaign, n_iterations, batch_size):
287314
run_iterations(campaign, n_iterations, batch_size)
288315

289316

290317
@pytest.mark.slow
291-
@pytest.mark.parametrize("recommender", valid_hybrid_recommenders)
318+
@pytest.mark.parametrize(
319+
"recommender",
320+
valid_hybrid_recommenders,
321+
ids=[c.__class__ for c in valid_hybrid_recommenders],
322+
)
292323
@pytest.mark.parametrize(
293324
"parameter_names",
294325
[["Categorical_1", "SomeSetting", "Num_disc_1", "Conti_finite1", "Conti_finite2"]],
326+
ids=["hybrid_params"],
295327
)
296328
def test_recommenders_hybrid(campaign, n_iterations, batch_size):
297329
run_iterations(campaign, n_iterations, batch_size)
298330

299331

300-
@pytest.mark.parametrize("recommender", valid_meta_recommenders, indirect=True)
332+
@pytest.mark.parametrize(
333+
"recommender",
334+
valid_meta_recommenders,
335+
ids=[c.__class__ for c in valid_meta_recommenders],
336+
indirect=True,
337+
)
301338
def test_meta_recommenders(campaign, n_iterations, batch_size):
302339
run_iterations(campaign, n_iterations, batch_size)
303340

304341

305-
@pytest.mark.parametrize("acqf", [qTS(), qUCB()])
306-
@pytest.mark.parametrize("surrogate_model", [BetaBernoulliMultiArmedBanditSurrogate()])
342+
@pytest.mark.parametrize(
343+
"acqf", [qTS(), qUCB()], ids=[qTS.abbreviation, qUCB.abbreviation]
344+
)
345+
@pytest.mark.parametrize(
346+
"surrogate_model",
347+
[BetaBernoulliMultiArmedBanditSurrogate()],
348+
ids=["bernoulli_bandit_surrogate"],
349+
)
307350
@pytest.mark.parametrize(
308351
"parameter_names",
309352
[
@@ -314,7 +357,7 @@ def test_meta_recommenders(campaign, n_iterations, batch_size):
314357
["Frame_B"],
315358
],
316359
)
317-
@pytest.mark.parametrize("batch_size", [1])
318-
@pytest.mark.parametrize("target_names", [["Target_binary"]])
360+
@pytest.mark.parametrize("target_names", [["Target_binary"]], ids=["binary_target"])
361+
@pytest.mark.parametrize("batch_size", [1], ids=["b1"])
319362
def test_multi_armed_bandit(campaign, n_iterations, batch_size):
320363
run_iterations(campaign, n_iterations, batch_size, add_noise=False)

‎tests/test_pending_experiments.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,15 @@ def test_pending_points(campaign, batch_size):
125125
)
126126

127127

128-
_non_mc_acqfs = [a() for a in get_subclasses(AcquisitionFunction) if not a.is_mc]
128+
acqfs_non_pending = [
129+
a()
130+
for a in get_subclasses(AcquisitionFunction)
131+
if (not a.supports_pending_experiments)
132+
]
129133

130134

131135
@pytest.mark.parametrize(
132-
"acqf", _non_mc_acqfs, ids=[a.abbreviation for a in _non_mc_acqfs]
136+
"acqf", acqfs_non_pending, ids=[a.abbreviation for a in acqfs_non_pending]
133137
)
134138
@pytest.mark.parametrize(
135139
"parameter_names",
@@ -140,7 +144,7 @@ def test_pending_points(campaign, batch_size):
140144
],
141145
)
142146
@pytest.mark.parametrize("n_grid_points", [5], ids=["g5"])
143-
@pytest.mark.parametrize("batch_size", [3], ids=["b3"])
147+
@pytest.mark.parametrize("batch_size", [1], ids=["b1"])
144148
def test_invalid_acqf(searchspace, recommender, objective, batch_size, acqf):
145149
"""Test exception raised for acqfs that don't support pending experiments."""
146150
recommender = TwoPhaseMetaRecommender(

0 commit comments

Comments
 (0)
Please sign in to comment.