Skip to content

Commit 8ea2d0d

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Add LogIntToFloat transform (facebook#3091)
Summary: This is a simple subclass of `IntToFloat` that only transforms log-scale parameters. Replacing `IntToFloat` with `LogIntToFloat` will avoid unnecessary use of continuous relaxation across the board, and allow us to utilize the various optimizers available in `Acquisition.optimize`. Additional context: With log-scale parameters, we have two options: transform them in Ax or transform them in BoTorch. Transforming them in Ax leads to both modeling and optimizing the parameter in the log-scale (good), but transforming in BoTorch leads to modeling in log-scale but optimizing in the raw scale (not ideal) and also introduces `TransformedPosterior` and some incompatibilities it brings. So, we want to transform log-scale parameters in Ax. Since log of an int parameter is no longer int, we have to relax them. But we don't want to relax any other int parameters, so we don't want to use `IntToFloat`. `LogIntToFloat` makes it possible to use continuous relaxation only for the log-scale parameters, which is a good step in the right direction. Differential Revision: D66244582
1 parent 454cb9a commit 8ea2d0d

File tree

4 files changed

+84
-10
lines changed

4 files changed

+84
-10
lines changed

ax/core/search_space.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class SearchSpace(Base):
6767

6868
def __init__(
6969
self,
70-
parameters: list[Parameter],
70+
parameters: Sequence[Parameter],
7171
parameter_constraints: list[ParameterConstraint] | None = None,
7272
) -> None:
7373
"""Initialize SearchSpace

ax/modelbridge/transforms/int_to_float.py

+49-6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ax.core.observation import Observation, ObservationFeatures
1313
from ax.core.parameter import Parameter, ParameterType, RangeParameter
1414
from ax.core.search_space import SearchSpace
15+
from ax.exceptions.core import UserInputError
1516
from ax.modelbridge.transforms.base import Transform
1617
from ax.modelbridge.transforms.rounding import (
1718
contains_constrained_integer,
@@ -65,18 +66,22 @@ def __init__(
6566
self.min_choices: int = checked_cast(int, config.get("min_choices", 0))
6667

6768
# Identify parameters that should be transformed
68-
self.transform_parameters: set[str] = {
69+
self.transform_parameters: set[str] = self._get_transform_parameters()
70+
if contains_constrained := contains_constrained_integer(
71+
self.search_space, self.transform_parameters
72+
):
73+
self.rounding = "randomized"
74+
self.contains_constrained_integer: bool = contains_constrained
75+
76+
def _get_transform_parameters(self) -> set[str]:
77+
"""Identify parameters that should be transformed."""
78+
return {
6979
p_name
7080
for p_name, p in self.search_space.parameters.items()
7181
if isinstance(p, RangeParameter)
7282
and p.parameter_type == ParameterType.INT
7383
and ((p.cardinality() >= self.min_choices) or p.log_scale)
7484
}
75-
if contains_constrained_integer(self.search_space, self.transform_parameters):
76-
self.rounding = "randomized"
77-
self.contains_constrained_integer: bool = True
78-
else:
79-
self.contains_constrained_integer: bool = False
8085

8186
def transform_observation_features(
8287
self, observation_features: list[ObservationFeatures]
@@ -183,3 +188,41 @@ def untransform_observation_features(
183188
obsf.parameters[p_name] = rounded_parameters[p_name]
184189

185190
return observation_features
191+
192+
193+
class LogIntToFloat(IntToFloat):
194+
"""Convert a log-scale RangeParameter of type int to type float.
195+
196+
The behavior of this transform mirrors ``IntToFloat`` with the key difference
197+
being that it only operates on log-scale parameters.
198+
"""
199+
200+
def __init__(
201+
self,
202+
search_space: SearchSpace | None = None,
203+
observations: list[Observation] | None = None,
204+
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None,
205+
config: TConfig | None = None,
206+
) -> None:
207+
if config is not None and "min_choices" in config:
208+
raise UserInputError(
209+
"`min_choices` cannot be specified for `LogIntToFloat` transform. "
210+
)
211+
super().__init__(
212+
search_space=search_space,
213+
observations=observations,
214+
modelbridge=modelbridge,
215+
config=config,
216+
)
217+
# Delete the attribute to avoid it presenting a misleading value.
218+
del self.min_choices
219+
220+
def _get_transform_parameters(self) -> set[str]:
221+
"""Identify parameters that should be transformed."""
222+
return {
223+
p_name
224+
for p_name, p in self.search_space.parameters.items()
225+
if isinstance(p, RangeParameter)
226+
and p.parameter_type == ParameterType.INT
227+
and p.log_scale
228+
}

ax/modelbridge/transforms/tests/test_int_to_float_transform.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter
1414
from ax.core.parameter_constraint import OrderConstraint, SumConstraint
1515
from ax.core.search_space import RobustSearchSpace, SearchSpace
16-
from ax.exceptions.core import UnsupportedError
17-
from ax.modelbridge.transforms.int_to_float import IntToFloat
16+
from ax.exceptions.core import UnsupportedError, UserInputError
17+
from ax.modelbridge.transforms.int_to_float import IntToFloat, LogIntToFloat
1818
from ax.utils.common.testutils import TestCase
1919
from ax.utils.common.typeutils import checked_cast
2020
from ax.utils.testing.core_stubs import get_robust_search_space
@@ -324,3 +324,33 @@ def test_w_parameter_distributions(self) -> None:
324324
)
325325
with self.assertRaisesRegex(UnsupportedError, "transform is not supported"):
326326
t.transform_search_space(rss)
327+
328+
329+
class LogIntToFloatTransformTest(TestCase):
330+
def test_log_int_to_float(self) -> None:
331+
parameters = [
332+
RangeParameter("x", lower=1, upper=3, parameter_type=ParameterType.INT),
333+
RangeParameter("y", lower=1, upper=50, parameter_type=ParameterType.INT),
334+
RangeParameter(
335+
"z", lower=1, upper=50, parameter_type=ParameterType.INT, log_scale=True
336+
),
337+
]
338+
search_space = SearchSpace(parameters=parameters)
339+
with self.assertRaisesRegex(UserInputError, "min_choices"):
340+
LogIntToFloat(search_space=search_space, config={"min_choices": 5})
341+
t = LogIntToFloat(search_space=search_space)
342+
self.assertFalse(hasattr(t, "min_choices"))
343+
self.assertEqual(t.transform_parameters, {"z"})
344+
t_ss = t.transform_search_space(search_space)
345+
self.assertEqual(t_ss.parameters["x"], parameters[0])
346+
self.assertEqual(t_ss.parameters["y"], parameters[1])
347+
self.assertEqual(
348+
t_ss.parameters["z"],
349+
RangeParameter(
350+
name="z",
351+
lower=0.50001,
352+
upper=50.49999,
353+
parameter_type=ParameterType.FLOAT,
354+
log_scale=True,
355+
),
356+
)

ax/storage/transform_registry.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ax.modelbridge.transforms.derelativize import Derelativize
1919
from ax.modelbridge.transforms.fill_missing_parameters import FillMissingParameters
2020
from ax.modelbridge.transforms.int_range_to_choice import IntRangeToChoice
21-
from ax.modelbridge.transforms.int_to_float import IntToFloat
21+
from ax.modelbridge.transforms.int_to_float import IntToFloat, LogIntToFloat
2222
from ax.modelbridge.transforms.ivw import IVW
2323
from ax.modelbridge.transforms.log import Log
2424
from ax.modelbridge.transforms.log_y import LogY
@@ -95,6 +95,7 @@
9595
TimeAsFeature: 27,
9696
TransformToNewSQ: 28,
9797
FillMissingParameters: 29,
98+
LogIntToFloat: 30,
9899
}
99100

100101
"""

0 commit comments

Comments
 (0)