Skip to content

Commit 50ee506

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Inference trace and Best Point Recommendation (BPR) bugfix (facebook#4128)
Summary: This diff addresses two issues in the computation of inference trace: 1. The generation strategy is copied inside run_optimization_with_orchestrator --> we retrieve the traces on an unused generation strategy --> get_best_point defaults to the best raw observation on ALL obserations 2. Relevant data not filtered in the fallback option for get_best_parameters_from_model_predictions_with_trial_index Both of these individually lead to the inference trace being incorrect - the first to the best raw value of ALL trials, the second to the best predicted across ALL trials. Changes: - Moved copying of generation strategy to the level`benchmark_replication`, since results need to be computed on the used `generation_strategy` and not an empty copy. This means that `run_optimization_with_orchestrator` no longer `clone_and_reset`'s the GS. - Clearer sequencing in get_best_parameters_from_model_predictions_with_trial_index - Removed model fit quality check as part of BPR Previous, redacted changes: - Added argument use_model_only_if_good to force model-based BPR even if model fit is bad Differential Revision: D80019803
1 parent fb3af11 commit 50ee506

File tree

4 files changed

+268
-181
lines changed

4 files changed

+268
-181
lines changed

ax/benchmark/benchmark.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@
5151
from ax.core.utils import get_model_times
5252
from ax.generation_strategy.generation_strategy import GenerationStrategy
5353
from ax.service.orchestrator import Orchestrator
54-
from ax.service.utils.best_point import get_trace
55-
from ax.service.utils.best_point_mixin import BestPointMixin
54+
from ax.service.utils.best_point import (
55+
get_best_parameters_from_model_predictions_with_trial_index,
56+
get_trace,
57+
)
5658
from ax.service.utils.orchestrator_options import OrchestratorOptions, TrialType
5759
from ax.utils.common.logger import DEFAULT_LOG_LEVEL, get_logger
5860
from ax.utils.common.random import with_rng_seed
@@ -338,9 +340,9 @@ def get_best_parameters(
338340
best point.
339341
trial_indices: Use data from only these trials. If None, use all data.
340342
"""
341-
result = BestPointMixin._get_best_trial(
343+
result = get_best_parameters_from_model_predictions_with_trial_index(
342344
experiment=experiment,
343-
generation_strategy=generation_strategy,
345+
adapter=generation_strategy.adapter,
344346
trial_indices=trial_indices,
345347
)
346348
if result is None:
@@ -507,7 +509,7 @@ def run_optimization_with_orchestrator(
507509

508510
orchestrator = Orchestrator(
509511
experiment=experiment,
510-
generation_strategy=method.generation_strategy.clone_reset(),
512+
generation_strategy=method.generation_strategy,
511513
options=orchestrator_options,
512514
)
513515

@@ -562,6 +564,8 @@ def benchmark_replication(
562564
Return:
563565
``BenchmarkResult`` object.
564566
"""
567+
# Reset the generation strategy to ensure that it is in an unused state.
568+
method.generation_strategy = method.generation_strategy.clone_reset()
565569
experiment = run_optimization_with_orchestrator(
566570
problem=problem,
567571
method=method,

ax/benchmark/tests/test_benchmark.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ def test_run_optimization_with_orchestrator(self) -> None:
491491
none_throws(runner.simulated_backend_runner).simulator._verbose_logging
492492
)
493493

494+
method.generation_strategy = method.generation_strategy.clone_reset()
494495
with self.subTest("Logs not produced by default"), self.assertNoLogs(
495496
level=logging.INFO, logger=logger
496497
), self.assertNoLogs(logger=logger):
@@ -618,9 +619,9 @@ def test_early_stopping(self) -> None:
618619
self.assertEqual(max_run, {0: 4, 1: 2, 2: 2, 3: 2})
619620

620621
def test_replication_variable_runtime(self) -> None:
621-
method = get_async_benchmark_method(max_pending_trials=1)
622622
for map_data in [False, True]:
623623
with self.subTest(map_data=map_data):
624+
method = get_async_benchmark_method(max_pending_trials=1)
624625
problem = get_async_benchmark_problem(
625626
map_data=map_data,
626627
step_runtime_fn=lambda params: params["x0"] + 1,
@@ -1196,6 +1197,7 @@ def test_get_opt_trace_by_cumulative_epochs(self) -> None:
11961197
):
11971198
get_opt_trace_by_steps(experiment=experiment)
11981199

1200+
method.generation_strategy = method.generation_strategy.clone_reset()
11991201
with self.subTest("Constrained"):
12001202
problem = get_benchmark_problem("constrained_gramacy_observed_noise")
12011203
experiment = self.run_optimization_with_orchestrator(

0 commit comments

Comments
 (0)