diff --git a/a2a_agents/python/a2ui_agent/src/a2ui/extension/send_a2ui_to_client_toolset.py b/a2a_agents/python/a2ui_agent/src/a2ui/extension/send_a2ui_to_client_toolset.py index 1101869cf..baba339f4 100644 --- a/a2a_agents/python/a2ui_agent/src/a2ui/extension/send_a2ui_to_client_toolset.py +++ b/a2a_agents/python/a2ui_agent/src/a2ui/extension/send_a2ui_to_client_toolset.py @@ -87,6 +87,7 @@ async def get_schema(ctx: ReadonlyContext) -> dict[str, Any]: from a2a import types as a2a_types from a2ui.extension.a2ui_extension import create_a2ui_part from a2ui.extension.a2ui_schema_utils import wrap_as_json_array +from a2ui.extension.validation import validate_a2ui_json from google.adk.a2a.converters import part_converter from google.adk.agents.readonly_context import ReadonlyContext from google.adk.models import LlmRequest @@ -262,7 +263,7 @@ async def run_async( a2ui_json_payload = [a2ui_json_payload] a2ui_schema = await self.get_a2ui_schema(tool_context) - jsonschema.validate(instance=a2ui_json_payload, schema=a2ui_schema) + validate_a2ui_json(a2ui_json_payload, a2ui_schema) logger.info( f"Validated call to tool {self.TOOL_NAME} with {self.A2UI_JSON_ARG_NAME}" diff --git a/a2a_agents/python/a2ui_agent/src/a2ui/extension/validation.py b/a2a_agents/python/a2ui_agent/src/a2ui/extension/validation.py new file mode 100644 index 000000000..db98e16f9 --- /dev/null +++ b/a2a_agents/python/a2ui_agent/src/a2ui/extension/validation.py @@ -0,0 +1,290 @@ +from typing import Any, Dict, Iterator, List, Set, Tuple, Union +import jsonschema +import re + +# RFC 6901 compliant regex for JSON Pointer +JSON_POINTER_PATTERN = re.compile(r"^(?:\/(?:[^~\/]|~[01])*)*$") + +# Recursion Limits +MAX_GLOBAL_DEPTH = 50 +MAX_FUNC_CALL_DEPTH = 5 + +# Constants +COMPONENTS = "components" +ID = "id" +COMPONENT_PROPERTIES = "componentProperties" +ROOT = "root" +PATH = "path" +FUNCTION_CALL = "functionCall" +CALL = "call" +ARGS = "args" + + +def validate_a2ui_json( + a2ui_json: Union[Dict[str, Any], List[Any]], a2ui_schema: Dict[str, Any] +) -> None: + """ + Validates the A2UI JSON payload against the provided schema and checks for integrity. + + Checks performed: + 1. **JSON Schema Validation**: Ensures payload adheres to the A2UI schema. + 2. **Component Integrity**: + - All component IDs are unique. + - A 'root' component exists. + - All unique component references point to valid IDs. + 3. **Topology**: + - No circular references (including self-references). + - No orphaned components (all components must be reachable from 'root'). + 4. **Recursion Limits**: + - Global recursion depth limit (50). + - FunctionCall recursion depth limit (5). + 5. **Path Syntax**: + - Validates JSON Pointer syntax for data paths. + + Args: + a2ui_json: The JSON payload to validate. + a2ui_schema: The schema to validate against. + + Raises: + jsonschema.ValidationError: If the payload does not match the schema. + ValueError: If integrity, topology, or recursion checks fail. + """ + jsonschema.validate(instance=a2ui_json, schema=a2ui_schema) + + # Normalize to list for iteration + messages = a2ui_json if isinstance(a2ui_json, list) else [a2ui_json] + + for message in messages: + if not isinstance(message, dict): + continue + + # Check for SurfaceUpdate which has 'components' + if COMPONENTS in message: + ref_map = _extract_component_ref_fields(a2ui_schema) + _validate_component_integrity(message[COMPONENTS], ref_map) + _validate_topology(message[COMPONENTS], ref_map) + + _validate_recursion_and_paths(message) + + +def _validate_component_integrity( + components: List[Dict[str, Any]], + ref_fields_map: Dict[str, tuple[Set[str], Set[str]]], +) -> None: + """ + Validates that: + 1. All component IDs are unique. + 2. A 'root' component exists. + 3. All references (children, child, etc.) point to existing IDs. + """ + ids: Set[str] = set() + + # 1. Collect IDs and check for duplicates + for comp in components: + comp_id = comp.get(ID) + if comp_id is None: + continue + + if comp_id in ids: + raise ValueError(f"Duplicate component ID found: '{comp_id}'") + ids.add(comp_id) + + # 2. Check for root component + if ROOT not in ids: + raise ValueError( + f"Missing '{ROOT}' component: One component must have '{ID}' set to '{ROOT}'." + ) + + # 3. Check for dangling references using helper + for comp in components: + for ref_id, field_name in _get_component_references(comp, ref_fields_map): + if ref_id not in ids: + raise ValueError( + f"Component '{comp.get(ID)}' references missing ID '{ref_id}' in field" + f" '{field_name}'" + ) + + +def _validate_topology( + components: List[Dict[str, Any]], + ref_fields_map: Dict[str, tuple[Set[str], Set[str]]], +) -> None: + """ + Validates the topology of the component tree: + 1. No circular references (including self-references). + 2. No orphaned components (all components must be reachable from 'root'). + """ + adj_list: Dict[str, List[str]] = {} + all_ids: Set[str] = set() + + # Build Adjacency List + for comp in components: + comp_id = comp.get(ID) + if comp_id is None: + continue + + all_ids.add(comp_id) + if comp_id not in adj_list: + adj_list[comp_id] = [] + + for ref_id, field_name in _get_component_references(comp, ref_fields_map): + if ref_id == comp_id: + raise ValueError( + f"Self-reference detected: Component '{comp_id}' references itself in field" + f" '{field_name}'" + ) + adj_list[comp_id].append(ref_id) + + # Detect Cycles using DFS + visited: Set[str] = set() + recursion_stack: Set[str] = set() + + def dfs(node_id: str): + visited.add(node_id) + recursion_stack.add(node_id) + + for neighbor in adj_list.get(node_id, []): + if neighbor not in visited: + dfs(neighbor) + elif neighbor in recursion_stack: + raise ValueError( + f"Circular reference detected involving component '{neighbor}'" + ) + + recursion_stack.remove(node_id) + + if ROOT in all_ids: + dfs(ROOT) + + # Check for Orphans + orphans = all_ids - visited + if orphans: + sorted_orphans = sorted(list(orphans)) + raise ValueError( + f"Orphaned components detected (not reachable from '{ROOT}'): {sorted_orphans}" + ) + + +def _extract_component_ref_fields( + schema: Dict[str, Any], +) -> Dict[str, tuple[Set[str], Set[str]]]: + """ + Parses the JSON schema to identify which component properties reference other components. + Returns a map: { component_name: (set_of_single_ref_fields, set_of_list_ref_fields) } + """ + ref_map = {} + + root_defs = schema.get("$defs") or schema.get("definitions", {}) + + # Helper to check if a property schema looks like a ComponentId reference + def is_component_id_ref(prop_schema: Dict[str, Any]) -> bool: + ref = prop_schema.get("$ref", "") + if ref.endswith("ComponentId"): + return True + return False + + def is_child_list_ref(prop_schema: Dict[str, Any]) -> bool: + ref = prop_schema.get("$ref", "") + if ref.endswith("ChildList"): + return True + # Or array of ComponentIds + if prop_schema.get("type") == "array": + items = prop_schema.get("items", {}) + if is_component_id_ref(items): + return True + return False + + comps_schema = schema.get("properties", {}).get(COMPONENTS, {}) + items_schema = comps_schema.get("items", {}) + comp_props_schema = items_schema.get("properties", {}).get(COMPONENT_PROPERTIES, {}) + all_components = comp_props_schema.get("properties", {}) + + for comp_name, comp_schema in all_components.items(): + single_refs = set() + list_refs = set() + + props = comp_schema.get("properties", {}) + for prop_name, prop_schema in props.items(): + if is_component_id_ref(prop_schema): + single_refs.add(prop_name) + elif is_child_list_ref(prop_schema): + list_refs.add(prop_name) + + if single_refs or list_refs: + ref_map[comp_name] = (single_refs, list_refs) + + return ref_map + + +def _get_component_references( + component: Dict[str, Any], ref_fields_map: Dict[str, tuple[Set[str], Set[str]]] +) -> Iterator[Tuple[str, str]]: + """ + Helper to extract all referenced component IDs from a component. + Yields (referenced_id, field_name). + """ + comp_props_container = component.get(COMPONENT_PROPERTIES) + if not isinstance(comp_props_container, dict): + return + + for comp_type, props in comp_props_container.items(): + if not isinstance(props, dict): + continue + + single_refs, list_refs = ref_fields_map.get(comp_type, (set(), set())) + + for key, value in props.items(): + if key in single_refs: + if isinstance(value, str): + yield value, key + elif key in list_refs: + if isinstance(value, list): + for item in value: + if isinstance(item, str): + yield item, key + + +def _validate_recursion_and_paths(data: Any) -> None: + """ + Validates: + 1. Global recursion depth limit (50). + 2. FunctionCall recursion depth limit (5). + 3. Path syntax for DataBindings/DataModelUpdates. + """ + + def traverse(item: Any, global_depth: int, func_depth: int): + if global_depth > MAX_GLOBAL_DEPTH: + raise ValueError(f"Global recursion limit exceeded: Depth > {MAX_GLOBAL_DEPTH}") + + if isinstance(item, list): + for x in item: + traverse(x, global_depth + 1, func_depth) + return + + if isinstance(item, dict): + # Check for path + if PATH in item and isinstance(item[PATH], str): + path = item[PATH] + if not re.fullmatch(JSON_POINTER_PATTERN, path): + raise ValueError(f"Invalid JSON Pointer syntax: '{path}'") + + # Check for FunctionCall + is_func = CALL in item and ARGS in item + + if is_func: + if func_depth >= MAX_FUNC_CALL_DEPTH: + raise ValueError( + f"Recursion limit exceeded: {FUNCTION_CALL} depth > {MAX_FUNC_CALL_DEPTH}" + ) + + # Increment func_depth only for 'args', but global_depth matches traversal + for k, v in item.items(): + if k == ARGS: + traverse(v, global_depth + 1, func_depth + 1) + else: + traverse(v, global_depth + 1, func_depth) + else: + for v in item.values(): + traverse(v, global_depth + 1, func_depth) + + traverse(data, 0, 0) diff --git a/a2a_agents/python/a2ui_agent/tests/test_validation.py b/a2a_agents/python/a2ui_agent/tests/test_validation.py new file mode 100644 index 000000000..afd642e58 --- /dev/null +++ b/a2a_agents/python/a2ui_agent/tests/test_validation.py @@ -0,0 +1,358 @@ +import pytest +import jsonschema +from a2ui.extension.validation import validate_a2ui_json + + +# Fixture for the schema +@pytest.fixture +def schema(): + return { + "type": "object", + "$defs": { + "ComponentId": {"type": "string"}, + "ChildList": {"type": "array", "items": {"$ref": "#/$defs/ComponentId"}}, + }, + "properties": { + "components": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"$ref": "#/$defs/ComponentId"}, + "componentProperties": { + "type": "object", + "properties": { + "Column": { + "type": "object", + "properties": { + "children": {"$ref": "#/$defs/ChildList"} + }, + }, + "Row": { + "type": "object", + "properties": { + "children": {"$ref": "#/$defs/ChildList"} + }, + }, + "Container": { + "type": "object", + "properties": { + "children": {"$ref": "#/$defs/ChildList"} + }, + }, + "Card": { + "type": "object", + "properties": { + "child": {"$ref": "#/$defs/ComponentId"} + }, + }, + "Button": { + "type": "object", + "properties": { + "child": {"$ref": "#/$defs/ComponentId"}, + "action": { + "properties": { + "functionCall": { + "properties": { + "call": {"type": "string"}, + "args": {"type": "object"}, + } + } + } + }, + }, + }, + "Text": { + "type": "object", + "properties": { + "text": { + "oneOf": [ + {"type": "string"}, + {"type": "object"}, + ] + } + }, + }, + }, + }, + }, + "required": ["id"], + }, + } + }, + } + + +def test_validate_a2ui_json_valid_integrity(schema): + payload = { + "components": [ + {"id": "root", "componentProperties": {"Column": {"children": ["child1"]}}}, + {"id": "child1", "componentProperties": {"Text": {"text": "Hello"}}}, + ] + } + validate_a2ui_json(payload, schema) + + +def test_validate_a2ui_json_duplicate_ids(schema): + payload = { + "components": [ + {"id": "root", "componentProperties": {}}, + {"id": "root", "componentProperties": {}}, + ] + } + with pytest.raises(ValueError, match="Duplicate component ID found: 'root'"): + validate_a2ui_json(payload, schema) + + +def test_validate_a2ui_json_missing_root(schema): + payload = {"components": [{"id": "not-root", "componentProperties": {}}]} + with pytest.raises(ValueError, match="Missing 'root' component"): + validate_a2ui_json(payload, schema) + + +@pytest.mark.parametrize( + "component_type, field_name, ids_to_ref", + [ + ("Card", "child", "missing_child"), + ("Column", "children", ["child1", "missing_child"]), + ], +) +def test_validate_a2ui_json_dangling_references( + schema, component_type, field_name, ids_to_ref +): + """Test dangling references for both single and list fields.""" + # Construct payload dynamically + props = {field_name: ids_to_ref} + payload = { + "components": [{"id": "root", "componentProperties": {component_type: props}}] + } + if isinstance(ids_to_ref, list): + # Add valid children if any + for child_id in ids_to_ref: + if child_id != "missing_child": + payload["components"].append({"id": child_id, "componentProperties": {}}) + + with pytest.raises( + ValueError, + match=( + "Component 'root' references missing ID 'missing_child' in field" + f" '{field_name}'" + ), + ): + validate_a2ui_json(payload, schema) + + +def test_validate_a2ui_json_self_reference(schema): + payload = { + "components": [ + {"id": "root", "componentProperties": {"Container": {"children": ["root"]}}} + ] + } + with pytest.raises( + ValueError, + match=( + "Self-reference detected: Component 'root' references itself in field" + " 'children'" + ), + ): + validate_a2ui_json(payload, schema) + + +def test_validate_a2ui_json_circular_reference(schema): + payload = { + "components": [ + { + "id": "root", + "componentProperties": {"Container": {"children": ["child1"]}}, + }, + { + "id": "child1", + "componentProperties": {"Container": {"children": ["root"]}}, + }, + ] + } + with pytest.raises( + ValueError, match="Circular reference detected involving component" + ): + validate_a2ui_json(payload, schema) + + +def test_validate_a2ui_json_orphaned_component(schema): + payload = { + "components": [ + {"id": "root", "componentProperties": {"Container": {"children": []}}}, + {"id": "orphan", "componentProperties": {}}, + ] + } + with pytest.raises( + ValueError, + match=r"Orphaned components detected \(not reachable from 'root'\): \['orphan'\]", + ): + validate_a2ui_json(payload, schema) + + +def test_validate_a2ui_json_valid_topology_complex(schema): + """Test a valid topology with multiple levels.""" + payload = { + "components": [ + { + "id": "root", + "componentProperties": {"Container": {"children": ["child1", "child2"]}}, + }, + {"id": "child1", "componentProperties": {"Text": {"text": "Hello"}}}, + { + "id": "child2", + "componentProperties": {"Container": {"children": ["child3"]}}, + }, + {"id": "child3", "componentProperties": {"Text": {"text": "World"}}}, + ] + } + validate_a2ui_json(payload, schema) + + +def test_validate_recursion_limit_exceeded(schema): + """Test that recursion depth > 5 raises ValueError.""" + # Construct deep function call + args = {} + current = args + for i in range(5): # Depth 0 to 5 (6 levels) + current["arg"] = {"call": f"fn{i}", "args": {}} + current = current["arg"]["args"] + + payload = { + "components": [{ + "id": "root", + "componentProperties": { + "Button": { + "label": "Click me", + "action": {"functionCall": {"call": "fn_top", "args": args}}, + } + }, + }] + } + with pytest.raises(ValueError, match="Recursion limit exceeded"): + validate_a2ui_json(payload, schema) + + +def test_validate_recursion_limit_valid(schema): + """Test that recursion depth <= 5 is allowed.""" + # Construct max depth function call (Depth 5) + args = {} + current = args + for i in range(4): # Depth 0 to 4 (5 levels) + current["arg"] = {"call": f"fn{i}", "args": {}} + current = current["arg"]["args"] + + payload = { + "components": [{ + "id": "root", + "componentProperties": { + "Button": { + "label": "Click me", + "action": {"functionCall": {"call": "fn_top", "args": args}}, + } + }, + }] + } + validate_a2ui_json(payload, schema) + + +@pytest.mark.parametrize( + "payload", + [ + { + "updateDataModel": { + "surfaceId": "surface1", + "path": "invalid//path", + "value": "data", + } + }, + { + "components": [{ + "id": "root", + "componentProperties": { + "Text": {"text": {"path": "invalid path with spaces"}} + }, + }] + }, + { + "updateDataModel": { + "surfaceId": "surface1", + "path": "/invalid/escape/~2", + "value": "data", + } + }, + ], +) +def test_validate_invalid_paths(schema, payload): + """Test various invalid paths (JSON Pointer syntax).""" + with pytest.raises(ValueError, match="Invalid JSON Pointer syntax"): + validate_a2ui_json(payload, schema) + + +def test_validate_global_recursion_limit_exceeded(schema): + """Test that global recursion depth > 50 raises ValueError.""" + # Create a deeply nested dictionary + deep_payload = {"level": 0} + current = deep_payload + for i in range(55): + current["next"] = {"level": i + 1} + current = current["next"] + + with pytest.raises(ValueError, match="Global recursion limit exceeded"): + validate_a2ui_json(deep_payload, schema) + + +def test_validate_custom_schema_reference(): + """Test validation with a custom schema where a component has a non-standard reference field.""" + # Custom schema extending the base one + custom_schema = { + "type": "object", + "$defs": { + "ComponentId": {"type": "string"}, + "ChildList": {"type": "array", "items": {"$ref": "#/$defs/ComponentId"}}, + }, + "properties": { + "components": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"$ref": "#/$defs/ComponentId"}, + "componentProperties": { + "type": "object", + "properties": { + "CustomLink": { + "type": "object", + "properties": { + "linkedComponentId": { + "$ref": "#/$defs/ComponentId" + } + }, + } + }, + }, + }, + "required": ["id"], + }, + } + }, + } + + payload = { + "components": [{ + "id": "root", + "componentProperties": { + "CustomLink": {"linkedComponentId": "missing_target"} + }, + }] + } + + with pytest.raises( + ValueError, + match=( + "Component 'root' references missing ID 'missing_target' in field" + " 'linkedComponentId'" + ), + ): + validate_a2ui_json(payload, custom_schema)