|
10 | 10 | from itertools import product |
11 | 11 | from typing import Any |
12 | 12 | from unittest import mock |
13 | | -from unittest.mock import patch, PropertyMock |
| 13 | +from unittest.mock import Mock, patch, PropertyMock |
14 | 14 |
|
15 | 15 | import numpy as np |
16 | 16 |
|
@@ -919,6 +919,209 @@ def test_get_best_point_with_model_prediction( |
919 | 919 | self.assertEqual(best_params, params) |
920 | 920 | self.assertEqual(predictions, ({"y": mock.ANY}, {"y": {"y": mock.ANY}})) |
921 | 921 |
|
| 922 | + @mock_botorch_optimize |
| 923 | + def test_get_best_parameters_from_model_predictions_with_trial_index( |
| 924 | + self, |
| 925 | + ) -> None: |
| 926 | + # Setup experiment |
| 927 | + exp = get_branin_experiment() |
| 928 | + gs = choose_generation_strategy_legacy( |
| 929 | + search_space=exp.search_space, |
| 930 | + num_initialization_trials=3, |
| 931 | + suggested_model_override=Generators.BOTORCH_MODULAR, |
| 932 | + ) |
| 933 | + |
| 934 | + # Add some trials with data |
| 935 | + for _ in range(4): |
| 936 | + generator_run = gs.gen_single_trial(experiment=exp, n=1) |
| 937 | + trial = exp.new_trial(generator_run=generator_run) |
| 938 | + trial.run().mark_completed() |
| 939 | + exp.attach_data(exp.fetch_data()) |
| 940 | + |
| 941 | + # Test 1: No adapter (None) - should fall back to generator run |
| 942 | + with self.subTest("No adapter - fallback to generator run"): |
| 943 | + result = get_best_parameters_from_model_predictions_with_trial_index( |
| 944 | + experiment=exp, adapter=None |
| 945 | + ) |
| 946 | + self.assertIsNotNone(result) |
| 947 | + trial_index, params, _ = none_throws(result) |
| 948 | + self.assertIsInstance(trial_index, int) |
| 949 | + self.assertIsInstance(params, dict) |
| 950 | + |
| 951 | + # Test 2: Non-TorchAdapter - should fall back to generator run |
| 952 | + # Then, the recommendation should be in-sample |
| 953 | + with self.subTest("Non-TorchAdapter - fallback to generator run"): |
| 954 | + non_torch_adapter = Mock() # Not a TorchAdapter |
| 955 | + result = none_throws( |
| 956 | + get_best_parameters_from_model_predictions_with_trial_index( |
| 957 | + experiment=exp, adapter=non_torch_adapter |
| 958 | + ) |
| 959 | + ) |
| 960 | + arm_params = result[1] |
| 961 | + self.assertTrue( |
| 962 | + arm_params in [v.parameters for v in exp.arms_by_name.values()] |
| 963 | + ) |
| 964 | + |
| 965 | + # Test 3: TorchAdapter with use_model_only_if_good=False -> |
| 966 | + # skip model fit check |
| 967 | + with self.subTest("TorchAdapter with use_model_only_if_good=False"): |
| 968 | + with patch.object( |
| 969 | + TorchAdapter, |
| 970 | + "model_best_point", |
| 971 | + return_value=( |
| 972 | + exp.trials[0].arms[0], |
| 973 | + ({"branin": 1.0}, {"branin": {"branin": 0.1}}), |
| 974 | + ), |
| 975 | + ) as mock_model_best_point, patch( |
| 976 | + "ax.service.utils.best_point.cross_validate" |
| 977 | + ) as mock_cv: |
| 978 | + result = get_best_parameters_from_model_predictions_with_trial_index( |
| 979 | + experiment=exp, adapter=gs.adapter, use_model_only_if_good=False |
| 980 | + ) |
| 981 | + |
| 982 | + # Should not call cross_validate when consider_model_fit=False |
| 983 | + mock_cv.assert_not_called() |
| 984 | + mock_model_best_point.assert_called_once() |
| 985 | + self.assertIsNotNone(result) |
| 986 | + |
| 987 | + # Test 4: TorchAdapter with good model fit - should use adapter |
| 988 | + with self.subTest("TorchAdapter with good model fit"): |
| 989 | + with patch.object( |
| 990 | + TorchAdapter, |
| 991 | + "model_best_point", |
| 992 | + return_value=( |
| 993 | + exp.trials[0].arms[0], |
| 994 | + ({"branin": 1.0}, {"branin": {"branin": 0.1}}), |
| 995 | + ), |
| 996 | + ) as mock_model_best_point, patch( |
| 997 | + "ax.service.utils.best_point.assess_model_fit", |
| 998 | + return_value=AssessModelFitResult( |
| 999 | + good_fit_metrics_to_fisher_score={"branin": 1.0}, |
| 1000 | + bad_fit_metrics_to_fisher_score={}, |
| 1001 | + ), |
| 1002 | + ), self.assertLogs(logger=best_point_logger, level="INFO") as lg: |
| 1003 | + result = get_best_parameters_from_model_predictions_with_trial_index( |
| 1004 | + experiment=exp, adapter=gs.adapter, use_model_only_if_good=True |
| 1005 | + ) |
| 1006 | + |
| 1007 | + mock_model_best_point.assert_called_once() |
| 1008 | + |
| 1009 | + # Should log that model fit is acceptable |
| 1010 | + self.assertTrue( |
| 1011 | + any("Model fit is acceptable" in log for log in lg.output), |
| 1012 | + msg=lg.output, |
| 1013 | + ) |
| 1014 | + |
| 1015 | + self.assertIsNotNone(result) |
| 1016 | + |
| 1017 | + # Test 5: TorchAdapter with bad model fit - should fall back to raw data |
| 1018 | + with self.subTest("TorchAdapter with bad model fit"): |
| 1019 | + with patch.object( |
| 1020 | + TorchAdapter, |
| 1021 | + "model_best_point", |
| 1022 | + return_value=( |
| 1023 | + exp.trials[0].arms[0], |
| 1024 | + ({"branin": 1.0}, {"branin": {"branin": 0.1}}), |
| 1025 | + ), |
| 1026 | + ) as mock_model_best_point, patch( |
| 1027 | + "ax.service.utils.best_point.assess_model_fit", |
| 1028 | + return_value=AssessModelFitResult( |
| 1029 | + good_fit_metrics_to_fisher_score={}, |
| 1030 | + bad_fit_metrics_to_fisher_score={"branin": 0.1}, |
| 1031 | + ), |
| 1032 | + ), patch( |
| 1033 | + "ax.service.utils.best_point.get_best_by_raw_objective_with_trial_index", |
| 1034 | + return_value=(0, {"x1": 1.0, "x2": 2.0}, ({"branin": 5.0}, {})), |
| 1035 | + ) as mock_raw_best, self.assertLogs( |
| 1036 | + logger=best_point_logger, level="WARN" |
| 1037 | + ) as lg: |
| 1038 | + result = get_best_parameters_from_model_predictions_with_trial_index( |
| 1039 | + experiment=exp, adapter=gs.adapter, use_model_only_if_good=True |
| 1040 | + ) |
| 1041 | + |
| 1042 | + # Should not call model_best_point when model fit is bad |
| 1043 | + mock_model_best_point.assert_not_called() |
| 1044 | + # Should call raw objective fallbacak |
| 1045 | + mock_raw_best.assert_called_once() |
| 1046 | + |
| 1047 | + # Should log warning about poor model fit |
| 1048 | + self.assertTrue( |
| 1049 | + any("Model fit is poor" in log for log in lg.output), msg=lg.output |
| 1050 | + ) |
| 1051 | + |
| 1052 | + self.assertIsNotNone(result) |
| 1053 | + |
| 1054 | + # Test 6: TorchAdapter with model_best_point returning None -> fall back to GR |
| 1055 | + with self.subTest("TorchAdapter with model_best_point returning None"): |
| 1056 | + with patch.object( |
| 1057 | + TorchAdapter, "model_best_point", return_value=None |
| 1058 | + ) as mock_model_best_point, patch( |
| 1059 | + "ax.service.utils.best_point.assess_model_fit", |
| 1060 | + return_value=AssessModelFitResult( |
| 1061 | + good_fit_metrics_to_fisher_score={"branin": 1.0}, |
| 1062 | + bad_fit_metrics_to_fisher_score={}, |
| 1063 | + ), |
| 1064 | + ): |
| 1065 | + result = get_best_parameters_from_model_predictions_with_trial_index( |
| 1066 | + experiment=exp, adapter=gs.adapter, use_model_only_if_good=True |
| 1067 | + ) |
| 1068 | + |
| 1069 | + mock_model_best_point.assert_called_once() |
| 1070 | + # Should still return a result from generator run fallback |
| 1071 | + self.assertIsNotNone(result) |
| 1072 | + |
| 1073 | + # Test 7: No generator run available - should return None |
| 1074 | + with self.subTest("No generator run available"): |
| 1075 | + # Create experiment with no generator runs |
| 1076 | + empty_exp = get_branin_experiment() |
| 1077 | + empty_exp.new_trial().run().mark_completed() |
| 1078 | + |
| 1079 | + result = get_best_parameters_from_model_predictions_with_trial_index( |
| 1080 | + experiment=empty_exp, adapter=None |
| 1081 | + ) |
| 1082 | + self.assertIsNone(result) |
| 1083 | + |
| 1084 | + # Test 10: Trial indices subset - should work with subset of data |
| 1085 | + with self.subTest("Trial indices subset"): |
| 1086 | + trial_indices = [0, 1] # Only use first two trials |
| 1087 | + |
| 1088 | + with patch.object( |
| 1089 | + TorchAdapter, |
| 1090 | + "model_best_point", |
| 1091 | + return_value=( |
| 1092 | + exp.trials[0].arms[0], |
| 1093 | + ({"branin": 1.0}, {"branin": {"branin": 0.1}}), |
| 1094 | + ), |
| 1095 | + ): |
| 1096 | + result = get_best_parameters_from_model_predictions_with_trial_index( |
| 1097 | + experiment=exp, |
| 1098 | + adapter=gs.adapter, |
| 1099 | + trial_indices=trial_indices, |
| 1100 | + use_model_only_if_good=False, |
| 1101 | + ) |
| 1102 | + |
| 1103 | + self.assertIsNotNone(result) |
| 1104 | + |
| 1105 | + # Test 11: Noisy data with poor model fit - should log additional warning |
| 1106 | + with self.subTest("Noisy data with poor model fit"): |
| 1107 | + with patch( |
| 1108 | + "ax.service.utils.best_point.assess_model_fit", |
| 1109 | + return_value=AssessModelFitResult( |
| 1110 | + good_fit_metrics_to_fisher_score={}, |
| 1111 | + bad_fit_metrics_to_fisher_score={"branin": 0.1}, |
| 1112 | + ), |
| 1113 | + ), patch( |
| 1114 | + "ax.service.utils.best_point._is_all_noiseless", |
| 1115 | + return_value=False, # Simulate noisy data |
| 1116 | + ), patch( |
| 1117 | + "ax.service.utils.best_point" |
| 1118 | + ".get_best_by_raw_objective_with_trial_index", |
| 1119 | + return_value=(0, {"x1": 1.0, "x2": 2.0}, ({"branin": 5.0}, {})), |
| 1120 | + ): |
| 1121 | + result = get_best_parameters_from_model_predictions_with_trial_index( |
| 1122 | + experiment=exp, adapter=gs.adapter, use_model_only_if_good=True |
| 1123 | + ) |
| 1124 | + |
922 | 1125 |
|
923 | 1126 | def _repeat_elements(list_to_replicate: list[Any], n_repeats: int) -> pd.Series: |
924 | 1127 | return pd.Series([item for item in list_to_replicate for _ in range(n_repeats)]) |
0 commit comments