From a80737ee8cdb638e7906d87b0bce2efa81f56f68 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Mon, 16 Jun 2025 09:58:51 -0700 Subject: [PATCH] Support ExperimentData in MetadataToFloat & MapKeyToFloat transforms Summary: Supports transforming `ExperimentData` with `MetadataToFloat` & `MapKeyToFloat` transforms. The transform constructor is also updated to support extracting the observations for the relevant parameters from `experiment_data`. For `MapKeyToFloat`, the actual transform is no-op, since we have the map keys in `observation_data` and we can simply extract them from there in `Adapter`. Background: As part of the larger refactor, we will be using `ExperimentData` in place of `list[Observation]` within the `Adapter`. - The transforms will be initialized using `ExperimentData`. The `observations` input to the constructors may be deprecated once the use cases are updated. - The training data for `Adapter` will be represented with `ExperimentData` and will be transformed using `transform_experiment_data`. - For misc input / output to various `Adapter` and other methods, the `Observation / ObservationFeatures / ObservationData` objects will remain. To support these, we will retain the existing transform methods that service these objects. - Since `ExperimentData` is not planned to be used as an output of user facing methods, we do not need to untransform it. We are not planning to implement`untransform_experiment_data`. Reviewed By: ltiao Differential Revision: D76627256 --- ax/adapter/transforms/map_key_to_float.py | 34 ++++++++++ ax/adapter/transforms/metadata_to_float.py | 43 +++++++++--- .../tests/test_map_key_to_float_transform.py | 36 +++++++--- .../tests/test_metadata_to_float_transform.py | 67 ++++++++++++++----- 4 files changed, 148 insertions(+), 32 deletions(-) diff --git a/ax/adapter/transforms/map_key_to_float.py b/ax/adapter/transforms/map_key_to_float.py index 4ddf79e0377..5fdbf4b9add 100644 --- a/ax/adapter/transforms/map_key_to_float.py +++ b/ax/adapter/transforms/map_key_to_float.py @@ -76,6 +76,28 @@ def __init__( config=config, ) + def _get_values_for_parameter( + self, + name: str, + observations: list[Observation] | None, + experiment_data: ExperimentData | None, + ) -> set[float]: + if experiment_data is not None: + obs_data = experiment_data.observation_data + if name not in obs_data.index.names: + raise ValueError( + f"Parameter {name} is not in the index of the observation data." + ) + return set( + obs_data.index.unique(level=name).dropna().astype(float).tolist() + ) + # For Observations, the logic is identical to the parent class. + return super()._get_values_for_parameter( + name=name, + observations=observations, + experiment_data=experiment_data, + ) + def _transform_observation_feature(self, obsf: ObservationFeatures) -> None: if len(obsf.parameters) == 0: obsf.parameters = {p.name: p.upper for p in self._parameter_list} @@ -87,3 +109,15 @@ def _transform_observation_feature(self, obsf: ObservationFeatures) -> None: if isnan(metadata[p.name]): metadata[p.name] = p.upper super()._transform_observation_feature(obsf) + + def transform_experiment_data( + self, experiment_data: ExperimentData + ) -> ExperimentData: + """No-op transform for experiment data. + + This operates based on the assumption that the relevant map keys already + exist on the index of the observation data (verified in __init__), + and the downstream code will extract the map keys from there directly. + We do not need to duplicate the map keys in the arm data. + """ + return experiment_data diff --git a/ax/adapter/transforms/metadata_to_float.py b/ax/adapter/transforms/metadata_to_float.py index 965f936e3a6..a5ed39998b9 100644 --- a/ax/adapter/transforms/metadata_to_float.py +++ b/ax/adapter/transforms/metadata_to_float.py @@ -60,9 +60,9 @@ def __init__( adapter: adapter_module.base.Adapter | None = None, config: TConfig | None = None, ) -> None: - if observations is None or not observations: + if (observations is None or not observations) and experiment_data is None: raise DataRequiredError( - "`MetadataToRange` transform requires non-empty data." + f"`{self.__class__.__name__}` transform requires non-empty data." ) super().__init__( search_space=search_space, @@ -79,12 +79,9 @@ def __init__( self._parameter_list: list[RangeParameter] = [] for name in self.parameters: - values: set[float] = set() - for obs in observations: - obsf_metadata = none_throws(obs.features.metadata) - value = float(assert_is_instance(obsf_metadata[name], SupportsFloat)) - if not isnan(value): - values.add(value) + values: set[float] = self._get_values_for_parameter( + name=name, observations=observations, experiment_data=experiment_data + ) if len(values) == 0: logger.debug( @@ -127,6 +124,24 @@ def __init__( ) self._parameter_list.append(parameter) + def _get_values_for_parameter( + self, + name: str, + observations: list[Observation] | None, + experiment_data: ExperimentData | None, + ) -> set[float]: + if experiment_data is not None: + all_metadata = experiment_data.arm_data["metadata"] + return all_metadata.str.get(name).dropna().astype(float).tolist() + + values: set[float] = set() + for obs in none_throws(observations): + obsf_metadata = none_throws(obs.features.metadata) + value = float(assert_is_instance(obsf_metadata[name], SupportsFloat)) + if not isnan(value): + values.add(value) + return values + def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: for parameter in self._parameter_list: search_space.add_parameter(parameter.clone()) @@ -158,6 +173,18 @@ def _transform_observation_feature(self, obsf: ObservationFeatures) -> None: keys=[p.name for p in self._parameter_list], ) + def transform_experiment_data( + self, experiment_data: ExperimentData + ) -> ExperimentData: + arm_data = experiment_data.arm_data + metadata = arm_data["metadata"] + for name in self.parameters: + arm_data[name] = metadata.apply(lambda x: x.pop(name)) # noqa B023 + return ExperimentData( + arm_data=arm_data, + observation_data=experiment_data.observation_data, + ) + def _transfer( src: dict[str, Any], diff --git a/ax/adapter/transforms/tests/test_map_key_to_float_transform.py b/ax/adapter/transforms/tests/test_map_key_to_float_transform.py index 2ed86f36af9..5bc065dd9ba 100644 --- a/ax/adapter/transforms/tests/test_map_key_to_float_transform.py +++ b/ax/adapter/transforms/tests/test_map_key_to_float_transform.py @@ -12,6 +12,7 @@ import numpy as np from ax.adapter import Adapter from ax.adapter.base import DataLoaderConfig +from ax.adapter.data_utils import extract_experiment_data from ax.adapter.registry import Generators, MBM_X_trans, Y_trans from ax.adapter.torch import TorchAdapter from ax.adapter.transforms.map_key_to_float import MapKeyToFloat @@ -263,14 +264,20 @@ def test_Init(self) -> None: with self.assertRaisesRegex(UserInputError, "optimization config"): MapKeyToFloat(observations=self.observations) - # Check for default initialization - self.assertEqual(len(self.t._parameter_list), 1) - (p,) = self.t._parameter_list - self.assertEqual(p.name, self.map_key) - self.assertEqual(p.parameter_type, ParameterType.FLOAT) - self.assertEqual(p.lower, 0.0) - self.assertEqual(p.upper, 4.0) - self.assertFalse(p.log_scale) # False since lower is 0.0. + experiment_data = extract_experiment_data( + experiment=self.experiment, + data_loader_config=DataLoaderConfig(fit_only_completed_map_metrics=False), + ) + t2 = MapKeyToFloat(experiment_data=experiment_data, adapter=self.adapter) + for t in (self.t, t2): + # Check for default initialization + self.assertEqual(len(t._parameter_list), 1) + (p,) = t._parameter_list + self.assertEqual(p.name, self.map_key) + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 0.0) + self.assertEqual(p.upper, 4.0) + self.assertFalse(p.log_scale) # False since lower is 0.0. # specifying a parameter name that is not in the observation features' metadata with self.assertRaisesRegex(KeyError, "'baz'"): @@ -500,3 +507,16 @@ def test_with_different_map_key(self) -> None: self.assertEqual( tf_obs_ft[1].parameters, {"x1": 0.1, "x2": 0.9, "map_key": 12346.0} ) + + def test_transform_experiment_data(self) -> None: + experiment_data = extract_experiment_data( + experiment=self.experiment, + data_loader_config=DataLoaderConfig(fit_only_completed_map_metrics=False), + ) + copy_experiment_data = deepcopy(experiment_data) + transformed_data = self.t.transform_experiment_data( + experiment_data=copy_experiment_data + ) + # Check that it is returned unmodified. + self.assertIs(copy_experiment_data, transformed_data) + self.assertEqual(experiment_data, transformed_data) diff --git a/ax/adapter/transforms/tests/test_metadata_to_float_transform.py b/ax/adapter/transforms/tests/test_metadata_to_float_transform.py index a006a198180..2c53b67f992 100644 --- a/ax/adapter/transforms/tests/test_metadata_to_float_transform.py +++ b/ax/adapter/transforms/tests/test_metadata_to_float_transform.py @@ -10,6 +10,8 @@ from copy import deepcopy from unittest.mock import ANY +from ax.adapter.base import DataLoaderConfig +from ax.adapter.data_utils import extract_experiment_data from ax.adapter.transforms.metadata_to_float import MetadataToFloat from ax.core.observation import ObservationFeatures, observations_from_data from ax.core.parameter import ParameterType, RangeParameter @@ -18,6 +20,7 @@ from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_experiment_with_observations +from pandas.testing import assert_frame_equal from pyre_extensions import assert_is_instance @@ -74,6 +77,10 @@ def setUp(self) -> None: self.observations = observations_from_data( experiment=self.experiment, data=self.experiment.lookup_data() ) + self.experiment_data = extract_experiment_data( + experiment=self.experiment, + data_loader_config=DataLoaderConfig(), + ) self.t = MetadataToFloat( observations=self.observations, @@ -81,28 +88,33 @@ def setUp(self) -> None: "parameters": {"bar": {"log_scale": True}}, }, ) + self.t2 = MetadataToFloat( + experiment_data=self.experiment_data, + config={ + "parameters": {"bar": {"log_scale": True}}, + }, + ) def test_Init(self) -> None: - self.assertEqual(len(self.t._parameter_list), 1) - - p = self.t._parameter_list[0] - - # check that the parameter options are specified in a sensible manner - # by default if the user does not specify them explicitly - self.assertEqual(p.name, "bar") - self.assertEqual(p.parameter_type, ParameterType.FLOAT) - self.assertEqual(p.lower, 3.0) - self.assertEqual(p.upper, 15.0) - self.assertTrue(p.log_scale) - self.assertFalse(p.logit_scale) - self.assertIsNone(p.digits) - self.assertFalse(p.is_fidelity) - self.assertIsNone(p.target_value) + for t in (self.t, self.t2): + self.assertEqual(len(t._parameter_list), 1) + p = t._parameter_list[0] + # check that the parameter options are specified in a sensible manner + # by default if the user does not specify them explicitly + self.assertEqual(p.name, "bar") + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 3.0) + self.assertEqual(p.upper, 15.0) + self.assertTrue(p.log_scale) + self.assertFalse(p.logit_scale) + self.assertIsNone(p.digits) + self.assertFalse(p.is_fidelity) + self.assertIsNone(p.target_value) with self.assertRaisesRegex(DataRequiredError, "requires non-empty data"): MetadataToFloat(search_space=None, observations=None) with self.assertRaisesRegex(DataRequiredError, "requires non-empty data"): - MetadataToFloat(search_space=None, observations=[]) + MetadataToFloat(search_space=None) def test_TransformSearchSpace(self) -> None: ss2 = deepcopy(self.search_space) @@ -147,3 +159,26 @@ def test_TransformObservationFeatures(self) -> None: ) obs_ft2 = self.t.untransform_observation_features(obs_ft2) self.assertEqual(obs_ft2, observation_features) + + def test_transform_experiment_data(self) -> None: + transformed_data = self.t.transform_experiment_data( + experiment_data=deepcopy(self.experiment_data) + ) + # Check that arm data now has a new column for the transformed parameter. + expected_bar_values = [ + 3.0 * s for steps in STEPS_ENDS for s in range(1, steps + 1) + ] + self.assertEqual(transformed_data.arm_data["bar"].tolist(), expected_bar_values) + # Metadata has been updated to remove the transform parameter. + for m in transformed_data.arm_data["metadata"]: + self.assertNotIn("bar", m) + # Remaining columns are unchanged. + assert_frame_equal( + transformed_data.arm_data.drop(columns=["bar", "metadata"]), + self.experiment_data.arm_data.drop(columns=["metadata"]), + ) + # Observation data is not changed. + assert_frame_equal( + transformed_data.observation_data, + self.experiment_data.observation_data, + )