Skip to content

Commit 8cfe4e6

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Add LogIntToFloat transform
Summary: This is a simple subclass of `IntToFloat` that only transforms log-scale parameters. Replacing `IntToFloat` with `LogIntToFloat` will avoid unnecessary use of continuous relaxation across the board, and allow us to utilize the various optimizers available in `Acquisition.optimize`. Differential Revision: D66244582
1 parent 4754c61 commit 8cfe4e6

File tree

3 files changed

+82
-9
lines changed

3 files changed

+82
-9
lines changed

ax/core/search_space.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class SearchSpace(Base):
6767

6868
def __init__(
6969
self,
70-
parameters: list[Parameter],
70+
parameters: Sequence[Parameter],
7171
parameter_constraints: list[ParameterConstraint] | None = None,
7272
) -> None:
7373
"""Initialize SearchSpace

ax/modelbridge/transforms/int_to_float.py

+49-6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ax.core.observation import Observation, ObservationFeatures
1313
from ax.core.parameter import Parameter, ParameterType, RangeParameter
1414
from ax.core.search_space import SearchSpace
15+
from ax.exceptions.core import UserInputError
1516
from ax.modelbridge.transforms.base import Transform
1617
from ax.modelbridge.transforms.rounding import (
1718
contains_constrained_integer,
@@ -65,18 +66,22 @@ def __init__(
6566
self.min_choices: int = checked_cast(int, config.get("min_choices", 0))
6667

6768
# Identify parameters that should be transformed
68-
self.transform_parameters: set[str] = {
69+
self.transform_parameters: set[str] = self._get_transform_parameters()
70+
if contains_constrained_integer(self.search_space, self.transform_parameters):
71+
self.rounding = "randomized"
72+
self.contains_constrained_integer: bool = True
73+
else:
74+
self.contains_constrained_integer: bool = False
75+
76+
def _get_transform_parameters(self) -> set[str]:
77+
"""Identify parameters that should be transformed."""
78+
return {
6979
p_name
7080
for p_name, p in self.search_space.parameters.items()
7181
if isinstance(p, RangeParameter)
7282
and p.parameter_type == ParameterType.INT
7383
and ((p.cardinality() >= self.min_choices) or p.log_scale)
7484
}
75-
if contains_constrained_integer(self.search_space, self.transform_parameters):
76-
self.rounding = "randomized"
77-
self.contains_constrained_integer: bool = True
78-
else:
79-
self.contains_constrained_integer: bool = False
8085

8186
def transform_observation_features(
8287
self, observation_features: list[ObservationFeatures]
@@ -183,3 +188,41 @@ def untransform_observation_features(
183188
obsf.parameters[p_name] = rounded_parameters[p_name]
184189

185190
return observation_features
191+
192+
193+
class LogIntToFloat(IntToFloat):
194+
"""Convert a log-scale RangeParameter of type int to type float.
195+
196+
The behavior of this transform mirrors ``IntToFloat`` with the key difference
197+
being that it only operates on log-scale parameters.
198+
"""
199+
200+
def __init__(
201+
self,
202+
search_space: SearchSpace | None = None,
203+
observations: list[Observation] | None = None,
204+
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None,
205+
config: TConfig | None = None,
206+
) -> None:
207+
if config is not None and "min_choices" in config:
208+
raise UserInputError(
209+
"`min_choices` cannot be specified for `LogIntToFloat` transform. "
210+
)
211+
super().__init__(
212+
search_space=search_space,
213+
observations=observations,
214+
modelbridge=modelbridge,
215+
config=config,
216+
)
217+
# Delete the attribute to avoid it presenting a misleading value.
218+
del self.min_choices
219+
220+
def _get_transform_parameters(self) -> set[str]:
221+
"""Identify parameters that should be transformed."""
222+
return {
223+
p_name
224+
for p_name, p in self.search_space.parameters.items()
225+
if isinstance(p, RangeParameter)
226+
and p.parameter_type == ParameterType.INT
227+
and p.log_scale
228+
}

ax/modelbridge/transforms/tests/test_int_to_float_transform.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter
1414
from ax.core.parameter_constraint import OrderConstraint, SumConstraint
1515
from ax.core.search_space import RobustSearchSpace, SearchSpace
16-
from ax.exceptions.core import UnsupportedError
17-
from ax.modelbridge.transforms.int_to_float import IntToFloat
16+
from ax.exceptions.core import UnsupportedError, UserInputError
17+
from ax.modelbridge.transforms.int_to_float import IntToFloat, LogIntToFloat
1818
from ax.utils.common.testutils import TestCase
1919
from ax.utils.common.typeutils import checked_cast
2020
from ax.utils.testing.core_stubs import get_robust_search_space
@@ -324,3 +324,33 @@ def test_w_parameter_distributions(self) -> None:
324324
)
325325
with self.assertRaisesRegex(UnsupportedError, "transform is not supported"):
326326
t.transform_search_space(rss)
327+
328+
329+
class LogIntToFloatTransformTest(TestCase):
330+
def test_log_int_to_float(self) -> None:
331+
parameters = [
332+
RangeParameter("x", lower=1, upper=3, parameter_type=ParameterType.INT),
333+
RangeParameter("y", lower=1, upper=50, parameter_type=ParameterType.INT),
334+
RangeParameter(
335+
"z", lower=1, upper=50, parameter_type=ParameterType.INT, log_scale=True
336+
),
337+
]
338+
search_space = SearchSpace(parameters=parameters)
339+
with self.assertRaisesRegex(UserInputError, "min_choices"):
340+
LogIntToFloat(search_space=search_space, config={"min_choices": 5})
341+
t = LogIntToFloat(search_space=search_space)
342+
self.assertFalse(hasattr(t, "min_choices"))
343+
self.assertEqual(t.transform_parameters, {"z"})
344+
t_ss = t.transform_search_space(search_space)
345+
self.assertEqual(t_ss.parameters["x"], parameters[0])
346+
self.assertEqual(t_ss.parameters["y"], parameters[1])
347+
self.assertEqual(
348+
t_ss.parameters["z"],
349+
RangeParameter(
350+
name="z",
351+
lower=0.50001,
352+
upper=50.49999,
353+
parameter_type=ParameterType.FLOAT,
354+
log_scale=True,
355+
),
356+
)

0 commit comments

Comments
 (0)