Skip to content

Commit e1a7593

Browse files
xuanyang15copybara-github
authored andcommitted
feat: Add header_provider to OpenAPIToolset and RestApiTool
Fixes: #3782 Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 843352147
1 parent cb3244b commit e1a7593

File tree

4 files changed

+224
-12
lines changed

4 files changed

+224
-12
lines changed

src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import logging
1919
import ssl
2020
from typing import Any
21+
from typing import Callable
2122
from typing import Dict
2223
from typing import Final
2324
from typing import List
@@ -71,6 +72,9 @@ def __init__(
7172
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
7273
tool_name_prefix: Optional[str] = None,
7374
ssl_verify: Optional[Union[bool, str, ssl.SSLContext]] = None,
75+
header_provider: Optional[
76+
Callable[[ReadonlyContext], Dict[str, str]]
77+
] = None,
7478
):
7579
"""Initializes the OpenAPIToolset.
7680
@@ -116,8 +120,14 @@ def __init__(
116120
- ssl.SSLContext: Custom SSL context for advanced configuration
117121
This is useful for enterprise environments where requests go through
118122
a TLS-intercepting proxy with a custom CA certificate.
123+
header_provider: A callable that returns a dictionary of headers to be
124+
included in API requests. The callable receives the ReadonlyContext as
125+
an argument, allowing dynamic header generation based on the current
126+
context. Useful for adding custom headers like correlation IDs,
127+
authentication tokens, or other request metadata.
119128
"""
120129
super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix)
130+
self._header_provider = header_provider
121131
if not spec_dict:
122132
spec_dict = self._load_spec(spec_str, spec_str_type)
123133
self._ssl_verify = ssl_verify
@@ -189,7 +199,11 @@ def _parse(self, openapi_spec_dict: Dict[str, Any]) -> List[RestApiTool]:
189199

190200
tools = []
191201
for o in operations:
192-
tool = RestApiTool.from_parsed_operation(o, ssl_verify=self._ssl_verify)
202+
tool = RestApiTool.from_parsed_operation(
203+
o,
204+
ssl_verify=self._ssl_verify,
205+
header_provider=self._header_provider,
206+
)
193207
logger.info("Parsed tool: %s", tool.name)
194208
tools.append(tool)
195209
return tools

src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import ssl
1818
from typing import Any
19+
from typing import Callable
1920
from typing import Dict
2021
from typing import List
2122
from typing import Literal
@@ -29,6 +30,7 @@
2930
import requests
3031
from typing_extensions import override
3132

