diff --git a/fastapi_mcp/openapi/utils.py b/fastapi_mcp/openapi/utils.py index 1821d57..6edcec2 100644 --- a/fastapi_mcp/openapi/utils.py +++ b/fastapi_mcp/openapi/utils.py @@ -16,42 +16,60 @@ def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str: return param_schema.get("type", "string") -def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dict[str, Any]) -> Dict[str, Any]: +def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dict[str, Any], visited_refs: set = None) -> Dict[str, Any]: """ - Resolve schema references in OpenAPI schemas. + Resolve schema references in OpenAPI schemas with circular reference protection. Args: schema_part: The part of the schema being processed that may contain references reference_schema: The complete schema used to resolve references from + visited_refs: Set of already visited references to prevent circular recursion Returns: The schema with references resolved """ + # Initialize visited_refs on first call + if visited_refs is None: + visited_refs = set() + # Make a copy to avoid modifying the input schema schema_part = schema_part.copy() # Handle $ref directly in the schema if "$ref" in schema_part: ref_path = schema_part["$ref"] + + # Check if we've already visited this reference (circular reference detection) + if ref_path in visited_refs: + # Return just the reference without resolving to prevent infinite recursion + return {"$ref": ref_path} + # Standard OpenAPI references are in the format "#/components/schemas/ModelName" if ref_path.startswith("#/components/schemas/"): model_name = ref_path.split("/")[-1] if "components" in reference_schema and "schemas" in reference_schema["components"]: if model_name in reference_schema["components"]["schemas"]: + # Mark this reference as visited + visited_refs.add(ref_path) + # Replace with the resolved schema ref_schema = reference_schema["components"]["schemas"][model_name].copy() # Remove the $ref key and merge with the original schema schema_part.pop("$ref") schema_part.update(ref_schema) + + # Continue resolving with the updated visited set + # but don't add to visited_refs permanently - remove after processing + visited_refs_copy = visited_refs.copy() # Recursively resolve references in all dictionary values for key, value in schema_part.items(): if isinstance(value, dict): - schema_part[key] = resolve_schema_references(value, reference_schema) + schema_part[key] = resolve_schema_references(value, reference_schema, visited_refs) elif isinstance(value, list): # Only process list items that are dictionaries since only they can contain refs schema_part[key] = [ - resolve_schema_references(item, reference_schema) if isinstance(item, dict) else item for item in value + resolve_schema_references(item, reference_schema, visited_refs) if isinstance(item, dict) else item for item in value ] return schema_part