From 0f837a18ec2c76ec2e05b338fe5b8cabee284825 Mon Sep 17 00:00:00 2001 From: Jelena Markovic-Voronov Date: Sun, 15 Dec 2024 20:17:43 -0800 Subject: [PATCH] allow TransformToNewSQ to be applied in TL modelbridge (#3179) Summary: Enabling TransformToNewSQ to work for both target and source experiments with their own status quo and training data/observations. Since TL doesn't treat different trials within a single experiment as tasks we need to transform the trial data to be with respect to the same status quo. Reviewed By: saitcakmak Differential Revision: D67156451 --- .../tests/test_transform_to_new_sq.py | 32 +++++++++++++++++-- .../transforms/transform_to_new_sq.py | 15 ++++++++- ax/utils/stats/statstools.py | 2 ++ ax/utils/testing/core_stubs.py | 3 +- 4 files changed, 48 insertions(+), 4 deletions(-) diff --git a/ax/modelbridge/transforms/tests/test_transform_to_new_sq.py b/ax/modelbridge/transforms/tests/test_transform_to_new_sq.py index 397d1ba08ef..76ed40616ce 100644 --- a/ax/modelbridge/transforms/tests/test_transform_to_new_sq.py +++ b/ax/modelbridge/transforms/tests/test_transform_to_new_sq.py @@ -9,6 +9,7 @@ import numpy as np import numpy.typing as npt from ax.core.batch_trial import BatchTrial +from ax.core.observation import observations_from_data from ax.modelbridge import ModelBridge from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.tests.test_relativize_transform import RelativizeDataTest @@ -20,6 +21,7 @@ get_branin_data_batch, get_branin_experiment, get_branin_optimization_config, + get_sobol, ) @@ -34,8 +36,8 @@ class TransformToNewSQTest(RelativizeDataTest): TransformToNewSQ, [ ( - np.array([-38.0, 505.0]), - np.array([[1600.0, 0.0], [0.0, 2892.56198347]]), + np.array([1.6, 10.0]), + np.array([[0.16, 0.0], [0.0, 0.2892562]]), ), (np.array([2.0, 5.0]), np.array([[0.1, 0.0], [0.0, 0.2]])), (np.array([1.0, 10.0]), np.array([[0.3, 0.0], [0.0, 0.4]])), @@ -134,3 +136,29 @@ def test_single_trial_is_not_transformed(self) -> None: ) obs2 = tf.transform_observations(obs) self.assertEqual(obs, obs2) + + def test_taget_trial_index(self) -> None: + sobol = get_sobol(search_space=self.exp.search_space) + self.exp.new_batch_trial(generator_run=sobol.gen(2)) + t = self.exp.trials[1] + t = checked_cast(BatchTrial, t) + t.mark_running(no_runner_required=True) + self.exp.attach_data(get_branin_data_batch(batch=checked_cast(BatchTrial, t))) + + observations = observations_from_data( + experiment=self.exp, + data=self.exp.lookup_data(), + ) + trial_indices = { + obs.features.trial_index + for obs in observations + if obs.features.trial_index is not None + } + + t = TransformToNewSQ( + search_space=self.exp.search_space, + observations=observations, + modelbridge=self.modelbridge, + ) + + self.assertEqual(t.default_trial_idx, min(trial_indices)) diff --git a/ax/modelbridge/transforms/transform_to_new_sq.py b/ax/modelbridge/transforms/transform_to_new_sq.py index b349102bf45..ec3526550ab 100644 --- a/ax/modelbridge/transforms/transform_to_new_sq.py +++ b/ax/modelbridge/transforms/transform_to_new_sq.py @@ -64,6 +64,19 @@ def __init__( if target_trial_index is not None: self.default_trial_idx: int = checked_cast(int, target_trial_index) + trial_indices = {} + if observations is not None: + trial_indices = { + obs.features.trial_index + for obs in observations + if obs.features.trial_index is not None + } + # in case no target trial index is provided or the provided target + # trial index is not a part of any trial from the observations, + # use the smallest trial index from the observations + if len(trial_indices) > 0 and (target_trial_index not in trial_indices): + self.default_trial_idx = min(trial_indices) + @property def control_as_constant(self) -> bool: """Whether or not the control is treated as a constant in the model.""" @@ -187,7 +200,7 @@ def _get_rel_mean_sem( sems_t=sems_t, mean_c=mean_c, sem_c=sem_c, - as_percent=True, + as_percent=False, control_as_constant=self.control_as_constant, ) if rel_op == relativize: diff --git a/ax/utils/stats/statstools.py b/ax/utils/stats/statstools.py index 004b6b70b75..5b917ac3800 100644 --- a/ax/utils/stats/statstools.py +++ b/ax/utils/stats/statstools.py @@ -318,6 +318,8 @@ def unrelativize( m_t = mean_c s_t = sem_c else: + m_t = np.array(m_t) + s_t = np.array(s_t) m_t[means_t == 0.0] = mean_c s_t[means_t == 0.0] = sem_c diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 388cea48f8e..33208498251 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -253,6 +253,7 @@ def get_branin_experiment( num_batch_trial: int = 1, with_completed_batch: bool = False, with_completed_trial: bool = False, + num_arms_per_trial: int = 15, ) -> Experiment: search_space = search_space or get_branin_search_space( with_fidelity_parameter=with_fidelity_parameter, @@ -276,7 +277,7 @@ def get_branin_experiment( if with_batch or with_completed_batch: for _ in range(num_batch_trial): sobol_generator = get_sobol(search_space=exp.search_space) - sobol_run = sobol_generator.gen(n=15) + sobol_run = sobol_generator.gen(n=num_arms_per_trial) trial = exp.new_batch_trial(optimize_for_power=with_status_quo) trial.add_generator_run(sobol_run) if with_completed_batch: