Skip to content

Commit 968efe5

Browse files
shawn-yang-googlecopybara-github
authored andcommitted
chore: Support specifying agent_framework in Agent Engine creation and update.
PiperOrigin-RevId: 825734997
1 parent e600277 commit 968efe5

File tree

5 files changed

+245
-9
lines changed

5 files changed

+245
-9
lines changed

tests/unit/vertexai/genai/test_agent_engines.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
import pytest
4141

4242

43-
_TEST_AGENT_FRAMEWORK = "test-agent-framework"
43+
_TEST_AGENT_FRAMEWORK = _genai_types.ReasoningEngineAgentFramework.GOOGLE_ADK
44+
_TEST_AGENT_FRAMEWORK_STR = "google-adk"
4445
GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = (
4546
"GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY"
4647
)
@@ -976,9 +977,11 @@ def test_create_agent_engine_config_with_source_packages(
976977
entrypoint_object="app",
977978
requirements_file=requirements_file_path,
978979
class_methods=_TEST_AGENT_ENGINE_CLASS_METHODS,
980+
agent_framework=_TEST_AGENT_FRAMEWORK,
979981
)
980982
assert config["display_name"] == _TEST_AGENT_ENGINE_DISPLAY_NAME
981983
assert config["description"] == _TEST_AGENT_ENGINE_DESCRIPTION
984+
assert config["spec"]["agent_framework"] == _TEST_AGENT_FRAMEWORK_STR
982985
assert config["spec"]["source_code_spec"] == {
983986
"inline_source": {"source_archive": "test_tarball"},
984987
"python_spec": {
@@ -1500,6 +1503,7 @@ def test_create_agent_engine_with_env_vars_dict(
15001503
entrypoint_module=None,
15011504
entrypoint_object=None,
15021505
requirements_file=None,
1506+
agent_framework=None,
15031507
)
15041508
request_mock.assert_called_with(
15051509
"post",
@@ -1586,6 +1590,7 @@ def test_create_agent_engine_with_custom_service_account(
15861590
entrypoint_module=None,
15871591
entrypoint_object=None,
15881592
requirements_file=None,
1593+
agent_framework=None,
15891594
)
15901595
request_mock.assert_called_with(
15911596
"post",
@@ -1674,6 +1679,7 @@ def test_create_agent_engine_with_experimental_mode(
16741679
entrypoint_module=None,
16751680
entrypoint_object=None,
16761681
requirements_file=None,
1682+
agent_framework=None,
16771683
)
16781684
request_mock.assert_called_with(
16791685
"post",
@@ -1826,6 +1832,7 @@ def test_create_agent_engine_with_class_methods(
18261832
entrypoint_module=None,
18271833
entrypoint_object=None,
18281834
requirements_file=None,
1835+
agent_framework=None,
18291836
)
18301837
request_mock.assert_called_with(
18311838
"post",
@@ -1845,6 +1852,92 @@ def test_create_agent_engine_with_class_methods(
18451852
None,
18461853
)
18471854

1855+
@mock.patch.object(agent_engines.AgentEngines, "_create_config")
1856+
@mock.patch.object(_agent_engines_utils, "_await_operation")
1857+
def test_create_agent_engine_with_agent_framework(
1858+
self,
1859+
mock_await_operation,
1860+
mock_create_config,
1861+
):
1862+
mock_create_config.return_value = {
1863+
"display_name": _TEST_AGENT_ENGINE_DISPLAY_NAME,
1864+
"description": _TEST_AGENT_ENGINE_DESCRIPTION,
1865+
"spec": {
1866+
"package_spec": {
1867+
"python_version": _TEST_PYTHON_VERSION,
1868+
"pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI,
1869+
"requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI,
1870+
},
1871+
"class_methods": [_TEST_AGENT_ENGINE_CLASS_METHOD_1],
1872+
"agent_framework": _TEST_AGENT_FRAMEWORK_STR,
1873+
},
1874+
}
1875+
mock_await_operation.return_value = _genai_types.AgentEngineOperation(
1876+
response=_genai_types.ReasoningEngine(
1877+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
1878+
spec=_TEST_AGENT_ENGINE_SPEC,
1879+
)
1880+
)
1881+
with mock.patch.object(
1882+
self.client.agent_engines._api_client, "request"
1883+
) as request_mock:
1884+
request_mock.return_value = genai_types.HttpResponse(body="")
1885+
self.client.agent_engines.create(
1886+
agent=self.test_agent,
1887+
config=_genai_types.AgentEngineConfig(
1888+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
1889+
requirements=_TEST_AGENT_ENGINE_REQUIREMENTS,
1890+
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
1891+
staging_bucket=_TEST_STAGING_BUCKET,
1892+
agent_framework=_TEST_AGENT_FRAMEWORK,
1893+
),
1894+
)
1895+
mock_create_config.assert_called_with(
1896+
mode="create",
1897+
agent=self.test_agent,
1898+
staging_bucket=_TEST_STAGING_BUCKET,
1899+
requirements=_TEST_AGENT_ENGINE_REQUIREMENTS,
1900+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
1901+
description=None,
1902+
gcs_dir_name=None,
1903+
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
1904+
env_vars=None,
1905+
service_account=None,
1906+
context_spec=None,
1907+
psc_interface_config=None,
1908+
min_instances=None,
1909+
max_instances=None,
1910+
resource_limits=None,
1911+
container_concurrency=None,
1912+
encryption_spec=None,
1913+
labels=None,
1914+
agent_server_mode=None,
1915+
class_methods=None,
1916+
source_packages=None,
1917+
entrypoint_module=None,
1918+
entrypoint_object=None,
1919+
requirements_file=None,
1920+
agent_framework=_TEST_AGENT_FRAMEWORK,
1921+
)
1922+
request_mock.assert_called_with(
1923+
"post",
1924+
"reasoningEngines",
1925+
{
1926+
"displayName": _TEST_AGENT_ENGINE_DISPLAY_NAME,
1927+
"description": _TEST_AGENT_ENGINE_DESCRIPTION,
1928+
"spec": {
1929+
"agent_framework": _TEST_AGENT_FRAMEWORK_STR,
1930+
"class_methods": [_TEST_AGENT_ENGINE_CLASS_METHOD_1],
1931+
"package_spec": {
1932+
"pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI,
1933+
"python_version": _TEST_PYTHON_VERSION,
1934+
"requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI,
1935+
},
1936+
},
1937+
},
1938+
None,
1939+
)
1940+
18481941
@pytest.mark.usefixtures("caplog")
18491942
@mock.patch.object(_agent_engines_utils, "_prepare")
18501943
@mock.patch.object(_agent_engines_utils, "_await_operation")

vertexai/_genai/_agent_engines_utils.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,26 @@
128128
_BASE_MODULES = set(_BUILTIN_MODULE_NAMES + tuple(_STDLIB_MODULE_NAMES))
129129
_BLOB_FILENAME = "agent_engine.pkl"
130130
_DEFAULT_AGENT_FRAMEWORK = "custom"
131+
_SUPPORTED_AGENT_FRAMEWORKS = frozenset(
132+
[
133+
"google-adk",
134+
"langchain",
135+
"langgraph",
136+
"ag2",
137+
"llama-index",
138+
"custom",
139+
]
140+
)
141+
_AGENT_FRAMEWORK_TO_STR = types.MappingProxyType(
142+
{
143+
genai_types.ReasoningEngineAgentFramework.GOOGLE_ADK: "google-adk",
144+
genai_types.ReasoningEngineAgentFramework.LANGCHAIN: "langchain",
145+
genai_types.ReasoningEngineAgentFramework.LANGGRAPH: "langgraph",
146+
genai_types.ReasoningEngineAgentFramework.AG2: "ag2",
147+
genai_types.ReasoningEngineAgentFramework.LLAMA_INDEX: "llama-index",
148+
genai_types.ReasoningEngineAgentFramework.CUSTOM: "custom",
149+
}
150+
)
131151
_DEFAULT_ASYNC_METHOD_NAME = "async_query"
132152
_DEFAULT_ASYNC_METHOD_RETURN_TYPE = "Coroutine[Any]"
133153
_DEFAULT_ASYNC_STREAM_METHOD_NAME = "async_stream_query"
@@ -705,13 +725,42 @@ def _generate_schema(
705725
return schema
706726

707727

708-
def _get_agent_framework(*, agent: _AgentEngineInterface) -> str:
709-
if (
710-
hasattr(agent, _AGENT_FRAMEWORK_ATTR)
711-
and getattr(agent, _AGENT_FRAMEWORK_ATTR) is not None
712-
and isinstance(getattr(agent, _AGENT_FRAMEWORK_ATTR), str)
713-
):
714-
return getattr(agent, _AGENT_FRAMEWORK_ATTR)
728+
def _get_agent_framework_str(
729+
*,
730+
agent_framework: genai_types.ReasoningEngineAgentFramework,
731+
agent: _AgentEngineInterface,
732+
) -> str:
733+
"""Gets the agent framework to use.
734+
735+
It prioritizes the provided `agent_framework`. If not provided or not
736+
supported, it checks the `_AGENT_FRAMEWORK_ATTR` attribute on the agent.
737+
If neither is found, it defaults to "_DEFAULT_AGENT_FRAMEWORK".
738+
739+
Args:
740+
agent_framework (genai_types.ReasoningEngineAgentFramework):
741+
The agent framework provided by the user.
742+
agent (_AgentEngineInterface):
743+
The agent engine instance.
744+
745+
Returns:
746+
str: The name of the agent framework to use.
747+
"""
748+
if agent_framework is not None and agent_framework in _AGENT_FRAMEWORK_TO_STR:
749+
logger.info(f"Using agent framework: {agent_framework}")
750+
return _AGENT_FRAMEWORK_TO_STR[agent_framework]
751+
if hasattr(agent, _AGENT_FRAMEWORK_ATTR):
752+
agent_framework_attr = getattr(agent, _AGENT_FRAMEWORK_ATTR)
753+
if (
754+
agent_framework_attr is not None
755+
and isinstance(agent_framework_attr, str)
756+
and agent_framework_attr in _SUPPORTED_AGENT_FRAMEWORKS
757+
):
758+
logger.info(f"Using agent framework: {agent_framework_attr}")
759+
return agent_framework_attr
760+
logger.info(
761+
f"The provided agent framework {agent_framework} is not supported."
762+
f" Defaulting to {_DEFAULT_AGENT_FRAMEWORK}."
763+
)
715764
return _DEFAULT_AGENT_FRAMEWORK
716765

717766

vertexai/_genai/agent_engines.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ def _CreateAgentEngineConfig_to_vertex(
9494
getv(from_object, ["requirements_file"]),
9595
)
9696

97+
if getv(from_object, ["agent_framework"]) is not None:
98+
setv(parent_object, ["agentFramework"], getv(from_object, ["agent_framework"]))
99+
97100
return to_object
98101

99102

@@ -285,6 +288,9 @@ def _UpdateAgentEngineConfig_to_vertex(
285288
getv(from_object, ["requirements_file"]),
286289
)
287290

291+
if getv(from_object, ["agent_framework"]) is not None:
292+
setv(parent_object, ["agentFramework"], getv(from_object, ["agent_framework"]))
293+
288294
if getv(from_object, ["update_mask"]) is not None:
289295
setv(
290296
parent_object, ["_query", "updateMask"], getv(from_object, ["update_mask"])
@@ -923,6 +929,7 @@ def create(
923929
entrypoint_module=config.entrypoint_module,
924930
entrypoint_object=config.entrypoint_object,
925931
requirements_file=config.requirements_file,
932+
agent_framework=config.agent_framework,
926933
)
927934
operation = self._create(config=api_config)
928935
# TODO: Use a more specific link.
@@ -986,6 +993,7 @@ def _create_config(
986993
entrypoint_module: Optional[str] = None,
987994
entrypoint_object: Optional[str] = None,
988995
requirements_file: Optional[str] = None,
996+
agent_framework: Optional[types.ReasoningEngineAgentFramework] = None,
989997
) -> types.UpdateAgentEngineConfigDict:
990998
import sys
991999

@@ -1195,7 +1203,10 @@ def _create_config(
11951203
] = agent_server_mode
11961204

11971205
agent_engine_spec["agent_framework"] = (
1198-
_agent_engines_utils._get_agent_framework(agent=agent)
1206+
_agent_engines_utils._get_agent_framework_str(
1207+
agent_framework=agent_framework,
1208+
agent=agent,
1209+
)
11991210
)
12001211
update_masks.append("spec.agent_framework")
12011212
config["spec"] = agent_engine_spec
@@ -1423,6 +1434,7 @@ def update(
14231434
entrypoint_module=config.entrypoint_module,
14241435
entrypoint_object=config.entrypoint_object,
14251436
requirements_file=config.requirements_file,
1437+
agent_framework=config.agent_framework,
14261438
)
14271439
operation = self._update(name=name, config=api_config)
14281440
logger.info(

vertexai/_genai/types/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,7 @@
648648
from .common import RawOutputDict
649649
from .common import RawOutputOrDict
650650
from .common import ReasoningEngine
651+
from .common import ReasoningEngineAgentFramework
651652
from .common import ReasoningEngineContextSpec
652653
from .common import ReasoningEngineContextSpecDict
653654
from .common import ReasoningEngineContextSpecMemoryBankConfig
@@ -1802,6 +1803,7 @@
18021803
"RubricContentType",
18031804
"EvaluationRunState",
18041805
"OptimizeTarget",
1806+
"ReasoningEngineAgentFramework",
18051807
"GenerateMemoriesResponseGeneratedMemoryAction",
18061808
"PromptOptimizerMethod",
18071809
"PromptData",

0 commit comments

Comments
 (0)