diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index 2f5d03a772..f68b349d9c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index cfd850b3a3..1bc4ee58c8 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 24fdce9d59..b0d187502e 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -168,6 +168,14 @@ class InvocationContext(BaseModel): agent_states: dict[str, dict[str, Any]] = Field(default_factory=dict) """The state of the agent for this invocation.""" + request_state: dict[str, Any] = Field(default_factory=dict) + """The ephemeral state of the request. + + This state is not persisted to the session and is only available for the + current invocation. It is used to pass sensitive information like tokens + that should not be stored in the session state. + """ + end_of_agents: dict[str, bool] = Field(default_factory=dict) """The end of agent status for each agent in this invocation.""" diff --git a/src/google/adk/agents/readonly_context.py b/src/google/adk/agents/readonly_context.py index 21cefa9a56..4915203105 100644 --- a/src/google/adk/agents/readonly_context.py +++ b/src/google/adk/agents/readonly_context.py @@ -14,8 +14,10 @@ from __future__ import annotations +from collections import ChainMap from types import MappingProxyType from typing import Any +from typing import Mapping from typing import Optional from typing import TYPE_CHECKING @@ -51,9 +53,14 @@ def agent_name(self) -> str: return self._invocation_context.agent.name @property - def state(self) -> MappingProxyType[str, Any]: + def state(self) -> Mapping[str, Any]: """The state of the current session. READONLY field.""" - return MappingProxyType(self._invocation_context.session.state) + return MappingProxyType( + ChainMap( + self._invocation_context.request_state, + self._invocation_context.session.state, + ) + ) @property def session(self) -> Session: diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 3a471c1157..96cd19b73e 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -175,6 +175,7 @@ class RunAgentRequest(common.BaseModel): new_message: types.Content streaming: bool = False state_delta: Optional[dict[str, Any]] = None + request_state: Optional[dict[str, Any]] = None # for resume long running functions invocation_id: Optional[str] = None @@ -1463,6 +1464,7 @@ async def run_agent(req: RunAgentRequest) -> list[Event]: session_id=req.session_id, new_message=req.new_message, state_delta=req.state_delta, + request_state=req.request_state, ) ) as agen: events = [event async for event in agen] @@ -1492,6 +1494,7 @@ async def event_generator(): session_id=req.session_id, new_message=req.new_message, state_delta=req.state_delta, + request_state=req.request_state, run_config=RunConfig(streaming_mode=stream_mode), invocation_id=req.invocation_id, ) diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 7e1a2e5d02..8f536cb1fb 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -398,6 +398,7 @@ async def run_async( invocation_id: Optional[str] = None, new_message: Optional[types.Content] = None, state_delta: Optional[dict[str, Any]] = None, + request_state: Optional[dict[str, Any]] = None, run_config: Optional[RunConfig] = None, ) -> AsyncGenerator[Event, None]: """Main entry method to run the agent in this runner. @@ -415,6 +416,7 @@ async def run_async( interrupted invocation. new_message: A new message to append to the session. state_delta: Optional state changes to apply to the session. + request_state: Optional ephemeral state for the request. run_config: The run config for the agent. Yields: @@ -462,6 +464,7 @@ async def _run_with_trace( invocation_id=invocation_id, run_config=run_config, state_delta=state_delta, + request_state=request_state, ) if invocation_context.end_of_agents.get( invocation_context.agent.name @@ -475,6 +478,7 @@ async def _run_with_trace( new_message=new_message, # new_message is not None. run_config=run_config, state_delta=state_delta, + request_state=request_state, ) async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: @@ -1185,6 +1189,7 @@ async def _setup_context_for_new_invocation( new_message: types.Content, run_config: RunConfig, state_delta: Optional[dict[str, Any]], + request_state: Optional[dict[str, Any]] = None, ) -> InvocationContext: """Sets up the context for a new invocation. @@ -1193,6 +1198,7 @@ async def _setup_context_for_new_invocation( new_message: The new message to process and append to the session. run_config: The run config of the agent. state_delta: Optional state changes to apply to the session. + request_state: Optional ephemeral state for the request. Returns: The invocation context for the new invocation. @@ -1202,6 +1208,7 @@ async def _setup_context_for_new_invocation( session, new_message=new_message, run_config=run_config, + request_state=request_state, ) # Step 2: Handle new message, by running callbacks and appending to # session. @@ -1224,6 +1231,7 @@ async def _setup_context_for_resumed_invocation( invocation_id: Optional[str], run_config: RunConfig, state_delta: Optional[dict[str, Any]], + request_state: Optional[dict[str, Any]] = None, ) -> InvocationContext: """Sets up the context for a resumed invocation. @@ -1233,6 +1241,7 @@ async def _setup_context_for_resumed_invocation( invocation_id: The invocation id to resume. run_config: The run config of the agent. state_delta: Optional state changes to apply to the session. + request_state: Optional ephemeral state for the request. Returns: The invocation context for the resumed invocation. @@ -1258,6 +1267,7 @@ async def _setup_context_for_resumed_invocation( new_message=user_message, run_config=run_config, invocation_id=invocation_id, + request_state=request_state, ) # Step 3: Maybe handle new message. if new_message: @@ -1302,6 +1312,7 @@ def _new_invocation_context( new_message: Optional[types.Content] = None, live_request_queue: Optional[LiveRequestQueue] = None, run_config: Optional[RunConfig] = None, + request_state: Optional[dict[str, Any]] = None, ) -> InvocationContext: """Creates a new invocation context. @@ -1311,6 +1322,7 @@ def _new_invocation_context( new_message: The new message for the context. live_request_queue: The live request queue for the context. run_config: The run config for the context. + request_state: The ephemeral state for the request. Returns: The new invocation context. @@ -1342,6 +1354,7 @@ def _new_invocation_context( live_request_queue=live_request_queue, run_config=run_config, resumability_config=self.resumability_config, + request_state=request_state or {}, ) def _new_invocation_context_for_live( diff --git a/src/google/adk/tools/mcp_tool/__init__.py b/src/google/adk/tools/mcp_tool/__init__.py index 1170b2e1af..2f1ba5f22d 100644 --- a/src/google/adk/tools/mcp_tool/__init__.py +++ b/src/google/adk/tools/mcp_tool/__init__.py @@ -22,6 +22,7 @@ from .mcp_session_manager import StreamableHTTPConnectionParams from .mcp_tool import MCPTool from .mcp_tool import McpTool + from .mcp_toolset import create_session_state_header_provider from .mcp_toolset import MCPToolset from .mcp_toolset import McpToolset @@ -32,6 +33,7 @@ 'MCPTool', 'McpToolset', 'MCPToolset', + 'create_session_state_header_provider', 'SseConnectionParams', 'StdioConnectionParams', 'StreamableHTTPConnectionParams', diff --git a/src/google/adk/tools/mcp_tool/_internal.py b/src/google/adk/tools/mcp_tool/_internal.py new file mode 100644 index 0000000000..42f0c7528f --- /dev/null +++ b/src/google/adk/tools/mcp_tool/_internal.py @@ -0,0 +1,317 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal utilities for MCP tools. + +This module contains internal validation and sanitization utilities +that are not part of the public API and follow RFC 7230 properly. + +**Security Notes:** + +- Header validation implements RFC 7230 §3.2 for proper HTTP header format +- Only truly dangerous control characters are removed from header values +- Legitimate multi-line headers with proper folding are preserved +- Binary data handling is separate from text data for security +- All functions log security-relevant warnings when appropriate + +**RFC 7230 Compliance:** + +- Header names: only letters, digits, and hyphens allowed +- Header values: control characters (0x00-0x1F, 0x7F) are dangerous +- Header folding: CRLF sequences are preserved for legitimate use cases +- Binary data: handled separately with explicit allow_binary flag + +**Attack Prevention:** + +- HTTP header injection attacks via control character filtering +- Response splitting attacks through CRLF handling +- Log injection attacks via character sanitization +- Type confusion attacks through strict validation +""" + +from __future__ import annotations + +import logging +import re +from typing import Any + +logger = logging.getLogger("google_adk." + __name__) + +# RFC 7230 compliant header patterns +# Control characters and special characters not allowed in header names +_HEADER_NAME_FORBIDDEN = r'\x00-\x1F\x7F()<>@,;:\\"/[\]?={} \t' + +# Header whitespace characters (RFC 7230 §3.2.4) +_HEADER_WHITESPACE = "\r\n" + +# RFC 7230 compliant header name pattern (allows letters, digits, hyphens) +_HEADER_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9-]+\Z") + +# Truly dangerous characters that should never appear in header values +# These are characters that can break HTTP parsing or cause injection +_DANGEROUS_CHARS = { + "\x00", + "\x01", + "\x02", + "\x03", + "\x04", + "\x05", + "\x06", + "\x07", + "\x08", + "\x0b", + "\x0c", + "\x0e", + "\x0f", + "\x10", + "\x11", + "\x12", + "\x13", + "\x14", + "\x15", + "\x16", + "\x17", + "\x18", + "\x19", + "\x1a", + "\x1b", + "\x1c", + "\x1d", + "\x1e", + "\x1f", + "\x7f", +} + + +def _is_printable_ascii(char: str) -> bool: + """Check if character is printable ASCII.""" + try: + return 0x20 <= ord(char) <= 0x7E + except ValueError: + return False + + +def _is_control_char(char: str) -> bool: + """Check if character is a control character.""" + return char in _DANGEROUS_CHARS + + +def _is_whitespace(char: str) -> bool: + """Check if character is whitespace.""" + return char in _HEADER_WHITESPACE + + +def _get_forbidden_char_desc(char: str) -> str: + """Get description of forbidden character.""" + if char == "\r": + return "carriage return" + elif char == "\n": + return "line feed" + elif char == "\t": + return "horizontal tab" + elif _is_printable_ascii(char): + return f"non-printable ASCII: {repr(char)}" + else: + return f"control character: {repr(char)}" + + +def validate_header_name(header_name: str) -> None: + """Validates that a header name conforms to RFC 7230. + Only allows printable ASCII, no control chars, spaces, or separators. + Rejects header names containing invalid characters. + """ + if not header_name: + raise ValueError("Header name cannot be empty.") + + if not _HEADER_NAME_PATTERN.match(header_name): + raise ValueError( + f'Header name "{header_name}" contains invalid characters. ' + "Header names must conform to RFC 7230 and cannot contain " + 'control characters, spaces, or separators like ():<>@,;:\\"/[]?={}.' + ) + + +def _sanitize_header_value(value: str) -> str: + """Sanitizes a header value to prevent injection attacks. + + This function removes ONLY truly dangerous characters that could cause + header injection attacks, while remaining RFC 7230 compliant. + + Args: + value: The header value to sanitize. + + Returns: + The sanitized header value with dangerous characters removed. + """ + if not isinstance(value, str): + value = str(value) + + # Remove only characters that are truly dangerous for HTTP headers + # These are control characters that can break parsing or enable injection + # We DON'T remove all \r\n sequences as that would break legitimate multi-line headers + # and violate RFC 7230 §3.2.4 which allows header folding + sanitized_chars = [] + for char in value: + if char not in _DANGEROUS_CHARS: + sanitized_chars.append(char) + else: + logger.warning( + f"Removed dangerous character {repr(char)} from header value " + "for security reasons" + ) + + return "".join(sanitized_chars) + + +def _validate_header_value(value: Any, allow_binary: bool = False) -> None: + """Validates header values with RFC 7230 compliance and proper binary handling. + + Args: + value: The header value to validate. + allow_binary: Whether to allow binary data (bytes) in header values. + + Raises: + ValueError: If value contains dangerous characters. + """ + if value is None: + return + + if isinstance(value, bytes): + if not allow_binary: + raise ValueError("Binary data not allowed in HTTP header values") + # For binary data, check for dangerous bytes + for byte_val in value: + if byte_val < 128: # ASCII range + char = chr(byte_val) + if char in _DANGEROUS_CHARS: + raise ValueError( + f"Binary data contains dangerous byte: {repr(char)} " + f"({_get_forbidden_char_desc(char)})" + ) + return + + # For strings, check for dangerous characters that could enable injection + if isinstance(value, str): + for char in value: + if char in _DANGEROUS_CHARS: + raise ValueError( + f"Header value contains dangerous character: {repr(char)} " + f"({_get_forbidden_char_desc(char)})" + ) + return + + # For other types, convert to string and validate + str_value = str(value) + for char in str_value: + if char in _DANGEROUS_CHARS: + raise ValueError( + "Header value (converted to string) contains dangerous character: " + f"{repr(char)} ({_get_forbidden_char_desc(char)})" + ) + + +def sanitize_header_value(value: Any) -> str: + """Sanitizes a header value to prevent injection attacks. + + This is a wrapper that converts non-string values to strings and then + applies core sanitization logic. + + Args: + value: The header value to sanitize (any type). + + Returns: + The sanitized header value as a string. + """ + if not isinstance(value, str): + value = str(value) + + return _sanitize_header_value(value) + + +def validate_header_value( + state_key: str, value: Any, strict: bool = False +) -> None: + """Validates that a state value is suitable for use in a header. + + Args: + state_key: The key being validated. + value: The value to validate. + strict: If True, raises ValueError for non-primitive types. + + Raises: + ValueError: If strict=True and value is not a primitive type. + """ + if not isinstance(value, (str, int, float, bool)): + msg = ( + f'Value for state key "{state_key}" is of type ' + f"{type(value).__name__}, which may not serialize correctly into a " + "header. Consider pre-serializing complex values or using " + "state_header_format." + ) + if strict: + raise ValueError(msg) + else: + logger.warning(msg) + + +def create_session_state_header_provider( + state_key: str, + header_name: str = "Authorization", + header_format: str = "Bearer {value}", + default_value: str = None, + strict: bool = False, +): + """Creates a header provider that extracts values from session state. + + This utility function generates a header_provider callable that can be used + with McpToolset to automatically extract values from session state and + format them as HTTP headers for MCP server connections. + + .. warning:: + **Security Best Practice**: For sensitive, short-lived tokens like JWTs, + use ``request_state`` instead of ``session.state`` to avoid persisting + sensitive data to the database. Pass tokens via + ``RunAgentRequest.request_state``, which will override ``session.state`` + for the duration of the request without being persisted. + + Args: + state_key: The key to look up in session.state (or request_state). + header_name: The HTTP header name to set (default: 'Authorization'). + header_format: Format string for the header value. Use {value} as a + placeholder for the state value (default: 'Bearer {value}'). + default_value: Default value if state_key is not found in session state. + If None, the header is omitted when the key is missing. + strict: If True, raises ValueError when non-primitive types are + encountered. If False (default), logs a warning instead. + + Returns: + A callable that takes a ReadonlyContext and returns a dictionary of + headers to be used for the MCP session. + """ + # Validate header name upfront + validate_header_name(header_name) + + def provider(ctx) -> dict[str, str]: + value = ctx.state.get(state_key, default_value) + # Skip header if value is None or empty string + if value is None or value == "": + return {} + + validate_header_value(state_key, value, strict=strict) + formatted_value = header_format.format(value=value) + sanitized_value = sanitize_header_value(formatted_value) + + return {header_name: sanitized_value} + + return provider diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index b15f2c73fe..a6a84636fa 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -41,6 +41,7 @@ from ..tool_context import ToolContext from .mcp_session_manager import MCPSessionManager from .mcp_session_manager import retry_on_errors +from .types import HeaderProvider logger = logging.getLogger("google_adk." + __name__) @@ -63,9 +64,7 @@ def __init__( auth_scheme: Optional[AuthScheme] = None, auth_credential: Optional[AuthCredential] = None, require_confirmation: Union[bool, Callable[..., bool]] = False, - header_provider: Optional[ - Callable[[ReadonlyContext], Dict[str, str]] - ] = None, + header_provider: Optional[HeaderProvider] = None, ): """Initializes an McpTool. @@ -81,6 +80,8 @@ def __init__( or a callable that takes the function's arguments and returns a boolean. If the callable returns True, the tool will require confirmation from the user. + header_provider: A callable that takes a ReadonlyContext and returns a + dictionary of headers to be used for the MCP session. Raises: ValueError: If mcp_tool or mcp_session_manager is None. diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index 035b75878b..5254486eb5 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -38,16 +38,152 @@ from ..base_toolset import ToolPredicate from ..tool_configs import BaseToolConfig from ..tool_configs import ToolArgsConfig +from ._internal import sanitize_header_value +from ._internal import validate_header_name +from ._internal import validate_header_value from .mcp_session_manager import MCPSessionManager from .mcp_session_manager import retry_on_errors from .mcp_session_manager import SseConnectionParams from .mcp_session_manager import StdioConnectionParams from .mcp_session_manager import StreamableHTTPConnectionParams from .mcp_tool import MCPTool +from .types import HeaderProvider logger = logging.getLogger("google_adk." + __name__) +def create_session_state_header_provider( + state_key: str, + header_name: str = "Authorization", + header_format: str = "Bearer {value}", + default_value: Optional[str] = None, + strict: bool = False, +) -> HeaderProvider: + """Creates a header provider that extracts values from session state. + + This utility function generates a header_provider callable that can be used + with McpToolset to automatically extract values from the session state and + format them as HTTP headers for MCP server connections. + + .. warning:: + **Security Best Practice**: For sensitive, short-lived tokens like JWTs, + use ``request_state`` instead of ``session.state`` to avoid persisting + sensitive data to the database. Pass tokens via + ``RunAgentRequest.request_state``, which will override ``session.state`` + for the duration of the request without being persisted. + + **Security Features:** + + - RFC 7230 compliant HTTP header validation and sanitization + - Automatic protection against header injection attacks + - Support for secure token propagation via session state + - Configurable strict validation for header values + + **Security Best Practices:** + + 1. **Token Security**: Use ``request_state`` for sensitive, short-lived tokens + (JWTs, API keys) instead of ``session.state`` to avoid persisting sensitive data. + + 2. **Header Validation**: Header names and values are automatically validated + according to RFC 7230 to prevent injection attacks. + + 3. **Complex Data**: For complex data structures, pre-serialize them or use + ``state_header_format`` to ensure proper string representation. + + 4. **Strict Mode**: Enable ``state_header_strict=True`` in configuration to + catch non-primitive type errors early. + + Args: + state_key: The key to look up in session.state (or request_state). + header_name: The HTTP header name to set (default: 'Authorization'). + header_format: Format string for the header value. Use {value} as a + placeholder for the state value (default: 'Bearer {value}'). + default_value: Default value if state_key is not found in session state. + If None, the header is omitted when the key is missing. + strict: If True, raises ValueError when non-primitive types are + encountered. If False (default), logs a warning instead. + + Returns: + A callable that takes a ReadonlyContext and returns a dictionary of + headers to be used for the MCP session. + + Raises: + ValueError: If strict=True and a non-primitive type is found in state, + or if header_name is invalid. + + Example:: + + # Example 1: Using request_state for JWT tokens (recommended) + toolset = McpToolset( + connection_params=StreamableHTTPConnectionParams( + url="http://api.example.com/mcp" + ), + header_provider=create_session_state_header_provider( + state_key="jwt_token", # Will read from request_state first + header_name="Authorization", + header_format="Bearer {value}" + ) + ) + + # Client sends request with ephemeral JWT + response = await agent.run( + RunAgentRequest( + session_id="user-123", + request_state={"jwt_token": "eyJhbG..."} # Ephemeral, not persisted + ) + ) + """ + # Validate header name upfront + validate_header_name(header_name) + + def provider(ctx: ReadonlyContext) -> Dict[str, str]: + value = ctx.state.get(state_key, default_value) + # Skip header if value is None or empty string + if value is None or value == "": + return {} + + validate_header_value(state_key, value, strict=strict) + formatted_value = header_format.format(value=value) + sanitized_value = sanitize_header_value(formatted_value) + + return {header_name: sanitized_value} + + return provider + + +def create_combined_header_provider( + providers: List[HeaderProvider], +) -> HeaderProvider: + """Creates a header provider that combines multiple providers. + + Args: + providers: A list of header providers to combine. + + Returns: + A single header provider that merges the results of all input providers. + """ + + def combined_provider(ctx: ReadonlyContext) -> Dict[str, str]: + headers = {} + num_providers = len(providers) + for i, provider in enumerate(providers): + try: + provider_headers = provider(ctx) + if provider_headers: + headers.update(provider_headers) + except Exception as e: + logger.error(f"Header provider {i+1}/{num_providers} failed: {e}") + raise + + if headers: + logger.debug( + f"Combined header provider generated {len(headers)} total headers" + ) + return headers + + return combined_provider + + class McpToolset(BaseToolset): """Connects to a MCP Server, and retrieves MCP Tools into ADK Tools. @@ -93,9 +229,7 @@ def __init__( auth_scheme: Optional[AuthScheme] = None, auth_credential: Optional[AuthCredential] = None, require_confirmation: Union[bool, Callable[..., bool]] = False, - header_provider: Optional[ - Callable[[ReadonlyContext], Dict[str, str]] - ] = None, + header_provider: Optional[HeaderProvider] = None, ): """Initializes the McpToolset. @@ -223,12 +357,32 @@ def from_config( else: raise ValueError("No connection params found in McpToolsetConfig.") + # Create header_provider from state_header_mapping if specified + header_provider = None + if mcp_toolset_config.state_header_mapping: + state_mapping = mcp_toolset_config.state_header_mapping + state_format = mcp_toolset_config.state_header_format or {} + + providers = [ + create_session_state_header_provider( + state_key=state_key, + header_name=header_name, + header_format=state_format.get(header_name, "{value}"), + default_value=None, + strict=mcp_toolset_config.state_header_strict, + ) + for state_key, header_name in state_mapping.items() + ] + + header_provider = create_combined_header_provider(providers) + return cls( connection_params=connection_params, tool_filter=mcp_toolset_config.tool_filter, tool_name_prefix=mcp_toolset_config.tool_name_prefix, auth_scheme=mcp_toolset_config.auth_scheme, auth_credential=mcp_toolset_config.auth_credential, + header_provider=header_provider, ) @@ -265,6 +419,55 @@ class McpToolsetConfig(BaseToolConfig): auth_credential: Optional[AuthCredential] = None + state_header_mapping: Optional[Dict[str, str]] = None + """Maps session state keys to HTTP header names. + + When specified, values from the session state will be extracted and passed + as HTTP headers to the MCP server. This is useful for propagating + authentication tokens or other context from the ADK session to the MCP server. + + Example:: + + state_header_mapping: + jwt_token: Authorization + tenant_id: X-Tenant-ID + + This will read `session.state["jwt_token"]` and set it as the + "Authorization" header, and read `session.state["tenant_id"]` and + set it as the "X-Tenant-ID" header. + """ + + state_header_format: Optional[Dict[str, str]] = None + """Optional formatting for header values extracted from session state. + + Supports format strings with {value} as a placeholder for the state value. + Only applies to headers specified in state_header_mapping. + + Example:: + + state_header_format: + Authorization: "Bearer {value}" + X-API-Key: "key:{value}" + + If not specified for a particular header, the value from session state is + used as-is. + """ + + state_header_strict: bool = False + """Enable strict type validation for state header values. + + When True, raises ValueError if state values are non-primitive types + (not str, int, float, or bool). This helps catch configuration errors + early by preventing accidental serialization of complex objects into headers. + + When False (default), non-primitive types trigger a warning but are still + formatted into headers. + + Example:: + + state_header_strict: true # Raises error on non-primitive types + """ + @model_validator(mode="after") def _check_only_one_params_field(self): param_fields = [ diff --git a/src/google/adk/tools/mcp_tool/types.py b/src/google/adk/tools/mcp_tool/types.py new file mode 100644 index 0000000000..f4198fcd88 --- /dev/null +++ b/src/google/adk/tools/mcp_tool/types.py @@ -0,0 +1,22 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Callable +from typing import Dict + +from ...agents.readonly_context import ReadonlyContext + +HeaderProvider = Callable[[ReadonlyContext], Dict[str, str]] diff --git a/tests/unittests/agents/test_readonly_context_state.py b/tests/unittests/agents/test_readonly_context_state.py new file mode 100644 index 0000000000..b4c628a1c1 --- /dev/null +++ b/tests/unittests/agents/test_readonly_context_state.py @@ -0,0 +1,83 @@ +from collections import ChainMap +import unittest +from unittest.mock import MagicMock + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.sessions.session import Session + + +class TestReadonlyContextState(unittest.TestCase): + + def test_state_merging_precedence(self): + # Setup + mock_session = MagicMock(spec=Session) + mock_session.state = { + "persistent_key": "persistent_value", + "conflict_key": "persistent_value", + } + + mock_invocation_context = MagicMock(spec=InvocationContext) + mock_invocation_context.session = mock_session + mock_invocation_context.request_state = { + "ephemeral_key": "ephemeral_value", + "conflict_key": "ephemeral_value", + } + + readonly_context = ReadonlyContext(mock_invocation_context) + + # Verify + state = readonly_context.state + + # Check that ephemeral keys are present + self.assertEqual(state["ephemeral_key"], "ephemeral_value") + + # Check that persistent keys are present + self.assertEqual(state["persistent_key"], "persistent_value") + + # Check that ephemeral keys override persistent keys + self.assertEqual(state["conflict_key"], "ephemeral_value") + + # Verify it behaves like a mapping + self.assertIn("ephemeral_key", state) + self.assertIn("persistent_key", state) + self.assertEqual(state.get("ephemeral_key"), "ephemeral_value") + + def test_state_merging_empty_request_state(self): + # Setup + mock_session = MagicMock(spec=Session) + mock_session.state = {"persistent_key": "persistent_value"} + + mock_invocation_context = MagicMock(spec=InvocationContext) + mock_invocation_context.session = mock_session + mock_invocation_context.request_state = {} + + readonly_context = ReadonlyContext(mock_invocation_context) + + # Verify + state = readonly_context.state + self.assertEqual(state["persistent_key"], "persistent_value") + self.assertNotIn("ephemeral_key", state) + + def test_state_immutability(self): + # Setup + mock_session = MagicMock(spec=Session) + mock_session.state = {"key": "value"} + + mock_invocation_context = MagicMock(spec=InvocationContext) + mock_invocation_context.session = mock_session + mock_invocation_context.request_state = {} + + readonly_context = ReadonlyContext(mock_invocation_context) + state = readonly_context.state + + # Verify it raises TypeError on assignment + with self.assertRaises(TypeError): + state["key"] = "new_value" + + with self.assertRaises(TypeError): + state["new_key"] = "value" + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unittests/tools/mcp_tool/test_jwt_token_propagation.py b/tests/unittests/tools/mcp_tool/test_jwt_token_propagation.py new file mode 100644 index 0000000000..4d49d1ef9e --- /dev/null +++ b/tests/unittests/tools/mcp_tool/test_jwt_token_propagation.py @@ -0,0 +1,658 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for JWT token propagation feature in MCP toolset.""" + +import sys +from unittest.mock import Mock + +import pytest + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" +) + +# Import dependencies with version checking +try: + from google.adk.agents.readonly_context import ReadonlyContext + from google.adk.tools.mcp_tool.mcp_toolset import create_session_state_header_provider + from google.adk.tools.mcp_tool.mcp_toolset import McpToolsetConfig + from mcp import StdioServerParameters +except ImportError as e: + if sys.version_info < (3, 10): + # Create dummy classes to prevent NameError during test collection + class DummyClass: + pass + + create_session_state_header_provider = DummyClass + McpToolsetConfig = DummyClass + StdioServerParameters = DummyClass + ReadonlyContext = DummyClass + else: + raise e + + +class TestCreateSessionStateHeaderProvider: + """Test suite for create_session_state_header_provider function.""" + + def test_extract_jwt_token_default_format(self): + """Test extracting JWT token with default Authorization Bearer format.""" + # Create mock context with JWT token in state + mock_context = Mock(spec=ReadonlyContext) + mock_context.state = { + "jwt_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." + } + + # Create header provider + provider = create_session_state_header_provider(state_key="jwt_token") + + # Call provider + headers = provider(mock_context) + + # Verify headers + assert headers == { + "Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." + } + + def test_extract_with_custom_header_name(self): + """Test extracting token with custom header name.""" + mock_context = Mock(spec=ReadonlyContext) + mock_context.state = {"api_key": "secret-key-123"} + + provider = create_session_state_header_provider( + state_key="api_key", header_name="X-API-Key", header_format="{value}" + ) + + headers = provider(mock_context) + + assert headers == {"X-API-Key": "secret-key-123"} + + def test_extract_with_custom_format(self): + """Test extracting token with custom formatting.""" + mock_context = Mock(spec=ReadonlyContext) + mock_context.state = {"tenant_id": "tenant-123"} + + provider = create_session_state_header_provider( + state_key="tenant_id", + header_name="X-Tenant-ID", + header_format="tenant:{value}", + ) + + headers = provider(mock_context) + + assert headers == {"X-Tenant-ID": "tenant:tenant-123"} + + def test_missing_state_key_returns_empty(self): + """Test that missing state key returns empty dict when no default.""" + mock_context = Mock(spec=ReadonlyContext) + mock_context.state = {} + + provider = create_session_state_header_provider(state_key="jwt_token") + + headers = provider(mock_context) + + assert headers == {} + + def test_missing_state_key_uses_default(self): + """Test that missing state key uses default value if provided.""" + mock_context = Mock(spec=ReadonlyContext) + mock_context.state = {} + + provider = create_session_state_header_provider( + state_key="jwt_token", default_value="default-token" + ) + + headers = provider(mock_context) + + assert headers == {"Authorization": "Bearer default-token"} + + def test_none_value_in_state_returns_empty(self): + """Test that None value in state returns empty dict.""" + mock_context = Mock(spec=ReadonlyContext) + mock_context.state = {"jwt_token": None} + + provider = create_session_state_header_provider(state_key="jwt_token") + + headers = provider(mock_context) + + assert headers == {} + + def test_empty_string_value_returns_empty(self): + """Test that empty string value in state returns empty dict.""" + mock_context = Mock(spec=ReadonlyContext) + mock_context.state = {"jwt_token": ""} + + provider = create_session_state_header_provider(state_key="jwt_token") + + headers = provider(mock_context) + + assert headers == {} + + def test_strict_mode_with_primitive_types(self): + """Test that strict mode works properly with primitive types.""" + mock_context = Mock(spec=ReadonlyContext) + + # Test with string + mock_context.state = {"token": "my-token"} + provider = create_session_state_header_provider( + state_key="token", strict=True + ) + headers = provider(mock_context) + assert headers == {"Authorization": "Bearer my-token"} + + # Test with int + mock_context.state = {"count": 42} + provider = create_session_state_header_provider( + state_key="count", + header_name="X-Count", + header_format="{value}", + strict=True, + ) + headers = provider(mock_context) + assert headers == {"X-Count": "42"} + + def test_strict_mode_raises_on_non_primitive_types(self): + """Test that strict mode raises ValueError for non-primitive types.""" + mock_context = Mock(spec=ReadonlyContext) + mock_context.state = {"complex_data": {"nested": "dict"}} + + provider = create_session_state_header_provider( + state_key="complex_data", strict=True + ) + + with pytest.raises(ValueError) as exc_info: + provider(mock_context) + + assert "complex_data" in str(exc_info.value) + assert "dict" in str(exc_info.value) + assert "may not serialize correctly" in str(exc_info.value) + + +class TestMcpToolsetConfigStateHeaderMapping: + """Test suite for state_header_mapping configuration.""" + + def test_config_with_single_state_mapping(self): + """Test config with single state to header mapping.""" + config = McpToolsetConfig( + stdio_server_params=StdioServerParameters( + command="test_command", args=[] + ), + state_header_mapping={"jwt_token": "Authorization"}, + state_header_format={"Authorization": "Bearer {value}"}, + ) + + assert config.state_header_mapping == {"jwt_token": "Authorization"} + assert config.state_header_format == {"Authorization": "Bearer {value}"} + + def test_config_with_multiple_state_mappings(self): + """Test config with multiple state to header mappings.""" + config = McpToolsetConfig( + stdio_server_params=StdioServerParameters( + command="test_command", args=[] + ), + state_header_mapping={ + "jwt_token": "Authorization", + "tenant_id": "X-Tenant-ID", + "api_key": "X-API-Key", + }, + state_header_format={ + "Authorization": "Bearer {value}", + "X-API-Key": "key:{value}", + }, + ) + + assert len(config.state_header_mapping) == 3 + assert config.state_header_mapping["jwt_token"] == "Authorization" + assert config.state_header_mapping["tenant_id"] == "X-Tenant-ID" + assert config.state_header_format["Authorization"] == "Bearer {value}" + + def test_config_without_state_mapping(self): + """Test config without state mapping (backward compatibility).""" + config = McpToolsetConfig( + stdio_server_params=StdioServerParameters( + command="test_command", args=[] + ) + ) + + assert config.state_header_mapping is None + assert config.state_header_format is None + + +class TestHeaderSecurityValidation: + """Test suite for header security validation features.""" + + def test_header_name_validation_valid_names(self): + """Test that valid header names are accepted.""" + from google.adk.tools.mcp_tool._internal import validate_header_name + + # Valid header names should not raise exceptions + valid_names = [ + "Authorization", + "X-API-Key", + "Content-Type", + "X-Custom-Header", + ] + + for name in valid_names: + validate_header_name(name) # Should not raise + + def test_header_name_validation_invalid_names(self): + """Test that invalid header names are rejected.""" + from google.adk.tools.mcp_tool._internal import validate_header_name + + # Invalid header names should raise ValueError + invalid_names = [ + "", # Empty string + "Authorization\n", # Newline + "X-API:Key", # Colon + "X-API Key", # Space + "X-API\x01Key", # Control character + ] + + for name in invalid_names: + with pytest.raises(ValueError) as exc_info: + validate_header_name(name) + assert ( + "invalid characters" in str(exc_info.value).lower() + or "empty" in str(exc_info.value).lower() + ) + + def test_header_value_sanitization_safe_values(self): + """Test that safe header values are unchanged.""" + from google.adk.tools.mcp_tool._internal import _sanitize_header_value + + safe_values = [ + "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9", + "api-key-123", + "tenant-456", + "Basic dXNlcjpwYXNz", # Base64 auth + ] + + for value in safe_values: + result = _sanitize_header_value(value) + assert result == value + + def test_header_value_sanitization_dangerous_values(self): + """Test that dangerous characters are removed from header values.""" + from google.adk.tools.mcp_tool._internal import _sanitize_header_value + + dangerous_values = [ + ("Bearer token\x00injected", "Bearer tokeninjected"), + ("api-key\x00malicious", "api-keymalicious"), + ("value\x00more", "valuemore"), + ("token\x00data", "tokendata"), + ] + + for input_val, expected in dangerous_values: + result = _sanitize_header_value(input_val) + assert result == expected + + def test_header_value_sanitization_non_string_values(self): + """Test that non-string values are converted to string.""" + from google.adk.tools.mcp_tool._internal import _sanitize_header_value + + result_int = _sanitize_header_value(123) + assert result_int == "123" + + result_bool = _sanitize_header_value(True) + assert result_bool == "True" + + def test_session_state_header_provider_with_invalid_header_name(self): + """Test that invalid header names raise ValueError during provider creation.""" + from google.adk.tools.mcp_tool.mcp_toolset import create_session_state_header_provider + + with pytest.raises(ValueError) as exc_info: + create_session_state_header_provider( + state_key="token", + header_name="Authorization\n", # Invalid header name + ) + + assert "invalid characters" in str(exc_info.value).lower() + + def test_session_state_header_provider_sanitizes_values(self): + """Test that header provider sanitizes values from state.""" + from google.adk.tools.mcp_tool.mcp_toolset import create_session_state_header_provider + + mock_context = Mock(spec=ReadonlyContext) + mock_context.state = {"token": "Bearer\x00token\x01injected"} + + provider = create_session_state_header_provider( + state_key="token", header_name="Authorization", header_format="{value}" + ) + + headers = provider(mock_context) + + # The provider should sanitize the dangerous characters + assert headers == {"Authorization": "Bearertokeninjected"} + + +class TestMcpToolsetFromConfigWithStateMapping: + """Test suite for McpToolset.from_config with state header mapping.""" + + def test_from_config_creates_header_provider(self): + """Test that from_config creates header provider from state mapping.""" + from google.adk.tools.mcp_tool.mcp_toolset import McpToolset + from google.adk.tools.tool_configs import ToolArgsConfig + + config = ToolArgsConfig( + stdio_server_params={"command": "test_command", "args": []}, + state_header_mapping={"jwt_token": "Authorization"}, + state_header_format={"Authorization": "Bearer {value}"}, + ) + + toolset = McpToolset.from_config(config, "/fake/path") + + # Verify header_provider was created + assert toolset._header_provider is not None + + # Test the created header provider + mock_context = Mock(spec=ReadonlyContext) + mock_context.state = {"jwt_token": "test-token-123"} + + headers = toolset._header_provider(mock_context) + + assert headers == {"Authorization": "Bearer test-token-123"} + + def test_from_config_multiple_headers(self): + """Test from_config with multiple header mappings.""" + from google.adk.tools.mcp_tool.mcp_toolset import McpToolset + from google.adk.tools.tool_configs import ToolArgsConfig + + config = ToolArgsConfig( + stdio_server_params={"command": "test_command", "args": []}, + state_header_mapping={ + "jwt_token": "Authorization", + "tenant_id": "X-Tenant-ID", + }, + state_header_format={"Authorization": "Bearer {value}"}, + ) + + toolset = McpToolset.from_config(config, "/fake/path") + + # Test the created header provider + mock_context = Mock(spec=ReadonlyContext) + mock_context.state = {"jwt_token": "token-123", "tenant_id": "tenant-456"} + + headers = toolset._header_provider(mock_context) + + assert headers["Authorization"] == "Bearer token-123" + assert headers["X-Tenant-ID"] == "tenant-456" + + def test_from_config_omits_missing_state_keys(self): + """Test that missing state keys are omitted from headers.""" + from google.adk.tools.mcp_tool.mcp_toolset import McpToolset + from google.adk.tools.tool_configs import ToolArgsConfig + + config = ToolArgsConfig( + stdio_server_params={"command": "test_command", "args": []}, + state_header_mapping={ + "jwt_token": "Authorization", + "tenant_id": "X-Tenant-ID", + }, + ) + + toolset = McpToolset.from_config(config, "/fake/path") + + # Only include jwt_token in state + mock_context = Mock(spec=ReadonlyContext) + mock_context.state = {"jwt_token": "token-123"} + + headers = toolset._header_provider(mock_context) + + # Only Authorization header should be present + assert "Authorization" in headers + assert "X-Tenant-ID" not in headers + + def test_from_config_no_state_mapping_no_provider(self): + """Test that no header provider is created when no state mapping.""" + from google.adk.tools.mcp_tool.mcp_toolset import McpToolset + from google.adk.tools.tool_configs import ToolArgsConfig + + config = ToolArgsConfig( + stdio_server_params={"command": "test_command", "args": []} + ) + + toolset = McpToolset.from_config(config, "/fake/path") + + # No header provider should be created + assert toolset._header_provider is None + + def test_from_config_with_strict_mode(self): + """Test that from_config respects state_header_strict setting.""" + from google.adk.tools.mcp_tool.mcp_toolset import McpToolset + from google.adk.tools.tool_configs import ToolArgsConfig + + config = ToolArgsConfig( + stdio_server_params={"command": "test_command", "args": []}, + state_header_mapping={"data": "X-Data"}, + state_header_strict=True, # Enable strict mode + ) + + toolset = McpToolset.from_config(config, "/fake/path") + + # Test with non-primitive type - should raise ValueError + mock_context = Mock(spec=ReadonlyContext) + mock_context.state = {"data": {"nested": "object"}} + + with pytest.raises(ValueError) as exc_info: + toolset._header_provider(mock_context) + + assert "data" in str(exc_info.value) + assert "dict" in str(exc_info.value) + + +class TestRFC7230Compliance: + """Test suite for RFC 7230 compliant header handling.""" + + def test_header_name_validation_rfc_compliant(self): + """Test that header name validation follows RFC 7230.""" + from google.adk.tools.mcp_tool._internal import validate_header_name + + # RFC 7230 compliant header names should be accepted + valid_names = [ + "Authorization", + "X-API-Key", + "Content-Type", + "X-Custom-Header-123", + "Accept-Encoding", + "User-Agent", + "If-Modified-Since", + ] + + for name in valid_names: + validate_header_name(name) # Should not raise + + # RFC 7230 invalid header names should be rejected + invalid_names = [ + "", # Empty + "Authorization\n", # Newline + "X-API:Key", # Colon + "X API Key", # Space + "X-API\x01Key", # Control character + "X-API()Key", # Parentheses + "X-API@Key", # At symbol + "X-API,Key", # Comma + "X-API;Key", # Semicolon + 'X-API"Key', # Double quote + "X-API\\Key", # Backslash + "X-API/Key", # Forward slash + "X-API[Key]", # Brackets + "X-API?Key", # Question mark + "X-API=Key", # Equals + "X-API{Key}", # Braces + "X-API\tKey", # Tab + ] + + for name in invalid_names: + with pytest.raises(ValueError) as exc_info: + validate_header_name(name) + assert ( + "invalid characters" in str(exc_info.value).lower() + or "empty" in str(exc_info.value).lower() + ) + + def test_header_value_sanitization_rfc_compliant(self): + """Test that header value sanitization is RFC 7230 compliant.""" + from google.adk.tools.mcp_tool._internal import _sanitize_header_value + + # Safe header values should remain unchanged + safe_values = [ + "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9", + "application/json", + "text/html; charset=utf-8", + "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW", + "Basic dXNlcjpwYXNz", # Base64 auth + "token-123456789", + "api-key-secret", + ] + + for value in safe_values: + result = _sanitize_header_value(value) + assert result == value + + # Only truly dangerous characters should be removed + dangerous_cases = [ + ("Bearer\x00token", "Bearertoken"), # Null byte + ("token\x01inject", "tokeninject"), # SOH + ("data\x02malicious", "datamalicious"), # STX + ("value\x03attack", "valueattack"), # ETX + ("header\x04break", "headerbreak"), # EOT + ("content\x05with\x06control", "contentwithcontrol"), # ENQ, ACK + ("test\x07bell\x08backspace", "testbellbackspace"), # BEL, BS + ("text\x0Bvertical\x0Ctab", "textverticaltab"), # VT, FF + ("shift\x0E\x0Fin", "shiftin"), # SO, SI + ("dle\x10control", "dlecontrol"), # DLE + ("data\x11\x12\x13\x14chars", "datachars"), # DC1-DC4 + ("nack\x15syn\x16etb\x17", "nacksynetb"), # NAK, SYN, ETB + ("can\x18em\x19sub\x1Aesc", "canemsubesc"), # CAN, EM, SUB, ESC + ("fs\x1ags\x1brs\x1cus", "fsgsrsus"), # FS, GS, RS, US + ("space\x20test", "space test"), # Space should be preserved + ( + "normal!@#$%^&*()test", + "normal!@#$%^&*()test", + ), # Special chars preserved + ] + + for input_val, expected in dangerous_cases: + result = _sanitize_header_value(input_val) + assert result == expected + + def test_header_value_preserves_rfc_folding(self): + """Test that legitimate CRLF sequences for header folding are preserved.""" + from google.adk.tools.mcp_tool._internal import _sanitize_header_value + + # Multi-line headers with proper folding should be preserved (RFC 7230 §3.2.4) + folding_cases = [ + ("Authorization: Bearer token", "Authorization: Bearer token"), + ("X-Custom: value1\r\n\tvalue2", "X-Custom: value1\r\n\tvalue2"), + ("X-Header: line1\r\n line2", "X-Header: line1\r\n line2"), + ( + "Content-Type: application/json\r\n\tcharset=utf-8", + "Content-Type: application/json\r\n\tcharset=utf-8", + ), + ] + + for folding_case in folding_cases: + result = _sanitize_header_value(folding_case[0]) + assert result == folding_case[1] + + def test_header_value_validation_rfc_compliant(self): + """Test that header value validation is RFC 7230 compliant.""" + from google.adk.tools.mcp_tool._internal import _validate_header_value + + # Valid header values should pass validation + valid_values = [ + "Bearer token123", + "application/json", + "text/html; charset=utf-8", + "Basic dXNlcjpwYXNz", + "multipart/form-data; boundary=something", + 123, # Numbers should be converted to string + True, # Boolean should be converted to string + 45.67, # Float should be converted to string + ] + + for value in valid_values: + try: + _validate_header_value(value) # Should not raise + except ValueError: + pytest.fail(f"Valid value {value!r} should not raise ValueError") + + # Invalid header values should raise ValueError + invalid_values = [ + "token\x00with\x01null", # Contains control characters + "data\x02with\x03control", # Contains control characters + b"binary\x00data", # Binary data with null bytes + # {"complex": "object"}, # Complex object (when converted to string) - REMOVED as it is valid + # ["list", "data"], # List (when converted to string) - REMOVED as it is valid + ] + + for value in invalid_values: + with pytest.raises(ValueError): + _validate_header_value(value) + + def test_session_state_header_provider_rfc_compliant(self): + """Test that session state header provider handles edge cases correctly.""" + from google.adk.tools.mcp_tool.mcp_toolset import create_session_state_header_provider + + mock_context = Mock(spec=ReadonlyContext) + + # Test with dangerous characters that get sanitized + mock_context.state = {"token": "Bearer\x00token\x01with\x02control"} + provider = create_session_state_header_provider( + state_key="token", header_name="Authorization", header_format="{value}" + ) + + headers = provider(mock_context) + assert headers == {"Authorization": "Bearertokenwithcontrol"} + + # Test with None and empty values + mock_context.state = {"token": None} + provider = create_session_state_header_provider(state_key="token") + headers = provider(mock_context) + assert headers == {} + + mock_context.state = {"token": ""} + provider = create_session_state_header_provider(state_key="token") + headers = provider(mock_context) + assert headers == {} + + # Test with default value + mock_context.state = {} + provider = create_session_state_header_provider( + state_key="token", default_value="default-token" + ) + headers = provider(mock_context) + assert headers == {"Authorization": "Bearer default-token"} + + def test_binary_data_handling(self): + """Test proper handling of binary data in header values.""" + from google.adk.tools.mcp_tool._internal import _validate_header_value + + # Binary data should be rejected by default + with pytest.raises(ValueError, match="Binary data not allowed"): + _validate_header_value(b"binary data") + + # Binary data should be accepted with allow_binary=True if no dangerous bytes + safe_binary = b"safe binary data" + try: + _validate_header_value(safe_binary, allow_binary=True) + except ValueError: + pytest.fail("Safe binary data should be allowed with allow_binary=True") + + # Binary data with dangerous bytes should be rejected even with allow_binary=True + dangerous_binary = b"dangerous\x00data\x01with\x02control" + with pytest.raises(ValueError, match="dangerous"): + _validate_header_value(dangerous_binary, allow_binary=True)