diff --git a/ads/aqua/common/entities.py b/ads/aqua/common/entities.py index f537b32f3..d22d32c8b 100644 --- a/ads/aqua/common/entities.py +++ b/ads/aqua/common/entities.py @@ -287,7 +287,7 @@ class AquaMultiModelRef(Serializable): description="Environment variables to override during container startup.", ) params: Optional[dict] = Field( - default_factory=dict, + default=None, description=( "Framework-specific startup parameters required by the container runtime. " "For example, vLLM models may use flags like `--tensor-parallel-size`, `--enforce-eager`, etc." diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index 6f4ac2070..eeb330090 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -28,7 +28,6 @@ build_params_string, build_pydantic_error_message, find_restricted_params, - get_combined_params, get_container_env_type, get_container_params_type, get_ocid_substring, @@ -918,10 +917,31 @@ def _create( # The values provided by user will override the ones provided by default config env_var = {**config_env, **env_var} - # validate user provided params - user_params = env_var.get("PARAMS", UNKNOWN) + # SMM Parameter Resolution Logic + # Check the raw user input from create_deployment_details to determine intent. + # We cannot use the merged 'env_var' here because it may already contain defaults. + user_input_env = create_deployment_details.env_var or {} + user_input_params = user_input_env.get("PARAMS") + + deployment_params = "" + + if user_input_params is None: + # Case 1: None (CLI default) -> Load full defaults from config + logger.info("No PARAMS provided (None). Loading default SMM parameters.") + deployment_params = config_params + elif str(user_input_params).strip() == "": + # Case 2: Empty String (UI Clear) -> Explicitly use no parameters + logger.info("Empty PARAMS provided. Clearing all parameters.") + deployment_params = "" + else: + # Case 3: Value Provided -> Use exact user value (No merging) + logger.info( + f"User provided PARAMS. Using exact user values: {user_input_params}" + ) + deployment_params = user_input_params - if user_params: + # Validate the resolved parameters + if deployment_params: # todo: remove this check in the future version, logic to be moved to container_index if ( container_type_key.lower() @@ -935,7 +955,7 @@ def _create( ) restricted_params = find_restricted_params( - params, user_params, container_type_key + params, deployment_params, container_type_key ) if restricted_params: raise AquaValueError( @@ -943,8 +963,6 @@ def _create( f"and cannot be overridden or are invalid." ) - deployment_params = get_combined_params(config_params, user_params) - params = f"{params} {deployment_params}".strip() if isinstance(aqua_model, DataScienceModelGroup): @@ -1212,7 +1230,7 @@ def _create_deployment( # we arbitrarily choose last 8 characters of OCID to identify MD in telemetry deployment_short_ocid = get_ocid_substring(deployment_id, key_len=8) - + # Prepare telemetry kwargs telemetry_kwargs = {"ocid": deployment_short_ocid} @@ -2048,9 +2066,11 @@ def recommend_shape(self, **kwargs) -> Union[Table, ShapeRecommendationReport]: self.telemetry.record_event_async( category="aqua/deployment", action="recommend_shape", - detail=get_ocid_substring(model_id, key_len=8) - if is_valid_ocid(ocid=model_id) - else model_id, + detail=( + get_ocid_substring(model_id, key_len=8) + if is_valid_ocid(ocid=model_id) + else model_id + ), **kwargs, ) diff --git a/ads/aqua/modeldeployment/model_group_config.py b/ads/aqua/modeldeployment/model_group_config.py index ca8961620..5d38dde3f 100644 --- a/ads/aqua/modeldeployment/model_group_config.py +++ b/ads/aqua/modeldeployment/model_group_config.py @@ -13,7 +13,6 @@ from ads.aqua.common.utils import ( build_params_string, find_restricted_params, - get_combined_params, get_container_params_type, get_params_dict, ) @@ -177,26 +176,36 @@ def _merge_gpu_count_params( model.model_id, AquaDeploymentConfig() ).configuration.get(deployment_details.instance_shape, ConfigurationItem()) + final_model_params = user_params params_found = False - for item in deployment_config.multi_model_deployment: - if model.gpu_count and item.gpu_count and item.gpu_count == model.gpu_count: - config_parameters = item.parameters.get( + user_explicitly_cleared = model.params is not None and not model.params + + # Only load defaults if user didn't provide params AND didn't explicitly clear them + if not user_params and not user_explicitly_cleared: + for item in deployment_config.multi_model_deployment: + if ( + model.gpu_count + and item.gpu_count + and item.gpu_count == model.gpu_count + ): + config_parameters = item.parameters.get( + get_container_params_type(container_type_key), UNKNOWN + ) + if config_parameters: + final_model_params = config_parameters + params_found = True + break + + if not params_found and deployment_config.parameters: + config_parameters = deployment_config.parameters.get( get_container_params_type(container_type_key), UNKNOWN ) - params = f"{params} {get_combined_params(config_parameters, user_params)}".strip() + if config_parameters: + final_model_params = config_parameters params_found = True - break - if not params_found and deployment_config.parameters: - config_parameters = deployment_config.parameters.get( - get_container_params_type(container_type_key), UNKNOWN - ) - params = f"{params} {get_combined_params(config_parameters, user_params)}".strip() - params_found = True - - # if no config parameters found, append user parameters directly. - if not params_found: - params = f"{params} {user_params}".strip() + # Combine Container System Defaults (params) + Model Params (final_model_params) + params = f"{params} {final_model_params}".strip() return params diff --git a/tests/unitary/with_extras/aqua/test_common_entities.py b/tests/unitary/with_extras/aqua/test_common_entities.py index 0c2b293b4..a686a3c11 100644 --- a/tests/unitary/with_extras/aqua/test_common_entities.py +++ b/tests/unitary/with_extras/aqua/test_common_entities.py @@ -196,7 +196,7 @@ def test_extract_params_from_env_var_missing_env(self): } result = AquaMultiModelRef.model_validate(values) assert result.env_var == {} - assert result.params == {} + assert result.params is None def test_all_model_ids_no_finetunes(self): model = AquaMultiModelRef(model_id="ocid1.model.oc1..base") diff --git a/tests/unitary/with_extras/aqua/test_deployment.py b/tests/unitary/with_extras/aqua/test_deployment.py index 2a22a7f42..caf82d7b8 100644 --- a/tests/unitary/with_extras/aqua/test_deployment.py +++ b/tests/unitary/with_extras/aqua/test_deployment.py @@ -556,7 +556,7 @@ class TestDataset: "models": [ { "env_var": {}, - "params": {}, + "params": None, "gpu_count": 2, "model_id": "test_model_id_1", "model_name": "test_model_1", @@ -566,7 +566,7 @@ class TestDataset: }, { "env_var": {}, - "params": {}, + "params": None, "gpu_count": 2, "model_id": "test_model_id_2", "model_name": "test_model_2", @@ -576,7 +576,7 @@ class TestDataset: }, { "env_var": {}, - "params": {}, + "params": None, "gpu_count": 2, "model_id": "test_model_id_3", "model_name": "test_model_3", @@ -1058,7 +1058,7 @@ class TestDataset: multi_model_deployment_model_attributes = [ { "env_var": {"--test_key_one": "test_value_one"}, - "params": {}, + "params": None, "gpu_count": 1, "model_id": "ocid1.compartment.oc1..", "model_name": "model_one", @@ -1068,7 +1068,7 @@ class TestDataset: }, { "env_var": {"--test_key_two": "test_value_two"}, - "params": {}, + "params": None, "gpu_count": 1, "model_id": "ocid1.compartment.oc1..", "model_name": "model_two", @@ -1078,7 +1078,7 @@ class TestDataset: }, { "env_var": {"--test_key_three": "test_value_three"}, - "params": {}, + "params": None, "gpu_count": 1, "model_id": "ocid1.compartment.oc1..", "model_name": "model_three", @@ -1258,9 +1258,7 @@ def test_get_deployment(self, mock_get_resource_name): mock_get_resource_name.side_effect = lambda param: ( "log-group-name" if param.startswith("ocid1.loggroup") - else "log-name" - if param.startswith("ocid1.log") - else "" + else "log-name" if param.startswith("ocid1.log") else "" ) result = self.app.get(model_deployment_id=TestDataset.MODEL_DEPLOYMENT_ID) @@ -1301,9 +1299,7 @@ def test_get_multi_model_deployment( mock_get_resource_name.side_effect = lambda param: ( "log-group-name" if param.startswith("ocid1.loggroup") - else "log-name" - if param.startswith("ocid1.log") - else "" + else "log-name" if param.startswith("ocid1.log") else "" ) aqua_multi_model = os.path.join( @@ -2956,3 +2952,170 @@ def test_from_create_model_deployment_details(self): model_group_config_no_ft.model_dump() == TestDataset.multi_model_deployment_group_config_no_ft ) + + +class TestSingleModelParamResolution(TestAquaDeployment): + def setUp(self): + super().setUp() + + self.app.get_container_config = MagicMock() + self.app.get_container_image = MagicMock(return_value="docker/image:latest") + + mock_shape = MagicMock() + mock_shape.name = "VM.GPU.A10.1" + self.app.list_shapes = MagicMock(return_value=[mock_shape]) + + self.mock_config = MagicMock() + self.mock_config.configuration.get.return_value.parameters.get.return_value = ( + "--default-param 100" + ) + self.app.get_deployment_config = MagicMock(return_value=self.mock_config) + + self.mock_container_item = MagicMock() + self.mock_container_item.spec.cli_param = "--mandatory-param 1" + self.mock_container_item.spec.restricted_params = [] + self.app.get_container_config_item = MagicMock( + return_value=self.mock_container_item + ) + + @patch("ads.model.deployment.model_deployment.ModelDeployment") + @patch("ads.aqua.model.AquaModelApp") + def test_case_1_none_loads_defaults(self, mock_model_app, mock_deploy): + details = CreateModelDeploymentDetails( + model_id="ocid1.model...", + instance_shape="VM.GPU.A10.1", + env_var={}, + ) + + with patch.object(self.app, "_create_deployment") as mock_create_internal: + self.app.create(create_deployment_details=details) + + call_args = mock_create_internal.call_args[1] + final_params = call_args["env_var"]["PARAMS"] + + self.assertIn("--mandatory-param 1", final_params) + self.assertIn("--default-param 100", final_params) + + @patch("ads.model.deployment.model_deployment.ModelDeployment") + @patch("ads.aqua.model.AquaModelApp") + def test_case_2_empty_clears_defaults(self, mock_model_app, mock_deploy): + details = CreateModelDeploymentDetails( + model_id="ocid1.model...", + instance_shape="VM.GPU.A10.1", + env_var={"PARAMS": ""}, + ) + + with patch.object(self.app, "_create_deployment") as mock_create_internal: + self.app.create(create_deployment_details=details) + + call_args = mock_create_internal.call_args[1] + final_params = call_args["env_var"]["PARAMS"] + + self.assertIn("--mandatory-param 1", final_params) + self.assertNotIn("--default-param 100", final_params) + + @patch("ads.model.deployment.model_deployment.ModelDeployment") + @patch("ads.aqua.model.AquaModelApp") + def test_case_3_value_overrides_defaults(self, mock_model_app, mock_deploy): + details = CreateModelDeploymentDetails( + model_id="ocid1.model...", + instance_shape="VM.GPU.A10.1", + env_var={"PARAMS": "--user-override 99"}, + ) + + with patch.object(self.app, "_create_deployment") as mock_create_internal: + self.app.create(create_deployment_details=details) + + call_args = mock_create_internal.call_args[1] + final_params = call_args["env_var"]["PARAMS"] + + self.assertIn("--mandatory-param 1", final_params) + self.assertIn("--user-override 99", final_params) + self.assertNotIn("--default-param 100", final_params) + + @patch("ads.model.deployment.model_deployment.ModelDeployment") + @patch("ads.aqua.model.AquaModelApp") + def test_validation_blocks_restricted_params(self, mock_model_app, mock_deploy): + restricted_mock_item = MagicMock() + restricted_mock_item.spec.cli_param = "--mandatory 1" + restricted_mock_item.spec.restricted_params = ["--seed"] + + self.app.get_container_config_item = MagicMock( + return_value=restricted_mock_item + ) + + details = CreateModelDeploymentDetails( + model_id="ocid1.model...", + instance_shape="VM.GPU.A10.1", + env_var={"PARAMS": "--seed 999"}, + ) + + with self.assertRaises(AquaValueError) as context: + self.app.create(create_deployment_details=details) + + self.assertIn("Parameters ['--seed'] are set by Aqua", str(context.exception)) + + +class TestMultiModelParamResolution(unittest.TestCase): + def setUp(self): + self.mock_config_summary = MagicMock() + self.mock_deploy_config = MagicMock() + + self.mock_deploy_config.configuration.get.return_value.parameters.get.return_value = ( + "--smm-default 500" + ) + self.mock_config_summary.deployment_config.get.return_value = ( + self.mock_deploy_config + ) + + self.mock_details = MagicMock() + self.mock_details.instance_shape = "VM.GPU.A10.2" + + self.container_params = "--mandatory 1" + + def test_case_1_none_loads_defaults(self): + model = AquaMultiModelRef(model_id="ocid1...", gpu_count=1, params=None) + + result = ModelGroupConfig._merge_gpu_count_params( + model, + self.mock_config_summary, + self.mock_details, + "container_key", + self.container_params, + ) + + self.assertIn("--mandatory 1", result) + self.assertIn("--smm-default 500", result) + + def test_case_2_empty_clears_defaults(self): + model = AquaMultiModelRef(model_id="ocid1...", gpu_count=1, params={}) + + result = ModelGroupConfig._merge_gpu_count_params( + model, + self.mock_config_summary, + self.mock_details, + "container_key", + self.container_params, + ) + + self.assertIn("--mandatory 1", result) + self.assertNotIn("--smm-default 500", result) + + def test_case_3_value_overrides_defaults(self): + model = AquaMultiModelRef( + model_id="ocid1...", + gpu_count=1, + params={"--custom": "99"}, + ) + + result = ModelGroupConfig._merge_gpu_count_params( + model, + self.mock_config_summary, + self.mock_details, + "container_key", + self.container_params, + ) + + self.assertIn("--mandatory 1", result) + self.assertIn("--custom 99", result) + self.assertNotIn("--smm-default 500", result) diff --git a/tests/unitary/with_extras/langchain/chat_models/test_oci_data_science.py b/tests/unitary/with_extras/langchain/chat_models/test_oci_data_science.py index f2b036f1c..a69ebd614 100644 --- a/tests/unitary/with_extras/langchain/chat_models/test_oci_data_science.py +++ b/tests/unitary/with_extras/langchain/chat_models/test_oci_data_science.py @@ -160,7 +160,7 @@ def test_stream_vllm(*args: Any) -> None: else: output += chunk count += 1 - assert count == 6 + assert count == 5 assert output is not None if output is not None: assert str(output.content).strip() == CONST_COMPLETION