Skip to content

Commit 8743553

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Inference trace and Best Point Recommendation (BPR) bugfix (facebook#4128)
Summary: This diff addresses two issues in the computation of inference trace: 1. The generation strategy is copied inside run_optimization_with_orchestrator --> we retrieve the traces on an unused generation strategy --> get_best_point defaults to the best raw observation on ALL obserations 2. Relevant data not filtered in the fallback option for get_best_parameters_from_model_predictions_with_trial_index Both of these individually lead to the inference trace being incorrect - the first to the best raw value of ALL trials, the second to the best predicted across ALL trials. Changes: - (re?)-Moved copying of generation strategy - Added argument use_model_only_if_good to force model-based BPR even if model fit is bad - Clearer sequencing in get_best_parameters_from_model_predictions_with_trial_index Differential Revision: D80019803
1 parent 8f2b8ca commit 8743553

File tree

4 files changed

+304
-50
lines changed

4 files changed

+304
-50
lines changed

ax/benchmark/benchmark.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,9 @@ def get_best_parameters(
342342
experiment=experiment,
343343
generation_strategy=generation_strategy,
344344
trial_indices=trial_indices,
345+
# disables the model quality check which determines whether to evaluate
346+
# inference trace or raw observations when retrieving the best point
347+
use_model_only_if_good=False,
345348
)
346349
if result is None:
347350
# This can happen if no points are predicted to satisfy all outcome
@@ -507,7 +510,7 @@ def run_optimization_with_orchestrator(
507510

508511
orchestrator = Orchestrator(
509512
experiment=experiment,
510-
generation_strategy=method.generation_strategy.clone_reset(),
513+
generation_strategy=method.generation_strategy,
511514
options=orchestrator_options,
512515
)
513516

@@ -562,6 +565,8 @@ def benchmark_replication(
562565
Return:
563566
``BenchmarkResult`` object.
564567
"""
568+
# Reset the generation strategy to ensure that it is in an unused state.
569+
method.generation_strategy = method.generation_strategy.clone_reset()
565570
experiment = run_optimization_with_orchestrator(
566571
problem=problem,
567572
method=method,

ax/service/tests/test_best_point_utils.py

Lines changed: 204 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from itertools import product
1111
from typing import Any
1212
from unittest import mock
13-
from unittest.mock import patch, PropertyMock
13+
from unittest.mock import Mock, patch, PropertyMock
1414

1515
import numpy as np
1616

@@ -919,6 +919,209 @@ def test_get_best_point_with_model_prediction(
919919
self.assertEqual(best_params, params)
920920
self.assertEqual(predictions, ({"y": mock.ANY}, {"y": {"y": mock.ANY}}))
921921

922+
@mock_botorch_optimize
923+
def test_get_best_parameters_from_model_predictions_with_trial_index(
924+
self,
925+
) -> None:
926+
# Setup experiment
927+
exp = get_branin_experiment()
928+
gs = choose_generation_strategy_legacy(
929+
search_space=exp.search_space,
930+
num_initialization_trials=3,
931+
suggested_model_override=Generators.BOTORCH_MODULAR,
932+
)
933+
934+
# Add some trials with data
935+
for _ in range(4):
936+
generator_run = gs.gen_single_trial(experiment=exp, n=1)
937+
trial = exp.new_trial(generator_run=generator_run)
938+
trial.run().mark_completed()
939+
exp.attach_data(exp.fetch_data())
940+
941+
# Test 1: No adapter (None) - should fall back to generator run
942+
with self.subTest("No adapter - fallback to generator run"):
943+
result = get_best_parameters_from_model_predictions_with_trial_index(
944+
experiment=exp, adapter=None
945+
)
946+
self.assertIsNotNone(result)
947+
trial_index, params, _ = none_throws(result)
948+
self.assertIsInstance(trial_index, int)
949+
self.assertIsInstance(params, dict)
950+
951+
# Test 2: Non-TorchAdapter - should fall back to generator run
952+
# Then, the recommendation should be in-sample
953+
with self.subTest("Non-TorchAdapter - fallback to generator run"):
954+
non_torch_adapter = Mock() # Not a TorchAdapter
955+
result = none_throws(
956+
get_best_parameters_from_model_predictions_with_trial_index(
957+
experiment=exp, adapter=non_torch_adapter
958+
)
959+
)
960+
arm_params = result[1]
961+
self.assertTrue(
962+
arm_params in [v.parameters for v in exp.arms_by_name.values()]
963+
)
964+
965+
# Test 3: TorchAdapter with use_model_only_if_good=False ->
966+
# skip model fit check
967+
with self.subTest("TorchAdapter with use_model_only_if_good=False"):
968+
with patch.object(
969+
TorchAdapter,
970+
"model_best_point",
971+
return_value=(
972+
exp.trials[0].arms[0],
973+
({"branin": 1.0}, {"branin": {"branin": 0.1}}),
974+
),
975+
) as mock_model_best_point, patch(
976+
"ax.service.utils.best_point.cross_validate"
977+
) as mock_cv:
978+
result = get_best_parameters_from_model_predictions_with_trial_index(
979+
experiment=exp, adapter=gs.adapter, use_model_only_if_good=False
980+
)
981+
982+
# Should not call cross_validate when consider_model_fit=False
983+
mock_cv.assert_not_called()
984+
mock_model_best_point.assert_called_once()
985+
self.assertIsNotNone(result)
986+
987+
# Test 4: TorchAdapter with good model fit - should use adapter
988+
with self.subTest("TorchAdapter with good model fit"):
989+
with patch.object(
990+
TorchAdapter,
991+
"model_best_point",
992+
return_value=(
993+
exp.trials[0].arms[0],
994+
({"branin": 1.0}, {"branin": {"branin": 0.1}}),
995+
),
996+
) as mock_model_best_point, patch(
997+
"ax.service.utils.best_point.assess_model_fit",
998+
return_value=AssessModelFitResult(
999+
good_fit_metrics_to_fisher_score={"branin": 1.0},
1000+
bad_fit_metrics_to_fisher_score={},
1001+
),
1002+
), self.assertLogs(logger=best_point_logger, level="INFO") as lg:
1003+
result = get_best_parameters_from_model_predictions_with_trial_index(
1004+
experiment=exp, adapter=gs.adapter, use_model_only_if_good=True
1005+
)
1006+
1007+
mock_model_best_point.assert_called_once()
1008+
1009+
# Should log that model fit is acceptable
1010+
self.assertTrue(
1011+
any("Model fit is acceptable" in log for log in lg.output),
1012+
msg=lg.output,
1013+
)
1014+
1015+
self.assertIsNotNone(result)
1016+
1017+
# Test 5: TorchAdapter with bad model fit - should fall back to raw data
1018+
with self.subTest("TorchAdapter with bad model fit"):
1019+
with patch.object(
1020+
TorchAdapter,
1021+
"model_best_point",
1022+
return_value=(
1023+
exp.trials[0].arms[0],
1024+
({"branin": 1.0}, {"branin": {"branin": 0.1}}),
1025+
),
1026+
) as mock_model_best_point, patch(
1027+
"ax.service.utils.best_point.assess_model_fit",
1028+
return_value=AssessModelFitResult(
1029+
good_fit_metrics_to_fisher_score={},
1030+
bad_fit_metrics_to_fisher_score={"branin": 0.1},
1031+
),
1032+
), patch(
1033+
"ax.service.utils.best_point.get_best_by_raw_objective_with_trial_index",
1034+
return_value=(0, {"x1": 1.0, "x2": 2.0}, ({"branin": 5.0}, {})),
1035+
) as mock_raw_best, self.assertLogs(
1036+
logger=best_point_logger, level="WARN"
1037+
) as lg:
1038+
result = get_best_parameters_from_model_predictions_with_trial_index(
1039+
experiment=exp, adapter=gs.adapter, use_model_only_if_good=True
1040+
)
1041+
1042+
# Should not call model_best_point when model fit is bad
1043+
mock_model_best_point.assert_not_called()
1044+
# Should call raw objective fallbacak
1045+
mock_raw_best.assert_called_once()
1046+
1047+
# Should log warning about poor model fit
1048+
self.assertTrue(
1049+
any("Model fit is poor" in log for log in lg.output), msg=lg.output
1050+
)
1051+
1052+
self.assertIsNotNone(result)
1053+
1054+
# Test 6: TorchAdapter with model_best_point returning None -> fall back to GR
1055+
with self.subTest("TorchAdapter with model_best_point returning None"):
1056+
with patch.object(
1057+
TorchAdapter, "model_best_point", return_value=None
1058+
) as mock_model_best_point, patch(
1059+
"ax.service.utils.best_point.assess_model_fit",
1060+
return_value=AssessModelFitResult(
1061+
good_fit_metrics_to_fisher_score={"branin": 1.0},
1062+
bad_fit_metrics_to_fisher_score={},
1063+
),
1064+
):
1065+
result = get_best_parameters_from_model_predictions_with_trial_index(
1066+
experiment=exp, adapter=gs.adapter, use_model_only_if_good=True
1067+
)
1068+
1069+
mock_model_best_point.assert_called_once()
1070+
# Should still return a result from generator run fallback
1071+
self.assertIsNotNone(result)
1072+
1073+
# Test 7: No generator run available - should return None
1074+
with self.subTest("No generator run available"):
1075+
# Create experiment with no generator runs
1076+
empty_exp = get_branin_experiment()
1077+
empty_exp.new_trial().run().mark_completed()
1078+
1079+
result = get_best_parameters_from_model_predictions_with_trial_index(
1080+
experiment=empty_exp, adapter=None
1081+
)
1082+
self.assertIsNone(result)
1083+
1084+
# Test 10: Trial indices subset - should work with subset of data
1085+
with self.subTest("Trial indices subset"):
1086+
trial_indices = [0, 1] # Only use first two trials
1087+
1088+
with patch.object(
1089+
TorchAdapter,
1090+
"model_best_point",
1091+
return_value=(
1092+
exp.trials[0].arms[0],
1093+
({"branin": 1.0}, {"branin": {"branin": 0.1}}),
1094+
),
1095+
):
1096+
result = get_best_parameters_from_model_predictions_with_trial_index(
1097+
experiment=exp,
1098+
adapter=gs.adapter,
1099+
trial_indices=trial_indices,
1100+
use_model_only_if_good=False,
1101+
)
1102+
1103+
self.assertIsNotNone(result)
1104+
1105+
# Test 11: Noisy data with poor model fit - should log additional warning
1106+
with self.subTest("Noisy data with poor model fit"):
1107+
with patch(
1108+
"ax.service.utils.best_point.assess_model_fit",
1109+
return_value=AssessModelFitResult(
1110+
good_fit_metrics_to_fisher_score={},
1111+
bad_fit_metrics_to_fisher_score={"branin": 0.1},
1112+
),
1113+
), patch(
1114+
"ax.service.utils.best_point._is_all_noiseless",
1115+
return_value=False, # Simulate noisy data
1116+
), patch(
1117+
"ax.service.utils.best_point"
1118+
".get_best_by_raw_objective_with_trial_index",
1119+
return_value=(0, {"x1": 1.0, "x2": 2.0}, ({"branin": 5.0}, {})),
1120+
):
1121+
result = get_best_parameters_from_model_predictions_with_trial_index(
1122+
experiment=exp, adapter=gs.adapter, use_model_only_if_good=True
1123+
)
1124+
9221125

9231126
def _repeat_elements(list_to_replicate: list[Any], n_repeats: int) -> pd.Series:
9241127
return pd.Series([item for item in list_to_replicate for _ in range(n_repeats)])

0 commit comments

Comments
 (0)