Skip to content

Commit ae64ca1

Browse files
authoredFeb 7, 2025··
Refactor benchmark architecture (#477)
This PR refines the benchmarking classes and fixes a few other things along the way: * Most importantly, the convergence-test-specific `optimal_function_inputs` and `best_possible_result` attributes are moved to a separate `ConvergenceBenchmark` class, with corresponding `ConvergenceBenchmarkSettings`. * The remaining benchmark attributes / properties are cleaned up.
2 parents ac8c518 + 2322b2a commit ae64ca1

File tree

12 files changed

+188
-150
lines changed

12 files changed

+188
-150
lines changed
 

‎.lockfiles/py310-dev.lock

+19-19
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ anyio==4.4.0
88
# via
99
# httpx
1010
# jupyter-server
11-
appnope==0.1.4 ; platform_system == 'Darwin'
11+
appnope==0.1.4 ; sys_platform == 'darwin'
1212
# via ipykernel
1313
argon2-cffi==23.1.0
1414
# via jupyter-server
@@ -61,7 +61,7 @@ cachetools==5.4.0
6161
# via
6262
# streamlit
6363
# tox
64-
cattrs==23.2.3
64+
cattrs==24.1.2
6565
# via baybe (pyproject.toml)
6666
certifi==2024.7.4
6767
# via
@@ -240,7 +240,7 @@ importlib-metadata==7.1.0
240240
# opentelemetry-api
241241
iniconfig==2.0.0
242242
# via pytest
243-
intel-openmp==2021.4.0 ; platform_system == 'Windows'
243+
intel-openmp==2021.4.0 ; sys_platform == 'win32'
244244
# via mkl
245245
interface-meta==1.3.0
246246
# via formulaic
@@ -393,7 +393,7 @@ mdurl==0.1.2
393393
# via markdown-it-py
394394
mistune==3.0.2
395395
# via nbconvert
396-
mkl==2021.4.0 ; platform_system == 'Windows'
396+
mkl==2021.4.0 ; sys_platform == 'win32'
397397
# via torch
398398
mmh3==5.0.1
399399
# via e3fp
@@ -487,36 +487,36 @@ numpy==1.26.4
487487
# types-seaborn
488488
# xarray
489489
# xyzpy
490-
nvidia-cublas-cu12==12.1.3.1 ; platform_machine == 'x86_64' and platform_system == 'Linux'
490+
nvidia-cublas-cu12==12.1.3.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
491491
# via
492492
# nvidia-cudnn-cu12
493493
# nvidia-cusolver-cu12
494494
# torch
495-
nvidia-cuda-cupti-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux'
495+
nvidia-cuda-cupti-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux'
496496
# via torch
497-
nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux'
497+
nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux'
498498
# via torch
499-
nvidia-cuda-runtime-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux'
499+
nvidia-cuda-runtime-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux'
500500
# via torch
501-
nvidia-cudnn-cu12==8.9.2.26 ; platform_machine == 'x86_64' and platform_system == 'Linux'
501+
nvidia-cudnn-cu12==8.9.2.26 ; platform_machine == 'x86_64' and sys_platform == 'linux'
502502
# via torch
503-
nvidia-cufft-cu12==11.0.2.54 ; platform_machine == 'x86_64' and platform_system == 'Linux'
503+
nvidia-cufft-cu12==11.0.2.54 ; platform_machine == 'x86_64' and sys_platform == 'linux'
504504
# via torch
505-
nvidia-curand-cu12==10.3.2.106 ; platform_machine == 'x86_64' and platform_system == 'Linux'
505+
nvidia-curand-cu12==10.3.2.106 ; platform_machine == 'x86_64' and sys_platform == 'linux'
506506
# via torch
507-
nvidia-cusolver-cu12==11.4.5.107 ; platform_machine == 'x86_64' and platform_system == 'Linux'
507+
nvidia-cusolver-cu12==11.4.5.107 ; platform_machine == 'x86_64' and sys_platform == 'linux'
508508
# via torch
509-
nvidia-cusparse-cu12==12.1.0.106 ; platform_machine == 'x86_64' and platform_system == 'Linux'
509+
nvidia-cusparse-cu12==12.1.0.106 ; platform_machine == 'x86_64' and sys_platform == 'linux'
510510
# via
511511
# nvidia-cusolver-cu12
512512
# torch
513-
nvidia-nccl-cu12==2.20.5 ; platform_machine == 'x86_64' and platform_system == 'Linux'
513+
nvidia-nccl-cu12==2.20.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
514514
# via torch
515-
nvidia-nvjitlink-cu12==12.5.82 ; platform_machine == 'x86_64' and platform_system == 'Linux'
515+
nvidia-nvjitlink-cu12==12.5.82 ; platform_machine == 'x86_64' and sys_platform == 'linux'
516516
# via
517517
# nvidia-cusolver-cu12
518518
# nvidia-cusparse-cu12
519-
nvidia-nvtx-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux'
519+
nvidia-nvtx-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux'
520520
# via torch
521521
onnx==1.16.1
522522
# via
@@ -922,7 +922,7 @@ sympy==1.13.1
922922
# via
923923
# onnxruntime
924924
# torch
925-
tbb==2021.13.0 ; platform_system == 'Windows'
925+
tbb==2021.13.0 ; sys_platform == 'win32'
926926
# via mkl
927927
tenacity==8.5.0
928928
# via
@@ -1007,7 +1007,7 @@ traitlets==5.14.3
10071007
# nbclient
10081008
# nbconvert
10091009
# nbformat
1010-
triton==2.3.1 ; python_full_version < '3.12' and platform_machine == 'x86_64' and platform_system == 'Linux'
1010+
triton==2.3.1 ; python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform == 'linux'
10111011
# via torch
10121012
typeguard==2.13.3
10131013
# via
@@ -1050,7 +1050,7 @@ virtualenv==20.26.3
10501050
# via
10511051
# pre-commit
10521052
# tox
1053-
watchdog==4.0.1 ; platform_system != 'Darwin'
1053+
watchdog==4.0.1 ; sys_platform != 'darwin'
10541054
# via streamlit
10551055
wcwidth==0.2.13
10561056
# via prompt-toolkit

‎benchmarks/__init__.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
"""Benchmarking module for performance tracking."""
22

3-
from benchmarks.definition import Benchmark
3+
from benchmarks.definition import (
4+
Benchmark,
5+
BenchmarkSettings,
6+
ConvergenceBenchmark,
7+
ConvergenceBenchmarkSettings,
8+
)
49
from benchmarks.result import Result
510

611
__all__ = [
7-
"Result",
812
"Benchmark",
13+
"BenchmarkSettings",
14+
"ConvergenceBenchmark",
15+
"ConvergenceBenchmarkSettings",
16+
"Result",
917
]

‎benchmarks/definition/__init__.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1-
"""Benchmark task definitions."""
1+
"""Benchmark definitions."""
22

3-
from benchmarks.definition.config import (
3+
from benchmarks.definition.base import (
44
Benchmark,
55
BenchmarkSettings,
6-
ConvergenceExperimentSettings,
6+
)
7+
from benchmarks.definition.convergence import (
8+
ConvergenceBenchmark,
9+
ConvergenceBenchmarkSettings,
710
)
811

912
__all__ = [
10-
"ConvergenceExperimentSettings",
1113
"Benchmark",
1214
"BenchmarkSettings",
15+
"ConvergenceBenchmark",
16+
"ConvergenceBenchmarkSettings",
1317
]

‎benchmarks/definition/base.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Basic benchmark configuration."""
2+
3+
import time
4+
from abc import ABC
5+
from collections.abc import Callable
6+
from datetime import datetime, timedelta, timezone
7+
from typing import Generic, TypeVar
8+
9+
from attrs import define, field
10+
from attrs.validators import instance_of
11+
from cattrs import override
12+
from cattrs.gen import make_dict_unstructure_fn
13+
from pandas import DataFrame
14+
15+
from baybe.utils.random import temporary_seed
16+
from benchmarks.result import Result, ResultMetadata
17+
from benchmarks.serialization import BenchmarkSerialization, converter
18+
19+
20+
@define(frozen=True, kw_only=True)
21+
class BenchmarkSettings(ABC, BenchmarkSerialization):
22+
"""The basic benchmark configuration."""
23+
24+
random_seed: int = field(validator=instance_of(int), default=1337)
25+
"""The used random seed."""
26+
27+
28+
BenchmarkSettingsType = TypeVar("BenchmarkSettingsType", bound=BenchmarkSettings)
29+
30+
31+
@define(frozen=True)
32+
class Benchmark(Generic[BenchmarkSettingsType], BenchmarkSerialization):
33+
"""The base class for all benchmark definitions."""
34+
35+
function: Callable[[BenchmarkSettingsType], DataFrame] = field()
36+
"""The callable containing the benchmarking logic."""
37+
38+
settings: BenchmarkSettingsType = field()
39+
"""The benchmark configuration."""
40+
41+
@function.validator
42+
def _validate_function(self, _, function) -> None:
43+
if function.__doc__ is None:
44+
raise ValueError("The benchmark function must have a docstring.")
45+
46+
@property
47+
def name(self) -> str:
48+
"""The name of the benchmark function."""
49+
return self.function.__name__
50+
51+
@property
52+
def description(self) -> str:
53+
"""The description of the benchmark function."""
54+
assert self.function.__doc__ is not None
55+
return self.function.__doc__
56+
57+
def __call__(self) -> Result:
58+
"""Execute the benchmark and return the result."""
59+
start_datetime = datetime.now(timezone.utc)
60+
61+
with temporary_seed(self.settings.random_seed):
62+
start_sec = time.perf_counter()
63+
result = self.function(self.settings)
64+
stop_sec = time.perf_counter()
65+
66+
duration = timedelta(seconds=stop_sec - start_sec)
67+
68+
metadata = ResultMetadata(
69+
start_datetime=start_datetime,
70+
duration=duration,
71+
)
72+
73+
return Result(self.name, result, metadata)
74+
75+
76+
@converter.register_unstructure_hook
77+
def unstructure_benchmark(benchmark: Benchmark) -> dict:
78+
"""Unstructure a benchmark instance."""
79+
fn = make_dict_unstructure_fn(
80+
type(benchmark), converter, function=override(omit=True)
81+
)
82+
return {
83+
"name": benchmark.name,
84+
"description": benchmark.description,
85+
**fn(benchmark),
86+
}

‎benchmarks/definition/config.py

-103
This file was deleted.

‎benchmarks/definition/convergence.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Convergence benchmark configuration."""
2+
3+
from typing import Any
4+
5+
from attrs import define, field
6+
from attrs.validators import deep_mapping, instance_of, optional
7+
8+
from benchmarks.definition.base import Benchmark, BenchmarkSettings
9+
10+
11+
@define(frozen=True, kw_only=True)
12+
class ConvergenceBenchmarkSettings(BenchmarkSettings):
13+
"""Benchmark configuration for recommender convergence analyses."""
14+
15+
batch_size: int = field(validator=instance_of(int))
16+
"""The recommendation batch size."""
17+
18+
n_doe_iterations: int = field(validator=instance_of(int))
19+
"""The number of Design of Experiment iterations."""
20+
21+
n_mc_iterations: int = field(validator=instance_of(int))
22+
"""The number of Monte Carlo iterations."""
23+
24+
25+
@define(frozen=True)
26+
class ConvergenceBenchmark(Benchmark[ConvergenceBenchmarkSettings]):
27+
"""A class for defining convergence benchmarks."""
28+
29+
optimal_target_values: dict[str, Any] | None = field(
30+
default=None,
31+
validator=optional(
32+
deep_mapping(
33+
key_validator=instance_of(str),
34+
mapping_validator=instance_of(dict),
35+
value_validator=lambda *_: None,
36+
)
37+
),
38+
)
39+
"""The optimal values that can be achieved for the targets **individually**."""

‎benchmarks/domains/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Benchmark domains."""
22

3-
from benchmarks.definition.config import Benchmark
3+
from benchmarks.definition.base import Benchmark
44
from benchmarks.domains.synthetic_2C1D_1C import synthetic_2C1D_1C_benchmark
55

66
BENCHMARKS: list[Benchmark] = [

‎benchmarks/domains/synthetic_2C1D_1C.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from baybe.searchspace import SearchSpace
1616
from baybe.simulation import simulate_scenarios
1717
from baybe.targets import NumericalTarget
18-
from benchmarks.definition import (
19-
Benchmark,
20-
ConvergenceExperimentSettings,
18+
from benchmarks.definition.convergence import (
19+
ConvergenceBenchmark,
20+
ConvergenceBenchmarkSettings,
2121
)
2222

2323
if TYPE_CHECKING:
@@ -49,7 +49,7 @@ def lookup(df: pd.DataFrame, /) -> pd.DataFrame:
4949
)
5050

5151

52-
def synthetic_2C1D_1C(settings: ConvergenceExperimentSettings) -> DataFrame:
52+
def synthetic_2C1D_1C(settings: ConvergenceBenchmarkSettings) -> DataFrame:
5353
"""Hybrid synthetic test function.
5454
5555
Inputs:
@@ -95,20 +95,16 @@ def synthetic_2C1D_1C(settings: ConvergenceExperimentSettings) -> DataFrame:
9595
)
9696

9797

98-
benchmark_config = ConvergenceExperimentSettings(
98+
benchmark_config = ConvergenceBenchmarkSettings(
9999
batch_size=5,
100100
n_doe_iterations=30,
101101
n_mc_iterations=50,
102102
)
103103

104-
synthetic_2C1D_1C_benchmark = Benchmark(
104+
synthetic_2C1D_1C_benchmark = ConvergenceBenchmark(
105105
function=synthetic_2C1D_1C,
106-
best_possible_result=4.09685,
106+
optimal_target_values={"target": 4.09685},
107107
settings=benchmark_config,
108-
optimal_function_inputs=[
109-
{"x": 1.610, "y": 1.571, "z": 3},
110-
{"x": 1.610, "y": -4.712, "z": 3},
111-
],
112108
)
113109

114110

‎benchmarks/persistence/persistence.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,12 @@ class PathConstructor:
4848
benchmark_name: str = field(validator=instance_of(str))
4949
"""The name of the benchmark for which the path should be constructed."""
5050

51-
branch: str = field(validator=instance_of(str))
52-
"""The branch checked out at benchmark execution time."""
51+
branch: str = field(
52+
converter=lambda x: x or "-branchless-",
53+
validator=instance_of(str),
54+
)
55+
"""The branch checked out at benchmark execution time.
56+
In case of detached head state the branch is set to '-branchless-'."""
5357

5458
latest_baybe_tag: str = field(validator=instance_of(str))
5559
"""The latest BayBE version tag existing at benchmark execution time."""
@@ -108,6 +112,7 @@ def get_path(self, strategy: PathStrategy) -> Path:
108112
separator = "/" if strategy is PathStrategy.HIERARCHICAL else "_"
109113

110114
file_usable_date = self.execution_date_time.strftime("%Y-%m-%d")
115+
111116
components = [
112117
self.benchmark_name,
113118
self.branch,

‎benchmarks/result/metadata.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import git
66
from attrs import define, field
7-
from attrs.validators import instance_of
7+
from attrs.validators import instance_of, optional
88
from cattrs.gen import make_dict_unstructure_fn
99

1010
from benchmarks.serialization import BenchmarkSerialization, converter
@@ -26,15 +26,18 @@ class ResultMetadata(BenchmarkSerialization):
2626
latest_baybe_tag: str = field(validator=instance_of(str), init=False)
2727
"""The latest BayBE tag reachable in the ancestor commit history."""
2828

29-
branch: str = field(validator=instance_of(str), init=False)
30-
"""The branch currently checked out."""
29+
branch: str | None = field(validator=optional(instance_of(str)), init=False)
30+
"""The branch checked out during benchmark execution."""
3131

3232
@branch.default
33-
def _default_branch(self) -> str:
33+
def _default_branch(self) -> str | None:
3434
"""Set the current checkout branch."""
3535
repo = git.Repo(search_parent_directories=True)
36-
current_branch = repo.active_branch.name
37-
return current_branch
36+
try:
37+
current_branch = repo.active_branch.name
38+
return current_branch
39+
except TypeError:
40+
return None
3841

3942
@commit_hash.default
4043
def _default_commit_hash(self) -> str:

‎benchmarks/result/result.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@define(frozen=True)
1818
class Result(BenchmarkSerialization):
19-
"""A single result of the benchmarking."""
19+
"""A single benchmarking result."""
2020

2121
benchmark_identifier: str = field(validator=instance_of(str))
2222
"""The identifier of the benchmark that produced the result."""

‎pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ dynamic = ['version']
3030
dependencies = [
3131
"attrs>=24.1.0",
3232
"botorch>=0.9.3,<1",
33-
"cattrs>=23.2.0",
33+
"cattrs>=24.1.0",
3434
"exceptiongroup",
3535
"funcy>=1.17,<2",
3636
"gpytorch>=1.9.1,<2",

0 commit comments

Comments
 (0)
Please sign in to comment.