Skip to content

Commit

Permalink
allow TransformToNewSQ to be applied in TL modelbridge (#3179)
Browse files Browse the repository at this point in the history
Summary:

Enabling TransformToNewSQ to work for both target and source experiments with their own status quo and training data/observations. Since TL doesn't treat different trials within a single experiment as tasks we need to transform the trial data to be with respect to the same status quo.

Reviewed By: saitcakmak

Differential Revision: D67156451
  • Loading branch information
Jelena Markovic-Voronov authored and facebook-github-bot committed Dec 16, 2024
1 parent f71703f commit 0f837a1
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 4 deletions.
32 changes: 30 additions & 2 deletions ax/modelbridge/transforms/tests/test_transform_to_new_sq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import numpy.typing as npt
from ax.core.batch_trial import BatchTrial
from ax.core.observation import observations_from_data
from ax.modelbridge import ModelBridge
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.tests.test_relativize_transform import RelativizeDataTest
Expand All @@ -20,6 +21,7 @@
get_branin_data_batch,
get_branin_experiment,
get_branin_optimization_config,
get_sobol,
)


Expand All @@ -34,8 +36,8 @@ class TransformToNewSQTest(RelativizeDataTest):
TransformToNewSQ,
[
(
np.array([-38.0, 505.0]),
np.array([[1600.0, 0.0], [0.0, 2892.56198347]]),
np.array([1.6, 10.0]),
np.array([[0.16, 0.0], [0.0, 0.2892562]]),
),
(np.array([2.0, 5.0]), np.array([[0.1, 0.0], [0.0, 0.2]])),
(np.array([1.0, 10.0]), np.array([[0.3, 0.0], [0.0, 0.4]])),
Expand Down Expand Up @@ -134,3 +136,29 @@ def test_single_trial_is_not_transformed(self) -> None:
)
obs2 = tf.transform_observations(obs)
self.assertEqual(obs, obs2)

def test_taget_trial_index(self) -> None:
sobol = get_sobol(search_space=self.exp.search_space)
self.exp.new_batch_trial(generator_run=sobol.gen(2))
t = self.exp.trials[1]
t = checked_cast(BatchTrial, t)
t.mark_running(no_runner_required=True)
self.exp.attach_data(get_branin_data_batch(batch=checked_cast(BatchTrial, t)))

observations = observations_from_data(
experiment=self.exp,
data=self.exp.lookup_data(),
)
trial_indices = {
obs.features.trial_index
for obs in observations
if obs.features.trial_index is not None
}

t = TransformToNewSQ(
search_space=self.exp.search_space,
observations=observations,
modelbridge=self.modelbridge,
)

self.assertEqual(t.default_trial_idx, min(trial_indices))
15 changes: 14 additions & 1 deletion ax/modelbridge/transforms/transform_to_new_sq.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,19 @@ def __init__(
if target_trial_index is not None:
self.default_trial_idx: int = checked_cast(int, target_trial_index)

trial_indices = {}
if observations is not None:
trial_indices = {
obs.features.trial_index
for obs in observations
if obs.features.trial_index is not None
}
# in case no target trial index is provided or the provided target
# trial index is not a part of any trial from the observations,
# use the smallest trial index from the observations
if len(trial_indices) > 0 and (target_trial_index not in trial_indices):
self.default_trial_idx = min(trial_indices)

@property
def control_as_constant(self) -> bool:
"""Whether or not the control is treated as a constant in the model."""
Expand Down Expand Up @@ -187,7 +200,7 @@ def _get_rel_mean_sem(
sems_t=sems_t,
mean_c=mean_c,
sem_c=sem_c,
as_percent=True,
as_percent=False,
control_as_constant=self.control_as_constant,
)
if rel_op == relativize:
Expand Down
2 changes: 2 additions & 0 deletions ax/utils/stats/statstools.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ def unrelativize(
m_t = mean_c
s_t = sem_c
else:
m_t = np.array(m_t)
s_t = np.array(s_t)
m_t[means_t == 0.0] = mean_c
s_t[means_t == 0.0] = sem_c

Expand Down
3 changes: 2 additions & 1 deletion ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def get_branin_experiment(
num_batch_trial: int = 1,
with_completed_batch: bool = False,
with_completed_trial: bool = False,
num_arms_per_trial: int = 15,
) -> Experiment:
search_space = search_space or get_branin_search_space(
with_fidelity_parameter=with_fidelity_parameter,
Expand All @@ -276,7 +277,7 @@ def get_branin_experiment(
if with_batch or with_completed_batch:
for _ in range(num_batch_trial):
sobol_generator = get_sobol(search_space=exp.search_space)
sobol_run = sobol_generator.gen(n=15)
sobol_run = sobol_generator.gen(n=num_arms_per_trial)
trial = exp.new_batch_trial(optimize_for_power=with_status_quo)
trial.add_generator_run(sobol_run)
if with_completed_batch:
Expand Down

0 comments on commit 0f837a1

Please sign in to comment.