Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions ax/generators/tests/test_botorch_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ def test_get_model(self) -> None:
"sd_prior": GammaPrior(2.0, 0.44),
"eta": 0.6,
}
x[0, 1] = 0
x[1, 1] = 1
model = _get_model(
X=x, Y=y, Yvar=partial_var.clone(), task_feature=1, prior=prior
)
Expand All @@ -117,6 +115,7 @@ def test_get_model(self) -> None:
task_covar_module.IndexKernelPrior.correlation_prior.eta,
0.6,
)

model = _get_model(
X=x,
Y=y,
Expand Down
15 changes: 10 additions & 5 deletions ax/generators/tests/test_botorch_moo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,6 @@ def test_BotorchMOOModel_with_qehvi(
[
[11.0, 2.0],
[9.0, 3.0],
[12.0, 0.0],
[13.0, 0.0],
],
**tkwargs,
)
Expand Down Expand Up @@ -561,8 +559,16 @@ def test_BotorchMOOModel_with_qehvi(
ckwargs = _mock_model_infer_objective_thresholds.call_args[1]
X_observed = ckwargs["X_observed"]
sorted_idcs = X_observed[:, 0].argsort()
sorted_idcs2 = Xs[:, 0].argsort()
self.assertTrue(torch.equal(X_observed[sorted_idcs], Xs[sorted_idcs2]))
expected_X_observed = torch.tensor(
[[1.0, 2.0, 3.0], [0.9, 1.9, 2.9]], **tkwargs
)
sorted_idcs2 = expected_X_observed[:, 0].argsort()
self.assertTrue(
torch.equal(
X_observed[sorted_idcs],
expected_X_observed[sorted_idcs2],
)
)
self.assertTrue(
torch.equal(
ckwargs["objective_weights"],
Expand Down Expand Up @@ -782,7 +788,6 @@ def test_BotorchMOOModel_with_qehvi_and_outcome_constraints(
feature_names,
_,
) = get_torch_test_data(dtype=dtype, cuda=cuda, constant_noise=True)
bounds[0] = (0.0, 1.0) # make one data point out of bounds
training_data = [
SupervisedDataset(
X=Xs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from typing import Any

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.generators.torch.botorch_modular.utils import get_all_task_values_from_ssd

from ax.utils.common.typeutils import _argparse_type_encoder
from botorch.models.transforms.outcome import (
Expand All @@ -22,7 +20,7 @@
)
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
from botorch.utils.dispatcher import Dispatcher
from pyre_extensions import assert_is_instance, none_throws
from pyre_extensions import assert_is_instance

outcome_transform_argparse = Dispatcher(
name="outcome_transform_argparse", encoder=_argparse_type_encoder
Expand All @@ -34,7 +32,6 @@ def _outcome_transform_argparse_base(
outcome_transform_class: type[OutcomeTransform],
dataset: SupervisedDataset | None = None,
outcome_transform_options: dict[str, Any] | None = None,
search_space_digest: SearchSpaceDigest | None = None,
) -> dict[str, Any]:
"""
Extract the outcome transform kwargs from the given arguments.
Expand All @@ -61,7 +58,6 @@ def _outcome_transform_argparse_standardize(
outcome_transform_class: type[Standardize],
dataset: SupervisedDataset,
outcome_transform_options: dict[str, Any] | None = None,
search_space_digest: SearchSpaceDigest | None = None,
) -> dict[str, Any]:
"""Extract the outcome transform kwargs form the given arguments.

Expand All @@ -88,7 +84,6 @@ def _outcome_transform_argparse_stratified_standardize(
outcome_transform_class: type[StratifiedStandardize],
dataset: SupervisedDataset,
outcome_transform_options: dict[str, Any] | None = None,
search_space_digest: SearchSpaceDigest | None = None,
) -> dict[str, Any]:
"""Extract the outcome transform kwargs form the given arguments.

Expand All @@ -111,20 +106,7 @@ def _outcome_transform_argparse_stratified_standardize(
else:
task_feature_index = dataset.task_feature_index
task_values = dataset.X[..., dataset.task_feature_index].unique().long()
ssd = none_throws(search_space_digest)
if (ssd.target_values is not None) and (
target_value := ssd.target_values.get(none_throws(task_feature_index))
) is not None:
outcome_transform_options.setdefault("default_task_value", int(target_value))
outcome_transform_options.setdefault("stratification_idx", task_feature_index)
outcome_transform_options.setdefault("observed_task_values", task_values)
outcome_transform_options.setdefault(
"all_task_values",
torch.tensor(
get_all_task_values_from_ssd(search_space_digest=ssd),
dtype=torch.long,
device=next(iter(dataset.datasets.values())).X.device,
),
)
outcome_transform_options.setdefault("task_values", task_values)

return outcome_transform_options
12 changes: 0 additions & 12 deletions ax/generators/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
convert_to_block_design,
copy_model_config_with_default_values,
fit_botorch_model,
get_all_task_values_from_ssd,
get_cv_fold,
ModelConfig,
subset_state_dict,
Expand Down Expand Up @@ -272,7 +271,6 @@ def _make_botorch_outcome_transform(
outcome_transform_classes: list[type[OutcomeTransform]],
outcome_transform_options: dict[str, dict[str, Any]],
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
) -> OutcomeTransform | None:
"""
Makes a BoTorch outcome transform from the provided classes and options.
Expand All @@ -292,7 +290,6 @@ def _make_botorch_outcome_transform(
outcome_transform_options.get(transform_class.__name__, {})
),
dataset=dataset,
search_space_digest=search_space_digest,
)
for transform_class in outcome_transform_classes
]
Expand Down Expand Up @@ -376,7 +373,6 @@ def _error_if_arg_not_supported(arg_name: str) -> None:
outcome_transform_classes=outcome_transform_classes,
outcome_transform_options=model_config.outcome_transform_options or {},
dataset=dataset,
search_space_digest=search_space_digest,
)
elif "outcome_transform" in botorch_model_class_args:
# This is a temporary solution until all BoTorch models use
Expand Down Expand Up @@ -1295,14 +1291,6 @@ def _submodel_input_constructor_mtgp(
target_value := search_space_digest.target_values.get(task_feature)
) is not None:
formatted_model_inputs["output_tasks"] = [int(target_value)]
# This enables making predictions for inputs at unobserved task values,
# by making predictions for the target task.
# This is important for MTGP models that are used in ModelListGPs where
# some metrics have only been observed for some tasks and not others.
formatted_model_inputs["validate_task_values"] = False
formatted_model_inputs["all_tasks"] = get_all_task_values_from_ssd(
search_space_digest=search_space_digest
)
else:
raise UserInputError(
"output_tasks or target task value must be provided for MultiTaskGP."
Expand Down
14 changes: 0 additions & 14 deletions ax/generators/torch/botorch_modular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,17 +663,3 @@ def get_cv_fold(
test_X=X[idcs],
test_Y=Y[idcs],
)


def get_all_task_values_from_ssd(search_space_digest: SearchSpaceDigest) -> list[int]:
"""Get all task values from a search space digest.

Args:
search_space_digest: The search space digest.

Returns:
A list of all task values.
"""
task_feature = search_space_digest.task_features[0]
task_bounds = search_space_digest.bounds[task_feature]
return list(range(int(task_bounds[0]), int(task_bounds[1] + 1)))
49 changes: 17 additions & 32 deletions ax/generators/torch/tests/test_outcome_transform_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# pyre-strict

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.generators.torch.botorch_modular.input_constructors.outcome_transform import (
outcome_transform_argparse,
)
Expand All @@ -18,7 +17,6 @@
)
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
from pyre_extensions import assert_is_instance
from torch import Tensor


class DummyOutcomeTransform(OutcomeTransform):
Expand Down Expand Up @@ -72,52 +70,39 @@ def test_argparse_stratified_standardize(self) -> None:
X = self.dataset.X
X[:5, 3] = 0
X[5:, 3] = 1
ssd = SearchSpaceDigest(
feature_names=self.dataset.feature_names,
bounds=[(0.0, 1.0)] * 3 + [(0.0, 2.0)],
task_features=[3],
target_values={3: 1},
)
mt_dataset = MultiTaskDataset.from_joint_dataset(
dataset=self.dataset,
task_feature_index=3,
target_task_value=1,
)
outcome_transform_kwargs_a = outcome_transform_argparse(
StratifiedStandardize,
dataset=mt_dataset,
search_space_digest=ssd,
StratifiedStandardize, dataset=mt_dataset
)
options_b = {"stratification_idx": 2, "default_task_value": 4}
options_b = {
"stratification_idx": 2,
"task_values": torch.tensor([0, 3]),
}
outcome_transform_kwargs_b = outcome_transform_argparse(
StratifiedStandardize,
dataset=mt_dataset,
outcome_transform_options=options_b,
search_space_digest=ssd,
)
expected_options_a = {
"stratification_idx": 3,
"observed_task_values": torch.tensor([0, 1], dtype=torch.long),
"all_task_values": torch.tensor([0, 1, 2], dtype=torch.long),
"default_task_value": 1,
}
expected_options_b = {
"stratification_idx": 2,
"observed_task_values": torch.tensor([0, 1], dtype=torch.long),
"all_task_values": torch.tensor([0, 1, 2], dtype=torch.long),
"default_task_value": 4,
"task_values": torch.tensor([0, 1]),
}
for expected_options, actual_options in zip(
(expected_options_a, expected_options_b),
(expected_options_a, options_b),
(outcome_transform_kwargs_a, outcome_transform_kwargs_b),
):
self.assertEqual(len(actual_options), 4)
for k in ("stratification_idx", "stratification_idx"):
self.assertEqual(actual_options[k], expected_options[k])
for k in ("observed_task_values", "all_task_values"):
self.assertTrue(
torch.equal(
actual_options[k],
assert_is_instance(expected_options[k], Tensor),
)
self.assertEqual(len(actual_options), 2)
self.assertEqual(
actual_options["stratification_idx"],
expected_options["stratification_idx"],
)
self.assertTrue(
torch.equal(
actual_options["task_values"],
assert_is_instance(expected_options["task_values"], torch.Tensor),
)
)
3 changes: 2 additions & 1 deletion ax/generators/torch/tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2047,7 +2047,8 @@ def test_fit(self) -> None:
),
}

Xs, Ys, Yvars, _, _, _, _ = get_torch_test_data(dtype=self.dtype)
# offset makes task feature point to valid outcome indices
Xs, Ys, Yvars, _, _, _, _ = get_torch_test_data(dtype=self.dtype, offset=-1)
ds1 = SupervisedDataset(
X=Xs,
Y=Ys,
Expand Down
2 changes: 1 addition & 1 deletion ax/utils/testing/torch_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_torch_test_data(
Yvar = torch.tensor([[0.0 + offset], [2.0 + offset]], **tkwargs)

bounds = [
(0.0 + offset, 2.0 + offset),
(0.0 + offset, 1.0 + offset),
(1.0 + offset, 4.0 + offset),
(2.0 + offset, 5.0 + offset),
]
Expand Down
Loading