diff --git a/truss-chains/truss_chains/deployment/deployment_client.py b/truss-chains/truss_chains/deployment/deployment_client.py index 47c9b9c8d..646cc8a58 100644 --- a/truss-chains/truss_chains/deployment/deployment_client.py +++ b/truss-chains/truss_chains/deployment/deployment_client.py @@ -531,6 +531,7 @@ def _create_baseten_chain( environment=baseten_options.environment, progress_bar=progress_bar, disable_chain_download=baseten_options.disable_chain_download, + deployment_name=baseten_options.deployment_name, ) return BasetenChainService( baseten_options.chain_name, diff --git a/truss-chains/truss_chains/private_types.py b/truss-chains/truss_chains/private_types.py index 579502eaa..64ca00602 100644 --- a/truss-chains/truss_chains/private_types.py +++ b/truss-chains/truss_chains/private_types.py @@ -266,6 +266,7 @@ class PushOptionsBaseten(PushOptions): include_git_info: bool working_dir: pathlib.Path disable_chain_download: bool = False + deployment_name: Optional[str] = None @classmethod def create( @@ -279,6 +280,7 @@ def create( working_dir: pathlib.Path, environment: Optional[str] = None, disable_chain_download: bool = False, + deployment_name: Optional[str] = None, ) -> "PushOptionsBaseten": if promote and not environment: environment = PRODUCTION_ENVIRONMENT_NAME @@ -293,6 +295,7 @@ def create( include_git_info=include_git_info, working_dir=working_dir, disable_chain_download=disable_chain_download, + deployment_name=deployment_name, ) diff --git a/truss/cli/chains_commands.py b/truss/cli/chains_commands.py index bcecbf3d0..dc2b0db71 100644 --- a/truss/cli/chains_commands.py +++ b/truss/cli/chains_commands.py @@ -219,6 +219,15 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]: default=False, help="Disable downloading of pushed chain source code from the UI.", ) +@click.option( + "--deployment-name", + type=str, + required=False, + help=( + "Name of the deployment created by the publish. Can only be used " + "in combination with '--publish' or '--promote'." + ), +) @click.pass_context @common.common_options() def push_chain( @@ -236,6 +245,7 @@ def push_chain( experimental_watch_chainlet_names: Optional[str], include_git_info: bool = False, disable_chain_download: bool = False, + deployment_name: Optional[str] = None, ) -> None: """ Deploys a chain remotely. @@ -294,6 +304,7 @@ def push_chain( include_git_info=include_git_info, working_dir=source.parent if source.is_file() else source.resolve(), disable_chain_download=disable_chain_download, + deployment_name=deployment_name, ) service = deployment_client.push( entrypoint_cls, options, progress_bar=progress.Progress diff --git a/truss/remote/baseten/api.py b/truss/remote/baseten/api.py index 7b25a1038..9f8b95c4c 100644 --- a/truss/remote/baseten/api.py +++ b/truss/remote/baseten/api.py @@ -333,6 +333,7 @@ def deploy_chain_atomic( is_draft: bool = False, original_source_artifact_s3_key: Optional[str] = None, allow_truss_download: Optional[bool] = True, + deployment_name: Optional[str] = None, ): if allow_truss_download is None: allow_truss_download = True @@ -360,6 +361,8 @@ def deploy_chain_atomic( params.append(f"is_draft: {str(is_draft).lower()}") if allow_truss_download is False: params.append("allow_truss_download: false") + if deployment_name: + params.append(f'deployment_name: "{deployment_name}"') params_str = PARAMS_INDENT.join(params) diff --git a/truss/remote/baseten/core.py b/truss/remote/baseten/core.py index 686275c48..57666c39b 100644 --- a/truss/remote/baseten/core.py +++ b/truss/remote/baseten/core.py @@ -132,6 +132,7 @@ def create_chain_atomic( environment: Optional[str], original_source_artifact_s3_key: Optional[str] = None, allow_truss_download: bool = True, + deployment_name: Optional[str] = None, ) -> ChainDeploymentHandleAtomic: if environment and is_draft: logging.info( @@ -156,6 +157,7 @@ def create_chain_atomic( truss_user_env=truss_user_env, original_source_artifact_s3_key=original_source_artifact_s3_key, allow_truss_download=allow_truss_download, + deployment_name=deployment_name, ) elif chain_id: # This is the only case where promote has relevance, since @@ -171,6 +173,7 @@ def create_chain_atomic( truss_user_env=truss_user_env, original_source_artifact_s3_key=original_source_artifact_s3_key, allow_truss_download=allow_truss_download, + deployment_name=deployment_name, ) except ApiError as e: if ( @@ -193,6 +196,7 @@ def create_chain_atomic( truss_user_env=truss_user_env, original_source_artifact_s3_key=original_source_artifact_s3_key, allow_truss_download=allow_truss_download, + deployment_name=deployment_name, ) return ChainDeploymentHandleAtomic( diff --git a/truss/remote/baseten/remote.py b/truss/remote/baseten/remote.py index 54ab77aac..f83332a2e 100644 --- a/truss/remote/baseten/remote.py +++ b/truss/remote/baseten/remote.py @@ -280,6 +280,7 @@ def push_chain_atomic( environment: Optional[str] = None, progress_bar: Optional[Type["progress.Progress"]] = None, disable_chain_download: bool = False, + deployment_name: Optional[str] = None, ) -> ChainDeploymentHandleAtomic: # If we are promoting a model to an environment after deploy, it must be published. # Draft models cannot be promoted. @@ -300,6 +301,7 @@ def push_chain_atomic( origin=custom_types.ModelOrigin.CHAINS, progress_bar=progress_bar, disable_truss_download=disable_chain_download, + deployment_name=deployment_name, ) oracle_data = custom_types.OracleData( model_name=push_data.model_name, @@ -337,6 +339,7 @@ def push_chain_atomic( environment=environment, original_source_artifact_s3_key=raw_chain_s3_key, allow_truss_download=not disable_chain_download, + deployment_name=deployment_name, ) logging.info("Successfully pushed to baseten. Chain is building and deploying.") return chain_deployment_handle diff --git a/truss/tests/cli/test_chains_cli.py b/truss/tests/cli/test_chains_cli.py index dcd1834ad..8ec505a28 100644 --- a/truss/tests/cli/test_chains_cli.py +++ b/truss/tests/cli/test_chains_cli.py @@ -98,3 +98,47 @@ def test_chains_push_help_includes_disable_chain_download(): assert result.exit_code == 0 assert "--disable-chain-download" in result.output + + +def test_chains_push_with_deployment_name_flag(): + """Test that --deployment-name flag is properly parsed and passed through.""" + runner = CliRunner() + + mock_entrypoint_cls = Mock() + mock_entrypoint_cls.meta_data.chain_name = "test_chain" + mock_entrypoint_cls.display_name = "TestChain" + + mock_service = Mock() + mock_service.run_remote_url = "http://test.com/run_remote" + mock_service.is_websocket = False + + with patch( + "truss_chains.framework.ChainletImporter.import_target" + ) as mock_importer: + with patch("truss_chains.deployment.deployment_client.push") as mock_push: + mock_importer.return_value.__enter__.return_value = mock_entrypoint_cls + mock_push.return_value = mock_service + + result = runner.invoke( + truss_cli, + [ + "chains", + "push", + "test_chain.py", + "--deployment-name", + "custom_deployment", + "--remote", + "test_remote", + "--publish", + "--dryrun", + ], + ) + + assert result.exit_code == 0 + + mock_push.assert_called_once() + call_args = mock_push.call_args + options = call_args[0][1] + + assert hasattr(options, "deployment_name") + assert options.deployment_name == "custom_deployment" diff --git a/truss/tests/remote/baseten/test_api.py b/truss/tests/remote/baseten/test_api.py index 940e0228d..9777ff41a 100644 --- a/truss/tests/remote/baseten/test_api.py +++ b/truss/tests/remote/baseten/test_api.py @@ -452,6 +452,30 @@ def test_deploy_chain_deployment(mock_post, baseten_api): assert 'chain_id: "chain_id"' in gql_mutation assert "dependencies:" in gql_mutation assert "entrypoint:" in gql_mutation + assert "deployment_name" not in gql_mutation + + +@mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response()) +def test_deploy_chain_deployment_with_deployment_name(mock_post, baseten_api): + baseten_api.deploy_chain_atomic( + environment="production", + chain_id="chain_id", + dependencies=[], + entrypoint=ChainletDataAtomic( + name="chainlet-1", + oracle=OracleData( + model_name="model-1", + s3_key="s3-key-1", + encoded_config_str="encoded-config-str-1", + ), + ), + truss_user_env=b10_types.TrussUserEnv.collect(), + deployment_name="chain-deployment-name", + ) + + gql_mutation = mock_post.call_args[1]["json"]["query"] + + assert 'deployment_name: "chain-deployment-name"' in gql_mutation @mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response()) diff --git a/truss/tests/remote/baseten/test_chain_upload.py b/truss/tests/remote/baseten/test_chain_upload.py index 6954e6903..d4731a200 100644 --- a/truss/tests/remote/baseten/test_chain_upload.py +++ b/truss/tests/remote/baseten/test_chain_upload.py @@ -186,6 +186,7 @@ def test_push_chain_atomic_with_chain_upload( chain_root = context["chain_root"] context["mock_prepare_push"].return_value = mock_push_data + deployment_name = "custom_deployment" result = remote.push_chain_atomic( chain_name=chain_name, @@ -194,13 +195,18 @@ def test_push_chain_atomic_with_chain_upload( truss_user_env=truss_user_env, chain_root=chain_root, publish=True, + deployment_name=deployment_name, ) assert result == mock_create_chain_atomic.return_value mock_archive_dir.assert_called_once_with(dir=chain_root, progress_bar=None) mock_upload_chain_artifact.assert_called_once() - mock_create_chain_atomic.assert_called_once() + create_kwargs = mock_create_chain_atomic.call_args.kwargs + assert create_kwargs["deployment_name"] == deployment_name + + prepare_kwargs = context["mock_prepare_push"].call_args.kwargs + assert prepare_kwargs["deployment_name"] == deployment_name @patch("truss.remote.baseten.remote.create_chain_atomic") @@ -239,6 +245,9 @@ def test_push_chain_atomic_without_chain_upload( mock_upload.assert_not_called() mock_create_chain_atomic.assert_called_once() + create_kwargs = mock_create_chain_atomic.call_args.kwargs + assert "deployment_name" in create_kwargs + assert create_kwargs["deployment_name"] is None @patch("truss.remote.baseten.core.multipart_upload_boto3") diff --git a/truss/tests/remote/baseten/test_remote.py b/truss/tests/remote/baseten/test_remote.py index f7e01c381..c4030bc7a 100644 --- a/truss/tests/remote/baseten/test_remote.py +++ b/truss/tests/remote/baseten/test_remote.py @@ -378,6 +378,57 @@ def test_create_chain_no_existing_chain(remote): assert deployment_handle.chain_deployment_id == "new-chain-deployment-id" +def test_create_chain_with_deployment_name(remote): + with requests_mock.Mocker() as m: + m.post( + _TEST_REMOTE_GRAPHQL_PATH, + [ + {"json": {"data": {"chains": []}}}, + { + "json": { + "data": { + "deploy_chain_atomic": { + "chain_deployment": { + "id": "new-chain-deployment-id", + "chain": { + "id": "new-chain-id", + "hostname": "hostname", + }, + } + } + } + } + }, + ], + ) + + deployment_name = "chain-deployment" + create_chain_atomic( + api=remote.api, + chain_name="new_chain", + entrypoint=ChainletDataAtomic( + name="chainlet-1", + oracle=OracleData( + model_name="model-1", + s3_key="s3-key-1", + encoded_config_str="encoded-config-str-1", + ), + ), + dependencies=[], + truss_user_env=b10_types.TrussUserEnv.collect(), + is_draft=False, + environment=None, + deployment_name=deployment_name, + ) + + create_chain_graphql_request = m.request_history[1] + + assert ( + 'deployment_name: "chain-deployment"' + in create_chain_graphql_request.json()["query"] + ) + + def test_create_chain_with_existing_chain_promote_to_environment_publish_false(remote): mock_deploy_response = { "chain_deployment": {