Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dependencies = [
"tenacity>=8.0.1",
"watchfiles>=0.19.0,<0.20",
"truss-transfer>=0.0.37,<0.0.40",
"gql-query-builder (>=0.1.7,<0.2.0)",
]

[project.urls]
Expand Down Expand Up @@ -102,6 +103,9 @@ default-groups = [
"dev-server",
]

[tool.uv.extra-build-dependencies]
gql-query-builder = ["pip"]

# Simplified Hatchling configuration - let it auto-discover truss, manually include others
[tool.hatch.build.targets.sdist]
include = [
Expand Down
65 changes: 24 additions & 41 deletions truss/remote/baseten/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, List, Mapping, Optional

import requests
from gql_query_builder import GqlQuery
from pydantic import BaseModel, Field

from truss.base.custom_types import SafeModel
Expand Down Expand Up @@ -328,52 +329,34 @@ def deploy_chain_atomic(
):
if allow_truss_download is None:
allow_truss_download = True
entrypoint_str = _chainlet_data_atomic_to_graphql_mutation(entrypoint)

dependencies_str = ", ".join(
[
mutation_params = {
"chain_id": chain_id,
"chain_name": chain_name,
"environment": environment,
"original_source_artifact_s3_key": original_source_artifact_s3_key,
"allow_truss_download": "false" if allow_truss_download is False else None,
"is_draft": is_draft,
"entrypoint": _chainlet_data_atomic_to_graphql_mutation(entrypoint),
"dependencies": [
_chainlet_data_atomic_to_graphql_mutation(dependency)
for dependency in dependencies
]
)

params = []
if chain_id:
params.append(f'chain_id: "{chain_id}"')
if chain_name:
params.append(f'chain_name: "{chain_name}"')
if environment:
params.append(f'environment: "{environment}"')
if original_source_artifact_s3_key:
params.append(
f'original_source_artifact_s3_key: "{original_source_artifact_s3_key}"'
)

params.append(f"is_draft: {str(is_draft).lower()}")
if allow_truss_download is False:
params.append("allow_truss_download: false")

params_str = PARAMS_INDENT.join(params)
],
"truss_user_env": "$trussUserEnv",
}
mutation_params = {
str(k): v for k, v in mutation_params.items() if v is not None
}

query_string = f"""
mutation ($trussUserEnv: String) {{
deploy_chain_atomic(
{params_str}
entrypoint: {entrypoint_str}
dependencies: [{dependencies_str}]
truss_user_env: $trussUserEnv
) {{
chain_deployment {{
id
chain {{
id
hostname
}}
}}
}}
}}
"""

gql = GqlQuery()
gql.operation(
"mutation",
"deploy_chain_atomic",
mutation_params,
["chain_deployment { id chain { id hostname } }"],
)
query_string = gql.generate()
resp = self._post_graphql_query(
query_string, variables={"trussUserEnv": truss_user_env.json()}
)
Expand Down
68 changes: 63 additions & 5 deletions truss/tests/remote/baseten/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,8 @@ def test_deploy_chain_deployment(mock_post, baseten_api):

gql_mutation = mock_post.call_args[1]["json"]["query"]

assert 'environment: "production"' in gql_mutation
assert 'chain_id: "chain_id"' in gql_mutation
assert "environment: production" in gql_mutation
assert "chain_id: chain_id" in gql_mutation
assert "dependencies:" in gql_mutation
assert "entrypoint:" in gql_mutation

Expand All @@ -378,8 +378,8 @@ def test_deploy_chain_deployment_with_gitinfo(mock_post, baseten_api):

gql_mutation = mock_post.call_args[1]["json"]["query"]

assert 'environment: "production"' in gql_mutation
assert 'chain_id: "chain_id"' in gql_mutation
assert "environment: production" in gql_mutation
assert "chain_id: chain_id" in gql_mutation
assert "dependencies:" in gql_mutation
assert "entrypoint:" in gql_mutation

Expand All @@ -402,12 +402,70 @@ def test_deploy_chain_deployment_no_environment(mock_post, baseten_api):

gql_mutation = mock_post.call_args[1]["json"]["query"]

assert 'chain_id: "chain_id"' in gql_mutation
assert "chain_id: chain_id" in gql_mutation
assert "environment" not in gql_mutation
assert "dependencies:" in gql_mutation
assert "entrypoint:" in gql_mutation


@mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
def test_deploy_chain_deployment_with_dependencies(mock_post, baseten_api):
dependencies = [
ChainletDataAtomic(
name="dependency-1",
oracle=OracleData(
model_name="dep-model-1",
s3_key="dep-s3-key-1",
encoded_config_str="dep-encoded-config-str-1",
),
),
ChainletDataAtomic(
name="dependency-2",
oracle=OracleData(
model_name="dep-model-2",
s3_key="dep-s3-key-2",
encoded_config_str="dep-encoded-config-str-2",
),
),
]

baseten_api.deploy_chain_atomic(
environment="production",
chain_id="chain_id",
dependencies=dependencies,
entrypoint=ChainletDataAtomic(
name="chainlet-1",
oracle=OracleData(
model_name="model-1",
s3_key="s3-key-1",
encoded_config_str="encoded-config-str-1",
),
),
truss_user_env=b10_types.TrussUserEnv.collect(),
)

gql_mutation = mock_post.call_args[1]["json"]["query"]

# Single regex to check all assertions
import re

pattern = (
r"(?=.*environment: production)"
r"(?=.*chain_id: chain_id)"
r"(?=.*dependencies:)"
r"(?=.*entrypoint:)"
r'(?=.*name: "dependency-1")'
r'(?=.*name: "dependency-2")'
r'(?=.*model_name: "dep-model-1")'
r'(?=.*model_name: "dep-model-2")'
r'(?=.*s3_key: "dep-s3-key-1")'
r'(?=.*s3_key: "dep-s3-key-2")'
)
assert re.search(pattern, gql_mutation), (
f"GraphQL mutation does not contain all expected elements: {gql_mutation}"
)


@mock.patch("requests.post", return_value=mock_upsert_training_project_response())
def test_upsert_training_project(mock_post, baseten_api):
baseten_api.upsert_training_project(
Expand Down
10 changes: 9 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading