Skip to content

Commit 296dd0f

Browse files
shawn-yang-googlecopybara-github
authored andcommitted
chore: Support specifying agent_framework in Agent Engine creation and update.
PiperOrigin-RevId: 825734997
1 parent 05834cb commit 296dd0f

File tree

4 files changed

+231
-10
lines changed

4 files changed

+231
-10
lines changed

tests/unit/vertexai/genai/test_agent_engines.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
import pytest
4141

4242

43-
_TEST_AGENT_FRAMEWORK = "test-agent-framework"
43+
_TEST_AGENT_FRAMEWORK = "google-adk"
4444
GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = (
4545
"GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY"
4646
)
@@ -976,9 +976,11 @@ def test_create_agent_engine_config_with_source_packages(
976976
entrypoint_object="app",
977977
requirements_file=requirements_file_path,
978978
class_methods=_TEST_AGENT_ENGINE_CLASS_METHODS,
979+
agent_framework=_TEST_AGENT_FRAMEWORK,
979980
)
980981
assert config["display_name"] == _TEST_AGENT_ENGINE_DISPLAY_NAME
981982
assert config["description"] == _TEST_AGENT_ENGINE_DESCRIPTION
983+
assert config["spec"]["agent_framework"] == _TEST_AGENT_FRAMEWORK
982984
assert config["spec"]["source_code_spec"] == {
983985
"inline_source": {"source_archive": "test_tarball"},
984986
"python_spec": {
@@ -1500,6 +1502,7 @@ def test_create_agent_engine_with_env_vars_dict(
15001502
entrypoint_module=None,
15011503
entrypoint_object=None,
15021504
requirements_file=None,
1505+
agent_framework=None,
15031506
)
15041507
request_mock.assert_called_with(
15051508
"post",
@@ -1513,7 +1516,9 @@ def test_create_agent_engine_with_env_vars_dict(
15131516
"package_spec": {
15141517
"pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI,
15151518
"python_version": _TEST_PYTHON_VERSION,
1516-
"requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI,
1519+
"requirements_gcs_uri": (
1520+
_TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI
1521+
),
15171522
},
15181523
},
15191524
},
@@ -1586,6 +1591,7 @@ def test_create_agent_engine_with_custom_service_account(
15861591
entrypoint_module=None,
15871592
entrypoint_object=None,
15881593
requirements_file=None,
1594+
agent_framework=None,
15891595
)
15901596
request_mock.assert_called_with(
15911597
"post",
@@ -1674,6 +1680,7 @@ def test_create_agent_engine_with_experimental_mode(
16741680
entrypoint_module=None,
16751681
entrypoint_object=None,
16761682
requirements_file=None,
1683+
agent_framework=None,
16771684
)
16781685
request_mock.assert_called_with(
16791686
"post",
@@ -1826,6 +1833,7 @@ def test_create_agent_engine_with_class_methods(
18261833
entrypoint_module=None,
18271834
entrypoint_object=None,
18281835
requirements_file=None,
1836+
agent_framework=None,
18291837
)
18301838
request_mock.assert_called_with(
18311839
"post",
@@ -1845,6 +1853,92 @@ def test_create_agent_engine_with_class_methods(
18451853
None,
18461854
)
18471855

1856+
@mock.patch.object(agent_engines.AgentEngines, "_create_config")
1857+
@mock.patch.object(_agent_engines_utils, "_await_operation")
1858+
def test_create_agent_engine_with_agent_framework(
1859+
self,
1860+
mock_await_operation,
1861+
mock_create_config,
1862+
):
1863+
mock_create_config.return_value = {
1864+
"display_name": _TEST_AGENT_ENGINE_DISPLAY_NAME,
1865+
"description": _TEST_AGENT_ENGINE_DESCRIPTION,
1866+
"spec": {
1867+
"package_spec": {
1868+
"python_version": _TEST_PYTHON_VERSION,
1869+
"pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI,
1870+
"requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI,
1871+
},
1872+
"class_methods": [_TEST_AGENT_ENGINE_CLASS_METHOD_1],
1873+
"agent_framework": _TEST_AGENT_FRAMEWORK,
1874+
},
1875+
}
1876+
mock_await_operation.return_value = _genai_types.AgentEngineOperation(
1877+
response=_genai_types.ReasoningEngine(
1878+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
1879+
spec=_TEST_AGENT_ENGINE_SPEC,
1880+
)
1881+
)
1882+
with mock.patch.object(
1883+
self.client.agent_engines._api_client, "request"
1884+
) as request_mock:
1885+
request_mock.return_value = genai_types.HttpResponse(body="")
1886+
self.client.agent_engines.create(
1887+
agent=self.test_agent,
1888+
config=_genai_types.AgentEngineConfig(
1889+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
1890+
requirements=_TEST_AGENT_ENGINE_REQUIREMENTS,
1891+
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
1892+
staging_bucket=_TEST_STAGING_BUCKET,
1893+
agent_framework=_TEST_AGENT_FRAMEWORK,
1894+
),
1895+
)
1896+
mock_create_config.assert_called_with(
1897+
mode="create",
1898+
agent=self.test_agent,
1899+
staging_bucket=_TEST_STAGING_BUCKET,
1900+
requirements=_TEST_AGENT_ENGINE_REQUIREMENTS,
1901+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
1902+
description=None,
1903+
gcs_dir_name=None,
1904+
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
1905+
env_vars=None,
1906+
service_account=None,
1907+
context_spec=None,
1908+
psc_interface_config=None,
1909+
min_instances=None,
1910+
max_instances=None,
1911+
resource_limits=None,
1912+
container_concurrency=None,
1913+
encryption_spec=None,
1914+
labels=None,
1915+
agent_server_mode=None,
1916+
class_methods=None,
1917+
source_packages=None,
1918+
entrypoint_module=None,
1919+
entrypoint_object=None,
1920+
requirements_file=None,
1921+
agent_framework=_TEST_AGENT_FRAMEWORK,
1922+
)
1923+
request_mock.assert_called_with(
1924+
"post",
1925+
"reasoningEngines",
1926+
{
1927+
"displayName": _TEST_AGENT_ENGINE_DISPLAY_NAME,
1928+
"description": _TEST_AGENT_ENGINE_DESCRIPTION,
1929+
"spec": {
1930+
"agent_framework": _TEST_AGENT_FRAMEWORK,
1931+
"class_methods": [_TEST_AGENT_ENGINE_CLASS_METHOD_1],
1932+
"package_spec": {
1933+
"pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI,
1934+
"python_version": _TEST_PYTHON_VERSION,
1935+
"requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI,
1936+
},
1937+
},
1938+
},
1939+
None,
1940+
)
1941+
18481942
@pytest.mark.usefixtures("caplog")
18491943
@mock.patch.object(_agent_engines_utils, "_prepare")
18501944
@mock.patch.object(_agent_engines_utils, "_await_operation")

vertexai/_genai/_agent_engines_utils.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,16 @@
128128
_BASE_MODULES = set(_BUILTIN_MODULE_NAMES + tuple(_STDLIB_MODULE_NAMES))
129129
_BLOB_FILENAME = "agent_engine.pkl"
130130
_DEFAULT_AGENT_FRAMEWORK = "custom"
131+
_SUPPORTED_AGENT_FRAMEWORKS = frozenset(
132+
[
133+
"google-adk",
134+
"langchain",
135+
"langgraph",
136+
"ag2",
137+
"llama-index",
138+
"custom",
139+
]
140+
)
131141
_DEFAULT_ASYNC_METHOD_NAME = "async_query"
132142
_DEFAULT_ASYNC_METHOD_RETURN_TYPE = "Coroutine[Any]"
133143
_DEFAULT_ASYNC_STREAM_METHOD_NAME = "async_stream_query"
@@ -705,13 +715,43 @@ def _generate_schema(
705715
return schema
706716

707717

708-
def _get_agent_framework(*, agent: _AgentEngineInterface) -> str:
709-
if (
710-
hasattr(agent, _AGENT_FRAMEWORK_ATTR)
711-
and getattr(agent, _AGENT_FRAMEWORK_ATTR) is not None
712-
and isinstance(getattr(agent, _AGENT_FRAMEWORK_ATTR), str)
713-
):
714-
return getattr(agent, _AGENT_FRAMEWORK_ATTR)
718+
def _get_agent_framework(
719+
*,
720+
agent_framework: Optional[str],
721+
agent: _AgentEngineInterface,
722+
) -> str:
723+
"""Gets the agent framework to use.
724+
725+
The agent framework is determined in the following order of priority:
726+
1. The `agent_framework` passed to this function.
727+
2. The `agent_framework` attribute on the `agent` object.
728+
3. The default framework, "custom".
729+
730+
Args:
731+
agent_framework (str):
732+
The agent framework provided by the user.
733+
agent (_AgentEngineInterface):
734+
The agent engine instance.
735+
736+
Returns:
737+
str: The name of the agent framework to use.
738+
"""
739+
if agent_framework is not None and agent_framework in _SUPPORTED_AGENT_FRAMEWORKS:
740+
logger.info(f"Using agent framework: {agent_framework}")
741+
return agent_framework
742+
if hasattr(agent, _AGENT_FRAMEWORK_ATTR):
743+
agent_framework_attr = getattr(agent, _AGENT_FRAMEWORK_ATTR)
744+
if (
745+
agent_framework_attr is not None
746+
and isinstance(agent_framework_attr, str)
747+
and agent_framework_attr in _SUPPORTED_AGENT_FRAMEWORKS
748+
):
749+
logger.info(f"Using agent framework: {agent_framework_attr}")
750+
return agent_framework_attr
751+
logger.info(
752+
f"The provided agent framework {agent_framework} is not supported."
753+
f" Defaulting to {_DEFAULT_AGENT_FRAMEWORK}."
754+
)
715755
return _DEFAULT_AGENT_FRAMEWORK
716756

717757

vertexai/_genai/agent_engines.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ def _CreateAgentEngineConfig_to_vertex(
9494
getv(from_object, ["requirements_file"]),
9595
)
9696

97+
if getv(from_object, ["agent_framework"]) is not None:
98+
setv(parent_object, ["agentFramework"], getv(from_object, ["agent_framework"]))
99+
97100
return to_object
98101

99102

@@ -285,6 +288,9 @@ def _UpdateAgentEngineConfig_to_vertex(
285288
getv(from_object, ["requirements_file"]),
286289
)
287290

291+
if getv(from_object, ["agent_framework"]) is not None:
292+
setv(parent_object, ["agentFramework"], getv(from_object, ["agent_framework"]))
293+
288294
if getv(from_object, ["update_mask"]) is not None:
289295
setv(
290296
parent_object, ["_query", "updateMask"], getv(from_object, ["update_mask"])
@@ -923,6 +929,7 @@ def create(
923929
entrypoint_module=config.entrypoint_module,
924930
entrypoint_object=config.entrypoint_object,
925931
requirements_file=config.requirements_file,
932+
agent_framework=config.agent_framework,
926933
)
927934
operation = self._create(config=api_config)
928935
# TODO: Use a more specific link.
@@ -986,6 +993,7 @@ def _create_config(
986993
entrypoint_module: Optional[str] = None,
987994
entrypoint_object: Optional[str] = None,
988995
requirements_file: Optional[str] = None,
996+
agent_framework: Optional[str] = None,
989997
) -> types.UpdateAgentEngineConfigDict:
990998
import sys
991999

@@ -1195,7 +1203,10 @@ def _create_config(
11951203
] = agent_server_mode
11961204

11971205
agent_engine_spec["agent_framework"] = (
1198-
_agent_engines_utils._get_agent_framework(agent=agent)
1206+
_agent_engines_utils._get_agent_framework(
1207+
agent_framework=agent_framework,
1208+
agent=agent,
1209+
)
11991210
)
12001211
update_masks.append("spec.agent_framework")
12011212
config["spec"] = agent_engine_spec
@@ -1423,6 +1434,7 @@ def update(
14231434
entrypoint_module=config.entrypoint_module,
14241435
entrypoint_object=config.entrypoint_object,
14251436
requirements_file=config.requirements_file,
1437+
agent_framework=config.agent_framework,
14261438
)
14271439
operation = self._update(name=name, config=api_config)
14281440
logger.info(

vertexai/_genai/types/common.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5366,6 +5366,19 @@ class CreateAgentEngineConfig(_common.BaseModel):
53665366
the source package.
53675367
""",
53685368
)
5369+
agent_framework: Optional[
5370+
Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"]
5371+
] = Field(
5372+
default=None,
5373+
description="""The agent framework to be used for the Agent Engine.
5374+
The OSS agent framework used to develop the agent.
5375+
Currently supported values: "google-adk", "langchain", "langgraph",
5376+
"ag2", "llama-index", "custom".
5377+
If not specified:
5378+
- If `agent` is specified, the agent framework will be auto-detected.
5379+
- If `source_packages` is specified, the agent framework will
5380+
default to "custom".""",
5381+
)
53695382

53705383

53715384
class CreateAgentEngineConfigDict(TypedDict, total=False):
@@ -5464,6 +5477,18 @@ class CreateAgentEngineConfigDict(TypedDict, total=False):
54645477
the source package.
54655478
"""
54665479

5480+
agent_framework: Optional[
5481+
Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"]
5482+
]
5483+
"""The agent framework to be used for the Agent Engine.
5484+
The OSS agent framework used to develop the agent.
5485+
Currently supported values: "google-adk", "langchain", "langgraph",
5486+
"ag2", "llama-index", "custom".
5487+
If not specified:
5488+
- If `agent` is specified, the agent framework will be auto-detected.
5489+
- If `source_packages` is specified, the agent framework will
5490+
default to "custom"."""
5491+
54675492

54685493
CreateAgentEngineConfigOrDict = Union[
54695494
CreateAgentEngineConfig, CreateAgentEngineConfigDict
@@ -6067,6 +6092,19 @@ class UpdateAgentEngineConfig(_common.BaseModel):
60676092
the source package.
60686093
""",
60696094
)
6095+
agent_framework: Optional[
6096+
Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"]
6097+
] = Field(
6098+
default=None,
6099+
description="""The agent framework to be used for the Agent Engine.
6100+
The OSS agent framework used to develop the agent.
6101+
Currently supported values: "google-adk", "langchain", "langgraph",
6102+
"ag2", "llama-index", "custom".
6103+
If not specified:
6104+
- If `agent` is specified, the agent framework will be auto-detected.
6105+
- If `source_packages` is specified, the agent framework will
6106+
default to "custom".""",
6107+
)
60706108
update_mask: Optional[str] = Field(
60716109
default=None,
60726110
description="""The update mask to apply. For the `FieldMask` definition, see
@@ -6170,6 +6208,18 @@ class UpdateAgentEngineConfigDict(TypedDict, total=False):
61706208
the source package.
61716209
"""
61726210

6211+
agent_framework: Optional[
6212+
Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"]
6213+
]
6214+
"""The agent framework to be used for the Agent Engine.
6215+
The OSS agent framework used to develop the agent.
6216+
Currently supported values: "google-adk", "langchain", "langgraph",
6217+
"ag2", "llama-index", "custom".
6218+
If not specified:
6219+
- If `agent` is specified, the agent framework will be auto-detected.
6220+
- If `source_packages` is specified, the agent framework will
6221+
default to "custom"."""
6222+
61736223
update_mask: Optional[str]
61746224
"""The update mask to apply. For the `FieldMask` definition, see
61756225
https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask."""
@@ -12907,6 +12957,19 @@ class AgentEngineConfig(_common.BaseModel):
1290712957
the source package.
1290812958
""",
1290912959
)
12960+
agent_framework: Optional[
12961+
Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"]
12962+
] = Field(
12963+
default=None,
12964+
description="""The agent framework to be used for the Agent Engine.
12965+
The OSS agent framework used to develop the agent.
12966+
Currently supported values: "google-adk", "langchain", "langgraph",
12967+
"ag2", "llama-index", "custom".
12968+
If not specified:
12969+
- If `agent` is specified, the agent framework will be auto-detected.
12970+
- If `source_packages` is specified, the agent framework will
12971+
default to "custom".""",
12972+
)
1291012973

1291112974

1291212975
class AgentEngineConfigDict(TypedDict, total=False):
@@ -13034,6 +13097,18 @@ class AgentEngineConfigDict(TypedDict, total=False):
1303413097
the source package.
1303513098
"""
1303613099

13100+
agent_framework: Optional[
13101+
Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"]
13102+
]
13103+
"""The agent framework to be used for the Agent Engine.
13104+
The OSS agent framework used to develop the agent.
13105+
Currently supported values: "google-adk", "langchain", "langgraph",
13106+
"ag2", "llama-index", "custom".
13107+
If not specified:
13108+
- If `agent` is specified, the agent framework will be auto-detected.
13109+
- If `source_packages` is specified, the agent framework will
13110+
default to "custom"."""
13111+
1303713112

1303813113
AgentEngineConfigOrDict = Union[AgentEngineConfig, AgentEngineConfigDict]
1303913114

0 commit comments

Comments
 (0)