33+
from ....agents.readonly_context import ReadonlyContext
3234
from ....auth.auth_credential import AuthCredential
3335
from ....auth.auth_schemes import AuthScheme
3436
from ..._gemini_schema_util import _to_gemini_schema
@@ -90,6 +92,9 @@ def __init__(
9092
auth_credential: Optional[Union[AuthCredential, str]] = None,
9193
should_parse_operation=True,
9294
ssl_verify: Optional[Union[bool, str, ssl.SSLContext]] = None,
95+
header_provider: Optional[
96+
Callable[[ReadonlyContext], Dict[str, str]]
97+
] = None,
9398
):
9499
"""Initializes the RestApiTool with the given parameters.
95100
@@ -122,6 +127,11 @@ def __init__(
122127
- False: Disable SSL verification (insecure, not recommended)
123128
- str: Path to a CA bundle file or directory for custom CA
124129
- ssl.SSLContext: Custom SSL context for advanced configuration
130+
header_provider: A callable that returns a dictionary of headers to be
131+
included in API requests. The callable receives the ReadonlyContext as
132+
an argument, allowing dynamic header generation based on the current
133+
context. Useful for adding custom headers like correlation IDs,
134+
authentication tokens, or other request metadata.
125135
"""
126136
# Gemini restrict the length of function name to be less than 64 characters
127137
self.name = name[:60]
@@ -145,6 +155,7 @@ def __init__(
145155
self.credential_exchanger = AutoAuthCredentialExchanger()
146156
self._default_headers: Dict[str, str] = {}
147157
self._ssl_verify = ssl_verify
158+
self._header_provider = header_provider
148159
if should_parse_operation:
149160
self._operation_parser = OperationParser(self.operation)
150161

@@ -153,12 +164,20 @@ def from_parsed_operation(
153164
cls,
154165
parsed: ParsedOperation,
155166
ssl_verify: Optional[Union[bool, str, ssl.SSLContext]] = None,
167+
header_provider: Optional[
168+
Callable[[ReadonlyContext], Dict[str, str]]
169+
] = None,
156170
) -> "RestApiTool":
157171
"""Initializes the RestApiTool from a ParsedOperation object.
158172
159173
Args:
160174
parsed: A ParsedOperation object.
161175
ssl_verify: SSL certificate verification option.
176+
header_provider: A callable that returns a dictionary of headers to be
177+
included in API requests. The callable receives the ReadonlyContext as
178+
an argument, allowing dynamic header generation based on the current
179+
context. Useful for adding custom headers like correlation IDs,
180+
authentication tokens, or other request metadata.
162181
163182
Returns:
164183
A RestApiTool object.
@@ -178,6 +197,7 @@ def from_parsed_operation(
178197
auth_scheme=parsed.auth_scheme,
179198
auth_credential=parsed.auth_credential,
180199
ssl_verify=ssl_verify,
200+
header_provider=header_provider,
181201
)
182202
generated._operation_parser = operation_parser
183203
return generated
@@ -450,6 +470,13 @@ async def call(
450470
request_params = self._prepare_request_params(api_params, api_args)
451471
if self._ssl_verify is not None:
452472
request_params["verify"] = self._ssl_verify
473+
474+
# Add headers from header_provider if configured
475+
if self._header_provider is not None and tool_context is not None:
476+
provider_headers = self._header_provider(tool_context)
477+
if provider_headers:
478+
request_params.setdefault("headers", {}).update(provider_headers)
479+
453480
response = requests.request(**request_params)
454481

455482
# Parse API response

tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_toolset.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,8 @@ def test_openapi_toolset_configure_auth_on_init(openapi_spec: Dict):
135135
auth_scheme=auth_scheme,
136136
auth_credential=auth_credential,
137137
)
138-
for tool in toolset._tools:
139-
assert tool.auth_scheme == auth_scheme
140-
assert tool.auth_credential == auth_credential
138+
assert all(tool.auth_scheme == auth_scheme for tool in toolset._tools)
139+
assert all(tool.auth_credential == auth_credential for tool in toolset._tools)
141140

142141

143142
@pytest.mark.parametrize(
@@ -151,24 +150,21 @@ def test_openapi_toolset_verify_on_init(
151150
spec_dict=openapi_spec,
152151
ssl_verify=verify_value,
153152
)
154-
for tool in toolset._tools:
155-
assert tool._ssl_verify == verify_value
153+
assert all(tool._ssl_verify == verify_value for tool in toolset._tools)
156154

157155

158156
def test_openapi_toolset_configure_verify_all(openapi_spec: Dict[str, Any]):
159157
"""Test configure_verify_all method."""
160158
toolset = OpenAPIToolset(spec_dict=openapi_spec)
161159

162160
# Initially verify should be None
163-
for tool in toolset._tools:
164-
assert tool._ssl_verify is None
161+
assert all(tool._ssl_verify is None for tool in toolset._tools)
165162

166163
# Configure verify for all tools
167164
ca_bundle_path = "/path/to/custom-ca.crt"
168165
toolset.configure_ssl_verify_all(ca_bundle_path)
169166

170-
for tool in toolset._tools:
171-
assert tool._ssl_verify == ca_bundle_path
167+
assert all(tool._ssl_verify == ca_bundle_path for tool in toolset._tools)
172168

173169

174170
async def test_openapi_toolset_tool_name_prefix(openapi_spec: Dict[str, Any]):
@@ -183,10 +179,42 @@ async def test_openapi_toolset_tool_name_prefix(openapi_spec: Dict[str, Any]):
183179
assert len(prefixed_tools) == 5
184180

185181
# Verify all tool names are prefixed
186-
for tool in prefixed_tools:
187-
assert tool.name.startswith(f"{prefix}_")
182+
assert all(tool.name.startswith(f"{prefix}_") for tool in prefixed_tools)
188183

189184
# Verify specific tool name is prefixed
190185
expected_prefixed_name = "my_api_calendar_calendars_insert"
191186
prefixed_tool_names = [t.name for t in prefixed_tools]
192187
assert expected_prefixed_name in prefixed_tool_names
188+
189+
190+
def test_openapi_toolset_header_provider(openapi_spec: Dict[str, Any]):
191+
"""Test header_provider parameter is passed to tools."""
192+
193+
def my_header_provider(context):
194+
return {"X-Custom-Header": "custom-value", "X-Request-ID": "12345"}
195+
196+
toolset = OpenAPIToolset(
197+
spec_dict=openapi_spec,
198+
header_provider=my_header_provider,
199+
)
200+
201+
# Verify the toolset has the header_provider set
202+
assert toolset._header_provider is my_header_provider
203+
204+
# Verify all tools have the header_provider
205+
assert all(
206+
tool._header_provider is my_header_provider for tool in toolset._tools
207+
)
208+
209+
210+
def test_openapi_toolset_header_provider_none_by_default(
211+
openapi_spec: Dict[str, Any],
212+
):
213+
"""Test that header_provider is None by default."""
214+
toolset = OpenAPIToolset(spec_dict=openapi_spec)
215+
216+
# Verify the toolset has no header_provider by default
217+
assert toolset._header_provider is None
218+
219+
# Verify all tools have no header_provider
220+
assert all(tool._header_provider is None for tool in toolset._tools)

tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,149 @@ async def test_call_with_configure_verify(
10361036
call_kwargs = mock_request.call_args[1]
10371037
assert call_kwargs["verify"] == ca_bundle_path
10381038

1039+
def test_init_with_header_provider(
1040+
self,
1041+
sample_endpoint,
1042+
sample_operation,
1043+
):
1044+
"""Test that header_provider is stored correctly."""
1045+
1046+
def my_header_provider(context):
1047+
return {"X-Custom": "value"}
1048+
1049+
tool = RestApiTool(
1050+
name="test_tool",
1051+
description="Test Tool",
1052+
endpoint=sample_endpoint,
1053+
operation=sample_operation,
1054+
header_provider=my_header_provider,
1055+
)
1056+
assert tool._header_provider is my_header_provider
1057+
1058+
def test_init_header_provider_none_by_default(
1059+
self,
1060+
sample_endpoint,
1061+
sample_operation,
1062+
):
1063+
"""Test that header_provider is None by default."""
1064+
tool = RestApiTool(
1065+
name="test_tool",
1066+
description="Test Tool",
1067+
endpoint=sample_endpoint,
1068+
operation=sample_operation,
1069+
)
1070+
assert tool._header_provider is None
1071+
1072+
@pytest.mark.asyncio
1073+
async def test_call_with_header_provider(
1074+
self,
1075+
mock_tool_context,
1076+
sample_endpoint,
1077+
sample_operation,
1078+
sample_auth_scheme,
1079+
sample_auth_credential,
1080+
):
1081+
"""Test that header_provider adds headers to the request."""
1082+
mock_response = mock.create_autospec(
1083+
requests.Response, instance=True, spec_set=True
1084+
)
1085+
mock_response.json.return_value = {"result": "success"}
1086+
1087+
def my_header_provider(context):
1088+
return {"X-Custom-Header": "custom-value", "X-Request-ID": "12345"}
1089+
1090+
tool = RestApiTool(
1091+
name="test_tool",
1092+
description="Test Tool",
1093+
endpoint=sample_endpoint,
1094+
operation=sample_operation,
1095+
auth_scheme=sample_auth_scheme,
1096+
auth_credential=sample_auth_credential,
1097+
header_provider=my_header_provider,
1098+
)
1099+
1100+
with patch.object(
1101+
requests, "request", return_value=mock_response, autospec=True
1102+
) as mock_request:
1103+
await tool.call(args={}, tool_context=mock_tool_context)
1104+
1105+
# Verify the headers were added to the request
1106+
assert mock_request.called
1107+
_, call_kwargs = mock_request.call_args
1108+
assert call_kwargs["headers"]["X-Custom-Header"] == "custom-value"
1109+
assert call_kwargs["headers"]["X-Request-ID"] == "12345"
1110+
1111+
@pytest.mark.asyncio
1112+
async def test_call_header_provider_receives_tool_context(
1113+
self,
1114+
mock_tool_context,
1115+
sample_endpoint,
1116+
sample_operation,
1117+
sample_auth_scheme,
1118+
sample_auth_credential,
1119+
):
1120+
"""Test that header_provider receives the tool_context."""
1121+
mock_response = mock.create_autospec(
1122+
requests.Response, instance=True, spec_set=True
1123+
)
1124+
mock_response.json.return_value = {"result": "success"}
1125+
1126+
received_context = []
1127+
1128+
def my_header_provider(context):
1129+
received_context.append(context)
1130+
return {"X-Test": "test"}
1131+
1132+
tool = RestApiTool(
1133+
name="test_tool",
1134+
description="Test Tool",
1135+
endpoint=sample_endpoint,
1136+
operation=sample_operation,
1137+
auth_scheme=sample_auth_scheme,
1138+
auth_credential=sample_auth_credential,
1139+
header_provider=my_header_provider,
1140+
)
1141+
1142+
with patch.object(
1143+
requests, "request", return_value=mock_response, autospec=True
1144+
):
1145+
await tool.call(args={}, tool_context=mock_tool_context)
1146+
1147+
# Verify header_provider was called with the tool_context
1148+
assert len(received_context) == 1
1149+
assert received_context[0] is mock_tool_context
1150+
1151+
@pytest.mark.asyncio
1152+
async def test_call_without_header_provider(
1153+
self,
1154+
mock_tool_context,
1155+
sample_endpoint,
1156+
sample_operation,
1157+
sample_auth_scheme,
1158+
sample_auth_credential,
1159+
):
1160+
"""Test that call works without header_provider."""
1161+
mock_response = mock.create_autospec(
1162+
requests.Response, instance=True, spec_set=True
1163+
)
1164+
mock_response.json.return_value = {"result": "success"}
1165+
1166+
tool = RestApiTool(
1167+
name="test_tool",
1168+
description="Test Tool",
1169+
endpoint=sample_endpoint,
1170+
operation=sample_operation,
1171+
auth_scheme=sample_auth_scheme,
1172+
auth_credential=sample_auth_credential,
1173+
)
1174+
1175+
with patch.object(
1176+
requests, "request", return_value=mock_response, autospec=True
1177+
):
1178+
result = await tool.call(args={}, tool_context=mock_tool_context)
1179+
1180+
assert result == {"result": "success"}
1181+
10391182

10401183
def test_snake_to_lower_camel():
10411184
assert snake_to_lower_camel("single") == "single"

0 commit comments

Comments
 (0)