diff --git a/DSL/Resql/rag-search/POST/mock-count-active-services.sql b/DSL/Resql/rag-search/POST/mock-count-active-services.sql new file mode 100644 index 0000000..d68d273 --- /dev/null +++ b/DSL/Resql/rag-search/POST/mock-count-active-services.sql @@ -0,0 +1,11 @@ +-- Count active services for tool classifier +-- Used by Service Workflow to determine search strategy: +-- - If count <= 50: Use all services for LLM context +-- - If count > 50: Use Qdrant semantic search for top 20 + +SELECT + COUNT(*) AS active_service_count +FROM + public.services +WHERE + current_state = 'active'; diff --git a/DSL/Resql/rag-search/POST/mock-get-all-active-services.sql b/DSL/Resql/rag-search/POST/mock-get-all-active-services.sql new file mode 100644 index 0000000..5bd981b --- /dev/null +++ b/DSL/Resql/rag-search/POST/mock-get-all-active-services.sql @@ -0,0 +1,20 @@ +-- Get all active services for intent detection +-- Used when active_service_count <= 50 +-- Returns all service metadata needed for LLM intent detection + +SELECT + service_id, + name, + description, + ruuter_type, + slot, + entities, + examples, + structure, + endpoints +FROM + public.services +WHERE + current_state = 'active' +ORDER BY + name ASC; diff --git a/DSL/Resql/rag-search/POST/mock-get-service-by-id.sql b/DSL/Resql/rag-search/POST/mock-get-service-by-id.sql new file mode 100644 index 0000000..dbf375a --- /dev/null +++ b/DSL/Resql/rag-search/POST/mock-get-service-by-id.sql @@ -0,0 +1,24 @@ +-- Get specific service by service_id for validation +-- Used after LLM detects intent to validate the service exists and is active +-- Returns all service details needed to trigger the external service call + +SELECT + id, + service_id, + name, + description, + ruuter_type, + current_state, + is_common, + slot, + entities, + examples, + structure, + endpoints, + created_at, + updated_at +FROM + public.services +WHERE + service_id = :serviceId + AND current_state = 'active'; diff --git a/DSL/Ruuter.public/rag-search/GET/services/get-services.yml b/DSL/Ruuter.public/rag-search/GET/services/get-services.yml new file mode 100644 index 0000000..01356d9 --- /dev/null +++ b/DSL/Ruuter.public/rag-search/GET/services/get-services.yml @@ -0,0 +1,60 @@ +declaration: + call: declare + version: 0.1 + description: "Get services for intent detection - returns all services if count <= 10, otherwise signals to use semantic search" + method: get + returns: json + namespace: rag-search + +# Step 1: Count active services +count_services: + call: http.post + args: + url: "[#RAG_SEARCH_RESQL]/mock-count-active-services" + body: {} + result: count_result + next: check_service_count + +# Step 2: Check if count > threshold (10) +check_service_count: + assign: + service_count: ${Number(count_result.response.body[0].active_service_count)} + switch: + - condition: "${service_count > 10}" + next: return_semantic_search_flag + next: fetch_all_services + +# Step 3a: If > 10, return flag for semantic search +return_semantic_search_flag: + assign: + semantic_search_response: + use_semantic_search: true + service_count: ${service_count} + message: "Service count exceeds threshold - use semantic search" + next: return_semantic_search_response + +return_semantic_search_response: + return: ${semantic_search_response} + next: end + +# Step 3b: If <= 10, fetch all services +fetch_all_services: + call: http.post + args: + url: "[#RAG_SEARCH_RESQL]/mock-get-all-active-services" + body: {} + result: services_result + next: return_all_services + +# Step 4: Return all services for LLM +return_all_services: + assign: + all_services_response: + use_semantic_search: false + service_count: ${services_result.response.body.length} + services: ${services_result.response.body} + next: return_all_services_response + +return_all_services_response: + return: ${all_services_response} + next: end diff --git a/docs/TOOL_CLASSIFIER_AND_SERVICE_WORKFLOW.md b/docs/TOOL_CLASSIFIER_AND_SERVICE_WORKFLOW.md new file mode 100644 index 0000000..15669e4 --- /dev/null +++ b/docs/TOOL_CLASSIFIER_AND_SERVICE_WORKFLOW.md @@ -0,0 +1,660 @@ +# Tool Classifier and Service Workflow Architecture + +## Overview + +The Tool Classifier implements a **layer-wise fallback architecture** that routes user queries to the most appropriate workflow handler. The Service Workflow (Layer 1) handles external API/service calls with intelligent intent detection and entity extraction. + +--- + +## Tool Classifier - Layer Architecture + +### Design Pattern: Chain of Responsibility + +The classifier tries each layer sequentially. If a layer returns `None`, it falls back to the next layer: + +``` +Layer 1: SERVICE → External API calls (currency, weather, etc.) +Layer 2: CONTEXT → Greetings, conversation history queries +Layer 3: RAG → Knowledge base retrieval (documents, regulations) +Layer 4: OOD → Out-of-domain fallback (polite rejection) +``` + +### Layer Execution Flow + +```python +# Non-streaming mode +classification = await classifier.classify(query, history, language) +response = await classifier.route_to_workflow(classification, request, is_streaming=False) + +# Streaming mode +classification = await classifier.classify(query, history, language) +stream = await classifier.route_to_workflow(classification, request, is_streaming=True) +async for sse_chunk in stream: + yield sse_chunk +``` + +### Fallback Chain + +Each workflow's `execute_async()` or `execute_streaming()` can return: +- **OrchestrationResponse / AsyncIterator[str]**: Layer handled the query successfully +- **None**: Layer cannot handle → Fallback to next layer + +Example cascading: +``` +Query: "What is VAT rate?" +└─ SERVICE (Layer 1) → No matching service → Returns None + └─ CONTEXT (Layer 2) → Not a greeting → Returns None + └─ RAG (Layer 3) → Found in docs → Returns response ✓ +``` + +--- + +## Service Workflow (Layer 1) - Detailed Architecture + +### Purpose +Handle queries that require calling external services/APIs: +- Currency conversion: "How much is 100 EUR in USD?" +- Weather services: "What's the temperature in Tallinn?" +- Custom Ruuter endpoints: Any service registered in database + +### High-Level Flow + +``` +1. Service Discovery + ↓ +2. Service Selection (Semantic Search or LLM-based) + ↓ +3. Intent Detection (DSPy LLM Call) + ↓ +4. Entity Extraction (From LLM Output) + ↓ +5. Entity Validation (Against Service Schema) + ↓ +6. Entity Transformation (Dict → Ordered Array) + ↓ +7. Service Call (TODO: Ruuter endpoint invocation) +``` + +--- + +## 1. Service Discovery + +### Method: `_call_service_discovery()` + +Calls Ruuter public endpoint to fetch available services: + +```python +GET /rag-search/get-services-from-llm +``` + +**Response Structure:** +```json +{ + "response": { + "service_count": 15, + "use_semantic_search": true, + "services": [ + { + "serviceId": "currency_conversion_eur", + "name": "Currency Conversion (EUR Base)", + "description": "Convert EUR to other currencies", + "ruuterType": "POST", + "ruuterUrl": "/currency/convert", + "entities": ["target_currency"], + "examples": [ + "How much is 100 EUR in USD?", + "Convert EUR to JPY" + ] + } + ] + } +} +``` + +### Service Count Threshold Logic + +```python +SERVICE_COUNT_THRESHOLD = 10 + +if service_count <= 10: + # Few services → Use all services for LLM intent detection + services = response["services"] + +elif service_count > 10: + # Many services → Use semantic search to narrow down + services = await _semantic_search_services(query, top_k=5) +``` + +--- + +## 2. Service Selection + +### Semantic Search (When Many Services) + +**Method:** `_semantic_search_services()` + +Uses Qdrant vector database to find relevant services: + +```python +# 1. Generate embedding for user query +embedding = orchestration_service.create_embeddings_for_indexer([query]) + +# 2. Search Qdrant collection +search_payload = { + "vector": query_embedding, + "limit": 5, # Top 5 services + "score_threshold": 0.4, # Minimum similarity + "with_payload": True +} + +response = qdrant_client.post( + f"/collections/{QDRANT_COLLECTION}/points/search", + json=search_payload +) +``` + +**Returns:** Top-K most semantically relevant services for intent detection + +--- + +## 3. Intent Detection (LLM-Based) + +### Method: `_detect_service_intent()` + +Uses **DSPy + LLM** to intelligently match user query to a specific service and extract entities. + +### DSPy Module: `IntentDetectionModule` + +**Purpose:** Analyze user query against available services and extract structured information + +**Signature:** +```python +class ServiceIntentDetector(dspy.Signature): + # Inputs + user_query: str # "How much is 100 EUR in USD?" + available_services: str # JSON of service definitions + conversation_context: str # Recent 3 conversation turns + + # Output + intent_result: str # JSON: {matched_service_id, confidence, entities, reasoning} +``` + +### LLM Call Flow + +```python +# 1. Prepare service context +services_formatted = [ + { + "service_id": "currency_conversion_eur", + "name": "Currency Conversion", + "description": "Convert EUR to other currencies", + "required_entities": ["target_currency"], + "examples": ["How much is EUR in USD?", "Convert EUR to JPY"] + } +] + +# 2. Prepare conversation context (last 3 turns) +conversation_context = """ +user: Hello +assistant: Hi! How can I help? +user: How much is 100 EUR in USD? +""" + +# 3. Call DSPy module +intent_result = intent_detector.forward( + user_query="How much is 100 EUR in USD?", + services=services_formatted, + conversation_history=conversation_history +) +``` + +### LLM Output Format + +The LLM returns structured JSON: + +```json +{ + "matched_service_id": "currency_conversion_eur", + "confidence": 0.95, + "entities": { + "target_currency": "USD" + }, + "reasoning": "User wants to convert EUR to USD, matches currency conversion service" +} +``` + +### Confidence Threshold + +```python +if confidence < 0.7: + # Low confidence → Service workflow returns None → Fallback to RAG + return None +``` + +### Cost Tracking + +Intent detection is an LLM call, so costs are tracked: + +```python +# Before LLM call +history_length_before = len(dspy.settings.lm.history) + +# Call intent detector +intent_result = intent_module.forward(...) + +# After LLM call +usage_info = get_lm_usage_since(history_length_before) +costs_dict["intent_detection"] = usage_info + +# Later: orchestration_service.log_costs(costs_dict) +``` + +--- + +## 4. Entity Extraction + +### From LLM Output + +The LLM extracts entities directly from the user query: + +**User Query:** `"Palju saan 1 EUR eest THBdes?"` +(Estonian: "How much do I get for 1 EUR in THB?") + +**LLM Extraction:** +```json +{ + "entities": { + "target_currency": "THB" + } +} +``` + +### Entity Format + +Entities are extracted as **key-value pairs** where: +- **Key**: Entity name defined in service schema (`target_currency`) +- **Value**: Extracted value from user query (`"THB"`) + +### Multi-Entity Example + +**Service Schema:** +```json +{ + "serviceId": "weather_forecast", + "entities": ["location", "date"] +} +``` + +**User Query:** "What's the weather in Tallinn tomorrow?" + +**LLM Extraction:** +```json +{ + "entities": { + "location": "Tallinn", + "date": "tomorrow" + } +} +``` + +--- + +## 5. Entity Validation + +### Method: `_validate_entities()` + +Validates extracted entities against the service's expected schema. + +### Validation Checks + +#### 1. Missing Entities +Entities required by schema but not extracted by LLM: + +```python +service_schema = ["target_currency", "amount"] +extracted = {"target_currency": "USD"} + +# Missing: "amount" +missing_entities = ["amount"] +``` + +**Strategy:** Send empty string for missing entities (let service validate) + +#### 2. Extra Entities +Entities extracted but not in service schema: + +```python +service_schema = ["target_currency"] +extracted = {"target_currency": "USD", "random_field": "value"} + +# Extra: "random_field" +extra_entities = ["random_field"] +``` + +**Strategy:** Ignore extra entities (not sent to service) + +#### 3. Empty Values +Entities extracted but with empty values: + +```python +extracted = {"target_currency": ""} + +validation_errors = ["Entity 'target_currency' has empty value"] +``` + +**Strategy:** Log warning, proceed anyway (service validates) + +### Validation Result + +```python +{ + "is_valid": True, # Always true (lenient validation) + "missing_entities": ["amount"], # Will send empty strings + "extra_entities": ["random_field"], # Will be ignored + "validation_errors": [ # Warnings only + "Entity 'amount' has empty value" + ] +} +``` + +### Validation Philosophy + +**Lenient Approach:** +- Always returns `is_valid: True` +- Proceeds with partial entities +- Service endpoint validates required parameters +- Avoids false negatives from over-strict validation + +--- + +## 6. Entity Transformation + +### Method: `_transform_entities_to_array()` + +Transforms entity dictionary to **ordered array** matching service schema order. + +### Why Ordered Array? + +Ruuter services expect parameters in specific order: +```python +# Service schema defines order +entities_schema = ["target_currency", "source_currency", "amount"] + +# LLM extraction (unordered dict) +entities_dict = { + "amount": "100", + "target_currency": "USD", + "source_currency": "EUR" +} + +# Transform to ordered array +entities_array = ["USD", "EUR", "100"] +# ↑ ↑ ↑ +# [0] [1] [2] (matches schema order) +``` + +### Transformation Logic + +```python +def _transform_entities_to_array( + entities_dict: Dict[str, str], + entity_order: List[str] +) -> List[str]: + """Transform entity dict to ordered array.""" + ordered_array = [] + + for entity_key in entity_order: + # Get value from dict, or empty string if missing + value = entities_dict.get(entity_key, "") + ordered_array.append(value) + + return ordered_array +``` + +### Example + +**Service Schema:** +```json +["target_currency", "base_currency", "amount"] +``` + +**Extracted Entities:** +```json +{ + "target_currency": "JPY", + "amount": "500" +} +``` + +**Transformed Array:** +```python +["JPY", "", "500"] +# ↑ +# Missing "base_currency" → empty string +``` + +--- + +## 7. Service Call (TODO: Step 7) + +### Endpoint Construction + +```python +endpoint_url = f"{RUUTER_BASE_URL}/services/active{service_name}" +# Example: "http://ruuter:8080/services/active/currency-conversion" +# (Note: service_name from service metadata, e.g., "/currency-conversion") +``` + +### Payload Construction (Planned) + +```python +payload = { + "input": entities_array, # ["USD", "EUR", "100"] + "authorId": request.authorId, + "chatId": request.chatId +} +``` + +### HTTP Call (Planned) + +```python +# Non-streaming +response = await httpx.post( + endpoint_url, + json=payload, + timeout=5.0 +) + +# Streaming +async with httpx.stream("POST", endpoint_url, json=payload) as stream: + async for line in stream.aiter_lines(): + yield orchestration_service.format_sse(chat_id, line) +``` + +--- + +## Complete Example Flow + +### User Query +``` +"Palju saan 1 EUR eest THBdes?" +(How much do I get for 1 EUR in THB?) +``` + +### Step-by-Step Execution + +#### 1. Service Discovery +```json +{ + "service_count": 5, + "services": [ + { + "serviceId": "currency_conversion_eur", + "name": "Currency Conversion (EUR)", + "entities": ["target_currency"], + "examples": ["How much is EUR in USD?"] + } + ] +} +``` + +#### 2. Service Selection +```python +# Few services (5 <= 10) → Use all for intent detection +services = discovery_result["services"] +``` + +#### 3. Intent Detection (LLM Call) +```json +{ + "matched_service_id": "currency_conversion_eur", + "confidence": 0.92, + "entities": { + "target_currency": "THB" + }, + "reasoning": "User wants to convert EUR to THB" +} +``` + +#### 4. Entity Extraction +```python +entities_dict = {"target_currency": "THB"} +``` + +#### 5. Entity Validation +```python +validation_result = { + "is_valid": True, + "missing_entities": [], + "extra_entities": [], + "validation_errors": [] +} +``` + +#### 6. Entity Transformation +```python +# Schema: ["target_currency"] +# Dict: {"target_currency": "THB"} +# Array: ["THB"] +entities_array = ["THB"] +``` + +#### 7. Service Call (TODO) +```python +# Planned implementation +response = await call_service( + url="http://ruuter:8080/currency/convert", + method="POST", + payload={"input": ["THB"], "chatId": "..."} +) +``` + +--- + +## Cost Tracking + +Service workflow tracks LLM costs following the RAG workflow pattern: + +```python +# Create costs dict at workflow level +costs_dict: Dict[str, Dict[str, Any]] = {} + +# Intent detection captures costs +intent_result, intent_usage = await _detect_service_intent(...) +costs_dict["intent_detection"] = intent_usage + +# Log costs after workflow completes +orchestration_service.log_costs(costs_dict) +``` + +**Cost Breakdown Logged:** +``` +LLM USAGE COSTS BREAKDOWN: + intent_detection : $0.000120 (1 calls, 450 tokens) +``` + +--- + +## Fallback Behavior + +### When Service Workflow Returns None + +```python +# Scenario 1: No service match (confidence < 0.7) +if not intent_result or intent_result.get("confidence", 0) < 0.7: + return None # Fallback to CONTEXT layer + +# Scenario 2: Service validation failed +if not validated_service: + return None # Fallback to CONTEXT layer + +# Scenario 3: No services discovered +if not services: + return None # Fallback to CONTEXT layer +``` + +### Fallback Chain Result + +``` +Query: "What is VAT?" +└─ SERVICE → No service matches "VAT information" → None + └─ CONTEXT → Not a greeting → None + └─ RAG → Found in knowledge base → Response ✓ +``` + +--- + +## Configuration Constants + +```python +# Service discovery +RUUTER_BASE_URL = "http://ruuter.public:8080" +SERVICE_DISCOVERY_TIMEOUT = 5.0 # seconds + +# Service selection thresholds +SERVICE_COUNT_THRESHOLD = 10 # Switch to semantic search if exceeded +MAX_SERVICES_FOR_LLM_CONTEXT = 20 # Max services to pass to LLM + +# Semantic search +QDRANT_COLLECTION = "services_collection" +SEMANTIC_SEARCH_TOP_K = 5 # Top 5 relevant services +SEMANTIC_SEARCH_THRESHOLD = 0.4 # Minimum similarity score +QDRANT_TIMEOUT = 2.0 # seconds + +# Intent detection +INTENT_CONFIDENCE_THRESHOLD = 0.7 # Minimum confidence to proceed +``` + +--- + +## Key Design Decisions + +### 1. **Lenient Entity Validation** +- Proceeds with partial entities +- Service validates required parameters +- Reduces false negatives + +### 2. **Ordered Entity Arrays** +- Ruuter services expect positional parameters +- Schema defines canonical order +- Missing entities → empty strings + +### 3. **Two-Stage Service Selection** +- Few services (≤10): Pass all to LLM +- Many services (>10): Semantic search first + +### 4. **LLM-Based Intent Detection** +- Intelligent service matching +- Natural language understanding +- Multilingual support (Estonian, English, Russian) + +### 5. **Cost Tracking** +- Follows RAG workflow pattern +- Tracks intent detection LLM costs +- Integrated with budget system + +--- + +## Summary + +The Tool Classifier's layer architecture enables intelligent query routing with graceful fallbacks. The Service Workflow (Layer 1) uses **LLM-based intent detection** to match user queries to external services, extract entities, validate them against service schemas, and prepare them for service invocation—all while maintaining comprehensive cost tracking and seamless integration with the broader RAG pipeline. diff --git a/docs/TOOL_CLASSIFIER_SKELETON_USAGE.md b/docs/TOOL_CLASSIFIER_SKELETON_USAGE.md index 9dc87c8..38ce1f5 100644 --- a/docs/TOOL_CLASSIFIER_SKELETON_USAGE.md +++ b/docs/TOOL_CLASSIFIER_SKELETON_USAGE.md @@ -361,9 +361,9 @@ class MyCustomWorkflow(BaseWorkflow): # Stream result token-by-token async def stream_result(): for chunk in self._split_into_tokens(result): - yield self._format_sse(request.chatId, chunk) + yield self.format_sse(request.chatId, chunk) await asyncio.sleep(0.01) - yield self._format_sse(request.chatId, "END") + yield self.format_sse(request.chatId, "END") return stream_result() ``` diff --git a/enrich.yml.backup b/enrich.yml.backup deleted file mode 100644 index 28cd5b3..0000000 --- a/enrich.yml.backup +++ /dev/null @@ -1,157 +0,0 @@ -declaration: - call: declare - version: 0.1 - description: "Enrich service data and index in Qdrant" - method: post - accepts: json - returns: json - namespace: rag-search - allowlist: - body: - - field: service_id - type: string - description: "Unique service identifier" - - field: name - type: string - description: "Service name" - - field: description - type: string - description: "Service description" - - field: examples - type: array - description: "Example queries" - - field: entities - type: array - description: "Expected entity names" - - field: ruuter_type - type: string - description: "HTTP method (GET/POST)" - - field: current_state - type: string - description: "Service state (active/inactive/draft)" - - field: is_common - type: boolean - description: "Is common service" - -validate_request: - assign: - service_id: ${incoming.body.service_id} - service_name: ${incoming.body.name} - service_description: ${incoming.body.description} - next: check_required_fields - -check_required_fields: - switch: - - condition: ${!service_id} - next: assign_missing_service_id_error - - condition: ${!service_name} - next: assign_missing_name_error - - condition: ${!service_description} - next: assign_missing_description_error - next: prepare_service_data - -assign_missing_service_id_error: - assign: - error_response: { - success: false, - error: "MISSING_SERVICE_ID", - message: "service_id is required" - } - next: return_missing_service_id - -return_missing_service_id: - status: 400 - return: ${error_response} - next: end - -assign_missing_name_error: - assign: - error_response: { - success: false, - error: "MISSING_NAME", - message: "name is required" - } - next: return_missing_name - -return_missing_name: - status: 400 - return: ${error_response} - next: end - -assign_missing_description_error: - assign: - error_response: { - success: false, - error: "MISSING_DESCRIPTION", - message: "description is required" - } - next: return_missing_description - -return_missing_description: - status: 400 - return: ${error_response} - next: end - -prepare_service_data: - assign: - service_data: { - service_id: ${service_id}, - name: ${service_name}, - description: ${service_description}, - examples: ${incoming.body.examples || []}, - entities: ${incoming.body.entities || []}, - ruuter_type: ${incoming.body.ruuter_type || 'GET'}, - current_state: ${incoming.body.current_state || 'draft'}, - is_common: ${incoming.body.is_common || false} - } - next: stringify_service_data - -stringify_service_data: - assign: - service_json: ${JSON.stringify(service_data)} - next: execute_enrichment - -execute_enrichment: - call: http.post - args: - url: "[#RAG_SEARCH_CRON_MANAGER]/execute/service_enrichment/enrich_and_index" - query: - service_id: ${service_id} - service_data: ${service_json} - result: enrichment_result - next: assign_success_response - on_error: handle_enrichment_error - -handle_enrichment_error: - log: "ERROR: Service enrichment failed - ${enrichment_result.error || 'Unknown error'}" - next: assign_error_response - -assign_success_response: - assign: - success_response: { - success: true, - service_id: ${service_id}, - message: "Service enriched and indexed successfully", - enrichment_details: ${enrichment_result.response.body} - } - next: return_success - -assign_error_response: - assign: - error_response: { - success: false, - error: "ENRICHMENT_FAILED", - message: "Failed to enrich and index service", - details: ${enrichment_result.response.body || enrichment_result.error} - } - next: return_enrichment_error - -return_success: - status: 200 - return: ${success_response} - next: end - -return_enrichment_error: - status: 500 - return: ${error_response} - next: end diff --git a/src/intent_data_enrichment/constants.py b/src/intent_data_enrichment/constants.py index fd15a6a..f1f35f3 100644 --- a/src/intent_data_enrichment/constants.py +++ b/src/intent_data_enrichment/constants.py @@ -43,4 +43,6 @@ class EnrichmentConstants: - Related concepts - Common ways users might express this intent +IMPORTANT: Generate the context in the SAME LANGUAGE as the service description above. If the description is in Estonian, respond in Estonian. If in English, respond in English. If in Russian, respond in Russian. + Answer only with the enriched context and nothing else.""" diff --git a/src/intent_data_enrichment/main_enrichment.py b/src/intent_data_enrichment/main_enrichment.py index 2aedb26..d718678 100644 --- a/src/intent_data_enrichment/main_enrichment.py +++ b/src/intent_data_enrichment/main_enrichment.py @@ -91,12 +91,35 @@ async def enrich_service(service_data: ServiceData) -> EnrichmentResult: context = await api_client.generate_context(service_data) logger.success(f"Context generated: {len(context)} characters") - # Step 2: Create embedding for the context - logger.info("Step 2: Creating embedding vector") - embedding = await api_client.create_embedding(context) + # Step 2: Combine generated context with original metadata for embedding + logger.info("Step 2: Combining context with original service metadata") + combined_text_parts = [ + f"Service Name: {service_data.name}", + f"Description: {service_data.description}", + ] + + if service_data.examples: + combined_text_parts.append( + f"Example Queries: {' | '.join(service_data.examples)}" + ) + + if service_data.entities: + combined_text_parts.append( + f"Required Entities: {', '.join(service_data.entities)}" + ) + + # Add generated context last (enriched understanding) + combined_text_parts.append(f"Enriched Context: {context}") + + combined_text = "\n".join(combined_text_parts) + logger.info(f"Combined text length: {len(combined_text)} characters") + + # Step 3: Create embedding for combined text + logger.info("Step 3: Creating embedding vector for combined text") + embedding = await api_client.create_embedding(combined_text) logger.success(f"Embedding created: {len(embedding)}-dimensional vector") - # Step 3: Prepare enriched service + # Step 4: Prepare enriched service enriched_service = EnrichedService( id=service_data.service_id, name=service_data.name, @@ -107,8 +130,8 @@ async def enrich_service(service_data: ServiceData) -> EnrichmentResult: embedding=embedding, ) - # Step 4: Store in Qdrant - logger.info("Step 3: Storing in Qdrant") + # Step 5: Store in Qdrant + logger.info("Step 5: Storing in Qdrant") qdrant = QdrantManager() try: qdrant.connect() diff --git a/src/llm_orchestration_service.py b/src/llm_orchestration_service.py index 3c059f5..e2eb0c9 100644 --- a/src/llm_orchestration_service.py +++ b/src/llm_orchestration_service.py @@ -274,7 +274,7 @@ async def process_orchestration_request( ) # Log final costs and return response - self._log_costs(costs_dict) + self.log_costs(costs_dict) log_step_timings(timing_dict, request.chatId) # Update budget for the LLM connection @@ -331,7 +331,7 @@ async def process_orchestration_request( } ) langfuse.flush() - self._log_costs(costs_dict) + self.log_costs(costs_dict) log_step_timings(timing_dict, request.chatId) # Update budget even on error @@ -405,8 +405,8 @@ async def stream_orchestration_response( ) # Yield SSE format error + END marker - yield self._format_sse(request.chatId, validation_msg) - yield self._format_sse(request.chatId, "END") + yield self.format_sse(request.chatId, validation_msg) + yield self.format_sse(request.chatId, "END") return # Stop processing # Use StreamManager for centralized tracking and guaranteed cleanup @@ -441,11 +441,11 @@ async def stream_orchestration_response( f"[{request.chatId}] [{stream_ctx.stream_id}] Input blocked by guardrails: " f"{input_check_result.reason}" ) - yield self._format_sse( + yield self.format_sse( request.chatId, INPUT_GUARDRAIL_VIOLATION_MESSAGE ) - yield self._format_sse(request.chatId, "END") - self._log_costs(costs_dict) + yield self.format_sse(request.chatId, "END") + self.log_costs(costs_dict) stream_ctx.mark_completed() return @@ -500,7 +500,7 @@ async def stream_orchestration_response( ) # Log costs and timings - self._log_costs(costs_dict) + self.log_costs(costs_dict) log_step_timings(timing_dict, request.chatId) stream_ctx.mark_completed() return # Exit after successful classifier routing @@ -546,10 +546,10 @@ async def stream_orchestration_response( logger, error_id, "streaming_orchestration", request.chatId, e ) - yield self._format_sse(request.chatId, TECHNICAL_ISSUE_MESSAGE) - yield self._format_sse(request.chatId, "END") + yield self.format_sse(request.chatId, TECHNICAL_ISSUE_MESSAGE) + yield self.format_sse(request.chatId, "END") - self._log_costs(costs_dict) + self.log_costs(costs_dict) log_step_timings(timing_dict, request.chatId) # Update budget even on outer exception @@ -645,9 +645,9 @@ async def _stream_rag_pipeline( localized_msg = get_localized_message( OUT_OF_SCOPE_MESSAGES, detected_language ) - yield self._format_sse(request.chatId, localized_msg) - yield self._format_sse(request.chatId, "END") - self._log_costs(costs_dict) + yield self.format_sse(request.chatId, localized_msg) + yield self.format_sse(request.chatId, "END") + self.log_costs(costs_dict) log_step_timings(timing_dict, request.chatId) stream_ctx.mark_completed() return @@ -659,9 +659,9 @@ async def _stream_rag_pipeline( localized_msg = get_localized_message( OUT_OF_SCOPE_MESSAGES, detected_language ) - yield self._format_sse(request.chatId, localized_msg) - yield self._format_sse(request.chatId, "END") - self._log_costs(costs_dict) + yield self.format_sse(request.chatId, localized_msg) + yield self.format_sse(request.chatId, "END") + self.log_costs(costs_dict) log_step_timings(timing_dict, request.chatId) stream_ctx.mark_completed() return @@ -690,9 +690,9 @@ async def _stream_rag_pipeline( localized_msg = get_localized_message( OUT_OF_SCOPE_MESSAGES, detected_language ) - yield self._format_sse(request.chatId, localized_msg) - yield self._format_sse(request.chatId, "END") - self._log_costs(costs_dict) + yield self.format_sse(request.chatId, localized_msg) + yield self.format_sse(request.chatId, "END") + self.log_costs(costs_dict) log_step_timings(timing_dict, request.chatId) stream_ctx.mark_completed() return @@ -755,14 +755,14 @@ async def bot_response_generator() -> AsyncIterator[str]: f"[{request.chatId}] [{stream_ctx.stream_id}] Token limit exceeded: " f"{stream_ctx.token_count} > {StreamConfig.MAX_TOKENS_PER_STREAM}" ) - yield self._format_sse( + yield self.format_sse( request.chatId, STREAM_TOKEN_LIMIT_MESSAGE ) - yield self._format_sse(request.chatId, "END") + yield self.format_sse(request.chatId, "END") usage_info = get_lm_usage_since(history_length_before) costs_dict["streaming_generation"] = usage_info - self._log_costs(costs_dict) + self.log_costs(costs_dict) log_step_timings(timing_dict, request.chatId) stream_ctx.mark_completed() return @@ -784,20 +784,20 @@ async def bot_response_generator() -> AsyncIterator[str]: logger.warning( f"[{request.chatId}] [{stream_ctx.stream_id}] Guardrails violation detected" ) - yield self._format_sse( + yield self.format_sse( request.chatId, OUTPUT_GUARDRAIL_VIOLATION_MESSAGE ) - yield self._format_sse(request.chatId, "END") + yield self.format_sse(request.chatId, "END") usage_info = get_lm_usage_since(history_length_before) costs_dict["streaming_generation"] = usage_info - self._log_costs(costs_dict) + self.log_costs(costs_dict) log_step_timings(timing_dict, request.chatId) stream_ctx.mark_completed() return # Yield the validated chunk to client - yield self._format_sse(request.chatId, validated_chunk) + yield self.format_sse(request.chatId, validated_chunk) except GeneratorExit: stream_ctx.mark_cancelled() logger.info( @@ -816,9 +816,9 @@ async def bot_response_generator() -> AsyncIterator[str]: f"{i + 1}. [{ref.document_url}]({ref.document_url})" for i, ref in enumerate(doc_references) ) - yield self._format_sse(request.chatId, refs_text) + yield self.format_sse(request.chatId, refs_text) - yield self._format_sse(request.chatId, "END") + yield self.format_sse(request.chatId, "END") else: # No guardrails - stream directly @@ -837,14 +837,14 @@ async def bot_response_generator() -> AsyncIterator[str]: logger.error( f"[{request.chatId}] [{stream_ctx.stream_id}] Token limit exceeded (no guardrails)" ) - yield self._format_sse( + yield self.format_sse( request.chatId, STREAM_TOKEN_LIMIT_MESSAGE ) - yield self._format_sse(request.chatId, "END") + yield self.format_sse(request.chatId, "END") stream_ctx.mark_completed() return - yield self._format_sse(request.chatId, token) + yield self.format_sse(request.chatId, token) # Send document references before END token doc_references = self._extract_document_references(relevant_chunks) @@ -853,9 +853,9 @@ async def bot_response_generator() -> AsyncIterator[str]: f"{i + 1}. [{ref.document_url}]({ref.document_url})" for i, ref in enumerate(doc_references) ) - yield self._format_sse(request.chatId, refs_text) + yield self.format_sse(request.chatId, refs_text) - yield self._format_sse(request.chatId, "END") + yield self.format_sse(request.chatId, "END") # Extract usage information after streaming completes usage_info = get_lm_usage_since(history_length_before) @@ -872,7 +872,7 @@ async def bot_response_generator() -> AsyncIterator[str]: ) # Log costs and trace - self._log_costs(costs_dict) + self.log_costs(costs_dict) log_step_timings(timing_dict, request.chatId) # Update budget @@ -935,7 +935,7 @@ async def bot_response_generator() -> AsyncIterator[str]: ) usage_info = get_lm_usage_since(history_length_before) costs_dict["streaming_generation"] = usage_info - self._log_costs(costs_dict) + self.log_costs(costs_dict) log_step_timings(timing_dict, request.chatId) # Update budget even on client disconnect @@ -953,12 +953,12 @@ async def bot_response_generator() -> AsyncIterator[str]: request.chatId, stream_error, ) - yield self._format_sse(request.chatId, TECHNICAL_ISSUE_MESSAGE) - yield self._format_sse(request.chatId, "END") + yield self.format_sse(request.chatId, TECHNICAL_ISSUE_MESSAGE) + yield self.format_sse(request.chatId, "END") usage_info = get_lm_usage_since(history_length_before) costs_dict["streaming_generation"] = usage_info - self._log_costs(costs_dict) + self.log_costs(costs_dict) log_step_timings(timing_dict, request.chatId) # Update budget even on streaming error @@ -966,7 +966,7 @@ async def bot_response_generator() -> AsyncIterator[str]: request.connection_id, costs_dict, request.environment ) - def _format_sse(self, chat_id: str, content: str) -> str: + def format_sse(self, chat_id: str, content: str) -> str: """ Format SSE message with exact specification. @@ -1885,7 +1885,7 @@ async def _check_output_guardrails( usage={}, ) - def _log_costs(self, costs_dict: Dict[str, Dict[str, Any]]) -> None: + def log_costs(self, costs_dict: Dict[str, Dict[str, Any]]) -> None: """ Log cost information for tracking. diff --git a/src/tool_classifier/classifier.py b/src/tool_classifier/classifier.py index 71a4592..c8bef8a 100644 --- a/src/tool_classifier/classifier.py +++ b/src/tool_classifier/classifier.py @@ -55,6 +55,7 @@ def __init__( # Initialize workflow executors self.service_workflow = ServiceWorkflowExecutor( llm_manager=llm_manager, + orchestration_service=orchestration_service, ) self.context_workflow = ContextWorkflowExecutor( llm_manager=llm_manager, @@ -75,10 +76,11 @@ async def classify( """ Classify a user query to determine which workflow should handle it. - Implements layer-wise classification logic: - 1. Check if SERVICE workflow can handle (intent detection) - 2. Check if CONTEXT workflow can handle (greeting/history check) - 3. Default to RAG workflow (knowledge retrieval) + Implements layer-wise classification logic with fallback chain: + 1. SERVICE workflow (external API calls) + 2. CONTEXT workflow (greetings/conversation history) + 3. RAG workflow (knowledge base retrieval) + 4. OOD workflow (out-of-domain) Args: query: User's query string @@ -87,60 +89,15 @@ async def classify( Returns: ClassificationResult indicating which workflow to use - - Note: - In this skeleton, always defaults to RAG. Full implementation - will add Layer 1 and Layer 2 logic in separate tasks. """ logger.info(f"Classifying query: {query[:100]}...") - # TODO: LAYER 1 - SERVICE WORKFLOW DETECTION - # Implementation task: Service workflow implementation - # Logic: - # 1. Count active services in database - # 2. If count > 50: Use Qdrant semantic search for top 20 services - # 3. If count <= 50: Use all services - # 4. Call LLM to detect intent and extract entities - # 5. If intent detected and service valid: return SERVICE classification - # Example: - # service_check = await self._check_service_layer(query, language) - # if service_check.can_handle: - # return ClassificationResult( - # workflow=WorkflowType.SERVICE, - # confidence=service_check.confidence, - # metadata=service_check.metadata, - # reasoning="Service intent detected" - # ) - - # TODO: LAYER 2 - CONTEXT WORKFLOW DETECTION - # Implementation task: Context workflow implementation - # Logic: - # 1. Check if query is a greeting using LLM - # 2. If greeting: return CONTEXT classification - # 3. If conversation_history exists: Check if query references history - # 4. Call LLM to determine if history contains answer - # 5. If can answer from history: return CONTEXT classification - # Example: - # context_check = await self._check_context_layer( - # query, conversation_history, language - # ) - # if context_check.can_handle: - # return ClassificationResult( - # workflow=WorkflowType.CONTEXT, - # confidence=context_check.confidence, - # metadata=context_check.metadata, - # reasoning="Greeting or answerable from history" - # ) - - # LAYER 3 - RAG WORKFLOW (DEFAULT) - # Always defaults to RAG for now - # RAG workflow will handle the query or return OOD if no chunks found - logger.info("Defaulting to RAG workflow (Layers 1-2 not implemented)") + logger.info("Starting layer-wise fallback: ") return ClassificationResult( - workflow=WorkflowType.RAG, + workflow=WorkflowType.SERVICE, confidence=1.0, metadata={}, - reasoning="Default to RAG workflow (service and context layers not implemented)", + reasoning="Start with Service workflow - will cascade through layers", ) @overload @@ -235,10 +192,7 @@ async def _execute_with_fallback_async( """ Execute workflow with fallback to subsequent layers (non-streaming). - TODO: Implement full fallback chain logic - Currently just executes the primary workflow. - - Full implementation should: + Implementation: 1. Try primary workflow 2. If returns None, try next layer in WORKFLOW_LAYER_ORDER 3. Continue until workflow returns non-None result @@ -256,19 +210,39 @@ async def _execute_with_fallback_async( logger.info(f"[{chat_id}] {workflow_name} handled successfully") return result - # TODO: Implement fallback to next layer - # For now, if workflow returns None, call RAG as fallback - logger.warning( + # Implement layer-wise fallback chain + logger.info( f"[{chat_id}] {workflow_name} returned None, " - f"falling back to RAG workflow" + f"trying next layer in fallback chain" ) - rag_result = await self.rag_workflow.execute_async(request, {}) - if rag_result is not None: - return rag_result - else: - # This should never happen since RAG always returns a result - # But handle gracefully - raise RuntimeError("RAG workflow returned None unexpectedly") + + # Get the layer order starting from current layer + from tool_classifier.enums import WORKFLOW_LAYER_ORDER + + current_index = WORKFLOW_LAYER_ORDER.index(start_layer) + remaining_layers = WORKFLOW_LAYER_ORDER[current_index + 1 :] + + # Try each subsequent layer in order + for next_layer in remaining_layers: + next_workflow = self._get_workflow_executor(next_layer) + next_name = WORKFLOW_DISPLAY_NAMES.get(next_layer, next_layer.value) + + logger.info( + f"[{chat_id}] Falling back to {next_name} " + f"(Layer {WORKFLOW_LAYER_ORDER.index(next_layer) + 1})" + ) + + result = await next_workflow.execute_async(request, {}) + + if result is not None: + logger.info(f"[{chat_id}] {next_name} handled successfully") + return result + + logger.info(f"[{chat_id}] {next_name} returned None, continuing...") + current_index += 1 + + # This should never happen since RAG/OOD should always return result + raise RuntimeError("All workflows returned None (unexpected)") except Exception as e: logger.error(f"[{chat_id}] Error executing {workflow_name}: {e}") @@ -290,10 +264,7 @@ async def _execute_with_fallback_streaming( """ Execute workflow with fallback to subsequent layers (streaming). - TODO: Implement full fallback chain logic - Currently just executes the primary workflow. - - Full implementation should: + Implementation: 1. Try primary workflow 2. If returns None, try next layer in WORKFLOW_LAYER_ORDER 3. Stream from the first workflow that returns non-None @@ -313,18 +284,42 @@ async def _execute_with_fallback_streaming( yield chunk return - # TODO: Implement fallback to next layer - # For now, if workflow returns None, call RAG as fallback - logger.warning( + # Implement layer-wise fallback chain for streaming + logger.info( f"[{chat_id}] {workflow_name} returned None, " - f"falling back to RAG workflow streaming" + f"trying next layer in fallback chain" ) - streaming_result = await self.rag_workflow.execute_streaming(request, {}) - if streaming_result is not None: - async for chunk in streaming_result: - yield chunk - else: - raise RuntimeError("RAG workflow returned None unexpectedly") + + # Get the layer order starting from current layer + from tool_classifier.enums import WORKFLOW_LAYER_ORDER + + current_index = WORKFLOW_LAYER_ORDER.index(start_layer) + remaining_layers = WORKFLOW_LAYER_ORDER[current_index + 1 :] + + # Try each subsequent layer in order + for next_layer in remaining_layers: + next_workflow = self._get_workflow_executor(next_layer) + next_name = WORKFLOW_DISPLAY_NAMES.get(next_layer, next_layer.value) + + layer_number = WORKFLOW_LAYER_ORDER.index(next_layer) + 1 + logger.info( + f"[{chat_id}] Falling back to {next_name} streaming " + f"(Layer {layer_number})" + ) + + result = await next_workflow.execute_streaming(request, {}) + + if result is not None: + logger.info(f"[{chat_id}] {next_name} streaming started") + async for chunk in result: + yield chunk + return + + logger.info(f"[{chat_id}] {next_name} returned None, continuing...") + current_index += 1 + + # This should never happen + raise RuntimeError("All workflows returned None in streaming (unexpected)") except Exception as e: logger.error(f"[{chat_id}] Error executing {workflow_name} streaming: {e}") diff --git a/src/tool_classifier/constants.py b/src/tool_classifier/constants.py new file mode 100644 index 0000000..c885b52 --- /dev/null +++ b/src/tool_classifier/constants.py @@ -0,0 +1,60 @@ +"""Constants and configuration for tool classifier module.""" + + +# ============================================================================ +# Qdrant Vector Database Configuration +# ============================================================================ + +QDRANT_HOST = "qdrant" +"""Qdrant server hostname.""" + +QDRANT_PORT = 6333 +"""Qdrant server port.""" + +QDRANT_TIMEOUT = 10.0 +"""Qdrant HTTP client timeout in seconds.""" + + +# ============================================================================ +# Semantic Search Configuration +# ============================================================================ + +QDRANT_COLLECTION = "intent_collections" +"""Qdrant collection name for service intent search.""" + +SEMANTIC_SEARCH_TOP_K = 10 +"""Number of top services to return from semantic search.""" + +SEMANTIC_SEARCH_THRESHOLD = 0.2 +"""Minimum similarity score threshold for semantic search (0.0-1.0). +Lowered from 0.4 to handle broader queries.""" + + +# ============================================================================ +# Ruuter Service Configuration +# ============================================================================ + +RUUTER_BASE_URL = "http://ruuter-private:8086" +"""Base URL for Ruuter private service endpoints.""" + +RAG_SEARCH_RUUTER_PUBLIC = "http://ruuter-public:8086/rag-search" +"""Public Ruuter endpoint for RAG search service discovery.""" + +SERVICE_CALL_TIMEOUT = 10 +"""Timeout in seconds for external service calls via Ruuter.""" + +SERVICE_DISCOVERY_TIMEOUT = 10.0 +"""Timeout in seconds for service discovery calls.""" + + +# ============================================================================ +# Service Workflow Thresholds +# ============================================================================ + +MAX_SERVICES_FOR_LLM_CONTEXT = 50 +"""Maximum number of services to send to LLM without semantic filtering. +If service count exceeds this, semantic search is used to filter to top-K.""" + +SERVICE_COUNT_THRESHOLD = 10 +"""Threshold for triggering semantic search. If service count > this value, +semantic search is used instead of sending all services to LLM.""" diff --git a/src/tool_classifier/intent_detector.py b/src/tool_classifier/intent_detector.py new file mode 100644 index 0000000..24c1538 --- /dev/null +++ b/src/tool_classifier/intent_detector.py @@ -0,0 +1,133 @@ +"""Service intent detection using DSPy.""" + +import json +from typing import Any, Dict, List, Optional + +import dspy +from loguru import logger + + +class ServiceIntentDetector(dspy.Signature): + """Detect which service matches user intent and extract entities. + + CRITICAL LANGUAGE RULE: + - Understand Estonian, Russian, and English queries + - Extract entities in their original form from the query + + Rules: + - Match user query against available services + - Extract required entity values from the query + - Return valid JSON format strictly + - If no service matches well (confidence < 0.7), return null for matched_service_id + - Be conservative - only match when confident + - Prioritize services whose examples closely match the user query + """ + + user_query: str = dspy.InputField( + desc="User's question/request in Estonian, Russian, or English" + ) + available_services: str = dspy.InputField( + desc="JSON string of available services with id, name, description, entities, examples" + ) + conversation_context: str = dspy.InputField( + desc="Recent conversation history for context (optional, may be empty)" + ) + + intent_result: str = dspy.OutputField( + desc='Valid JSON only: {"matched_service_id": "id_string" or null, "confidence": 0.0-1.0, "entities": {}, "reasoning": "brief explanation"}' + ) + + +class IntentDetectionModule(dspy.Module): + """DSPy Module for service intent detection.""" + + def __init__(self) -> None: + """Initialize intent detection module with ChainOfThought.""" + super().__init__() + self.detector = dspy.ChainOfThought(ServiceIntentDetector) + + def forward( + self, + user_query: str, + services: List[Dict[str, Any]], + conversation_history: Optional[List[Dict[str, Any]]] = None, + ) -> Dict[str, Any]: + """ + Detect service intent using LLM via DSPy. + + Args: + user_query: User's query + services: List of service dicts with serviceId, name, description, entities, examples + conversation_history: Recent messages (optional) + + Returns: + Parsed intent result dict with matched_service_id, confidence, entities, reasoning + """ + # Format services for prompt (keep it concise) + services_formatted = [] + for s in services: + service_entry = { + "service_id": s.get("serviceId", s.get("service_id")), + "name": s.get("name", "Unknown"), + "description": s.get("description", ""), + "required_entities": s.get("entities", []), + "examples": s.get("examples", [])[:3], # Top 3 examples + } + services_formatted.append(service_entry) + + services_json = json.dumps(services_formatted, ensure_ascii=False, indent=2) + + # Format conversation history + if conversation_history: + history_lines = [] + for msg in conversation_history[-3:]: # Last 3 turns + role = msg.get("authorRole", "unknown") + content = msg.get("message", "") + if content: + history_lines.append(f"{role}: {content}") + history_text = "\n".join(history_lines) if history_lines else "(Empty)" + else: + history_text = "(No conversation history)" + + # Call DSPy detector with ChainOfThought + result = None + try: + result = self.detector( + user_query=user_query, + available_services=services_json, + conversation_context=history_text, + ) + + # Parse JSON response + intent_data = json.loads(result.intent_result) + + # Validate structure + if not isinstance(intent_data, dict): + raise ValueError("Intent result is not a dictionary") + + # Ensure required keys exist + intent_data.setdefault("matched_service_id", None) + intent_data.setdefault("confidence", 0.0) + intent_data.setdefault("entities", {}) + intent_data.setdefault("reasoning", "") + + return intent_data + + except json.JSONDecodeError as e: + logger.error(f"Failed to parse intent JSON: {e}") + if result: + logger.error(f"Raw response: {result.intent_result}") + return { + "matched_service_id": None, + "confidence": 0.0, + "entities": {}, + "reasoning": f"JSON parse error: {e}", + } + except Exception as e: + logger.error(f"Intent detection forward failed: {e}", exc_info=True) + return { + "matched_service_id": None, + "confidence": 0.0, + "entities": {}, + "reasoning": f"Detection error: {e}", + } diff --git a/src/tool_classifier/workflows/ood_workflow.py b/src/tool_classifier/workflows/ood_workflow.py index fed467a..cd114f7 100644 --- a/src/tool_classifier/workflows/ood_workflow.py +++ b/src/tool_classifier/workflows/ood_workflow.py @@ -104,9 +104,9 @@ async def execute_streaming( # Stream message for UX consistency (no guardrails needed - fixed message) async def stream_ood_message(): for chunk in split_into_tokens(ood_message, chunk_size=5): - yield self._format_sse(request.chatId, chunk) + yield self.format_sse(request.chatId, chunk) await asyncio.sleep(0.01) - yield self._format_sse(request.chatId, "END") + yield self.format_sse(request.chatId, "END") return stream_ood_message() ``` diff --git a/src/tool_classifier/workflows/rag_workflow.py b/src/tool_classifier/workflows/rag_workflow.py index d83080a..6c58648 100644 --- a/src/tool_classifier/workflows/rag_workflow.py +++ b/src/tool_classifier/workflows/rag_workflow.py @@ -87,7 +87,7 @@ async def execute_async( ) # Log costs and timings - self.orchestration_service._log_costs(costs_dict) + self.orchestration_service.log_costs(costs_dict) from src.utils.time_tracker import log_step_timings log_step_timings(timing_dict, request.chatId) diff --git a/src/tool_classifier/workflows/service_workflow.py b/src/tool_classifier/workflows/service_workflow.py index 8a6889b..d71e2d9 100644 --- a/src/tool_classifier/workflows/service_workflow.py +++ b/src/tool_classifier/workflows/service_workflow.py @@ -1,137 +1,796 @@ """Service workflow executor - Layer 1: External service/API calls.""" -from typing import Any, AsyncIterator, Dict, Optional +from typing import Any, AsyncIterator, Dict, List, Optional, Protocol + +import dspy +import httpx from loguru import logger -from models.request_models import OrchestrationRequest, OrchestrationResponse +from src.utils.cost_utils import get_lm_usage_since + +from models.request_models import ( + OrchestrationRequest, + OrchestrationResponse, +) from tool_classifier.base_workflow import BaseWorkflow +from tool_classifier.constants import ( + MAX_SERVICES_FOR_LLM_CONTEXT, + QDRANT_COLLECTION, + QDRANT_HOST, + QDRANT_PORT, + QDRANT_TIMEOUT, + RAG_SEARCH_RUUTER_PUBLIC, + RUUTER_BASE_URL, + SEMANTIC_SEARCH_THRESHOLD, + SEMANTIC_SEARCH_TOP_K, + SERVICE_COUNT_THRESHOLD, + SERVICE_DISCOVERY_TIMEOUT, +) +from tool_classifier.intent_detector import IntentDetectionModule -class ServiceWorkflowExecutor(BaseWorkflow): - """ - Executes external service calls via Ruuter endpoints (Layer 1). - - This workflow handles queries that require calling external government - services or APIs. It performs: - 1. Service discovery (semantic search if >50 services) - 2. Intent detection using LLM - 3. Entity extraction from query - 4. Service validation against database - 5. External API call via Ruuter - 6. Output guardrails validation - - Examples of Service queries: - - "What's the EUR to USD exchange rate?" - - "Check my document status" - - "Submit a tax declaration" - - Implementation Status: SKELETON - Returns None (triggers fallback to Context workflow) - - TODO - Full Implementation (Separate Task): - - Service discovery logic (Qdrant semantic search) - - Intent detection (LLM-based) - - Entity extraction and transformation - - Service validation (database lookup) - - Ruuter API integration - - Output guardrails for service responses - """ - - def __init__(self, llm_manager: Any): +class LLMServiceProtocol(Protocol): + """Protocol defining interface for LLM service embedding operations.""" + + def create_embeddings_for_indexer( + self, + texts: List[str], + environment: str = "production", + connection_id: Optional[str] = None, + batch_size: int = 10, + ) -> Dict[str, Any]: + """Create embeddings for text inputs using the configured embedding model. + + Args: + texts: List of text strings to embed + environment: Environment for model resolution + connection_id: Optional connection ID for service selection + batch_size: Number of texts to process in each batch + + Returns: + Dictionary containing embeddings list and metadata + """ + ... + + def format_sse(self, chat_id: str, content: str) -> str: + """Format content as SSE message. + + Args: + chat_id: Chat/channel identifier + content: Content to send (token, "END", error message, etc.) + + Returns: + SSE-formatted string: "data: {json}\\n\\n" """ - Initialize service workflow executor. + ... + + def log_costs(self, costs_dict: Dict[str, Dict[str, Any]]) -> None: + """Log cost information for tracking. Args: - llm_manager: LLM manager for intent detection + costs_dict: Dictionary of costs per component """ + ... + + +class ServiceWorkflowExecutor(BaseWorkflow): + """Executes external service calls via Ruuter endpoints (Layer 1).""" + + def __init__( + self, + llm_manager: Any, + orchestration_service: Optional[LLMServiceProtocol] = None, + ) -> None: + """Initialize service workflow executor.""" self.llm_manager = llm_manager - logger.info("Service workflow executor initialized (skeleton)") + self.orchestration_service = orchestration_service - async def execute_async( + async def _semantic_search_services( self, + query: str, request: OrchestrationRequest, - context: Dict[str, Any], - ) -> Optional[OrchestrationResponse]: + chat_id: str, + top_k: int = SEMANTIC_SEARCH_TOP_K, + ) -> Optional[List[Dict[str, Any]]]: + """Search services using semantic search via Qdrant. + + Creates a new httpx.AsyncClient per request to ensure proper resource cleanup. + This is safe and efficient since semantic search is infrequent (only when many services exist). """ - Execute service workflow in non-streaming mode. - - TODO: Implement service workflow logic: - 1. Extract service metadata from context (service_id, intent, entities) - 2. Validate service exists and is active in database - 3. Transform entities to array format for service call - 4. Call Ruuter endpoint: POST {RUUTER_BASE_URL}/services/active{ServiceName} - 5. Validate response with output guardrails - 6. Return OrchestrationResponse with service result - - Failure scenarios: - - No service_id in context → return None (fallback to Context) - - Service not found/inactive → return None (fallback to Context) - - Service call timeout → return error response - - Output guardrails blocked → return violation response or None + if not self.orchestration_service: + logger.error( + f"[{chat_id}] Semantic search unavailable: orchestration service not provided" + ) + return None - Args: - request: Orchestration request with user query - context: Metadata with service_id, intent, entities + try: + # Generate embedding using orchestration service + embedding_result = self.orchestration_service.create_embeddings_for_indexer( + texts=[query], + environment=request.environment, + connection_id=request.connection_id, + batch_size=1, + ) + + embeddings = embedding_result.get("embeddings", []) + if not embeddings or len(embeddings) == 0: + logger.error(f"[{chat_id}] No embedding returned for query") + return None + + query_embedding = embeddings[0] + + # Create Qdrant client with proper resource cleanup via context manager + qdrant_url = f"http://{QDRANT_HOST}:{QDRANT_PORT}" + async with httpx.AsyncClient( + base_url=qdrant_url, timeout=QDRANT_TIMEOUT + ) as client: + # Verify collection exists and has data + try: + collection_info = await client.get( + f"/collections/{QDRANT_COLLECTION}" + ) + if collection_info.status_code == 200: + info = collection_info.json() + points_count = info.get("result", {}).get("points_count", 0) + if points_count == 0: + logger.error(f"[{chat_id}] Collection is empty") + return None + except Exception as e: + logger.warning(f"[{chat_id}] Could not verify collection: {e}") + + # Search Qdrant collection + search_payload = { + "vector": query_embedding, + "limit": top_k, + "score_threshold": SEMANTIC_SEARCH_THRESHOLD, + "with_payload": True, + } + + response = await client.post( + f"/collections/{QDRANT_COLLECTION}/points/search", + json=search_payload, + ) + + if response.status_code != 200: + logger.error( + f"[{chat_id}] Qdrant search failed: HTTP {response.status_code}" + ) + return None + + search_results = response.json() + points = search_results.get("result", []) + + if len(points) == 0: + logger.warning( + f"[{chat_id}] No services matched (threshold={SEMANTIC_SEARCH_THRESHOLD})" + ) + return None + + # Transform Qdrant results to service format + services: List[Dict[str, Any]] = [] + for point in points: + payload = point.get("payload", {}) + score = float(point.get("score", 0)) + + service = { + "serviceId": payload.get("service_id"), + "service_id": payload.get("service_id"), + "name": payload.get("name"), + "description": payload.get("description"), + "examples": payload.get("examples", []), + "entities": payload.get("entities", []), + # Note: endpoint not stored in intent_collections, + # will be resolved via database lookup if needed + "similarity_score": score, + } + services.append(service) + + logger.info( + f"[{chat_id}] Found {len(services)} services via semantic search" + ) + return services + + except Exception as e: + logger.error(f"[{chat_id}] Semantic search failed: {e}", exc_info=True) + return None + + async def _call_service_discovery(self, chat_id: str) -> Optional[Dict[str, Any]]: + """Call Ruuter endpoint to get services for intent detection.""" + endpoint = f"{RAG_SEARCH_RUUTER_PUBLIC}/services/get-services" + + try: + async with httpx.AsyncClient(timeout=SERVICE_DISCOVERY_TIMEOUT) as client: + response = await client.get(endpoint) + response.raise_for_status() + data = response.json() + return data + except httpx.TimeoutException: + logger.error( + f"[{chat_id}] Service discovery timeout after {SERVICE_DISCOVERY_TIMEOUT}s" + ) + return None + except httpx.HTTPStatusError as e: + logger.error( + f"[{chat_id}] Service discovery HTTP error: {e.response.status_code}" + ) + return None + except Exception as e: + logger.error(f"[{chat_id}] Service discovery failed: {e}", exc_info=True) + return None + + async def _detect_service_intent( + self, + user_query: str, + services: List[Dict[str, Any]], + conversation_history: List[Any], + chat_id: str, + ) -> tuple[Optional[Dict[str, Any]], Dict[str, Any]]: + """Use DSPy + LLMManager to detect service intent and extract entities. Returns: - OrchestrationResponse with service result or None to fallback + Tuple of (intent_result, usage_info): + - intent_result: Intent detection result dict (or None on error) + - usage_info: Cost and token usage information """ - logger.debug( - f"[{request.chatId}] Service workflow execute_async called " - f"(not implemented - returning None)" - ) + try: + # Ensure DSPy is configured with LLMManager + if self.llm_manager: + self.llm_manager.ensure_global_config() + else: + logger.error(f"[{chat_id}] LLM Manager not available") + return None, {} - # TODO: Implement service workflow logic here - # For now, return None to trigger fallback to next layer + # Capture history length before LLM call for cost tracking + lm = dspy.settings.lm + history_length_before = ( + len(lm.history) if lm and hasattr(lm, "history") else 0 + ) + + # Create DSPy module + intent_module = IntentDetectionModule() + + # Convert conversation history to dict format + history_dicts = [ + {"authorRole": msg.authorRole, "message": msg.message} + for msg in conversation_history + if hasattr(msg, "authorRole") and hasattr(msg, "message") + ] + + # Call DSPy forward with task-local config + with self.llm_manager.use_task_local(): + intent_result = intent_module.forward( + user_query=user_query, + services=services, + conversation_history=history_dicts, + ) + + # Extract usage information after LLM call + usage_info = get_lm_usage_since(history_length_before) + + return intent_result, usage_info + + except Exception as e: + logger.error(f"[{chat_id}] Intent detection failed: {e}", exc_info=True) + return None, {} + + def _validate_detected_service( + self, + matched_service_id: str, + services: List[Dict[str, Any]], + chat_id: str, + ) -> Optional[Dict[str, Any]]: + """Validate that detected service exists in active services list.""" + for service in services: + service_id = service.get("serviceId", service.get("service_id")) + if service_id == matched_service_id: + return service + + logger.warning( + f"[{chat_id}] Service validation failed: '{matched_service_id}' not found" + ) return None - async def execute_streaming( + async def _process_intent_detection( self, + services: List[Dict[str, Any]], request: OrchestrationRequest, + chat_id: str, context: Dict[str, Any], - ) -> Optional[AsyncIterator[str]]: + costs_dict: Dict[str, Dict[str, Any]], + ) -> None: + """Detect intent, validate service, and populate context. + + This helper method encapsulates the common logic of: + 1. Calling intent detection (LLM) + 2. Tracking costs + 3. Validating matched service + 4. Populating context with service metadata + + Args: + services: List of services to match against + request: Orchestration request + chat_id: Chat ID for logging + context: Context dict to populate with results + costs_dict: Dictionary to track LLM costs """ - Execute service workflow in streaming mode. - - TODO: Implement service workflow streaming: - 1. Execute service call (same as non-streaming) - 2. Get complete service response - 3. Validate with output guardrails (validation-first) - 4. If blocked: yield violation message + END - 5. If allowed: chunk response and stream token-by-token - 6. Simulate streaming for consistent UX with RAG - - Streaming approach (validation-first): - ```python - # Get complete response - service_response = await call_service(...) - - # Validate BEFORE streaming - is_safe = await guardrails.check_output_async(service_response) - if not is_safe: - yield format_sse(chatId, VIOLATION_MESSAGE) - yield format_sse(chatId, "END") - return - - # Stream validated response - for chunk in split_into_tokens(service_response, chunk_size=5): - yield format_sse(chatId, chunk) - await asyncio.sleep(0.01) - yield format_sse(chatId, "END") - ``` + intent_result, intent_usage = await self._detect_service_intent( + user_query=request.message, + services=services, + conversation_history=request.conversationHistory, + chat_id=chat_id, + ) + costs_dict["intent_detection"] = intent_usage + + if intent_result and intent_result.get("matched_service_id"): + service_id = intent_result["matched_service_id"] + logger.info(f"[{chat_id}] Matched: {service_id}") + + validated_service = self._validate_detected_service( + matched_service_id=service_id, + services=services, + chat_id=chat_id, + ) + + if validated_service: + context["service_id"] = service_id + context["confidence"] = intent_result.get("confidence", 0.0) + context["entities"] = intent_result.get("entities", {}) + context["service_data"] = validated_service + + def _extract_service_metadata( + self, context: Dict[str, Any], chat_id: str + ) -> Optional[Dict[str, Any]]: + """Extract service and entity metadata from context.""" + # Check if service_id exists + service_id = context.get("service_id") + if not service_id: + logger.error(f"[{chat_id}] Missing service_id in context") + return None + + # Check if service_data exists + service_data = context.get("service_data") + if not service_data: + logger.error(f"[{chat_id}] Missing service_data in context") + return None + + # Extract entities dict from context (LLM extracted) + entities_dict = context.get("entities", {}) + + # Extract entity schema from service_data (expected order) + entity_schema = service_data.get("entities", []) + if entity_schema is None: + entity_schema = [] + + # Extract service name + service_name = service_data.get("name", service_id) + + # Extract HTTP method (ruuter_type) - defaults to GET if not specified + ruuter_type = service_data.get("ruuter_type", "GET") + + return { + "service_id": service_id, + "service_name": service_name, + "entities_dict": entities_dict, + "entity_schema": entity_schema, + "ruuter_type": ruuter_type, + "service_data": service_data, + } + + def _validate_entities( + self, + extracted_entities: Dict[str, str], + service_schema: List[str], + service_name: str, + chat_id: str, + ) -> Dict[str, Any]: + """ + Validate extracted entities against service schema. Args: - request: Orchestration request with user query - context: Metadata with service_id, intent, entities + extracted_entities: Entity key-value pairs from LLM + service_schema: Expected entity keys from database + service_name: Service name for logging + chat_id: For logging Returns: - AsyncIterator yielding SSE strings or None to fallback + Dict with validation results: + - is_valid: Overall validation status + - missing_entities: List of schema entities not extracted + - extra_entities: List of extracted entities not in schema + - validation_errors: List of error messages """ - logger.debug( - f"[{request.chatId}] Service workflow execute_streaming called " - f"(not implemented - returning None)" + missing_entities = [] + extra_entities = [] + validation_errors = [] + + # Check for missing entities (in schema but not extracted) + for schema_key in service_schema: + if schema_key not in extracted_entities: + missing_entities.append(schema_key) + elif extracted_entities[schema_key] == "": + # Entity extracted but value is empty + validation_errors.append(f"Entity '{schema_key}' has empty value") + + # Check for extra entities (extracted but not in schema) + for entity_key in extracted_entities: + if entity_key not in service_schema: + extra_entities.append(entity_key) + + # Determine overall validity + # We consider it valid even with missing entities (will send empty strings) + # Let the external service validate required parameters + is_valid = True # Always true - we proceed with partial entities + + return { + "is_valid": is_valid, + "missing_entities": missing_entities, + "extra_entities": extra_entities, + "validation_errors": validation_errors, + } + + def _transform_entities_to_array( + self, entities_dict: Dict[str, str], entity_order: List[str] + ) -> List[str]: + """Transform entity dictionary to ordered array based on service schema.""" + if not entity_order: + return [] + + # Transform to ordered array, filling missing with empty strings + return [entities_dict.get(key, "") for key in entity_order] + + def _construct_service_endpoint(self, service_name: str, chat_id: str) -> str: + """Construct the full service endpoint URL for Ruuter.""" + return f"{RUUTER_BASE_URL}/services/active{service_name}" + + def _format_debug_response( + self, + service_name: str, + endpoint_url: str, + http_method: str, + entities_array: List[str], + ) -> str: + """Format debug information for testing (temporary before Step 7 implementation).""" + entities_str = ", ".join(f'"{e}"' for e in entities_array) + return ( + f" Service Validated: {service_name}\n" + f" Endpoint URL: {endpoint_url}\n" + f" HTTP Method: {http_method}\n" + f" Extracted Entities: [{entities_str}]\n\n" ) - # TODO: Implement service streaming logic here - # For now, return None to trigger fallback to next layer - return None + async def _log_request_details( + self, + request: OrchestrationRequest, + context: Dict[str, Any], + mode: str, + costs_dict: Dict[str, Dict[str, Any]], + ) -> None: + """Log request details and perform service discovery. + + Args: + request: The orchestration request + context: Workflow context dictionary + mode: Execution mode ("streaming" or "non-streaming") + costs_dict: Dictionary to accumulate cost tracking information + """ + chat_id = request.chatId + logger.info(f"[{chat_id}] SERVICE WORKFLOW ({mode}): {request.message}") + + # Service Discovery + discovery_result = await self._call_service_discovery(chat_id) + + if discovery_result: + # Extract data from nested response structure + response_data = discovery_result.get("response", {}) + use_semantic = response_data.get("use_semantic_search", False) + service_count = response_data.get("service_count", 0) + + # Handle service_count if it's a string or NaN + if isinstance(service_count, str): + try: + service_count = int(service_count) + except (ValueError, TypeError): + service_count = 0 + + services_from_ruuter = response_data.get("services", []) + + # Use semantic search if count > threshold + if service_count > SERVICE_COUNT_THRESHOLD: + use_semantic = True + + if use_semantic: + # Use semantic search to find relevant services + services = await self._semantic_search_services( + query=request.message, + request=request, + chat_id=chat_id, + top_k=SEMANTIC_SEARCH_TOP_K, + ) + + if not services: + logger.warning(f"[{chat_id}] Semantic search failed") + + if services_from_ruuter: + services = services_from_ruuter + elif service_count <= MAX_SERVICES_FOR_LLM_CONTEXT: + fallback_result = await self._call_service_discovery(chat_id) + if fallback_result: + fallback_data = fallback_result.get("response", {}) + services = fallback_data.get("services", []) + else: + services = [] + else: + logger.error(f"[{chat_id}] Too many services ({service_count})") + services = [] + + if services: + await self._process_intent_detection( + services=services, + request=request, + chat_id=chat_id, + context=context, + costs_dict=costs_dict, + ) + else: + services = response_data.get("services", []) + + if services: + await self._process_intent_detection( + services=services, + request=request, + chat_id=chat_id, + context=context, + costs_dict=costs_dict, + ) + else: + logger.warning(f"[{chat_id}] Service discovery failed") + + async def execute_async( + self, + request: OrchestrationRequest, + context: Dict[str, Any], + ) -> Optional[OrchestrationResponse]: + """Execute service workflow in non-streaming mode.""" + chat_id = request.chatId + + # Create costs tracking dictionary (follows RAG workflow pattern) + costs_dict: Dict[str, Dict[str, Any]] = {} + + # Log comprehensive request details and perform service discovery + await self._log_request_details( + request, context, mode="non-streaming", costs_dict=costs_dict + ) + + # Check if service was detected and validated + if not context.get("service_id"): + logger.info( + f"[{chat_id}] No service detected or validated - " + f"returning None to fallback to next layer" + ) + return None + + # Entity Transformation & Validation + logger.info(f"[{chat_id}] Entity Transformation:") + + # Step 1: Extract service metadata from context + service_metadata = self._extract_service_metadata(context, chat_id) + if not service_metadata: + logger.error( + f"[{chat_id}] - Metadata extraction failed - " + f"returning None to fallback" + ) + return None + + logger.info(f"[{chat_id}] - Service: {service_metadata['service_name']}") + logger.info( + f"[{chat_id}] - Schema entities: {service_metadata['entity_schema']}" + ) + logger.info( + f"[{chat_id}] - Extracted entities: {service_metadata['entities_dict']}" + ) + + # Step 2: Validate entities against schema + validation_result = self._validate_entities( + extracted_entities=service_metadata["entities_dict"], + service_schema=service_metadata["entity_schema"], + service_name=service_metadata["service_name"], + chat_id=chat_id, + ) + + logger.info( + f"[{chat_id}] - Validation status: " + f"{'PASSED ✓' if validation_result['is_valid'] else 'FAILED ✗'}" + ) + + if validation_result["missing_entities"]: + logger.warning( + f"[{chat_id}] - Missing entities (will send empty strings): " + f"{validation_result['missing_entities']}" + ) + + if validation_result["extra_entities"]: + logger.info( + f"[{chat_id}] - Extra entities (ignored): " + f"{validation_result['extra_entities']}" + ) + + if validation_result["validation_errors"]: + for error in validation_result["validation_errors"]: + logger.warning(f"[{chat_id}] - Validation warning: {error}") + + # Step 3: Transform entities dict to ordered array + entities_array = self._transform_entities_to_array( + entities_dict=service_metadata["entities_dict"], + entity_order=service_metadata["entity_schema"], + ) + + context["entities_array"] = entities_array + context["validation_result"] = validation_result + + # Construct service endpoint URL + endpoint_url = self._construct_service_endpoint( + service_name=service_metadata["service_name"], chat_id=chat_id + ) + + context["endpoint_url"] = endpoint_url + context["http_method"] = service_metadata["ruuter_type"] + + logger.info(f"[{chat_id}] Service prepared: {endpoint_url}") + + # TODO: STEP 7 - Call Ruuter service endpoint and return response + # 1. Build payload: {"input": entities_array, "authorId": request.authorId, "chatId": request.chatId} + # 2. Call endpoint using http_method (POST/GET) with SERVICE_CALL_TIMEOUT + # 3. Parse Ruuter response and extract result + # 4. Return OrchestrationResponse with actual service result + # 5. Handle errors (timeout, HTTP errors, malformed JSON) + + # STEP 6: Return debug response (temporary until Step 7 - Ruuter call implemented) + # REMOVE THIS BLOCK AFTER STEP 7 IMPLEMENTATION (START) + debug_content = self._format_debug_response( + service_name=service_metadata["service_name"], + endpoint_url=endpoint_url, + http_method=service_metadata["ruuter_type"], + entities_array=entities_array, + ) + + logger.info(f"[{chat_id}] Returning debug response (Step 7 pending)") + + # Log costs after service workflow completes (follows RAG workflow pattern) + if self.orchestration_service: + self.orchestration_service.log_costs(costs_dict) + + return OrchestrationResponse( + chatId=request.chatId, + llmServiceActive=True, + questionOutOfLLMScope=False, + inputGuardFailed=False, + content=debug_content, + ) + # REMOVE THIS BLOCK AFTER STEP 7 IMPLEMENTATION (END) + + async def execute_streaming( + self, + request: OrchestrationRequest, + context: Dict[str, Any], + ) -> Optional[AsyncIterator[str]]: + """Execute service workflow in streaming mode.""" + chat_id = request.chatId + + # Create costs tracking dictionary (follows RAG workflow pattern) + costs_dict: Dict[str, Dict[str, Any]] = {} + + # Log comprehensive request details and perform service discovery + await self._log_request_details( + request, context, mode="streaming", costs_dict=costs_dict + ) + + # Check if service was detected and validated + if not context.get("service_id"): + logger.info( + f"[{chat_id}] No service detected or validated - " + f"returning None to fallback to next layer" + ) + return None + + # Entity Transformation & Validation + logger.info(f"[{chat_id}] Entity Transformation:") + + # Step 1: Extract service metadata from context + service_metadata = self._extract_service_metadata(context, chat_id) + if not service_metadata: + logger.error( + f"[{chat_id}] - Metadata extraction failed - " + f"returning None to fallback" + ) + return None + + logger.info(f"[{chat_id}] - Service: {service_metadata['service_name']}") + logger.info( + f"[{chat_id}] - Schema entities: {service_metadata['entity_schema']}" + ) + logger.info( + f"[{chat_id}] - Extracted entities: {service_metadata['entities_dict']}" + ) + + # Step 2: Validate entities against schema + validation_result = self._validate_entities( + extracted_entities=service_metadata["entities_dict"], + service_schema=service_metadata["entity_schema"], + service_name=service_metadata["service_name"], + chat_id=chat_id, + ) + + logger.info( + f"[{chat_id}] - Validation status: " + f"{'PASSED ✓' if validation_result['is_valid'] else 'FAILED ✗'}" + ) + + if validation_result["missing_entities"]: + logger.warning( + f"[{chat_id}] - Missing entities (will send empty strings): " + f"{validation_result['missing_entities']}" + ) + + if validation_result["extra_entities"]: + logger.info( + f"[{chat_id}] - Extra entities (ignored): " + f"{validation_result['extra_entities']}" + ) + + if validation_result["validation_errors"]: + for error in validation_result["validation_errors"]: + logger.warning(f"[{chat_id}] - Validation warning: {error}") + + # Step 3: Transform entities dict to ordered array + entities_array = self._transform_entities_to_array( + entities_dict=service_metadata["entities_dict"], + entity_order=service_metadata["entity_schema"], + ) + + context["entities_array"] = entities_array + context["validation_result"] = validation_result + + # Construct service endpoint URL + endpoint_url = self._construct_service_endpoint( + service_name=service_metadata["service_name"], chat_id=chat_id + ) + + context["endpoint_url"] = endpoint_url + context["http_method"] = service_metadata["ruuter_type"] + + logger.info(f"[{chat_id}] Service prepared: {endpoint_url}") + + # TODO: STEP 7 - Call Ruuter service endpoint and stream response + # 1. Build payload: {"input": entities_array, "authorId": request.authorId, "chatId": request.chatId} + # 2. Call endpoint using http_method (POST/GET) with SERVICE_CALL_TIMEOUT + # 3. Parse Ruuter response and extract result + # 4. Format result as SSE and yield chunks + # 5. Handle errors (timeout, HTTP errors, malformed JSON) + + # STEP 6: Return debug response as async iterator (temporary until Step 7) + # REMOVE THIS BLOCK AFTER STEP 7 IMPLEMENTATION (START) + debug_content = self._format_debug_response( + service_name=service_metadata["service_name"], + endpoint_url=endpoint_url, + http_method=service_metadata["ruuter_type"], + entities_array=entities_array, + ) + + logger.info(f"[{chat_id}] Streaming debug response (Step 7 pending)") + + if self.orchestration_service is None: + raise RuntimeError("Orchestration service not initialized for streaming") + + # Store reference for closure (helps type checker) + orchestration_service = self.orchestration_service + + async def debug_stream() -> AsyncIterator[str]: + yield orchestration_service.format_sse(chat_id, debug_content) + yield orchestration_service.format_sse(chat_id, "END") + + # Log costs after streaming completes (follows RAG workflow pattern) + # Must be inside generator because costs are accumulated during streaming + orchestration_service.log_costs(costs_dict) + + return debug_stream() + # REMOVE THIS BLOCK AFTER STEP 7 IMPLEMENTATION (END)