|
12 | 12 | from ax.core.observation import Observation, ObservationFeatures
|
13 | 13 | from ax.core.parameter import Parameter, ParameterType, RangeParameter
|
14 | 14 | from ax.core.search_space import SearchSpace
|
| 15 | +from ax.exceptions.core import UserInputError |
15 | 16 | from ax.modelbridge.transforms.base import Transform
|
16 | 17 | from ax.modelbridge.transforms.rounding import (
|
17 | 18 | contains_constrained_integer,
|
@@ -65,18 +66,22 @@ def __init__(
|
65 | 66 | self.min_choices: int = checked_cast(int, config.get("min_choices", 0))
|
66 | 67 |
|
67 | 68 | # Identify parameters that should be transformed
|
68 |
| - self.transform_parameters: set[str] = { |
| 69 | + self.transform_parameters: set[str] = self._get_transform_parameters() |
| 70 | + if contains_constrained_integer(self.search_space, self.transform_parameters): |
| 71 | + self.rounding = "randomized" |
| 72 | + self.contains_constrained_integer: bool = True |
| 73 | + else: |
| 74 | + self.contains_constrained_integer: bool = False |
| 75 | + |
| 76 | + def _get_transform_parameters(self) -> set[str]: |
| 77 | + """Identify parameters that should be transformed.""" |
| 78 | + return { |
69 | 79 | p_name
|
70 | 80 | for p_name, p in self.search_space.parameters.items()
|
71 | 81 | if isinstance(p, RangeParameter)
|
72 | 82 | and p.parameter_type == ParameterType.INT
|
73 | 83 | and ((p.cardinality() >= self.min_choices) or p.log_scale)
|
74 | 84 | }
|
75 |
| - if contains_constrained_integer(self.search_space, self.transform_parameters): |
76 |
| - self.rounding = "randomized" |
77 |
| - self.contains_constrained_integer: bool = True |
78 |
| - else: |
79 |
| - self.contains_constrained_integer: bool = False |
80 | 85 |
|
81 | 86 | def transform_observation_features(
|
82 | 87 | self, observation_features: list[ObservationFeatures]
|
@@ -183,3 +188,41 @@ def untransform_observation_features(
|
183 | 188 | obsf.parameters[p_name] = rounded_parameters[p_name]
|
184 | 189 |
|
185 | 190 | return observation_features
|
| 191 | + |
| 192 | + |
| 193 | +class LogIntToFloat(IntToFloat): |
| 194 | + """Convert a log-scale RangeParameter of type int to type float. |
| 195 | +
|
| 196 | + The behavior of this transform mirrors ``IntToFloat`` with the key difference |
| 197 | + being that it only operates on log-scale parameters. |
| 198 | + """ |
| 199 | + |
| 200 | + def __init__( |
| 201 | + self, |
| 202 | + search_space: SearchSpace | None = None, |
| 203 | + observations: list[Observation] | None = None, |
| 204 | + modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, |
| 205 | + config: TConfig | None = None, |
| 206 | + ) -> None: |
| 207 | + if config is not None and "min_choices" in config: |
| 208 | + raise UserInputError( |
| 209 | + "`min_choices` cannot be specified for `LogIntToFloat` transform. " |
| 210 | + ) |
| 211 | + super().__init__( |
| 212 | + search_space=search_space, |
| 213 | + observations=observations, |
| 214 | + modelbridge=modelbridge, |
| 215 | + config=config, |
| 216 | + ) |
| 217 | + # Delete the attribute to avoid it presenting a misleading value. |
| 218 | + del self.min_choices |
| 219 | + |
| 220 | + def _get_transform_parameters(self) -> set[str]: |
| 221 | + """Identify parameters that should be transformed.""" |
| 222 | + return { |
| 223 | + p_name |
| 224 | + for p_name, p in self.search_space.parameters.items() |
| 225 | + if isinstance(p, RangeParameter) |
| 226 | + and p.parameter_type == ParameterType.INT |
| 227 | + and p.log_scale |
| 228 | + } |
0 commit comments