From ee479353aca4b5e5090ea9d848b51fe1202d6b99 Mon Sep 17 00:00:00 2001 From: Eyobyb Date: Fri, 7 Nov 2025 11:21:51 +0300 Subject: [PATCH 1/3] Implement LLM provider detection and response formatting in PromptTemplate class - Added methods to detect LLM provider based on the _llm_type attribute and class name. - Implemented provider-specific response formatting for OpenAI, Anthropic, Google, and Cohere. - Enhanced the format_response_format_for_provider method to return appropriate formats based on detected provider. - Updated tests to cover new functionality, including provider detection and response formatting for various LLMs. --- .../prompts/prompt_template_loader.py | 363 +++++++++++++++++- .../prompts/test_prompt_template_loader.py | 329 +++++++++++++++- 2 files changed, 681 insertions(+), 11 deletions(-) diff --git a/src/sherpa_ai/prompts/prompt_template_loader.py b/src/sherpa_ai/prompts/prompt_template_loader.py index a2c64489..f3719b43 100644 --- a/src/sherpa_ai/prompts/prompt_template_loader.py +++ b/src/sherpa_ai/prompts/prompt_template_loader.py @@ -5,10 +5,12 @@ substitution capabilities for different prompt types. """ -from typing import Dict, List, Optional, Union, Any +from typing import Dict, List, Optional, Union, Any, Type from sherpa_ai.prompts.Base import ChatPromptVersion, TextPromptVersion, JsonPromptVersion from sherpa_ai.prompts.prompt_loader import PromptLoader import copy +from pydantic import BaseModel, create_model +from langchain_core.language_models import BaseChatModel class PromptTemplate(PromptLoader): """Template loader and formatter for prompts. @@ -37,6 +39,144 @@ def __init__(self, json_file_path: str): """ super().__init__(json_file_path) + def _detect_provider(self, llm: Optional[BaseChatModel]) -> Optional[str]: + """Detect the LLM provider from an LLM instance. + + This method identifies the provider by checking the _llm_type attribute + and class name as fallback. + + Args: + llm (Optional[BaseChatModel]): The LLM instance to detect. + + Returns: + Optional[str]: Provider name ('openai', 'anthropic', 'google', 'cohere') + or None if provider cannot be detected. + + Example: + >>> template = PromptTemplate("prompts.json") + >>> llm = ChatOpenAI() + >>> provider = template._detect_provider(llm) + >>> print(provider) + 'openai' + """ + if llm is None: + return None + + # Check for wrapped models (e.g., ChatModelWithLogging) + actual_llm = llm + if hasattr(llm, 'llm'): + actual_llm = llm.llm + + # Try to get provider from _llm_type attribute + if hasattr(actual_llm, '_llm_type'): + llm_type = actual_llm._llm_type.lower() + if 'openai' in llm_type: + return 'openai' + elif 'anthropic' in llm_type or 'claude' in llm_type or 'bedrock' in llm_type: + return 'anthropic' + elif 'google' in llm_type or 'gemini' in llm_type or 'vertexai' in llm_type: + return 'google' + elif 'cohere' in llm_type: + return 'cohere' + + # Fallback to class name inspection + class_name = actual_llm.__class__.__name__.lower() + if 'openai' in class_name: + return 'openai' + elif 'anthropic' in class_name or 'claude' in class_name: + return 'anthropic' + elif 'google' in class_name or 'gemini' in class_name or 'vertex' in class_name: + return 'google' + elif 'cohere' in class_name: + return 'cohere' + + return None + + def _pydantic_to_json_schema(self, pydantic_model: Type[BaseModel]) -> Dict[str, Any]: + """Convert a Pydantic model to JSON Schema. + + This method uses Pydantic's built-in model_json_schema() method to + convert a Pydantic model to JSON Schema format. + + Args: + pydantic_model (Type[BaseModel]): The Pydantic model class to convert. + + Returns: + Dict[str, Any]: JSON Schema representation of the Pydantic model. + + Example: + >>> from pydantic import BaseModel + >>> class User(BaseModel): + ... name: str + ... age: int + >>> template = PromptTemplate("prompts.json") + >>> schema = template._pydantic_to_json_schema(User) + >>> print(schema['properties']['name']['type']) + 'string' + """ + if not issubclass(pydantic_model, BaseModel): + raise ValueError(f"Expected a Pydantic BaseModel, got {type(pydantic_model)}") + + return pydantic_model.model_json_schema() + + def _json_schema_to_pydantic(self, json_schema: Dict[str, Any], model_name: str = "DynamicModel") -> Type[BaseModel]: + """Convert a JSON Schema to a Pydantic model. + + This method dynamically creates a Pydantic model from a JSON Schema. + Note: This is a simplified implementation and may not handle all + JSON Schema features. + + Args: + json_schema (Dict[str, Any]): The JSON Schema to convert. + model_name (str): Name for the generated Pydantic model class. + + Returns: + Type[BaseModel]: A dynamically created Pydantic model class. + + Example: + >>> template = PromptTemplate("prompts.json") + >>> schema = { + ... "type": "object", + ... "properties": { + ... "name": {"type": "string"}, + ... "age": {"type": "integer"} + ... }, + ... "required": ["name"] + ... } + >>> Model = template._json_schema_to_pydantic(schema, "User") + >>> instance = Model(name="Alice", age=30) + """ + if json_schema.get("type") != "object": + raise ValueError("JSON Schema must be of type 'object' to convert to Pydantic model") + + properties = json_schema.get("properties", {}) + required = set(json_schema.get("required", [])) + + # Map JSON Schema types to Python types + type_mapping = { + "string": str, + "integer": int, + "number": float, + "boolean": bool, + "array": list, + "object": dict, + } + + field_definitions = {} + for field_name, field_schema in properties.items(): + field_type = field_schema.get("type", "string") + python_type = type_mapping.get(field_type, str) + + # Handle Optional fields - in Pydantic v2, optional fields need a default value + if field_name not in required: + # Use None as default for optional fields + field_definitions[field_name] = (Optional[python_type], None) + else: + # Required fields use Ellipsis (...) + field_definitions[field_name] = (python_type, ...) + + return create_model(model_name, **field_definitions) + def format_prompt( self, prompt_parent_id: str, @@ -139,6 +279,191 @@ def replace_in_dict(data: Dict) -> Dict: else: raise ValueError(f"Unknown prompt version type: {type(prompt_version_obj)}") + def _format_for_openai(self, json_schema: Dict[str, Any], schema_name: str = "ResponseModel") -> Type[BaseModel]: + """Format response format for OpenAI/Azure OpenAI. + + OpenAI accepts Pydantic models directly. This method converts JSON Schema + to a Pydantic model. + + Args: + json_schema (Dict[str, Any]): JSON Schema to convert. + schema_name (str): Name for the generated Pydantic model. + + Returns: + Type[BaseModel]: Pydantic model class for OpenAI. + """ + # Extract schema if nested in json_schema structure + if "json_schema" in json_schema and "schema" in json_schema["json_schema"]: + schema = json_schema["json_schema"]["schema"] + if "name" in json_schema["json_schema"]: + schema_name = json_schema["json_schema"]["name"] + elif "schema" in json_schema: + schema = json_schema["schema"] + else: + schema = json_schema + + return self._json_schema_to_pydantic(schema, schema_name) + + def _format_for_anthropic(self, json_schema: Dict[str, Any], schema_name: str = "response") -> Dict[str, Any]: + """Format response format for Anthropic Claude. + + Anthropic uses JSON Schema in tool definitions with tool_choice. + + Args: + json_schema (Dict[str, Any]): JSON Schema to format. + schema_name (str): Name for the tool function. + + Returns: + Dict[str, Any]: Tool definition format for Anthropic. + """ + # Extract schema if nested + if "json_schema" in json_schema and "schema" in json_schema["json_schema"]: + schema = json_schema["json_schema"]["schema"] + if "name" in json_schema["json_schema"]: + schema_name = json_schema["json_schema"]["name"] + elif "schema" in json_schema: + schema = json_schema["schema"] + else: + schema = json_schema + + return { + "tools": [ + { + "type": "function", + "name": schema_name, + "input_schema": schema + } + ], + "tool_choice": { + "type": "tool", + "name": schema_name + } + } + + def _format_for_google(self, json_schema: Dict[str, Any]) -> Dict[str, Any]: + """Format response format for Google Gemini. + + Google Gemini uses JSON Schema via responseSchema parameter. + Uses OpenAPI 3.0 Schema subset. + + Args: + json_schema (Dict[str, Any]): JSON Schema to format. + + Returns: + Dict[str, Any]: responseSchema format for Google Gemini. + """ + # Extract schema if nested + if "json_schema" in json_schema and "schema" in json_schema["json_schema"]: + schema = json_schema["json_schema"]["schema"] + elif "schema" in json_schema: + schema = json_schema["schema"] + else: + schema = json_schema + + # Google Gemini expects OpenAPI 3.0 Schema format + # For most cases, the JSON Schema is compatible + return { + "responseSchema": schema + } + + def _format_for_cohere(self, json_schema: Dict[str, Any]) -> Dict[str, Any]: + """Format response format for Cohere. + + Cohere uses JSON Schema via response_format parameter. + + Args: + json_schema (Dict[str, Any]): JSON Schema to format. + + Returns: + Dict[str, Any]: response_format format for Cohere. + """ + # Extract schema if nested + if "json_schema" in json_schema and "schema" in json_schema["json_schema"]: + schema = json_schema["json_schema"]["schema"] + elif "schema" in json_schema: + schema = json_schema["schema"] + else: + schema = json_schema + + return { + "response_format": schema + } + + def format_response_format_for_provider( + self, + prompt_parent_id: str, + prompt_id: str, + version: str, + llm: Optional[BaseChatModel], + variables: Optional[Dict[str, Union[str, int, float, List]]] = None + ) -> Optional[Union[Dict[str, Any], Type[BaseModel]]]: + """Format response format schema for a specific LLM provider. + + This method detects the provider from the LLM instance and returns + the response format in the appropriate format for that provider. + + Args: + prompt_parent_id (str): Name of the wrapper containing the prompt. + prompt_id (str): ID of the prompt to format. + version (str): Version of the prompt to format. + llm (Optional[BaseChatModel]): LLM instance to detect provider from. + variables (Optional[Dict[str, Union[str, int, float, List]]]): Values to + substitute in the schema. If None, uses defaults from JSON. + + Returns: + Optional[Union[Dict[str, Any], Type[BaseModel]]]: + Provider-specific format: + - OpenAI: Pydantic model class + - Anthropic: Tool definition dict + - Google: responseSchema dict + - Cohere: response_format dict + - None/Unknown: JSON Schema dict (fallback) + + Example: + >>> template = PromptTemplate("prompts.json") + >>> llm = ChatOpenAI() + >>> formatted = template.format_response_format_for_provider( + ... prompt_parent_id="addition_prompts", + ... prompt_id="add_numbers_text", + ... version="1.0", + ... llm=llm + ... ) + >>> # Returns Pydantic model for OpenAI + """ + # First, get the formatted JSON Schema + json_schema = self.format_response_format(prompt_parent_id, prompt_id, version, variables) + if json_schema is None: + return None + + # Detect provider + provider = self._detect_provider(llm) + if provider is None: + # Fallback to JSON Schema if provider cannot be detected + return json_schema + + # Get schema name if available + schema_name = "ResponseModel" + if "json_schema" in json_schema and "name" in json_schema["json_schema"]: + schema_name = json_schema["json_schema"]["name"] + + # Format based on provider + try: + if provider == "openai": + return self._format_for_openai(json_schema, schema_name) + elif provider == "anthropic": + return self._format_for_anthropic(json_schema, schema_name) + elif provider == "google": + return self._format_for_google(json_schema) + elif provider == "cohere": + return self._format_for_cohere(json_schema) + else: + # Unknown provider, return JSON Schema as fallback + return json_schema + except (ValueError, KeyError) as e: + # If conversion fails, return original JSON Schema + # This maintains backward compatibility + return json_schema + def format_response_format( self, prompt_parent_id: str, @@ -228,8 +553,9 @@ def get_full_formatted_prompt( prompt_parent_id: str, prompt_id: str, version: str, - variables: Optional[Dict[str, Union[str, int, float, List]]] = None - ) -> Optional[Dict[str, Union[str, List[Dict[str, str]], Dict]]]: + variables: Optional[Dict[str, Union[str, int, float, List]]] = None, + llm: Optional[BaseChatModel] = None + ) -> Optional[Dict[str, Union[str, List[Dict[str, str]], Dict, Type[BaseModel]]]]: """Get a formatted prompt with metadata. This method formats a prompt and returns it along with its description @@ -242,11 +568,15 @@ def get_full_formatted_prompt( version (str): Version of the prompt to format. variables (Optional[Dict[str, Union[str, int, float, List]]]): Values to substitute in the prompt. If None, uses defaults from JSON. + llm (Optional[BaseChatModel]): Optional LLM instance for provider-specific + schema formatting. If provided, output_schema will be formatted for + the detected provider. Returns: - Optional[Dict[str, Union[str, List[Dict[str, str]], Dict]]]: + Optional[Dict[str, Union[str, List[Dict[str, str]], Dict, Type[BaseModel]]]]: Dictionary containing formatted content, description, and schema, - or None if prompt not found. + or None if prompt not found. The output_schema will be provider-specific + if llm is provided, otherwise it will be JSON Schema. Example: >>> template = PromptTemplate("prompts.json") @@ -258,6 +588,15 @@ def get_full_formatted_prompt( ... ) >>> print(result["description"]) 'Search query template' + >>> # With provider-specific formatting: + >>> llm = ChatOpenAI() + >>> result = template.get_full_formatted_prompt( + ... prompt_parent_id="addition_prompts", + ... prompt_id="add_numbers_text", + ... version="1.0", + ... llm=llm + ... ) + >>> # output_schema will be a Pydantic model for OpenAI """ target_prompt = None for pg in self.prompts: @@ -281,10 +620,16 @@ def get_full_formatted_prompt( if not formatted_content: return None - # Format the response format schema as well - formatted_response_format = self.format_response_format( - prompt_parent_id, prompt_id, version, variables - ) + # Format the response format schema + # Use provider-specific formatting if llm is provided + if llm is not None: + formatted_response_format = self.format_response_format_for_provider( + prompt_parent_id, prompt_id, version, llm, variables + ) + else: + formatted_response_format = self.format_response_format( + prompt_parent_id, prompt_id, version, variables + ) return { "description": target_prompt.description, diff --git a/src/tests/unit_tests/prompts/test_prompt_template_loader.py b/src/tests/unit_tests/prompts/test_prompt_template_loader.py index 63258624..22dcf048 100644 --- a/src/tests/unit_tests/prompts/test_prompt_template_loader.py +++ b/src/tests/unit_tests/prompts/test_prompt_template_loader.py @@ -1,5 +1,7 @@ -from unittest.mock import patch, Mock +from unittest.mock import patch, Mock, MagicMock from sherpa_ai.prompts.prompt_template_loader import PromptTemplate +from pydantic import BaseModel +from langchain_core.language_models import BaseChatModel import json @@ -375,4 +377,327 @@ def test_dynamic_enum_in_response_schema(): # Content includes variable substitutions; list prints as Python list string assert "A: 10" in full_result["content"] assert "B: 20" in full_result["content"] - assert "['csv', 'yaml']" in full_result["content"] \ No newline at end of file + assert "['csv', 'yaml']" in full_result["content"] + + +# Tests for provider-specific formatting +@patch('sherpa_ai.prompts.prompt_loader.load_json') +def test_detect_provider_openai(mock_load_json): + mock_load_json.return_value = mock_json_data + template = PromptTemplate("./tests/data/prompts.json") + + # Mock OpenAI LLM + mock_llm = MagicMock(spec=BaseChatModel) + mock_llm._llm_type = "openai" + + provider = template._detect_provider(mock_llm) + assert provider == "openai" + + +@patch('sherpa_ai.prompts.prompt_loader.load_json') +def test_detect_provider_anthropic(mock_load_json): + mock_load_json.return_value = mock_json_data + template = PromptTemplate("./tests/data/prompts.json") + + # Mock Anthropic LLM + mock_llm = MagicMock(spec=BaseChatModel) + mock_llm._llm_type = "anthropic" + + provider = template._detect_provider(mock_llm) + assert provider == "anthropic" + + +@patch('sherpa_ai.prompts.prompt_loader.load_json') +def test_detect_provider_google(mock_load_json): + mock_load_json.return_value = mock_json_data + template = PromptTemplate("./tests/data/prompts.json") + + # Mock Google LLM + mock_llm = MagicMock(spec=BaseChatModel) + mock_llm._llm_type = "google" + + provider = template._detect_provider(mock_llm) + assert provider == "google" + + +@patch('sherpa_ai.prompts.prompt_loader.load_json') +def test_detect_provider_cohere(mock_load_json): + mock_load_json.return_value = mock_json_data + template = PromptTemplate("./tests/data/prompts.json") + + # Mock Cohere LLM + mock_llm = MagicMock(spec=BaseChatModel) + mock_llm._llm_type = "cohere" + + provider = template._detect_provider(mock_llm) + assert provider == "cohere" + + +@patch('sherpa_ai.prompts.prompt_loader.load_json') +def test_detect_provider_wrapped_model(mock_load_json): + mock_load_json.return_value = mock_json_data + template = PromptTemplate("./tests/data/prompts.json") + + # Mock wrapped LLM (e.g., ChatModelWithLogging) + mock_inner_llm = MagicMock(spec=BaseChatModel) + mock_inner_llm._llm_type = "openai" + mock_llm = MagicMock(spec=BaseChatModel) + mock_llm.llm = mock_inner_llm + + provider = template._detect_provider(mock_llm) + assert provider == "openai" + + +@patch('sherpa_ai.prompts.prompt_loader.load_json') +def test_detect_provider_unknown(mock_load_json): + mock_load_json.return_value = mock_json_data + template = PromptTemplate("./tests/data/prompts.json") + + # Mock unknown LLM - _llm_type doesn't match any known provider + mock_llm = MagicMock(spec=BaseChatModel) + mock_llm._llm_type = "unknown_provider" + # Make sure class name also doesn't match + type(mock_llm).__name__ = "UnknownModel" + + provider = template._detect_provider(mock_llm) + assert provider is None + + +@patch('sherpa_ai.prompts.prompt_loader.load_json') +def test_format_response_format_for_provider_openai(mock_load_json): + mock_load_json.return_value = mock_json_data + template = PromptTemplate("./tests/data/prompts.json") + + # Mock OpenAI LLM + mock_llm = MagicMock(spec=BaseChatModel) + mock_llm._llm_type = "openai" + + formatted = template.format_response_format_for_provider( + prompt_parent_id="addition_prompts", + prompt_id="add_numbers_text", + version="1.0", + llm=mock_llm + ) + + # Should return a Pydantic model class for OpenAI + assert formatted is not None + assert issubclass(formatted, BaseModel) + # Verify the model has the expected fields + assert hasattr(formatted, 'model_fields') + assert 'result' in formatted.model_fields + assert 'explanation' in formatted.model_fields + + +@patch('sherpa_ai.prompts.prompt_loader.load_json') +def test_format_response_format_for_provider_anthropic(mock_load_json): + mock_load_json.return_value = mock_json_data + template = PromptTemplate("./tests/data/prompts.json") + + # Mock Anthropic LLM + mock_llm = MagicMock(spec=BaseChatModel) + mock_llm._llm_type = "anthropic" + + formatted = template.format_response_format_for_provider( + prompt_parent_id="addition_prompts", + prompt_id="add_numbers_text", + version="1.0", + llm=mock_llm + ) + + # Should return tool definition format for Anthropic + assert formatted is not None + assert isinstance(formatted, dict) + assert "tools" in formatted + assert "tool_choice" in formatted + assert len(formatted["tools"]) == 1 + assert formatted["tools"][0]["type"] == "function" + assert formatted["tools"][0]["name"] == "addition_result" + assert "input_schema" in formatted["tools"][0] + assert formatted["tool_choice"]["type"] == "tool" + assert formatted["tool_choice"]["name"] == "addition_result" + + +@patch('sherpa_ai.prompts.prompt_loader.load_json') +def test_format_response_format_for_provider_google(mock_load_json): + mock_load_json.return_value = mock_json_data + template = PromptTemplate("./tests/data/prompts.json") + + # Mock Google LLM + mock_llm = MagicMock(spec=BaseChatModel) + mock_llm._llm_type = "google" + + formatted = template.format_response_format_for_provider( + prompt_parent_id="addition_prompts", + prompt_id="add_numbers_text", + version="1.0", + llm=mock_llm + ) + + # Should return responseSchema format for Google + assert formatted is not None + assert isinstance(formatted, dict) + assert "responseSchema" in formatted + assert formatted["responseSchema"]["type"] == "object" + assert "properties" in formatted["responseSchema"] + + +@patch('sherpa_ai.prompts.prompt_loader.load_json') +def test_format_response_format_for_provider_cohere(mock_load_json): + mock_load_json.return_value = mock_json_data + template = PromptTemplate("./tests/data/prompts.json") + + # Mock Cohere LLM + mock_llm = MagicMock(spec=BaseChatModel) + mock_llm._llm_type = "cohere" + + formatted = template.format_response_format_for_provider( + prompt_parent_id="addition_prompts", + prompt_id="add_numbers_text", + version="1.0", + llm=mock_llm + ) + + # Should return response_format format for Cohere + assert formatted is not None + assert isinstance(formatted, dict) + assert "response_format" in formatted + assert formatted["response_format"]["type"] == "object" + assert "properties" in formatted["response_format"] + + +@patch('sherpa_ai.prompts.prompt_loader.load_json') +def test_format_response_format_for_provider_unknown(mock_load_json): + mock_load_json.return_value = mock_json_data + template = PromptTemplate("./tests/data/prompts.json") + + # Mock unknown LLM - _llm_type doesn't match any known provider + mock_llm = MagicMock(spec=BaseChatModel) + mock_llm._llm_type = "unknown_provider" + # Make sure class name also doesn't match + type(mock_llm).__name__ = "UnknownModel" + + formatted = template.format_response_format_for_provider( + prompt_parent_id="addition_prompts", + prompt_id="add_numbers_text", + version="1.0", + llm=mock_llm + ) + + # Should fallback to JSON Schema for unknown provider + assert formatted is not None + assert isinstance(formatted, dict) + assert "type" in formatted or "json_schema" in formatted + + +@patch('sherpa_ai.prompts.prompt_loader.load_json') +def test_format_response_format_for_provider_none_llm(mock_load_json): + mock_load_json.return_value = mock_json_data + template = PromptTemplate("./tests/data/prompts.json") + + formatted = template.format_response_format_for_provider( + prompt_parent_id="addition_prompts", + prompt_id="add_numbers_text", + version="1.0", + llm=None + ) + + # Should fallback to JSON Schema when llm is None + assert formatted is not None + assert isinstance(formatted, dict) + + +@patch('sherpa_ai.prompts.prompt_loader.load_json') +def test_get_full_formatted_prompt_with_provider_openai(mock_load_json): + mock_load_json.return_value = mock_json_data + template = PromptTemplate("./tests/data/prompts.json") + + # Mock OpenAI LLM + mock_llm = MagicMock(spec=BaseChatModel) + mock_llm._llm_type = "openai" + + result = template.get_full_formatted_prompt( + prompt_parent_id="addition_prompts", + prompt_id="add_numbers_text", + version="1.0", + llm=mock_llm + ) + + assert result is not None + assert "description" in result + assert "content" in result + assert "output_schema" in result + # For OpenAI, output_schema should be a Pydantic model + assert issubclass(result["output_schema"], BaseModel) + + +@patch('sherpa_ai.prompts.prompt_loader.load_json') +def test_get_full_formatted_prompt_with_provider_anthropic(mock_load_json): + mock_load_json.return_value = mock_json_data + template = PromptTemplate("./tests/data/prompts.json") + + # Mock Anthropic LLM + mock_llm = MagicMock(spec=BaseChatModel) + mock_llm._llm_type = "anthropic" + + result = template.get_full_formatted_prompt( + prompt_parent_id="addition_prompts", + prompt_id="add_numbers_text", + version="1.0", + llm=mock_llm + ) + + assert result is not None + assert "description" in result + assert "content" in result + assert "output_schema" in result + # For Anthropic, output_schema should be tool definition format + assert isinstance(result["output_schema"], dict) + assert "tools" in result["output_schema"] + + +@patch('sherpa_ai.prompts.prompt_loader.load_json') +def test_pydantic_to_json_schema(mock_load_json): + mock_load_json.return_value = mock_json_data + template = PromptTemplate("./tests/data/prompts.json") + + # Create a simple Pydantic model + class TestModel(BaseModel): + name: str + age: int + + schema = template._pydantic_to_json_schema(TestModel) + + assert schema is not None + assert isinstance(schema, dict) + assert "properties" in schema + assert "name" in schema["properties"] + assert "age" in schema["properties"] + assert schema["properties"]["name"]["type"] == "string" + assert schema["properties"]["age"]["type"] == "integer" + + +@patch('sherpa_ai.prompts.prompt_loader.load_json') +def test_json_schema_to_pydantic(mock_load_json): + mock_load_json.return_value = mock_json_data + template = PromptTemplate("./tests/data/prompts.json") + + json_schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + }, + "required": ["name"] + } + + Model = template._json_schema_to_pydantic(json_schema, "TestModel") + + assert Model is not None + assert issubclass(Model, BaseModel) + # Test creating an instance + instance = Model(name="Alice", age=30) + assert instance.name == "Alice" + assert instance.age == 30 + # Test optional field + instance2 = Model(name="Bob") + assert instance2.name == "Bob" \ No newline at end of file From 39c0088354fe823161042c99c046b74c54821353 Mon Sep 17 00:00:00 2001 From: Eyobyb Date: Fri, 7 Nov 2025 11:22:06 +0300 Subject: [PATCH 2/3] Refactor optional field handling in PromptTemplate class - Simplified the handling of optional fields by removing the explicit mention of default values in comments. - Updated comments for clarity regarding required and optional fields in the Pydantic model definition. --- src/sherpa_ai/prompts/prompt_template_loader.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/sherpa_ai/prompts/prompt_template_loader.py b/src/sherpa_ai/prompts/prompt_template_loader.py index f3719b43..170cb4bd 100644 --- a/src/sherpa_ai/prompts/prompt_template_loader.py +++ b/src/sherpa_ai/prompts/prompt_template_loader.py @@ -167,12 +167,10 @@ def _json_schema_to_pydantic(self, json_schema: Dict[str, Any], model_name: str field_type = field_schema.get("type", "string") python_type = type_mapping.get(field_type, str) - # Handle Optional fields - in Pydantic v2, optional fields need a default value + # Handle Optional fields if field_name not in required: - # Use None as default for optional fields field_definitions[field_name] = (Optional[python_type], None) else: - # Required fields use Ellipsis (...) field_definitions[field_name] = (python_type, ...) return create_model(model_name, **field_definitions) From bdc94a2643d5c9c2c5ef0824f683df183afff0b8 Mon Sep 17 00:00:00 2001 From: Eyobyb Date: Wed, 12 Nov 2025 12:56:47 +0300 Subject: [PATCH 3/3] Fix formatting in test_prompt_template_loader.py by adding a newline at the end of the file --- src/tests/unit_tests/prompts/test_prompt_template_loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tests/unit_tests/prompts/test_prompt_template_loader.py b/src/tests/unit_tests/prompts/test_prompt_template_loader.py index 22dcf048..049c5301 100644 --- a/src/tests/unit_tests/prompts/test_prompt_template_loader.py +++ b/src/tests/unit_tests/prompts/test_prompt_template_loader.py @@ -700,4 +700,5 @@ def test_json_schema_to_pydantic(mock_load_json): assert instance.age == 30 # Test optional field instance2 = Model(name="Bob") - assert instance2.name == "Bob" \ No newline at end of file + assert instance2.name == "Bob" + \ No newline at end of file