Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions ax/adapter/transforms/map_key_to_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,28 @@ def __init__(
config=config,
)

def _get_values_for_parameter(
self,
name: str,
observations: list[Observation] | None,
experiment_data: ExperimentData | None,
) -> set[float]:
if experiment_data is not None:
obs_data = experiment_data.observation_data
if name not in obs_data.index.names:
raise ValueError(
f"Parameter {name} is not in the index of the observation data."
)
return set(
obs_data.index.unique(level=name).dropna().astype(float).tolist()
)
# For Observations, the logic is identical to the parent class.
return super()._get_values_for_parameter(
name=name,
observations=observations,
experiment_data=experiment_data,
)

def _transform_observation_feature(self, obsf: ObservationFeatures) -> None:
if len(obsf.parameters) == 0:
obsf.parameters = {p.name: p.upper for p in self._parameter_list}
Expand All @@ -87,3 +109,15 @@ def _transform_observation_feature(self, obsf: ObservationFeatures) -> None:
if isnan(metadata[p.name]):
metadata[p.name] = p.upper
super()._transform_observation_feature(obsf)

def transform_experiment_data(
self, experiment_data: ExperimentData
) -> ExperimentData:
"""No-op transform for experiment data.

This operates based on the assumption that the relevant map keys already
exist on the index of the observation data (verified in __init__),
and the downstream code will extract the map keys from there directly.
We do not need to duplicate the map keys in the arm data.
"""
return experiment_data
43 changes: 35 additions & 8 deletions ax/adapter/transforms/metadata_to_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def __init__(
adapter: adapter_module.base.Adapter | None = None,
config: TConfig | None = None,
) -> None:
if observations is None or not observations:
if (observations is None or not observations) and experiment_data is None:
raise DataRequiredError(
"`MetadataToRange` transform requires non-empty data."
f"`{self.__class__.__name__}` transform requires non-empty data."
)
super().__init__(
search_space=search_space,
Expand All @@ -79,12 +79,9 @@ def __init__(

self._parameter_list: list[RangeParameter] = []
for name in self.parameters:
values: set[float] = set()
for obs in observations:
obsf_metadata = none_throws(obs.features.metadata)
value = float(assert_is_instance(obsf_metadata[name], SupportsFloat))
if not isnan(value):
values.add(value)
values: set[float] = self._get_values_for_parameter(
name=name, observations=observations, experiment_data=experiment_data
)

if len(values) == 0:
logger.debug(
Expand Down Expand Up @@ -127,6 +124,24 @@ def __init__(
)
self._parameter_list.append(parameter)

def _get_values_for_parameter(
self,
name: str,
observations: list[Observation] | None,
experiment_data: ExperimentData | None,
) -> set[float]:
if experiment_data is not None:
all_metadata = experiment_data.arm_data["metadata"]
return all_metadata.str.get(name).dropna().astype(float).tolist()

values: set[float] = set()
for obs in none_throws(observations):
obsf_metadata = none_throws(obs.features.metadata)
value = float(assert_is_instance(obsf_metadata[name], SupportsFloat))
if not isnan(value):
values.add(value)
return values

def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
for parameter in self._parameter_list:
search_space.add_parameter(parameter.clone())
Expand Down Expand Up @@ -158,6 +173,18 @@ def _transform_observation_feature(self, obsf: ObservationFeatures) -> None:
keys=[p.name for p in self._parameter_list],
)

def transform_experiment_data(
self, experiment_data: ExperimentData
) -> ExperimentData:
arm_data = experiment_data.arm_data
metadata = arm_data["metadata"]
for name in self.parameters:
arm_data[name] = metadata.apply(lambda x: x.pop(name)) # noqa B023
return ExperimentData(
arm_data=arm_data,
observation_data=experiment_data.observation_data,
)


def _transfer(
src: dict[str, Any],
Expand Down
36 changes: 28 additions & 8 deletions ax/adapter/transforms/tests/test_map_key_to_float_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
from ax.adapter import Adapter
from ax.adapter.base import DataLoaderConfig
from ax.adapter.data_utils import extract_experiment_data
from ax.adapter.registry import Generators, MBM_X_trans, Y_trans
from ax.adapter.torch import TorchAdapter
from ax.adapter.transforms.map_key_to_float import MapKeyToFloat
Expand Down Expand Up @@ -263,14 +264,20 @@ def test_Init(self) -> None:
with self.assertRaisesRegex(UserInputError, "optimization config"):
MapKeyToFloat(observations=self.observations)

# Check for default initialization
self.assertEqual(len(self.t._parameter_list), 1)
(p,) = self.t._parameter_list
self.assertEqual(p.name, self.map_key)
self.assertEqual(p.parameter_type, ParameterType.FLOAT)
self.assertEqual(p.lower, 0.0)
self.assertEqual(p.upper, 4.0)
self.assertFalse(p.log_scale) # False since lower is 0.0.
experiment_data = extract_experiment_data(
experiment=self.experiment,
data_loader_config=DataLoaderConfig(fit_only_completed_map_metrics=False),
)
t2 = MapKeyToFloat(experiment_data=experiment_data, adapter=self.adapter)
for t in (self.t, t2):
# Check for default initialization
self.assertEqual(len(t._parameter_list), 1)
(p,) = t._parameter_list
self.assertEqual(p.name, self.map_key)
self.assertEqual(p.parameter_type, ParameterType.FLOAT)
self.assertEqual(p.lower, 0.0)
self.assertEqual(p.upper, 4.0)
self.assertFalse(p.log_scale) # False since lower is 0.0.

# specifying a parameter name that is not in the observation features' metadata
with self.assertRaisesRegex(KeyError, "'baz'"):
Expand Down Expand Up @@ -500,3 +507,16 @@ def test_with_different_map_key(self) -> None:
self.assertEqual(
tf_obs_ft[1].parameters, {"x1": 0.1, "x2": 0.9, "map_key": 12346.0}
)

def test_transform_experiment_data(self) -> None:
experiment_data = extract_experiment_data(
experiment=self.experiment,
data_loader_config=DataLoaderConfig(fit_only_completed_map_metrics=False),
)
copy_experiment_data = deepcopy(experiment_data)
transformed_data = self.t.transform_experiment_data(
experiment_data=copy_experiment_data
)
# Check that it is returned unmodified.
self.assertIs(copy_experiment_data, transformed_data)
self.assertEqual(experiment_data, transformed_data)
67 changes: 51 additions & 16 deletions ax/adapter/transforms/tests/test_metadata_to_float_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from copy import deepcopy
from unittest.mock import ANY

from ax.adapter.base import DataLoaderConfig
from ax.adapter.data_utils import extract_experiment_data
from ax.adapter.transforms.metadata_to_float import MetadataToFloat
from ax.core.observation import ObservationFeatures, observations_from_data
from ax.core.parameter import ParameterType, RangeParameter
Expand All @@ -18,6 +20,7 @@
from ax.utils.common.constants import Keys
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_experiment_with_observations
from pandas.testing import assert_frame_equal
from pyre_extensions import assert_is_instance


Expand Down Expand Up @@ -74,35 +77,44 @@ def setUp(self) -> None:
self.observations = observations_from_data(
experiment=self.experiment, data=self.experiment.lookup_data()
)
self.experiment_data = extract_experiment_data(
experiment=self.experiment,
data_loader_config=DataLoaderConfig(),
)

self.t = MetadataToFloat(
observations=self.observations,
config={
"parameters": {"bar": {"log_scale": True}},
},
)
self.t2 = MetadataToFloat(
experiment_data=self.experiment_data,
config={
"parameters": {"bar": {"log_scale": True}},
},
)

def test_Init(self) -> None:
self.assertEqual(len(self.t._parameter_list), 1)

p = self.t._parameter_list[0]

# check that the parameter options are specified in a sensible manner
# by default if the user does not specify them explicitly
self.assertEqual(p.name, "bar")
self.assertEqual(p.parameter_type, ParameterType.FLOAT)
self.assertEqual(p.lower, 3.0)
self.assertEqual(p.upper, 15.0)
self.assertTrue(p.log_scale)
self.assertFalse(p.logit_scale)
self.assertIsNone(p.digits)
self.assertFalse(p.is_fidelity)
self.assertIsNone(p.target_value)
for t in (self.t, self.t2):
self.assertEqual(len(t._parameter_list), 1)
p = t._parameter_list[0]
# check that the parameter options are specified in a sensible manner
# by default if the user does not specify them explicitly
self.assertEqual(p.name, "bar")
self.assertEqual(p.parameter_type, ParameterType.FLOAT)
self.assertEqual(p.lower, 3.0)
self.assertEqual(p.upper, 15.0)
self.assertTrue(p.log_scale)
self.assertFalse(p.logit_scale)
self.assertIsNone(p.digits)
self.assertFalse(p.is_fidelity)
self.assertIsNone(p.target_value)

with self.assertRaisesRegex(DataRequiredError, "requires non-empty data"):
MetadataToFloat(search_space=None, observations=None)
with self.assertRaisesRegex(DataRequiredError, "requires non-empty data"):
MetadataToFloat(search_space=None, observations=[])
MetadataToFloat(search_space=None)

def test_TransformSearchSpace(self) -> None:
ss2 = deepcopy(self.search_space)
Expand Down Expand Up @@ -147,3 +159,26 @@ def test_TransformObservationFeatures(self) -> None:
)
obs_ft2 = self.t.untransform_observation_features(obs_ft2)
self.assertEqual(obs_ft2, observation_features)

def test_transform_experiment_data(self) -> None:
transformed_data = self.t.transform_experiment_data(
experiment_data=deepcopy(self.experiment_data)
)
# Check that arm data now has a new column for the transformed parameter.
expected_bar_values = [
3.0 * s for steps in STEPS_ENDS for s in range(1, steps + 1)
]
self.assertEqual(transformed_data.arm_data["bar"].tolist(), expected_bar_values)
# Metadata has been updated to remove the transform parameter.
for m in transformed_data.arm_data["metadata"]:
self.assertNotIn("bar", m)
# Remaining columns are unchanged.
assert_frame_equal(
transformed_data.arm_data.drop(columns=["bar", "metadata"]),
self.experiment_data.arm_data.drop(columns=["metadata"]),
)
# Observation data is not changed.
assert_frame_equal(
transformed_data.observation_data,
self.experiment_data.observation_data,
)