1
1
# TODO: This file needs to be refactored.
2
2
"""Tests various configurations for a small number of iterations."""
3
3
4
+ from contextlib import nullcontext
5
+
4
6
import pytest
7
+ from botorch .exceptions import UnsupportedError
5
8
from pytest import param
6
9
7
10
from baybe .acquisition import qKG , qNIPV , qTS , qUCB
8
11
from baybe .acquisition .base import AcquisitionFunction
9
- from baybe .exceptions import UnusedObjectWarning
12
+ from baybe .exceptions import InvalidSurrogateModelError , UnusedObjectWarning
10
13
from baybe .kernels .base import Kernel
11
14
from baybe .kernels .basic import (
12
15
LinearKernel ,
77
80
in [SearchSpaceType .CONTINUOUS , SearchSpaceType .HYBRID , SearchSpaceType .EITHER ]
78
81
]
79
82
80
- valid_active_learning_acqfs = [
83
+ acqfs_extra = [ # Additionally tested acqfs with extra configurations
81
84
qNIPV (sampling_fraction = 0.2 , sampling_method = "Random" ),
82
85
qNIPV (sampling_fraction = 0.2 , sampling_method = "FPS" ),
83
86
qNIPV (sampling_fraction = 1.0 , sampling_method = "FPS" ),
84
87
qNIPV (sampling_n_points = 1 , sampling_method = "Random" ),
85
88
qNIPV (sampling_n_points = 1 , sampling_method = "FPS" ),
86
89
]
87
- valid_mc_acqfs = [
88
- a () for a in get_subclasses (AcquisitionFunction ) if a .is_mc
89
- ] + valid_active_learning_acqfs
90
- valid_nonmc_acqfs = [a () for a in get_subclasses (AcquisitionFunction ) if not a .is_mc ]
90
+ acqfs_batching = [
91
+ a () for a in get_subclasses (AcquisitionFunction ) if a .supports_batching
92
+ ] + acqfs_extra
93
+ acqfs_non_batching = [
94
+ a () for a in get_subclasses (AcquisitionFunction ) if not a .supports_batching
95
+ ]
91
96
92
97
# List of all hybrid recommenders with default attributes. Is extended with other lists
93
98
# of hybrid recommenders like naive ones or recommenders not using default arguments
202
207
]
203
208
204
209
test_targets = [
205
- ["Target_max" ],
206
- ["Target_min" ],
207
- ["Target_match_bell" ],
208
- ["Target_match_triangular" ],
209
- ["Target_max_bounded" , "Target_min_bounded" ],
210
+ param ( ["Target_max" ], id = "Tmax" ) ,
211
+ param ( ["Target_min" ], id = "Tmin" ) ,
212
+ param ( ["Target_match_bell" ], id = "Tmatch_bell" ) ,
213
+ param ( ["Target_match_triangular" ], id = "Tmatch_triang" ) ,
214
+ param ( ["Target_max_bounded" , "Target_min_bounded" ], id = "Tmax_bounded_Tmin_bounded" ) ,
210
215
]
211
216
212
217
213
218
@pytest .mark .slow
214
219
@pytest .mark .parametrize (
215
- "acqf" , valid_mc_acqfs , ids = [a .abbreviation for a in valid_mc_acqfs ]
220
+ "acqf" , acqfs_batching , ids = [a .abbreviation for a in acqfs_batching ]
216
221
)
217
222
@pytest .mark .parametrize ("n_iterations" , [3 ], ids = ["i3" ])
218
- def test_mc_acqfs (campaign , n_iterations , batch_size , acqf ):
219
- if isinstance (acqf , qKG ):
220
- pytest .skip (f"{ acqf .__class__ .__name__ } only works with continuous spaces." )
221
- if isinstance (acqf , qTS ) and batch_size > 1 :
222
- pytest .skip (f"{ acqf .__class__ .__name__ } only works with batch size 1." )
223
-
224
- run_iterations (campaign , n_iterations , batch_size )
223
+ @pytest .mark .parametrize ("n_grid_points" , [5 ], ids = ["g5" ])
224
+ def test_batching_acqfs (campaign , n_iterations , batch_size , acqf ):
225
+ context = nullcontext ()
226
+ if campaign .searchspace .type not in [
227
+ SearchSpaceType .CONTINUOUS ,
228
+ SearchSpaceType .HYBRID ,
229
+ ] and isinstance (acqf , qKG ):
230
+ # qKG does not work with purely discrete spaces
231
+ context = pytest .raises (UnsupportedError )
232
+
233
+ with context :
234
+ run_iterations (campaign , n_iterations , batch_size )
225
235
226
236
227
237
@pytest .mark .slow
228
238
@pytest .mark .parametrize (
229
- "acqf" , valid_nonmc_acqfs , ids = [a .abbreviation for a in valid_nonmc_acqfs ]
239
+ "acqf" , acqfs_non_batching , ids = [a .abbreviation for a in acqfs_non_batching ]
230
240
)
231
241
@pytest .mark .parametrize ("n_iterations" , [3 ], ids = ["i3" ])
232
242
@pytest .mark .parametrize ("batch_size" , [1 ], ids = ["b1" ])
233
- def test_nonmc_acqfs (campaign , n_iterations , batch_size ):
243
+ def test_non_batching_acqfs (campaign , n_iterations , batch_size ):
234
244
run_iterations (campaign , n_iterations , batch_size )
235
245
236
246
@@ -256,13 +266,20 @@ def test_kernel_factories(campaign, n_iterations, batch_size):
256
266
ids = [c .__class__ for c in valid_surrogate_models ],
257
267
)
258
268
def test_surrogate_models (campaign , n_iterations , batch_size , surrogate_model ):
269
+ context = nullcontext ()
259
270
if batch_size > 1 and isinstance (surrogate_model , IndependentGaussianSurrogate ):
260
- pytest .skip ("Batch recommendation is not supported." )
261
- run_iterations (campaign , n_iterations , batch_size )
271
+ context = pytest .raises (InvalidSurrogateModelError )
272
+
273
+ with context :
274
+ run_iterations (campaign , n_iterations , batch_size )
262
275
263
276
264
277
@pytest .mark .slow
265
- @pytest .mark .parametrize ("recommender" , valid_initial_recommenders )
278
+ @pytest .mark .parametrize (
279
+ "recommender" ,
280
+ valid_initial_recommenders ,
281
+ ids = [c .__class__ for c in valid_initial_recommenders ],
282
+ )
266
283
def test_initial_recommenders (campaign , n_iterations , batch_size ):
267
284
with pytest .warns (UnusedObjectWarning ):
268
285
run_iterations (campaign , n_iterations , batch_size )
@@ -275,35 +292,61 @@ def test_targets(campaign, n_iterations, batch_size):
275
292
276
293
277
294
@pytest .mark .slow
278
- @pytest .mark .parametrize ("recommender" , valid_discrete_recommenders )
295
+ @pytest .mark .parametrize (
296
+ "recommender" ,
297
+ valid_discrete_recommenders ,
298
+ ids = [c .__class__ for c in valid_discrete_recommenders ],
299
+ )
279
300
def test_recommenders_discrete (campaign , n_iterations , batch_size ):
280
301
run_iterations (campaign , n_iterations , batch_size )
281
302
282
303
283
304
@pytest .mark .slow
284
- @pytest .mark .parametrize ("recommender" , valid_continuous_recommenders )
285
- @pytest .mark .parametrize ("parameter_names" , [["Conti_finite1" , "Conti_finite2" ]])
305
+ @pytest .mark .parametrize (
306
+ "recommender" ,
307
+ valid_continuous_recommenders ,
308
+ ids = [c .__class__ for c in valid_continuous_recommenders ],
309
+ )
310
+ @pytest .mark .parametrize (
311
+ "parameter_names" , [["Conti_finite1" , "Conti_finite2" ]], ids = ["conti_params" ]
312
+ )
286
313
def test_recommenders_continuous (campaign , n_iterations , batch_size ):
287
314
run_iterations (campaign , n_iterations , batch_size )
288
315
289
316
290
317
@pytest .mark .slow
291
- @pytest .mark .parametrize ("recommender" , valid_hybrid_recommenders )
318
+ @pytest .mark .parametrize (
319
+ "recommender" ,
320
+ valid_hybrid_recommenders ,
321
+ ids = [c .__class__ for c in valid_hybrid_recommenders ],
322
+ )
292
323
@pytest .mark .parametrize (
293
324
"parameter_names" ,
294
325
[["Categorical_1" , "SomeSetting" , "Num_disc_1" , "Conti_finite1" , "Conti_finite2" ]],
326
+ ids = ["hybrid_params" ],
295
327
)
296
328
def test_recommenders_hybrid (campaign , n_iterations , batch_size ):
297
329
run_iterations (campaign , n_iterations , batch_size )
298
330
299
331
300
- @pytest .mark .parametrize ("recommender" , valid_meta_recommenders , indirect = True )
332
+ @pytest .mark .parametrize (
333
+ "recommender" ,
334
+ valid_meta_recommenders ,
335
+ ids = [c .__class__ for c in valid_meta_recommenders ],
336
+ indirect = True ,
337
+ )
301
338
def test_meta_recommenders (campaign , n_iterations , batch_size ):
302
339
run_iterations (campaign , n_iterations , batch_size )
303
340
304
341
305
- @pytest .mark .parametrize ("acqf" , [qTS (), qUCB ()])
306
- @pytest .mark .parametrize ("surrogate_model" , [BetaBernoulliMultiArmedBanditSurrogate ()])
342
+ @pytest .mark .parametrize (
343
+ "acqf" , [qTS (), qUCB ()], ids = [qTS .abbreviation , qUCB .abbreviation ]
344
+ )
345
+ @pytest .mark .parametrize (
346
+ "surrogate_model" ,
347
+ [BetaBernoulliMultiArmedBanditSurrogate ()],
348
+ ids = ["bernoulli_bandit_surrogate" ],
349
+ )
307
350
@pytest .mark .parametrize (
308
351
"parameter_names" ,
309
352
[
@@ -314,7 +357,7 @@ def test_meta_recommenders(campaign, n_iterations, batch_size):
314
357
["Frame_B" ],
315
358
],
316
359
)
317
- @pytest .mark .parametrize ("batch_size " , [1 ])
318
- @pytest .mark .parametrize ("target_names " , [[ "Target_binary" ] ])
360
+ @pytest .mark .parametrize ("target_names " , [[ "Target_binary" ]], ids = [ "binary_target" ])
361
+ @pytest .mark .parametrize ("batch_size " , [1 ], ids = [ "b1" ])
319
362
def test_multi_armed_bandit (campaign , n_iterations , batch_size ):
320
363
run_iterations (campaign , n_iterations , batch_size , add_noise = False )
0 commit comments