From 078122cc1bd185e93627fd7aacb52571e75dce78 Mon Sep 17 00:00:00 2001 From: erangi-ar <111747955+erangi-ar@users.noreply.github.com> Date: Fri, 20 Feb 2026 10:29:00 +0530 Subject: [PATCH 1/2] Streaming response formatting (#280) * refactor: update SSE connection URL to use environment variable * foramt markdown of the llm response * feat: add markdown support to MessageContent component * title fix * prompt coniguration backend to be testing * custom prompt configuration update and fixed Pyright issues * fixed copilot reviews * pre validation step added when user query is inserted * added more validation cases * fixed review comments * resolved pr comments --------- Co-authored-by: erangi-ar Co-authored-by: nuwangeek Co-authored-by: Charith Nuwan Bimsara <59943919+nuwangeek@users.noreply.github.com> Co-authored-by: Thiru Dinesh Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Thiru Dinesh <56014038+Thirunayan22@users.noreply.github.com> --- GUI/.env.development | 4 +- .../MessageContent/MessageContent.scss | 127 ++++++--- GUI/src/components/MessageContent/index.tsx | 108 ++------ GUI/src/hooks/useStreamingResponse.tsx | 17 +- .../TestProductionLLM/TestProductionLLM.scss | 28 ++ GUI/src/pages/TestProductionLLM/index.tsx | 260 ++++++++++-------- GUI/translations/en/common.json | 15 + GUI/translations/et/common.json | 15 + 8 files changed, 337 insertions(+), 237 deletions(-) diff --git a/GUI/.env.development b/GUI/.env.development index 39f5e47a..ae5b1356 100644 --- a/GUI/.env.development +++ b/GUI/.env.development @@ -2,6 +2,6 @@ REACT_APP_RUUTER_API_URL=http://localhost:8086 REACT_APP_RUUTER_PRIVATE_API_URL=http://localhost:8088 REACT_APP_CUSTOMER_SERVICE_LOGIN=http://localhost:3004/et/dev-auth REACT_APP_SERVICE_ID=conversations,settings,monitoring -REACT_APP_NOTIFICATION_NODE_URL=http://localhost:3005 -REACT_APP_CSP=upgrade-insecure-requests; default-src 'self'; font-src 'self' data:; img-src 'self' data:; script-src 'self' 'unsafe-eval' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; object-src 'none'; connect-src 'self' http://localhost:8086 http://localhost:8088 http://localhost:3004 http://localhost:3005 ws://localhost; +REACT_APP_NOTIFICATION_NODE_URL=http://localhost:4040 +REACT_APP_CSP=upgrade-insecure-requests; default-src 'self'; font-src 'self' data:; img-src 'self' data:; script-src 'self' 'unsafe-eval' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; object-src 'none'; connect-src 'self' http://localhost:8086 http://localhost:8088 http://localhost:3004 http://localhost:4040 ws://localhost; REACT_APP_ENABLE_HIDDEN_FEATURES=TRUE \ No newline at end of file diff --git a/GUI/src/components/MessageContent/MessageContent.scss b/GUI/src/components/MessageContent/MessageContent.scss index 7b4eea5c..aec56409 100644 --- a/GUI/src/components/MessageContent/MessageContent.scss +++ b/GUI/src/components/MessageContent/MessageContent.scss @@ -1,61 +1,112 @@ .message-content-wrapper { width: 100%; + line-height: 1.6; - .message-text { - margin-bottom: 12px; - line-height: 1.6; + // Markdown text styling + p { + margin: 0 0 12px 0; white-space: pre-wrap; word-wrap: break-word; + + &:last-child { + margin-bottom: 0; + } + } + + // Bold text + .markdown-bold, + strong { + font-weight: 600; + } + + // Ordered lists (for references) + .markdown-list, + ol { + margin: 16px 0 0 0; + padding-left: 20px; + list-style-type: decimal; } - .message-references { - margin-top: 16px; - padding-top: 12px; - border-top: 1px solid rgba(0, 0, 0, 0.1); + // List items + .markdown-list-item, + li { + margin-bottom: 6px; + line-height: 1.5; - .references-title { - display: block; - font-weight: 600; - margin-bottom: 8px; - font-size: 14px; + &:last-child { + margin-bottom: 0; } + } - .references-list { - margin: 0; - padding-left: 20px; - list-style-type: decimal; + // Links + a { + color: #0066cc; + text-decoration: none; + word-break: break-all; + transition: color 0.2s ease; - li { - margin-bottom: 6px; - line-height: 1.5; + &:hover { + color: #0052a3; + text-decoration: underline; + } - &:last-child { - margin-bottom: 0; - } - } + &:visited { + color: #551a8b; + } + } - .reference-link { - color: #0066cc; - text-decoration: none; - word-break: break-all; - transition: color 0.2s ease; + // Inline code + code { + background-color: rgba(0, 0, 0, 0.05); + padding: 2px 6px; + border-radius: 3px; + font-family: monospace; + font-size: 0.9em; + } - &:hover { - color: #0052a3; - text-decoration: underline; - } + // Code blocks + pre { + background-color: rgba(0, 0, 0, 0.05); + padding: 12px; + border-radius: 6px; + overflow-x: auto; + margin: 12px 0; - &:visited { - color: #551a8b; - } - } + code { + background-color: transparent; + padding: 0; } } + + // Headings + h1, h2, h3, h4, h5, h6 { + margin: 16px 0 8px 0; + font-weight: 600; + } + + // Blockquotes + blockquote { + border-left: 4px solid rgba(0, 0, 0, 0.2); + padding-left: 12px; + margin: 12px 0; + color: rgba(0, 0, 0, 0.7); + } } // Dark mode support .test-production-llm__message--bot { - .message-references { - border-top-color: rgba(255, 255, 255, 0.1); + .message-content-wrapper { + code { + background-color: rgba(255, 255, 255, 0.1); + } + + pre { + background-color: rgba(255, 255, 255, 0.1); + } + + blockquote { + border-left-color: rgba(255, 255, 255, 0.2); + color: rgba(255, 255, 255, 0.7); + } } } diff --git a/GUI/src/components/MessageContent/index.tsx b/GUI/src/components/MessageContent/index.tsx index 63ff7f2a..69b7ffe3 100644 --- a/GUI/src/components/MessageContent/index.tsx +++ b/GUI/src/components/MessageContent/index.tsx @@ -1,4 +1,6 @@ import { FC } from 'react'; +import ReactMarkdown from 'react-markdown'; +import remarkGfm from 'remark-gfm'; import './MessageContent.scss'; interface MessageContentProps { @@ -6,85 +8,33 @@ interface MessageContentProps { } const MessageContent: FC = ({ content }) => { - // Function to parse and render message content with proper formatting - const renderContent = () => { - // Split by **References:** pattern - const referencesMatch = content.match(/\*\*References:\*\*([\s\S]*)/); - - if (!referencesMatch) { - // No references, return plain content with line breaks - return ( -
- {content.split('\n').map((line, index) => ( - - {line} - {index < content.split('\n').length - 1 &&
} -
- ))} -
- ); - } - - // Split content into main text and references - const mainText = content.substring(0, referencesMatch.index); - const referencesText = referencesMatch[1].trim(); - - // Parse numbered references with URLs - const referenceLines = referencesText - .split('\n') - .filter(line => line.trim()) - .map(line => { - // Match pattern: "1. https://url" or "1. url" - const match = line.match(/^(\d+)\.\s+(https?:\/\/[^\s]+)/); - if (match) { - return { - number: match[1], - url: match[2], - }; - } - return null; - }) - .filter(Boolean); - - return ( -
- {/* Main text */} - {mainText && ( -
- {mainText.split('\n').map((line, index) => ( - - {line} - {index < mainText.split('\n').length - 1 &&
} -
- ))} -
- )} - - {/* References section */} - {referenceLines.length > 0 && ( -
- References: -
    - {referenceLines.map((ref, index) => ( -
  1. - - {ref!.url} - -
  2. - ))} -
-
- )} -
- ); - }; - - return <>{renderContent()}; + return ( +
+ ( + + ), + // Style strong/bold text + strong: ({ node, ...props }) => ( + + ), + // Style ordered lists + ol: ({ node, ...props }) => ( +
    + ), + // Style list items + li: ({ node, ...props }) => ( +
  1. + ), + }} + > + {content} + +
+ ); }; export default MessageContent; diff --git a/GUI/src/hooks/useStreamingResponse.tsx b/GUI/src/hooks/useStreamingResponse.tsx index 211d44f5..8a9d7792 100644 --- a/GUI/src/hooks/useStreamingResponse.tsx +++ b/GUI/src/hooks/useStreamingResponse.tsx @@ -1,6 +1,19 @@ import { useState, useRef, useCallback, useEffect } from 'react'; import axios from 'axios'; +const getNotificationNodeUrl = (): string => { + const value = import.meta.env.REACT_APP_NOTIFICATION_NODE_URL; + if (!value) { + throw new Error( + 'Environment variable REACT_APP_NOTIFICATION_NODE_URL is not defined. ' + + 'Please set it to the base URL of the notification service to enable streaming responses.' + ); + } + return value; +}; +const notificationNodeUrl = getNotificationNodeUrl(); +console.log(notificationNodeUrl); + interface StreamingOptions { authorId: string; conversationHistory: Array<{ authorRole: string; message: string; timestamp: string }>; @@ -50,7 +63,7 @@ export const useStreamingResponse = (channelId: string): UseStreamingResponseRet try { // Step 1: Open SSE connection FIRST - const sseUrl = `https://est-rag-rtc.rootcode.software/notifications-server/sse/stream/${channelId}`; + const sseUrl = `${notificationNodeUrl}/sse/stream/${channelId}`; console.log('[SSE] Connecting to:', sseUrl); const eventSource = new EventSource(sseUrl); @@ -102,7 +115,7 @@ export const useStreamingResponse = (channelId: string): UseStreamingResponseRet await new Promise(resolve => setTimeout(resolve, 500)); // Step 3: POST to trigger streaming - const postUrl = `https://est-rag-rtc.rootcode.software/notifications-server/channels/${channelId}/orchestrate/stream`; + const postUrl = `${notificationNodeUrl}/channels/${channelId}/orchestrate/stream`; console.log('[API] Triggering stream:', postUrl); await axios.post(postUrl, { diff --git a/GUI/src/pages/TestProductionLLM/TestProductionLLM.scss b/GUI/src/pages/TestProductionLLM/TestProductionLLM.scss index 1bd8e0f1..df51e327 100644 --- a/GUI/src/pages/TestProductionLLM/TestProductionLLM.scss +++ b/GUI/src/pages/TestProductionLLM/TestProductionLLM.scss @@ -77,6 +77,34 @@ border-radius: 18px 18px 18px 4px; } } + + &--error { + .test-production-llm__message-content { + border-color: #f44336; + background-color: #ffebee; + } + } + } + + &__message-error { + display: flex; + align-items: flex-start; + gap: 0.5rem; + margin-top: 0.5rem; + padding-top: 0.5rem; + border-top: 1px solid #ffcdd2; + font-size: 0.85rem; + color: #c62828; + } + + &__message-error-icon { + flex-shrink: 0; + font-size: 1rem; + } + + &__message-error-text { + flex: 1; + line-height: 1.3; } &__message-content { diff --git a/GUI/src/pages/TestProductionLLM/index.tsx b/GUI/src/pages/TestProductionLLM/index.tsx index a9c14935..d978ba16 100644 --- a/GUI/src/pages/TestProductionLLM/index.tsx +++ b/GUI/src/pages/TestProductionLLM/index.tsx @@ -1,153 +1,169 @@ -import { FC, useState, useRef, useEffect } from 'react'; +import { FC, useState, useRef, useEffect, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { Button, FormTextarea, Section } from 'components'; -import { productionInference, ProductionInferenceRequest } from 'services/inference'; +import { Button, FormTextarea } from 'components'; import { useToast } from 'hooks/useToast'; +import { useStreamingResponse } from 'hooks/useStreamingResponse'; import './TestProductionLLM.scss'; - +import MessageContent from 'components/MessageContent'; interface Message { id: string; content: string; isUser: boolean; timestamp: string; + hasError?: boolean; + errorMessage?: string; } const TestProductionLLM: FC = () => { const { t } = useTranslation(); const toast = useToast(); - const [message, setMessage] = useState(''); + const [inputMessage, setInputMessage] = useState(''); const [messages, setMessages] = useState([]); const [isLoading, setIsLoading] = useState(false); const messagesEndRef = useRef(null); - const scrollToBottom = () => { - messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' }); - }; + // Generate a unique channel ID for this session + const channelId = useMemo(() => `channel-${Math.random().toString(36).substring(2, 15)}`, []); + const { startStreaming, stopStreaming, isStreaming } = useStreamingResponse(channelId); + // Auto-scroll to bottom useEffect(() => { - scrollToBottom(); + messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' }); }, [messages]); + // Cleanup incomplete messages on unmount if streaming is active + useEffect(() => { + return () => { + if (isStreaming) { + stopStreaming(); + // Remove incomplete bot messages on unmount + setMessages(prev => prev.filter(msg => msg.isUser || !msg.content.trim() === false)); + } + }; + }, [isStreaming, stopStreaming]); + const handleSendMessage = async () => { - if (!message.trim()) { + if (!inputMessage.trim()) { toast.open({ type: 'warning', - title: t('warningTitle'), - message: t('emptyMessageWarning'), + title: t('testProductionLLM.warningTitle'), + message: t('testProductionLLM.emptyMessageWarning'), }); return; } + const userMessageText = inputMessage.trim(); + + // Add user message const userMessage: Message = { id: `user-${Date.now()}`, - content: message.trim(), + content: userMessageText, isUser: true, timestamp: new Date().toISOString(), }; - // Add user message to chat setMessages(prev => [...prev, userMessage]); - setMessage(''); + setInputMessage(''); setIsLoading(true); - try { - // Hardcoded values as requested - const request: ProductionInferenceRequest = { - chatId: 'test-chat-001', - message: userMessage.content, - authorId: 'test-author-001', - conversationHistory: messages.map(msg => ({ - authorRole: msg.isUser ? 'user' : 'bot', - message: msg.content, - timestamp: msg.timestamp, - })), - url: 'https://test-url.example.com', - }; - - let response; - let attemptCount = 0; - const maxAttempts = 2; - - // Retry logic - while (attemptCount < maxAttempts) { - try { - attemptCount++; - console.log(`Production Inference Attempt ${attemptCount}/${maxAttempts}`); - response = await productionInference(request); - - // If we get a successful response, break out of retry loop - if (!response.status || response.status < 400) { - break; - } - - // If first attempt failed with error status, retry once more - if (attemptCount < maxAttempts && response.status >= 400) { - console.log('Retrying due to error status...'); - continue; - } - } catch (err) { - // If first attempt threw an error, retry once more - if (attemptCount < maxAttempts) { - console.log('Retrying due to exception...'); - continue; - } - throw err; // Re-throw on final attempt - } - } + // Create bot message ID + const botMessageId = `bot-${Date.now()}`; - console.log('Production Inference Response:', response); + // Prepare conversation history (exclude the current user message) + const conversationHistory = messages.map(msg => ({ + authorRole: msg.isUser ? 'user' : 'bot', + message: msg.content, + timestamp: msg.timestamp, + })); - // Create bot response message - let botContent = ''; - let botMessageType: 'success' | 'error' = 'success'; + const streamingOptions = { + authorId: 'test-user-456', + conversationHistory, + url: 'opensearch-dashboard-test', + }; - if (response.status && response.status >= 400) { - // Error response - botContent = response.content || 'An error occurred while processing your request.'; - botMessageType = 'error'; - } else { - // Success response - botContent = response?.response?.content || 'Response received successfully.'; + // Callbacks for streaming + const onToken = (token: string) => { + console.log('[Component] Received token:', token); + + setMessages(prev => { + // Find the bot message + const botMsgIndex = prev.findIndex(msg => msg.id === botMessageId); - if (response.questionOutOfLlmScope) { - botContent += ' (Note: This question appears to be outside the LLM scope)'; + if (botMsgIndex === -1) { + // First token - add the bot message + console.log('[Component] Adding bot message with first token'); + return [ + ...prev, + { + id: botMessageId, + content: token, + isUser: false, + timestamp: new Date().toISOString(), + } + ]; + } else { + // Append token to existing message + console.log('[Component] Appending token to existing message'); + const updated = [...prev]; + updated[botMsgIndex] = { + ...updated[botMsgIndex], + content: updated[botMsgIndex].content + token, + }; + return updated; } - } - - const botMessage: Message = { - id: `bot-${Date.now()}`, - content: botContent, - isUser: false, - timestamp: new Date().toISOString(), - }; - - setMessages(prev => [...prev, botMessage]); + }); + }; - // Show toast notification - // toast.open({ - // type: botMessageType, - // title: t('errorOccurred'), - // message: t('errorMessage'), - // }); + const onComplete = () => { + console.log('[Component] Stream completed'); + // Always reset loading state on completion + setIsLoading(false); + }; - } catch (error) { - console.error('Error sending message:', error); + const onError = (error: string) => { + console.error('[Component] Stream error:', error); + // Always reset loading state on error + setIsLoading(false); + + // Handle incomplete bot message + setMessages(prev => { + const botMsgIndex = prev.findIndex(msg => msg.id === botMessageId); + + if (botMsgIndex !== -1) { + const botMessage = prev[botMsgIndex]; + + // If the bot message has content, mark it as errored + if (botMessage.content.trim()) { + const updated = [...prev]; + updated[botMsgIndex] = { + ...botMessage, + hasError: true, + errorMessage: error, + }; + return updated; + } else { + // If no content, remove the empty bot message + return prev.filter(msg => msg.id !== botMessageId); + } + } + + return prev; + }); - const errorMessage: Message = { - id: `error-${Date.now()}`, - content: 'Failed to send message. Please check your connection and try again.', - isUser: false, - timestamp: new Date().toISOString(), - }; - - setMessages(prev => [...prev, errorMessage]); - toast.open({ type: 'error', - title: 'Connection Error', - message: 'Unable to connect to the production LLM service.', + title: t('testProductionLLM.streamingErrorTitle'), + message: error, }); - } finally { + }; + + // Start streaming + try { + await startStreaming(userMessageText, streamingOptions, onToken, onComplete, onError); + } catch (error) { + console.error('[Component] Failed to start streaming:', error); + // Reset loading state if streaming fails to start setIsLoading(false); } }; @@ -161,10 +177,11 @@ const TestProductionLLM: FC = () => { const clearChat = () => { setMessages([]); + stopStreaming(); toast.open({ type: 'info', - title: 'Chat Cleared', - message: 'All messages have been cleared.', + title: t('testProductionLLM.chatClearedTitle'), + message: t('testProductionLLM.chatClearedMessage'), }); }; @@ -172,9 +189,9 @@ const TestProductionLLM: FC = () => {
-

{t('Test Production LLM')}

+

{t('testProductionLLM.title')}

@@ -182,8 +199,8 @@ const TestProductionLLM: FC = () => {
{messages.length === 0 && (
-

Welcome to Production LLM Testing

-

Start a conversation by typing a message below.

+

{t('testProductionLLM.welcomeTitle')}

+

{t('testProductionLLM.welcomeSubtitle')}

)} @@ -192,10 +209,21 @@ const TestProductionLLM: FC = () => { key={msg.id} className={`test-production-llm__message ${ msg.isUser ? 'test-production-llm__message--user' : 'test-production-llm__message--bot' + } ${ + msg.hasError ? 'test-production-llm__message--error' : '' }`} >
- {msg.content} + + {msg.hasError && ( +
+ ⚠️ + + {t('testProductionLLM.incompleteMessageError', { defaultValue: 'This message is incomplete due to an error' })} + {msg.errorMessage && `: ${msg.errorMessage}`} + +
+ )}
{new Date(msg.timestamp).toLocaleTimeString()} @@ -220,22 +248,22 @@ const TestProductionLLM: FC = () => {
setMessage(e.target.value)} + value={inputMessage} + onChange={(e) => setInputMessage(e.target.value)} onKeyDown={handleKeyPress} - placeholder="Type your message here... (Press Enter to send, Shift+Enter for new line)" + placeholder={t('testProductionLLM.messagePlaceholder')??""} hideLabel maxRows={4} - disabled={isLoading} + disabled={isLoading || isStreaming} />
diff --git a/GUI/translations/en/common.json b/GUI/translations/en/common.json index 8c2cac8a..a71a2f3e 100644 --- a/GUI/translations/en/common.json +++ b/GUI/translations/en/common.json @@ -414,6 +414,21 @@ "azure": "Azure OpenAI" } }, + "testProductionLLM": { + "title": "Test Production LLM", + "clearChat": "Clear Chat", + "welcomeTitle": "Welcome to Production LLM Testing", + "welcomeSubtitle": "Start a conversation by typing a message below.", + "messageLabel": "Message", + "messagePlaceholder": "Type your message here... (Press Enter to send, Shift+Enter for new line)", + "sendButton": "Send", + "sendingButton": "Sending...", + "warningTitle": "Warning", + "emptyMessageWarning": "Please enter a message", + "streamingErrorTitle": "Streaming Error", + "chatClearedTitle": "Chat Cleared", + "chatClearedMessage": "All messages have been cleared." + }, "promptConfigurations": { "title": "Prompt Configurations", "subtitle": "Configure and manage your prompt templates", diff --git a/GUI/translations/et/common.json b/GUI/translations/et/common.json index 1c093b6f..b1030db5 100644 --- a/GUI/translations/et/common.json +++ b/GUI/translations/et/common.json @@ -415,6 +415,21 @@ "azure": "Azure OpenAI" } }, + "testProductionLLM": { + "title": "Testi Tootmise LLM", + "clearChat": "Tühjenda Vestlus", + "welcomeTitle": "Tere tulemast Tootmise LLM Testimisse", + "welcomeSubtitle": "Alusta vestlust, kirjutades allpool sõnumi.", + "messageLabel": "Sõnum", + "messagePlaceholder": "Kirjuta oma sõnum siia... (Vajuta Enter saatmiseks, Shift+Enter uue rea jaoks)", + "sendButton": "Saada", + "sendingButton": "Saatmine...", + "warningTitle": "Hoiatus", + "emptyMessageWarning": "Palun sisesta sõnum", + "streamingErrorTitle": "Voogedastuse Viga", + "chatClearedTitle": "Vestlus Tühjendatud", + "chatClearedMessage": "Kõik sõnumid on tühjendatud." + }, "promptConfigurations": { "title": "Viiba Seaded", "subtitle": "Seadista ja halda oma viiba malle", From 05f0f94ff6de5ce06d0abf3ffc3594e487bcbe83 Mon Sep 17 00:00:00 2001 From: Charith Nuwan Bimsara <59943919+nuwangeek@users.noreply.github.com> Date: Fri, 20 Feb 2026 13:29:32 +0530 Subject: [PATCH 2/2] Implement multi-layer Tool classification agent workflow routing skeleton with BaseWorkflow abstract class (#318) * prompt coniguration backend to be testing * custom prompt configuration update and fixed Pyright issues * fixed copilot reviews * pre validation step added when user query is inserted * added more validation cases * fixed review comments * implement tool classification orchestration agent skeleton * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fixed copilot suggested changes * fixed issue * added skills * fixed issue --------- Co-authored-by: Thiru Dinesh <56014038+Thirunayan22@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Thiru Dinesh --- .github/copilot-instructions.md | 304 ++++++ .github/skills/code-review/SKILL.md | 4 + .gitignore | 1 + docs/TOOL_CLASSIFIER_SKELETON_USAGE.md | 542 ++++++++++ pyproject.toml | 2 +- src/llm_orchestration_service.py | 982 ++++++++++-------- src/llm_orchestration_service_api.py | 10 +- src/llm_orchestrator_config/feature_flags.py | 82 ++ src/tool_classifier/__init__.py | 20 + src/tool_classifier/base_workflow.py | 118 +++ src/tool_classifier/classifier.py | 338 ++++++ src/tool_classifier/enums.py | 39 + src/tool_classifier/models.py | 81 ++ src/tool_classifier/workflows/__init__.py | 13 + .../workflows/context_workflow.py | 86 ++ src/tool_classifier/workflows/ood_workflow.py | 131 +++ src/tool_classifier/workflows/rag_workflow.py | 172 +++ .../workflows/service_workflow.py | 137 +++ 18 files changed, 2625 insertions(+), 437 deletions(-) create mode 100644 .github/copilot-instructions.md create mode 100644 .github/skills/code-review/SKILL.md create mode 100644 docs/TOOL_CLASSIFIER_SKELETON_USAGE.md create mode 100644 src/llm_orchestrator_config/feature_flags.py create mode 100644 src/tool_classifier/__init__.py create mode 100644 src/tool_classifier/base_workflow.py create mode 100644 src/tool_classifier/classifier.py create mode 100644 src/tool_classifier/enums.py create mode 100644 src/tool_classifier/models.py create mode 100644 src/tool_classifier/workflows/__init__.py create mode 100644 src/tool_classifier/workflows/context_workflow.py create mode 100644 src/tool_classifier/workflows/ood_workflow.py create mode 100644 src/tool_classifier/workflows/rag_workflow.py create mode 100644 src/tool_classifier/workflows/service_workflow.py diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 00000000..f71218be --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,304 @@ +# BYK-RAG Module - Copilot Instructions + +## Project Overview + +BYK-RAG is a Retrieval-Augmented Generation module for Estonian government digital services (Bürokratt ecosystem). It provides secure, multilingual AI-powered responses by integrating multiple LLM providers, contextual retrieval, and guardrails. + +## Build, Test, and Lint Commands + +### Environment Setup +```bash +# Install Python 3.12.10 and create virtual environment +uv python install 3.12.10 +uv sync --frozen + +# Install pre-commit hooks +uv run pre-commit install +``` + +### Running Services +```bash +# Always use uv run for Python scripts (whether venv is activated or not) +uv run python + +# Start all services with Docker Compose +docker compose up + +# Run FastAPI orchestration service locally +uv run uvicorn src.llm_orchestration_service_api:app --reload +``` + +### Testing +```bash +# Run all tests +uv run pytest + +# Run specific test file +uv run pytest tests/test_query_validator.py -v + +# Run integration tests (requires Docker and secrets) +uv run pytest tests/integration_tests/ -v --tb=short --log-cli-level=INFO + +# Run deepeval tests +uv run pytest tests/deepeval_tests/standard_tests.py -v --tb=short +``` + +### Linting and Formatting +```bash +# Check code formatting (does NOT modify files) +uv run ruff format --check + +# Apply code formatting (SAFE - layout only, no logic changes) +uv run ruff format + +# Check linting issues (manual fixes required) +uv run ruff check . + +# Get explanation for specific lint rule +uv run ruff rule # e.g., ANN204 + +# NEVER use ruff check --fix (can alter logic/control flow) +``` + +### Type Checking +```bash +# Run Pyright type checker (runs on src/ only, not tests/) +uv run pyright +``` + +### Pre-commit Hooks +```bash +# Run all pre-commit hooks manually +uv run pre-commit run --all-files +``` + +## Architecture + +### Core Components + +1. **LLM Orchestration Service** (`src/llm_orchestration_service.py`) + - Central business logic for RAG orchestration + - Coordinates prompt refinement, retrieval, generation, and guardrails + - Integrates with Langfuse for observability + +2. **FastAPI Application** (`src/llm_orchestration_service_api.py`) + - HTTP API layer exposing `/orchestrate` endpoint + - Handles streaming responses and rate limiting + - Request/response validation via Pydantic models + +3. **Contextual Retrieval** (`src/contextual_retrieval/`) + - Implements Anthropic's Contextual Retrieval methodology + - Hybrid search: Vector (semantic) + BM25 (lexical) with RRF fusion + - Multi-query expansion (6 refined queries per user query) + - Qdrant vector database integration + +4. **Prompt Refinement** (`src/prompt_refine_manager/`) + - DSPy-based query expansion + - Generates 5 refined variations + original query + +5. **Response Generation** (`src/response_generator/`) + - DSPy-based response synthesis + - Supports streaming via SSE (Server-Sent Events) + - Uses top-K retrieved chunks (default: 10) + +6. **Guardrails** (`src/guardrails/`) + - NeMo Guardrails integration with DSPy + - Input guardrails (pre-refinement) and output guardrails (post-generation) + - Blocks out-of-scope queries and harmful content + +7. **LLM Manager** (`src/llm_orchestrator_config/llm_manager.py`) + - Multi-provider support: AWS Bedrock, Azure OpenAI, Google Cloud, OpenAI, Anthropic + - HashiCorp Vault integration for secret management + - RSA-2048 encrypted credentials storage + +8. **Vector Indexer** (`src/vector_indexer/`) + - Qdrant collection management + - Embedding generation and indexing + - BM25 index creation + +### Supporting Services (Docker Compose) + +- **Ruuter** (Public/Private): API gateway and routing +- **DataMapper**: Data transformation layer +- **Resql**: PostgreSQL query builder +- **CronManager**: Scheduled jobs (knowledge base sync) +- **Qdrant**: Vector database +- **MinIO**: S3-compatible object storage +- **HashiCorp Vault**: Secret management +- **Grafana Loki**: Log aggregation +- **Langfuse**: LLM observability dashboard + +### Key Data Flow + +``` +User Query + ↓ +Input Guardrails (NeMo Rails) + ↓ +Prompt Refinement (DSPy) → 6 queries + ↓ +Parallel Hybrid Search (each query) + ├─→ Semantic Search (Qdrant, top-40 per query, threshold ≥0.4) + └─→ BM25 Search (top-40 per query) + ↓ +RRF Fusion → Top-K chunks (10 default) + ↓ +Response Generation (DSPy) + ↓ +Output Guardrails (NeMo Rails) + ↓ +Response to User (JSON or SSE stream) +``` + +## Key Conventions + +### Dependency Management + +- **ALWAYS use `uv add `** to add dependencies (never `pip install`) +- **ALWAYS commit both `pyproject.toml` AND `uv.lock`** together +- Use bounded version ranges: `uv add "package>=x.y,` for explanations +- Autofixes can alter control flow/logic unintentionally + +### Formatting (Ruff Formatter) + +- Double quotes for strings +- Spaces for indentation (4 spaces) +- Respects magic trailing commas +- Auto-detects line endings (LF/CRLF) +- Does NOT reformat docstring code blocks +- `uv run ruff format` is SAFE (layout only, no logic changes) + +### DSPy Usage + +- Used for prompt refinement (multi-query expansion) and response generation +- Custom LLM adapters integrate DSPy with NeMo Guardrails +- Optimization modules under `src/optimization/` for tuning prompts/metrics +- Models loaded via `optimized_module_loader.py` for compiled DSPy modules + +### HashiCorp Vault Integration + +- Secrets stored at `secret/users///` +- Each connection has `provider`, `environment`, and provider-specific keys +- RSA-2048 encryption layer BEFORE Vault storage +- GUI encrypts with public key; CronManager decrypts with private key +- Vault unavailable = graceful degradation (fail securely) + +### Logging + +- **loguru** for application logging +- Grafana Loki integration for centralized logs +- Use `logger.info()`, `logger.warning()`, `logger.error()` (NOT `print()`) +- Loki logger available at `grafana-configs/loki_logger.py` + +### Streaming Responses + +- Implemented via Server-Sent Events (SSE) in FastAPI +- `StreamConfig` and `stream_manager` coordinate streaming state +- `stream_response_native()` in response_generator yields tokens +- Timeout handling via `stream_timeout` utility +- Environment-gated: check `STREAMING_ALLOWED_ENVS` + +### Configuration Loading + +- `PromptConfigurationLoader` fetches prompt configs from Ruuter endpoint +- Cache TTL: `PROMPT_CONFIG_CACHE_TTL` +- Custom prompts per user/organization (stored in Vault/database) +- Fallback to defaults if Ruuter unavailable + +### Error Handling + +- `generate_error_id()` creates unique error IDs for tracking +- `log_error_with_context()` for structured error logging +- Localized error messages via `get_localized_message()` (multilingual support) +- Predefined message constants in `llm_orchestrator_constants.py` + +### Testing Conventions + +- Test files under `tests/` (unit, integration, deepeval) +- Integration tests use `testcontainers` for Docker orchestration +- Secrets required for integration tests (Azure OpenAI keys, etc.) +- Mock data in `tests/mocks/` and `tests/data/` + +### CI/CD Checks + +1. **uv-env-check**: Lockfile vs. pyproject.toml consistency +2. **pyright-type-check**: Type checking on src/ (strict mode) +3. **ruff-format-check**: Code formatting compliance +4. **ruff-lint-check**: Linting standards +5. **pytest-integration-check**: Full integration tests (requires secrets) +6. **deepeval-tests**: LLM evaluation metrics +7. **gitleaks-check**: Secret detection (pre-commit + CI) + +### Pre-commit Hooks + +Configured in `.pre-commit-config.yaml`: +- **gitleaks**: Secret scanning +- **uv-lock**: Ensures lockfile consistency + +### Constants and Thresholds + +Key retrieval constants (`src/vector_indexer/constants.py` and contextual retrieval): +- **Semantic search top-K**: 40 per query +- **Semantic threshold**: 0.4 (cosine similarity ≥0.4 = 50-60% alignment) +- **BM25 top-K**: 40 per query +- **Response generation top-K**: 10 chunks (after RRF fusion) +- **Query refinement count**: 5 variations + original = 6 total +- **Search timeout**: 2 seconds per query + +### Docker and Services + +- Use `docker compose` (not `docker-compose`) +- Services communicate via `bykstack` network +- Shared volumes: `shared-volume`, `cron_data` +- Vault agent containers per service (llm, gui, cron) +- Resource limits: CPU and memory constraints defined in docker-compose.yml + +## Important Notes + +- **Python version pinned to 3.12.10** (see `pyproject.toml` and `.python-version`) +- **Line length: 88** (Black-compatible, enforced by Ruff) +- **No print() statements** in production code (use loguru logger) +- **Pydantic for runtime validation** at API boundaries (FastAPI endpoints) +- **Langfuse tracing** for observability (public/secret keys from Vault) +- **Rate limiting** via `RateLimiter` utility (token and request budgets) +- **Cost tracking** via `calculate_total_costs()` and budget tracker +- **Language detection** for multilingual support (Estonian primary) diff --git a/.github/skills/code-review/SKILL.md b/.github/skills/code-review/SKILL.md new file mode 100644 index 00000000..b4e54798 --- /dev/null +++ b/.github/skills/code-review/SKILL.md @@ -0,0 +1,4 @@ +--- +name: code-review +description: Make sure all Python coding standards in the pyproject.toml file are followed, and that the code is clean, well-structured, maintainable, and efficient. Provide constructive feedback and suggestions for improvement. +--- diff --git a/.gitignore b/.gitignore index d0dc8cb8..77ec7863 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ datasets logs/ data_sets vault/agent-out +.vscode/ # RSA Private Keys - DO NOT COMMIT vault/keys/rsa_private_key.pem diff --git a/docs/TOOL_CLASSIFIER_SKELETON_USAGE.md b/docs/TOOL_CLASSIFIER_SKELETON_USAGE.md new file mode 100644 index 00000000..9dc87c88 --- /dev/null +++ b/docs/TOOL_CLASSIFIER_SKELETON_USAGE.md @@ -0,0 +1,542 @@ +# Tool Classifier Skeleton - Usage Guide + +**Version**: 1.0 +**Date**: February 17, 2026 +**Status**: Skeleton Implementation + +--- + +## Overview + +This skeleton implements the **framework** for a multi-workflow routing system based on the [TOOL_CLASSIFIER_EXTENSION_SPEC.md](./TOOL_CLASSIFIER_EXTENSION_SPEC.md) specification. + +### Current Status + + **Implemented (Skeleton)**: +- Abstract base classes and interfaces +- Workflow executor skeletons (Service, Context, RAG, OOD) +- Tool classifier with classification and routing logic +- Feature flags for safe deployment +- Integration into LLMOrchestrationService + + **Not Implemented (Separate Tasks)**: +- Service discovery logic (Layer 1) +- Context analysis logic (Layer 2) +- Actual LLM calls in workflows +- Output guardrails integration for new workflows +- Database schema changes + +### Current Behavior + +When `TOOL_CLASSIFIER_ENABLED=false` (default): +- System works exactly as before (RAG-only pipeline) +- No changes to existing functionality + +When `TOOL_CLASSIFIER_ENABLED=true`: +- Classifier routes queries (currently always to RAG) +- Service and Context workflows return `None` (fallback to RAG) +- RAG workflow wraps existing pipeline +- All queries ultimately handled by RAG + +--- + +## Architecture + +### Layer-Wise Workflow Routing + +``` +User Query + ↓ +Input Guardrails + ↓ +Tool Classifier + ↓ +┌────────────────┐ +│ Classification │ +└────────┬───────┘ + ↓ + ┌─────┴──────┐ + │ Routing │ + └─────┬──────┘ + ↓ + ╔═══════════════════════════════════╗ + ║ Layer 1: Service Workflow ║ → (returns None - not implemented) + ╚═══════════════════════════════════╝ + ↓ (fallback) + ╔═══════════════════════════════════╗ + ║ Layer 2: Context Workflow ║ → (returns None - not implemented) + ╚═══════════════════════════════════╝ + ↓ (fallback) + ╔═══════════════════════════════════╗ + ║ Layer 3: RAG Workflow ║ → Handles query (existing pipeline) + ╚═══════════════════════════════════╝ + ↓ + Response to User +``` + +### Component Structure + +``` +src/tool_classifier/ +├── __init__.py # Module exports +├── enums.py # WorkflowType enum +├── models.py # ClassificationResult models +├── base_workflow.py # Abstract BaseWorkflow class +├── classifier.py # Main ToolClassifier +└── workflows/ + ├── __init__.py + ├── service_workflow.py # Layer 1 (skeleton) + ├── context_workflow.py # Layer 2 (skeleton) + ├── rag_workflow.py # Layer 3 (complete) + └── ood_workflow.py # Layer 4 (skeleton) +``` + +### Abstract Base Class Pattern + +The system uses **BaseWorkflow** as an abstract base class to ensure all workflows follow the same contract. + +#### How It Works + +1. **BaseWorkflow defines the contract**: + - Every workflow MUST implement two methods: `execute_async()` and `execute_streaming()` + - Both methods return `Optional[...]` to support the fallback pattern (return `None` → next layer) + - Python's `@abstractmethod` decorator enforces this at instantiation time + +2. **All workflows inherit from BaseWorkflow**: + - ServiceWorkflowExecutor extends BaseWorkflow → implements both methods + - ContextWorkflowExecutor extends BaseWorkflow → implements both methods + - RAGWorkflowExecutor extends BaseWorkflow → implements both methods + - OODWorkflowExecutor extends BaseWorkflow → implements both methods + +3. **Classifier treats all workflows uniformly**: + - The `ToolClassifier.route_to_workflow()` method doesn't need to know which specific workflow it's calling + - It just calls `workflow.execute_async()` or `workflow.execute_streaming()` + - This is **polymorphism** - same interface, different behavior + +4. **Benefits**: + - **Consistency**: All workflows have the same interface + - **Enforcement**: Can't create a workflow without implementing required methods + - **Flexibility**: Easy to add new workflows - just extend BaseWorkflow + - **Testability**: Each workflow can be tested independently + - **Fallback Pattern**: `Optional` return type enables layer chaining + +#### Example Flow + +``` +ToolClassifier needs to execute a workflow + ↓ +Gets workflow object (could be Service, Context, RAG, or OOD) + ↓ +Calls workflow.execute_async(request, context) + ↓ +BaseWorkflow contract guarantees this method exists + ↓ +Each workflow implements its own logic + ↓ +Returns OrchestrationResponse or None (fallback to next layer) +``` + +The abstract class is like a **blueprint** that says: "Any workflow in this system MUST be able to do these two things: execute normally and execute with streaming. I don't care *how* you do it, but you must provide these capabilities." + +--- + +## Feature Flags + +### Environment Variables + +```bash +# Master switch (default: false for safe deployment) +TOOL_CLASSIFIER_ENABLED=false + +# Individual workflow toggles (only apply when classifier enabled) +SERVICE_WORKFLOW_ENABLED=true +CONTEXT_WORKFLOW_ENABLED=true +``` + +### Configuration Class + +```python +from src.llm_orchestrator_config.feature_flags import FeatureFlags + +# Check if classifier is enabled +if FeatureFlags.TOOL_CLASSIFIER_ENABLED: + # Use tool classifier + pass + +# Check specific workflow +if FeatureFlags.is_workflow_enabled("service"): + # Service workflow logic + pass + +# Log current configuration +FeatureFlags.log_configuration() +``` + +--- + +## How It Works + +### 1. Non-Streaming Endpoint (`/orchestrate`) + +#### Current Flow (TOOL_CLASSIFIER_ENABLED=false) + +```python +POST /orchestrate + ↓ +LLMOrchestrationService.process_orchestration_request() + ↓ +Initialize components (LLM, guardrails, retriever, generator) + ↓ +Execute RAG pipeline + ↓ +Return OrchestrationResponse +``` + +#### With Classifier (TOOL_CLASSIFIER_ENABLED=true) + +```python +POST /orchestrate + ↓ +LLMOrchestrationService.process_orchestration_request() + ↓ +Initialize components + ↓ +Tool Classifier Integration: + 1. Initialize ToolClassifier (if first time) + 2. Classify query → ClassificationResult + - Currently always returns: WorkflowType.RAG + 3. Route to workflow: + - ServiceWorkflow.execute_async() → returns None + - ContextWorkflow.execute_async() → returns None + - RAGWorkflow.execute_async() → returns response + ↓ +Return OrchestrationResponse +``` + +### 2. Streaming Endpoint (`/orchestrate/stream`) + +#### Current Flow (TOOL_CLASSIFIER_ENABLED=false) + +```python +POST /orchestrate/stream + ↓ +LLMOrchestrationService.stream_orchestration_response() + ↓ +Initialize components + ↓ +Check input guardrails + ↓ +Refine prompt → Retrieve chunks → Stream through NeMo + ↓ +Yield SSE strings +``` + +#### With Classifier (TOOL_CLASSIFIER_ENABLED=true) + +```python +POST /orchestrate/stream + ↓ +LLMOrchestrationService.stream_orchestration_response() + ↓ +Initialize components + ↓ +Check input guardrails + ↓ +Tool Classifier Integration: + 1. Initialize ToolClassifier (if first time) + 2. Classify query → ClassificationResult + 3. Route to streaming workflow: + - ServiceWorkflow.execute_streaming() → returns None + - ContextWorkflow.execute_streaming() → returns None + - RAGWorkflow.execute_streaming() → yields SSE + ↓ +Yield SSE strings +``` + +### 3. Test Endpoint (`/orchestrate/test`) + +Works identically to `/orchestrate`: +- Converts `TestOrchestrationRequest` → `OrchestrationRequest` +- Routes through classifier (if enabled) +- Converts response back to `TestOrchestrationResponse` + +--- + +## Code Examples + +### Using the Classification System + +```python +from src.tool_classifier import ToolClassifier, WorkflowType, ClassificationResult + +# Initialize classifier +classifier = ToolClassifier( + llm_manager=llm_manager, + orchestration_service=service, +) + +# Classify a query +classification = await classifier.classify( + query="Hello, how are you?", + conversation_history=[], + language="en", +) + +# Check result +print(classification.workflow) # WorkflowType.RAG (in skeleton) +print(classification.confidence) # 1.0 +print(classification.reasoning) # "Default to RAG workflow..." + +# Route to workflow +response = await classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, +) +``` + +### Implementing a Workflow (Example) + +```python +from src.tool_classifier.base_workflow import BaseWorkflow +from models.request_models import OrchestrationRequest, OrchestrationResponse + +class MyCustomWorkflow(BaseWorkflow): + """Custom workflow implementation.""" + + async def execute_async( + self, + request: OrchestrationRequest, + context: Dict[str, Any], + ) -> Optional[OrchestrationResponse]: + """Handle query in non-streaming mode.""" + + # Check if this workflow can handle the query + can_handle = await self._check_if_applicable(request.message) + + if not can_handle: + # Return None to trigger fallback to next layer + return None + + # Execute workflow logic + result = await self._process_query(request.message) + + # Validate with output guardrails (TODO) + # is_safe = await guardrails.check_output_async(result) + # if not is_safe: + # return None or violation_response + + # Return response + return OrchestrationResponse( + chatId=request.chatId, + llmServiceActive=True, + questionOutOfLLMScope=False, + inputGuardFailed=False, + content=result, + ) + + async def execute_streaming( + self, + request: OrchestrationRequest, + context: Dict[str, Any], + ) -> Optional[AsyncIterator[str]]: + """Handle query in streaming mode.""" + + # Check if applicable + can_handle = await self._check_if_applicable(request.message) + + if not can_handle: + return None # Fallback + + # Get complete result + result = await self._process_query(request.message) + + # Validate with guardrails (TODO) + # is_safe = await guardrails.check_output_async(result) + # if not is_safe: + # yield format_sse(chatId, VIOLATION_MESSAGE) + # yield format_sse(chatId, "END") + # return + + # Stream result token-by-token + async def stream_result(): + for chunk in self._split_into_tokens(result): + yield self._format_sse(request.chatId, chunk) + await asyncio.sleep(0.01) + yield self._format_sse(request.chatId, "END") + + return stream_result() +``` + +--- + +## Deployment Strategy + +### Phase 1: Testing (Current State) + +```bash +# Keep classifier disabled +TOOL_CLASSIFIER_ENABLED=false +``` + +**Result**: System works exactly as before (RAG-only) + +### Phase 2: Enable Classifier (No Impact) + +```bash +# Enable classifier (but workflows not implemented) +TOOL_CLASSIFIER_ENABLED=true +SERVICE_WORKFLOW_ENABLED=true +CONTEXT_WORKFLOW_ENABLED=true +``` + +**Result**: +- Classifier runs but always routes to RAG +- Service/Context return `None` → fallback to RAG +- Functionally identical to Phase 1 +- Validates integration works + +### Phase 3: Implement Service Workflow + +1. Implement service discovery logic (separate task) +2. Deploy with `SERVICE_WORKFLOW_ENABLED=true` +3. Monitor service routing behavior +4. Rollback flag if issues occur + +### Phase 4: Implement Context Workflow + +1. Implement context analysis logic (separate task) +2. Deploy with `CONTEXT_WORKFLOW_ENABLED=true` +3. Monitor greeting/context detection +4. Rollback flag if issues occur + +### Phase 5: Production + +All workflows operational, full layer-wise routing active. + +--- + +## Extending the System + +### Adding a New Workflow + +1. **Create Workflow Executor**: + +```python +# src/tool_classifier/workflows/custom_workflow.py + +from src.tool_classifier.base_workflow import BaseWorkflow + +class CustomWorkflowExecutor(BaseWorkflow): + """Your custom workflow.""" + + async def execute_async(self, request, context): + # Implement logic + pass + + async def execute_streaming(self, request, context): + # Implement streaming logic + pass +``` + +2. **Register in Classifier**: + +```python +# src/tool_classifier/enums.py + +class WorkflowType(Enum): + SERVICE = "service" + CONTEXT = "context" + RAG = "rag" + CUSTOM = "custom" # Add new type + OOD = "ood" + +# Update layer order +WORKFLOW_LAYER_ORDER = [ + WorkflowType.SERVICE, + WorkflowType.CONTEXT, + WorkflowType.CUSTOM, # Add to chain + WorkflowType.RAG, + WorkflowType.OOD, +] +``` + +3. **Initialize in ToolClassifier**: + +```python +# src/tool_classifier/classifier.py + +def __init__(self, ...): + # ... existing workflows ... + self.custom_workflow = CustomWorkflowExecutor(...) +``` + +4. **Add Feature Flag**: + +```python +# src/llm_orchestrator_config/feature_flags.py + +CUSTOM_WORKFLOW_ENABLED = ( + os.getenv("CUSTOM_WORKFLOW_ENABLED", "true").lower() == "true" +) +``` + +--- + +## Key Concepts + +### 1. None Return Pattern + +Workflows return `None` when they cannot handle a query: + +```python +if not can_handle: + return None # Triggers fallback to next layer +``` + +This enables the fallback chain: Service → Context → RAG → OOD + +### 2. Validation-First Streaming + +For Service and Context workflows (complete responses): + +```python +# 1. Get complete response +response = await call_service(...) + +# 2. Validate BEFORE streaming +is_safe = await guardrails.check_output_async(response) + +if not is_safe: + yield format_sse(chatId, VIOLATION_MESSAGE) + yield format_sse(chatId, "END") + return + +# 3. Stream validated response +for chunk in split_into_tokens(response): + yield format_sse(chatId, chunk) +yield format_sse(chatId, "END") +``` + +### 3. Two Execution Methods + +Every workflow implements both: +- `execute_async()` → For `/orchestrate` (returns complete response) +- `execute_streaming()` → For `/orchestrate/stream` (yields SSE strings) + +--- + +## Summary + +This skeleton provides: + + **Complete framework** for multi-workflow routing + **Safe deployment** with feature flags + **Extensible architecture** using OOP patterns + **Backward compatibility** (disabled by default) + **Clear contracts** via abstract base classes + **Documentation** for implementation tasks + +The system is ready for workflow implementation in separate, independent tasks. + +--- diff --git a/pyproject.toml b/pyproject.toml index dd8f876c..56e14264 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,4 +123,4 @@ exclude = [ ] # --- Global strictness --- -typeCheckingMode = "standard" # Standard typechecking mode \ No newline at end of file +typeCheckingMode = "standard" # Standard typechecking mode diff --git a/src/llm_orchestration_service.py b/src/llm_orchestration_service.py index 92dd7b02..3c059f59 100644 --- a/src/llm_orchestration_service.py +++ b/src/llm_orchestration_service.py @@ -60,6 +60,8 @@ ContextualRetrieverInitializationError, ContextualRetrievalFailureError, ) +from src.llm_orchestrator_config.feature_flags import FeatureFlags +from src.tool_classifier import ToolClassifier class LangfuseConfig: @@ -128,8 +130,15 @@ def __init__(self) -> None: f"Service will continue with default behavior." ) + # Initialize tool classifier (lazy initialization - will be created when first needed) + # This allows components to be initialized per-request with proper context + self.tool_classifier = None + + # Log feature flag configuration + FeatureFlags.log_configuration() + @observe(name="orchestration_request", as_type="agent") - def process_orchestration_request( + async def process_orchestration_request( self, request: OrchestrationRequest ) -> Union[OrchestrationResponse, TestOrchestrationResponse]: """ @@ -204,10 +213,65 @@ def process_orchestration_request( # Initialize all service components (only for valid queries) components = self._initialize_service_components(request) - # Execute the orchestration pipeline - response = self._execute_orchestration_pipeline( - request, components, costs_dict, timing_dict - ) + # TOOL CLASSIFIER INTEGRATION + # Route through tool classifier if enabled, otherwise use existing RAG pipeline + if FeatureFlags.TOOL_CLASSIFIER_ENABLED: + try: + logger.info( + f"[{request.chatId}] Tool classifier enabled - routing query" + ) + + # Initialize tool classifier if not already done + if self.tool_classifier is None: + self.tool_classifier = ToolClassifier( + llm_manager=components["llm_manager"], + orchestration_service=self, + ) + logger.info("Tool classifier initialized") + + # Classify query to determine workflow + classification = await self.tool_classifier.classify( + query=request.message, + conversation_history=request.conversationHistory, + language=detected_language, + ) + + logger.info( + f"[{request.chatId}] Classification: {classification.workflow.value} " + f"(confidence: {classification.confidence:.2f})" + ) + + # Route to appropriate workflow + response = await self.tool_classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=False, + ) + + except Exception as classifier_error: + logger.error( + f"[{request.chatId}] Tool classifier error: {classifier_error}", + exc_info=True, + ) + + if FeatureFlags.FALLBACK_TO_RAG_ON_ERROR: + logger.info( + f"[{request.chatId}] Falling back to RAG pipeline due to classifier error" + ) + # Execute existing RAG pipeline as fallback + response = await self._execute_orchestration_pipeline( + request, components, costs_dict, timing_dict + ) + else: + raise + else: + # Tool classifier disabled - use existing RAG pipeline + logger.debug( + f"[{request.chatId}] Tool classifier disabled - using RAG pipeline" + ) + response = await self._execute_orchestration_pipeline( + request, components, costs_dict, timing_dict + ) # Log final costs and return response self._log_costs(costs_dict) @@ -317,7 +381,6 @@ async def stream_orchestration_response( # Track costs after streaming completes costs_dict: Dict[str, Dict[str, Any]] = {} timing_dict: Dict[str, float] = {} - streaming_start_time = datetime.now() # STEP 0: Detect language from user message detected_language = detect_language(request.message) @@ -390,465 +453,518 @@ async def stream_orchestration_response( f"[{request.chatId}] [{stream_ctx.stream_id}] Input guardrails passed " ) - # STEP 2: REFINE USER PROMPT (blocking) - logger.info( - f"[{request.chatId}] [{stream_ctx.stream_id}] Step 2: Refining user prompt" - ) + # TOOL CLASSIFIER INTEGRATION (STREAMING) + # Route through tool classifier if enabled, otherwise use existing RAG pipeline + if FeatureFlags.TOOL_CLASSIFIER_ENABLED: + try: + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] Tool classifier enabled - routing query (streaming)" + ) - start_time = time.time() - refined_output, refiner_usage = self._refine_user_prompt( - llm_manager=components["llm_manager"], - original_message=request.message, - conversation_history=request.conversationHistory, - ) - timing_dict["prompt_refiner"] = time.time() - start_time - costs_dict["prompt_refiner"] = refiner_usage + # Initialize tool classifier if not already done + if self.tool_classifier is None: + self.tool_classifier = ToolClassifier( + llm_manager=components["llm_manager"], + orchestration_service=self, + ) + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] Tool classifier initialized" + ) - logger.info( - f"[{request.chatId}] [{stream_ctx.stream_id}] Prompt refinement complete " - ) + # Classify query to determine workflow + classification = await self.tool_classifier.classify( + query=request.message, + conversation_history=request.conversationHistory, + language=detected_language, + ) - # STEP 3: RETRIEVE CONTEXT CHUNKS (blocking) - logger.info( - f"[{request.chatId}] [{stream_ctx.stream_id}] Step 3: Retrieving context chunks" - ) + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] Classification: {classification.workflow.value} " + f"(confidence: {classification.confidence:.2f})" + ) - try: - start_time = time.time() - relevant_chunks = await self._safe_retrieve_contextual_chunks( - components["contextual_retriever"], refined_output, request - ) - timing_dict["contextual_retrieval"] = time.time() - start_time - except ( - ContextualRetrieverInitializationError, - ContextualRetrievalFailureError, - ) as e: - logger.warning( - f"[{request.chatId}] [{stream_ctx.stream_id}] Contextual retrieval failed: {str(e)}" - ) - logger.info( - f"[{request.chatId}] [{stream_ctx.stream_id}] Returning out-of-scope due to retrieval failure" - ) - yield self._format_sse(request.chatId, OUT_OF_SCOPE_MESSAGE) - yield self._format_sse(request.chatId, "END") - self._log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) - stream_ctx.mark_completed() - return - - if len(relevant_chunks) == 0: - logger.info( - f"[{request.chatId}] [{stream_ctx.stream_id}] No relevant chunks - out of scope" - ) - detected_lang = getattr(request, "_detected_language", "en") - localized_msg = get_localized_message( - OUT_OF_SCOPE_MESSAGES, detected_lang + # Route to appropriate workflow (streaming) + # route_to_workflow returns AsyncIterator[str] when is_streaming=True + stream_result = await self.tool_classifier.route_to_workflow( + classification=classification, + request=request, + is_streaming=True, + ) + + async for sse_chunk in stream_result: + yield sse_chunk + + # Successfully completed streaming through classifier + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] Tool classifier streaming completed" + ) + + # Log costs and timings + self._log_costs(costs_dict) + log_step_timings(timing_dict, request.chatId) + stream_ctx.mark_completed() + return # Exit after successful classifier routing + + except Exception as classifier_error: + logger.error( + f"[{request.chatId}] [{stream_ctx.stream_id}] Tool classifier error: {classifier_error}", + exc_info=True, + ) + + if not FeatureFlags.FALLBACK_TO_RAG_ON_ERROR: + # Don't fallback - raise error + raise + + # Fallback to RAG pipeline below + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] Falling back to RAG streaming due to classifier error" + ) + # Continue to existing RAG streaming pipeline below + else: + logger.debug( + f"[{request.chatId}] [{stream_ctx.stream_id}] Tool classifier disabled - using RAG streaming" ) - yield self._format_sse(request.chatId, localized_msg) - yield self._format_sse(request.chatId, "END") - self._log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) - stream_ctx.mark_completed() - return - logger.info( - f"[{request.chatId}] [{stream_ctx.stream_id}] Retrieved {len(relevant_chunks)} chunks " - ) + # Execute core RAG streaming pipeline + # NOTE: This only executes if tool classifier is disabled or fallback occurred + async for sse_chunk in self._stream_rag_pipeline( + request=request, + components=components, + stream_ctx=stream_ctx, + costs_dict=costs_dict, + timing_dict=timing_dict, + ): + yield sse_chunk + + # Pipeline completed successfully + return - # STEP 4: QUICK OUT-OF-SCOPE CHECK (blocking) - logger.info( - f"[{request.chatId}] [{stream_ctx.stream_id}] Step 4: Checking if question is in scope" + except Exception as e: + error_id = generate_error_id() + stream_ctx.mark_error(error_id) + log_error_with_context( + logger, error_id, "streaming_orchestration", request.chatId, e ) - start_time = time.time() - is_out_of_scope = await components[ - "response_generator" - ].check_scope_quick( - question=refined_output.original_question, - chunks=relevant_chunks, - max_blocks=ResponseGenerationConstants.DEFAULT_MAX_BLOCKS, + yield self._format_sse(request.chatId, TECHNICAL_ISSUE_MESSAGE) + yield self._format_sse(request.chatId, "END") + + self._log_costs(costs_dict) + log_step_timings(timing_dict, request.chatId) + + # Update budget even on outer exception + self._update_connection_budget( + request.connection_id, costs_dict, request.environment ) - timing_dict["scope_check"] = time.time() - start_time - if is_out_of_scope: - logger.info( - f"[{request.chatId}] [{stream_ctx.stream_id}] Question out of scope" - ) - detected_lang = getattr(request, "_detected_language", "en") - localized_msg = get_localized_message( - OUT_OF_SCOPE_MESSAGES, detected_lang + if self.langfuse_config.langfuse_client: + langfuse = self.langfuse_config.langfuse_client + langfuse.update_current_generation( + metadata={ + "error_id": error_id, + "error_type": type(e).__name__, + "streaming": True, + "streaming_failed": True, + "stream_id": stream_ctx.stream_id, + } ) - yield self._format_sse(request.chatId, localized_msg) - yield self._format_sse(request.chatId, "END") - self._log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) - stream_ctx.mark_completed() - return + langfuse.flush() - logger.info( - f"[{request.chatId}] [{stream_ctx.stream_id}] Question is in scope " - ) + async def _stream_rag_pipeline( + self, + request: OrchestrationRequest, + components: Dict[str, Any], + stream_ctx: Any, + costs_dict: Dict[str, Dict[str, Any]], + timing_dict: Dict[str, float], + ) -> AsyncIterator[str]: + """ + Core RAG streaming pipeline without classifier routing. - # STEP 5: STREAM THROUGH NEMO GUARDRAILS (validation-first) - logger.info( - f"[{request.chatId}] [{stream_ctx.stream_id}] Step 5: Starting streaming through NeMo Guardrails " - f"(validation-first, chunk_size=200)" - ) + This method contains the RAG pipeline logic that can be called directly + by workflows to avoid infinite recursion when the tool classifier is enabled. - streaming_step_start = time.time() + Pipeline Steps: + 1. Refine user prompt (blocking) + 2. Retrieve context chunks (blocking) + 3. Out-of-scope check (blocking) + 4. Stream through NeMo Guardrails (validation-first) - # Record history length before streaming - lm = dspy.settings.lm - history_length_before = ( - len(lm.history) if lm and hasattr(lm, "history") else 0 - ) + Args: + request: Orchestration request + components: Initialized service components (LLM, retriever, generator, guardrails) + stream_ctx: Stream context for tracking + costs_dict: Dictionary to accumulate costs + timing_dict: Dictionary to accumulate timings - async def bot_response_generator() -> AsyncIterator[str]: - """Generator that yields tokens from NATIVE DSPy LLM streaming.""" - async for token in stream_response_native( - agent=components["response_generator"], - question=refined_output.original_question, - chunks=relevant_chunks, - max_blocks=ResponseGenerationConstants.DEFAULT_MAX_BLOCKS, - ): - yield token + Yields: + SSE-formatted strings + """ + streaming_start_time = datetime.now() + detected_language = getattr(request, "_detected_language", "en") - # Create and store bot_generator in stream context for guaranteed cleanup - bot_generator = bot_response_generator() - stream_ctx.bot_generator = bot_generator + # STEP 1: REFINE USER PROMPT (blocking) + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] RAG Pipeline Step 1: Refining user prompt" + ) - # Wrap entire streaming logic in try/except for proper error handling - try: - # Track tokens and accumulated response in stream context - accumulated_response = [] # Track the full response for production storage - - if components["guardrails_adapter"]: - # Use NeMo's stream_with_guardrails helper method - # This properly integrates the external generator with NeMo's validation - chunk_count = 0 - - try: - async for validated_chunk in components[ - "guardrails_adapter" - ].stream_with_guardrails( - user_message=refined_output.original_question, - bot_message_generator=bot_generator, - ): - chunk_count += 1 - - # Estimate tokens (rough approximation: 4 characters = 1 token) - chunk_tokens = len(validated_chunk) // 4 - stream_ctx.token_count += chunk_tokens - - # Accumulate response for production storage - accumulated_response.append(validated_chunk) - - # Check token limit - if ( - stream_ctx.token_count - > StreamConfig.MAX_TOKENS_PER_STREAM - ): - logger.error( - f"[{request.chatId}] [{stream_ctx.stream_id}] Token limit exceeded: " - f"{stream_ctx.token_count} > {StreamConfig.MAX_TOKENS_PER_STREAM}" - ) - # Send error message and end stream immediately - yield self._format_sse( - request.chatId, STREAM_TOKEN_LIMIT_MESSAGE - ) - yield self._format_sse(request.chatId, "END") - - # Extract usage and log costs - usage_info = get_lm_usage_since( - history_length_before - ) - costs_dict["streaming_generation"] = usage_info - self._log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) - stream_ctx.mark_completed() - return # Stop immediately - cleanup happens in finally - - # Check for guardrail violations using blocked phrases - # Match the actual behavior of NeMo Guardrails adapter - is_guardrail_error = False - if isinstance(validated_chunk, str): - # Use the same blocked phrases as the guardrails adapter - blocked_phrases = GUARDRAILS_BLOCKED_PHRASES - chunk_lower = validated_chunk.strip().lower() - # Check if the chunk is primarily a blocked phrase - for phrase in blocked_phrases: - # More robust check: ensure the phrase is the main content - if ( - phrase.lower() in chunk_lower - and len(chunk_lower) - <= len(phrase.lower()) + 20 - ): - is_guardrail_error = True - break - - if is_guardrail_error: - logger.warning( - f"[{request.chatId}] [{stream_ctx.stream_id}] Guardrails violation detected" - ) - # Send the violation message and end stream - yield self._format_sse( - request.chatId, - OUTPUT_GUARDRAIL_VIOLATION_MESSAGE, - ) - yield self._format_sse(request.chatId, "END") - - # Log the violation - logger.warning( - f"[{request.chatId}] [{stream_ctx.stream_id}] Output blocked by guardrails: {validated_chunk}" - ) - - # Extract usage and log costs - usage_info = get_lm_usage_since( - history_length_before - ) - costs_dict["streaming_generation"] = usage_info - self._log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) - stream_ctx.mark_completed() - return # Cleanup happens in finally - - # Log first few chunks for debugging - if ( - chunk_count - <= ResponseGenerationConstants.DEFAULT_MAX_BLOCKS - ): - logger.debug( - f"[{request.chatId}] [{stream_ctx.stream_id}] Validated chunk {chunk_count}: {repr(validated_chunk)}" - ) - - # Yield the validated chunk to client - yield self._format_sse(request.chatId, validated_chunk) - except GeneratorExit: - # Client disconnected - stream_ctx.mark_cancelled() - logger.info( - f"[{request.chatId}] [{stream_ctx.stream_id}] Client disconnected during guardrails streaming" - ) - raise + start_time = time.time() + refined_output, refiner_usage = self._refine_user_prompt( + llm_manager=components["llm_manager"], + original_message=request.message, + conversation_history=request.conversationHistory, + ) + timing_dict["prompt_refiner"] = time.time() - start_time + costs_dict["prompt_refiner"] = refiner_usage - logger.info( - f"[{request.chatId}] [{stream_ctx.stream_id}] Stream completed successfully " - f"({chunk_count} chunks streamed)" - ) + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] Prompt refinement complete" + ) - # Send document references before END token - doc_references = self._extract_document_references( - relevant_chunks - ) - if doc_references: - logger.info( - f"[{request.chatId}] [{stream_ctx.stream_id}] Sending {len(doc_references)} document references before END" - ) - # Format references as markdown text - refs_text = "\n\n**References:**\n" + "\n".join( - f"{i + 1}. [{ref.document_url}]({ref.document_url})" - for i, ref in enumerate(doc_references) - ) - yield self._format_sse(request.chatId, refs_text) + # STEP 2: RETRIEVE CONTEXT CHUNKS (blocking) + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] RAG Pipeline Step 2: Retrieving context chunks" + ) - yield self._format_sse(request.chatId, "END") + try: + start_time = time.time() + relevant_chunks = await self._safe_retrieve_contextual_chunks( + components["contextual_retriever"], refined_output, request + ) + timing_dict["contextual_retrieval"] = time.time() - start_time + except ( + ContextualRetrieverInitializationError, + ContextualRetrievalFailureError, + ) as e: + logger.warning( + f"[{request.chatId}] [{stream_ctx.stream_id}] Contextual retrieval failed: {str(e)}" + ) + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] Returning out-of-scope due to retrieval failure" + ) + localized_msg = get_localized_message( + OUT_OF_SCOPE_MESSAGES, detected_language + ) + yield self._format_sse(request.chatId, localized_msg) + yield self._format_sse(request.chatId, "END") + self._log_costs(costs_dict) + log_step_timings(timing_dict, request.chatId) + stream_ctx.mark_completed() + return - else: - # No guardrails - stream directly - logger.warning( - f"[{request.chatId}] [{stream_ctx.stream_id}] Streaming without guardrails validation" - ) - chunk_count = 0 - async for token in bot_generator: - chunk_count += 1 - - # Estimate tokens and check limit - token_estimate = len(token) // 4 - stream_ctx.token_count += token_estimate - - # Accumulate response for production storage - accumulated_response.append(token) - - if ( - stream_ctx.token_count - > StreamConfig.MAX_TOKENS_PER_STREAM - ): - logger.error( - f"[{request.chatId}] [{stream_ctx.stream_id}] Token limit exceeded (no guardrails): " - f"{stream_ctx.token_count} > {StreamConfig.MAX_TOKENS_PER_STREAM}" - ) - yield self._format_sse( - request.chatId, STREAM_TOKEN_LIMIT_MESSAGE - ) - yield self._format_sse(request.chatId, "END") - stream_ctx.mark_completed() - return # Stop immediately - cleanup in finally - - yield self._format_sse(request.chatId, token) - - # Send document references before END token - doc_references = self._extract_document_references( - relevant_chunks - ) - if doc_references: - logger.info( - f"[{request.chatId}] [{stream_ctx.stream_id}] Sending {len(doc_references)} document references before END" - ) - # Format references as markdown text - refs_text = "\n\n**References:**\n" + "\n".join( - f"{i + 1}. [{ref.document_url}]({ref.document_url})" - for i, ref in enumerate(doc_references) - ) - yield self._format_sse(request.chatId, refs_text) + if len(relevant_chunks) == 0: + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] No relevant chunks - out of scope" + ) + localized_msg = get_localized_message( + OUT_OF_SCOPE_MESSAGES, detected_language + ) + yield self._format_sse(request.chatId, localized_msg) + yield self._format_sse(request.chatId, "END") + self._log_costs(costs_dict) + log_step_timings(timing_dict, request.chatId) + stream_ctx.mark_completed() + return - yield self._format_sse(request.chatId, "END") + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] Retrieved {len(relevant_chunks)} chunks" + ) - # Extract usage information after streaming completes - usage_info = get_lm_usage_since(history_length_before) - costs_dict["streaming_generation"] = usage_info + # STEP 3: QUICK OUT-OF-SCOPE CHECK (blocking) + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] RAG Pipeline Step 3: Checking if question is in scope" + ) - # Record streaming generation time - timing_dict["streaming_generation"] = ( - time.time() - streaming_step_start - ) - # Mark output guardrails as inline (not blocking) - timing_dict["output_guardrails"] = 0.0 # Inline during streaming + start_time = time.time() + is_out_of_scope = await components["response_generator"].check_scope_quick( + question=refined_output.original_question, + chunks=relevant_chunks, + max_blocks=ResponseGenerationConstants.DEFAULT_MAX_BLOCKS, + ) + timing_dict["scope_check"] = time.time() - start_time - # Calculate streaming duration - streaming_duration = ( - datetime.now() - streaming_start_time - ).total_seconds() - logger.info( - f"[{request.chatId}] [{stream_ctx.stream_id}] Streaming completed in {streaming_duration:.2f}s" - ) + if is_out_of_scope: + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] Question out of scope" + ) + localized_msg = get_localized_message( + OUT_OF_SCOPE_MESSAGES, detected_language + ) + yield self._format_sse(request.chatId, localized_msg) + yield self._format_sse(request.chatId, "END") + self._log_costs(costs_dict) + log_step_timings(timing_dict, request.chatId) + stream_ctx.mark_completed() + return - # Log costs and trace - self._log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) + logger.info(f"[{request.chatId}] [{stream_ctx.stream_id}] Question is in scope") - # Update budget for the LLM connection - self._update_connection_budget( - request.connection_id, costs_dict, request.environment - ) + # STEP 4: STREAM THROUGH NEMO GUARDRAILS (validation-first) + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] RAG Pipeline Step 4: Starting streaming through NeMo Guardrails" + ) - if self.langfuse_config.langfuse_client: - langfuse = self.langfuse_config.langfuse_client - total_costs = calculate_total_costs(costs_dict) - - langfuse.update_current_generation( - model=components["llm_manager"] - .get_provider_info() - .get("model", "unknown"), - usage_details={ - "input": usage_info.get("total_prompt_tokens", 0), - "output": usage_info.get("total_completion_tokens", 0), - "total": usage_info.get("total_tokens", 0), - }, - cost_details={ - "total": total_costs.get("total_cost", 0.0), - }, - metadata={ - "streaming": True, - "streaming_duration_seconds": streaming_duration, - "chunks_streamed": chunk_count, - "cost_breakdown": costs_dict, - "chat_id": request.chatId, - "environment": request.environment, - "stream_id": stream_ctx.stream_id, - }, - ) - langfuse.flush() - - # Store inference data (for production and testing environments) - if request.environment in [ - PRODUCTION_DEPLOYMENT_ENVIRONMENT, - TEST_DEPLOYMENT_ENVIRONMENT, - ]: - try: - await self._store_production_inference_data_async( - request=request, - refined_output=refined_output, - relevant_chunks=relevant_chunks, - accumulated_response="".join(accumulated_response), - ) - except Exception as storage_error: - # Log storage error but don't fail the request + streaming_step_start = time.time() + + # Record history length before streaming + lm = dspy.settings.lm + history_length_before = len(lm.history) if lm and hasattr(lm, "history") else 0 + + async def bot_response_generator() -> AsyncIterator[str]: + """Generator that yields tokens from NATIVE DSPy LLM streaming.""" + async for token in stream_response_native( + agent=components["response_generator"], + question=refined_output.original_question, + chunks=relevant_chunks, + max_blocks=ResponseGenerationConstants.DEFAULT_MAX_BLOCKS, + ): + yield token + + # Create and store bot_generator in stream context for guaranteed cleanup + bot_generator = bot_response_generator() + stream_ctx.bot_generator = bot_generator + + # Wrap entire streaming logic in try/except for proper error handling + try: + # Track tokens and accumulated response in stream context + accumulated_response = [] # Track the full response for production storage + + if components["guardrails_adapter"]: + # Use NeMo's stream_with_guardrails helper method + chunk_count = 0 + + try: + async for validated_chunk in components[ + "guardrails_adapter" + ].stream_with_guardrails( + user_message=refined_output.original_question, + bot_message_generator=bot_generator, + ): + chunk_count += 1 + + # Estimate tokens (rough approximation: 4 characters = 1 token) + chunk_tokens = len(validated_chunk) // 4 + stream_ctx.token_count += chunk_tokens + + # Accumulate response for production storage + accumulated_response.append(validated_chunk) + + # Check token limit + if stream_ctx.token_count > StreamConfig.MAX_TOKENS_PER_STREAM: logger.error( - f"Storage failed for chat_id: {request.chatId}, environment: {request.environment} - {str(storage_error)}" + f"[{request.chatId}] [{stream_ctx.stream_id}] Token limit exceeded: " + f"{stream_ctx.token_count} > {StreamConfig.MAX_TOKENS_PER_STREAM}" + ) + yield self._format_sse( + request.chatId, STREAM_TOKEN_LIMIT_MESSAGE + ) + yield self._format_sse(request.chatId, "END") + + usage_info = get_lm_usage_since(history_length_before) + costs_dict["streaming_generation"] = usage_info + self._log_costs(costs_dict) + log_step_timings(timing_dict, request.chatId) + stream_ctx.mark_completed() + return + + # Check for guardrail violations + is_guardrail_error = False + if isinstance(validated_chunk, str): + blocked_phrases = GUARDRAILS_BLOCKED_PHRASES + chunk_lower = validated_chunk.strip().lower() + for phrase in blocked_phrases: + if ( + phrase.lower() in chunk_lower + and len(chunk_lower) <= len(phrase.lower()) + 20 + ): + is_guardrail_error = True + break + + if is_guardrail_error: + logger.warning( + f"[{request.chatId}] [{stream_ctx.stream_id}] Guardrails violation detected" + ) + yield self._format_sse( + request.chatId, OUTPUT_GUARDRAIL_VIOLATION_MESSAGE ) + yield self._format_sse(request.chatId, "END") - # Mark stream as completed successfully - stream_ctx.mark_completed() + usage_info = get_lm_usage_since(history_length_before) + costs_dict["streaming_generation"] = usage_info + self._log_costs(costs_dict) + log_step_timings(timing_dict, request.chatId) + stream_ctx.mark_completed() + return + # Yield the validated chunk to client + yield self._format_sse(request.chatId, validated_chunk) except GeneratorExit: - # Client disconnected - mark as cancelled stream_ctx.mark_cancelled() logger.info( - f"[{request.chatId}] [{stream_ctx.stream_id}] Client disconnected" - ) - usage_info = get_lm_usage_since(history_length_before) - costs_dict["streaming_generation"] = usage_info - self._log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) - - # Update budget even on client disconnect - self._update_connection_budget( - request.connection_id, costs_dict, request.environment + f"[{request.chatId}] [{stream_ctx.stream_id}] Client disconnected during guardrails streaming" ) raise - except Exception as stream_error: - error_id = generate_error_id() - stream_ctx.mark_error(error_id) - log_error_with_context( - logger, - error_id, - "streaming_generation", - request.chatId, - stream_error, - ) - yield self._format_sse(request.chatId, TECHNICAL_ISSUE_MESSAGE) - yield self._format_sse(request.chatId, "END") - usage_info = get_lm_usage_since(history_length_before) - costs_dict["streaming_generation"] = usage_info - self._log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] Stream completed successfully ({chunk_count} chunks)" + ) - # Update budget even on streaming error - self._update_connection_budget( - request.connection_id, costs_dict, request.environment + # Send document references before END token + doc_references = self._extract_document_references(relevant_chunks) + if doc_references: + refs_text = "\n\n**References:**\n" + "\n".join( + f"{i + 1}. [{ref.document_url}]({ref.document_url})" + for i, ref in enumerate(doc_references) ) + yield self._format_sse(request.chatId, refs_text) - except Exception as e: - error_id = generate_error_id() - stream_ctx.mark_error(error_id) - log_error_with_context( - logger, error_id, "streaming_orchestration", request.chatId, e + yield self._format_sse(request.chatId, "END") + + else: + # No guardrails - stream directly + logger.warning( + f"[{request.chatId}] [{stream_ctx.stream_id}] Streaming without guardrails validation" ) + chunk_count = 0 + async for token in bot_generator: + chunk_count += 1 + + token_estimate = len(token) // 4 + stream_ctx.token_count += token_estimate + accumulated_response.append(token) + + if stream_ctx.token_count > StreamConfig.MAX_TOKENS_PER_STREAM: + logger.error( + f"[{request.chatId}] [{stream_ctx.stream_id}] Token limit exceeded (no guardrails)" + ) + yield self._format_sse( + request.chatId, STREAM_TOKEN_LIMIT_MESSAGE + ) + yield self._format_sse(request.chatId, "END") + stream_ctx.mark_completed() + return + + yield self._format_sse(request.chatId, token) + + # Send document references before END token + doc_references = self._extract_document_references(relevant_chunks) + if doc_references: + refs_text = "\n\n**References:**\n" + "\n".join( + f"{i + 1}. [{ref.document_url}]({ref.document_url})" + for i, ref in enumerate(doc_references) + ) + yield self._format_sse(request.chatId, refs_text) - yield self._format_sse(request.chatId, TECHNICAL_ISSUE_MESSAGE) yield self._format_sse(request.chatId, "END") - self._log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) + # Extract usage information after streaming completes + usage_info = get_lm_usage_since(history_length_before) + costs_dict["streaming_generation"] = usage_info - # Update budget even on outer exception - self._update_connection_budget( - request.connection_id, costs_dict, request.environment + # Record timings + timing_dict["streaming_generation"] = time.time() - streaming_step_start + timing_dict["output_guardrails"] = 0.0 # Inline during streaming + + # Calculate streaming duration + streaming_duration = (datetime.now() - streaming_start_time).total_seconds() + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] Streaming completed in {streaming_duration:.2f}s" + ) + + # Log costs and trace + self._log_costs(costs_dict) + log_step_timings(timing_dict, request.chatId) + + # Update budget + self._update_connection_budget( + request.connection_id, costs_dict, request.environment + ) + + # Langfuse tracking + if self.langfuse_config.langfuse_client: + langfuse = self.langfuse_config.langfuse_client + total_costs = calculate_total_costs(costs_dict) + + langfuse.update_current_generation( + model=components["llm_manager"] + .get_provider_info() + .get("model", "unknown"), + usage_details={ + "input": usage_info.get("total_prompt_tokens", 0), + "output": usage_info.get("total_completion_tokens", 0), + "total": usage_info.get("total_tokens", 0), + }, + cost_details={"total": total_costs.get("total_cost", 0.0)}, + metadata={ + "streaming": True, + "streaming_duration_seconds": streaming_duration, + "chunks_streamed": chunk_count, + "cost_breakdown": costs_dict, + "chat_id": request.chatId, + "environment": request.environment, + "stream_id": stream_ctx.stream_id, + }, ) + langfuse.flush() - if self.langfuse_config.langfuse_client: - langfuse = self.langfuse_config.langfuse_client - langfuse.update_current_generation( - metadata={ - "error_id": error_id, - "error_type": type(e).__name__, - "streaming": True, - "streaming_failed": True, - "stream_id": stream_ctx.stream_id, - } + # Store inference data (for production and testing environments) + if request.environment in [ + PRODUCTION_DEPLOYMENT_ENVIRONMENT, + TEST_DEPLOYMENT_ENVIRONMENT, + ]: + try: + await self._store_production_inference_data_async( + request=request, + refined_output=refined_output, + relevant_chunks=relevant_chunks, + accumulated_response="".join(accumulated_response), ) - langfuse.flush() + except Exception as storage_error: + logger.error( + f"Storage failed for chat_id: {request.chatId}, environment: {request.environment} - {str(storage_error)}" + ) + + # Mark stream as completed successfully + stream_ctx.mark_completed() + + except GeneratorExit: + # Client disconnected - mark as cancelled + stream_ctx.mark_cancelled() + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] Client disconnected" + ) + usage_info = get_lm_usage_since(history_length_before) + costs_dict["streaming_generation"] = usage_info + self._log_costs(costs_dict) + log_step_timings(timing_dict, request.chatId) + + # Update budget even on client disconnect + self._update_connection_budget( + request.connection_id, costs_dict, request.environment + ) + raise + except Exception as stream_error: + error_id = generate_error_id() + stream_ctx.mark_error(error_id) + log_error_with_context( + logger, + error_id, + "streaming_generation", + request.chatId, + stream_error, + ) + yield self._format_sse(request.chatId, TECHNICAL_ISSUE_MESSAGE) + yield self._format_sse(request.chatId, "END") + + usage_info = get_lm_usage_since(history_length_before) + costs_dict["streaming_generation"] = usage_info + self._log_costs(costs_dict) + log_step_timings(timing_dict, request.chatId) + + # Update budget even on streaming error + self._update_connection_budget( + request.connection_id, costs_dict, request.environment + ) def _format_sse(self, chat_id: str, content: str) -> str: """ @@ -992,7 +1108,7 @@ def _log_generator_status(self, components: Dict[str, Any]) -> None: logger.warning(f" Generator: Status check failed - {str(e)}") @observe(name="execute_orchestration_pipeline", as_type="span") - def _execute_orchestration_pipeline( + async def _execute_orchestration_pipeline( self, request: OrchestrationRequest, components: Dict[str, Any], @@ -1006,7 +1122,7 @@ def _execute_orchestration_pipeline( # Step 1: Input Guardrails Check if components["guardrails_adapter"]: start_time = time.time() - input_blocked_response = self.handle_input_guardrails( + input_blocked_response = await self.handle_input_guardrails( components["guardrails_adapter"], request, costs_dict ) timing_dict["input_guardrails_check"] = time.time() - start_time @@ -1026,7 +1142,7 @@ def _execute_orchestration_pipeline( # Step 3: Retrieve relevant chunks using contextual retrieval try: start_time = time.time() - relevant_chunks = self._safe_retrieve_contextual_chunks_sync( + relevant_chunks = await self._safe_retrieve_contextual_chunks( components["contextual_retriever"], refined_output, request ) timing_dict["contextual_retrieval"] = time.time() - start_time @@ -1057,7 +1173,7 @@ def _execute_orchestration_pipeline( # Step 5: Output Guardrails Check # Apply guardrails to all response types for consistent safety across all environments start_time = time.time() - output_guardrails_response = self.handle_output_guardrails( + output_guardrails_response = await self.handle_output_guardrails( components["guardrails_adapter"], generated_response, request, @@ -1132,14 +1248,14 @@ def _safe_initialize_response_generator( ) return None - def handle_input_guardrails( + async def handle_input_guardrails( self, guardrails_adapter: NeMoRailsAdapter, request: OrchestrationRequest, costs_dict: Dict[str, Dict[str, Any]], ) -> Union[OrchestrationResponse, TestOrchestrationResponse, None]: """Check input guardrails and return blocked response if needed.""" - input_check_result = self._check_input_guardrails( + input_check_result = await self._check_input_guardrails_async( guardrails_adapter=guardrails_adapter, user_message=request.message, costs_dict=costs_dict, @@ -1186,21 +1302,23 @@ def _safe_retrieve_contextual_chunks_sync( """Synchronous wrapper for _safe_retrieve_contextual_chunks for non-streaming pipeline.""" try: - # Safely execute the async method in the sync context + # Check if there's a running event loop try: asyncio.get_running_loop() - # If we get here, there's a running event loop; cannot block synchronously - raise RuntimeError( + # If we get here, there IS a running event loop; cannot use asyncio.run() + raise ContextualRetrievalFailureError( "Cannot call _safe_retrieve_contextual_chunks_sync from an async context with a running event loop. " "Please use the async version _safe_retrieve_contextual_chunks instead." ) except RuntimeError: - # No running loop, safe to use asyncio.run() - return asyncio.run( - self._safe_retrieve_contextual_chunks( - contextual_retriever, refined_output, request - ) + # No running loop (get_running_loop raised RuntimeError), safe to use asyncio.run() + pass + + return asyncio.run( + self._safe_retrieve_contextual_chunks( + contextual_retriever, refined_output, request ) + ) except ( ContextualRetrieverInitializationError, ContextualRetrievalFailureError, @@ -1255,7 +1373,7 @@ async def _safe_retrieve_contextual_chunks( f"Contextual chunk retrieval failed: {str(retrieval_error)}" ) from retrieval_error - def handle_output_guardrails( + async def handle_output_guardrails( self, guardrails_adapter: Optional[NeMoRailsAdapter], generated_response: Union[OrchestrationResponse, TestOrchestrationResponse], @@ -1273,7 +1391,7 @@ def handle_output_guardrails( if should_check_guardrails: # Type assertion: should_check_guardrails guarantees guardrails_adapter is not None assert guardrails_adapter is not None - output_check_result = self._check_output_guardrails( + output_check_result = await self._check_output_guardrails( guardrails_adapter=guardrails_adapter, assistant_message=generated_response.content, costs_dict=costs_dict, @@ -1694,7 +1812,7 @@ def _check_input_guardrails( ) @observe(name="check_output_guardrails", as_type="span") - def _check_output_guardrails( + async def _check_output_guardrails( self, guardrails_adapter: NeMoRailsAdapter, assistant_message: str, @@ -1714,7 +1832,7 @@ def _check_output_guardrails( logger.info("Starting output guardrails check") try: - result = guardrails_adapter.check_output(assistant_message) + result = await guardrails_adapter.check_output_async(assistant_message) # Store guardrail costs costs_dict["output_guardrails"] = result.usage diff --git a/src/llm_orchestration_service_api.py b/src/llm_orchestration_service_api.py index 8bdc80cc..2a929db0 100644 --- a/src/llm_orchestration_service_api.py +++ b/src/llm_orchestration_service_api.py @@ -225,7 +225,7 @@ def health_check(request: Request) -> dict[str, str]: summary="Process LLM orchestration request", description="Processes a user message through the LLM orchestration pipeline", ) -def orchestrate_llm_request( +async def orchestrate_llm_request( http_request: Request, request: OrchestrationRequest, ) -> OrchestrationResponse: @@ -262,7 +262,7 @@ def orchestrate_llm_request( ) # Process the request - response = orchestration_service.process_orchestration_request(request) + response = await orchestration_service.process_orchestration_request(request) logger.info(f"Successfully processed request for chatId: {request.chatId}") return response @@ -287,7 +287,7 @@ def orchestrate_llm_request( summary="Process test LLM orchestration request", description="Processes a simplified test message through the LLM orchestration pipeline", ) -def test_orchestrate_llm_request( +async def test_orchestrate_llm_request( http_request: Request, request: TestOrchestrationRequest, ) -> TestOrchestrationResponse: @@ -341,7 +341,9 @@ def test_orchestrate_llm_request( logger.info(f"This is full request constructed for testing: {full_request}") # Process the request using the same logic - response = orchestration_service.process_orchestration_request(full_request) + response = await orchestration_service.process_orchestration_request( + full_request + ) # If response is already TestOrchestrationResponse (when environment is testing), return it directly if isinstance(response, TestOrchestrationResponse): diff --git a/src/llm_orchestrator_config/feature_flags.py b/src/llm_orchestrator_config/feature_flags.py new file mode 100644 index 00000000..d0d3fff8 --- /dev/null +++ b/src/llm_orchestrator_config/feature_flags.py @@ -0,0 +1,82 @@ +"""Feature flags for tool classifier system.""" + +import os +from loguru import logger + + +class FeatureFlags: + """ + Feature flags for controlling tool classifier and workflow behavior. + + These flags enable safe deployment and gradual rollout of the multi-workflow + system. They can be controlled via environment variables. + + Deployment Strategy: + 1. Start with TOOL_CLASSIFIER_ENABLED=false (use existing RAG only) + 2. Enable classifier with all workflows disabled for testing + 3. Enable workflows one at a time (SERVICE → CONTEXT → etc.) + 4. Monitor and rollback if issues occur + + Environment Variables: + - TOOL_CLASSIFIER_ENABLED: Master switch for classifier (default: false) + - SERVICE_WORKFLOW_ENABLED: Enable Layer 1 service workflow (default: true) + - CONTEXT_WORKFLOW_ENABLED: Enable Layer 2 context workflow (default: true) + """ + + # Master switch for tool classifier + # When False: Uses existing RAG-only pipeline (backward compatibility) + # When True: Routes through tool classifier + TOOL_CLASSIFIER_ENABLED = ( + os.getenv("TOOL_CLASSIFIER_ENABLED", "false").lower() == "true" + ) + + # Individual workflow toggles + # These only take effect when TOOL_CLASSIFIER_ENABLED=true + SERVICE_WORKFLOW_ENABLED = ( + os.getenv("SERVICE_WORKFLOW_ENABLED", "true").lower() == "true" + ) + CONTEXT_WORKFLOW_ENABLED = ( + os.getenv("CONTEXT_WORKFLOW_ENABLED", "true").lower() == "true" + ) + + # RAG and OOD workflows are always enabled (no flags) + # RAG is the core fallback, OOD is the final safety net + + # Safety: Fallback to RAG if tool classifier encounters errors + # This ensures service continues working even if classifier fails + FALLBACK_TO_RAG_ON_ERROR = True + + @classmethod + def log_configuration(cls): + """Log current feature flag configuration (useful for debugging).""" + logger.info("Tool Classifier Feature Flags:") + logger.info(f" TOOL_CLASSIFIER_ENABLED: {cls.TOOL_CLASSIFIER_ENABLED}") + if cls.TOOL_CLASSIFIER_ENABLED: + logger.info(f" SERVICE_WORKFLOW_ENABLED: {cls.SERVICE_WORKFLOW_ENABLED}") + logger.info(f" CONTEXT_WORKFLOW_ENABLED: {cls.CONTEXT_WORKFLOW_ENABLED}") + logger.info(f" FALLBACK_TO_RAG_ON_ERROR: {cls.FALLBACK_TO_RAG_ON_ERROR}") + else: + logger.info(" (Classifier disabled - using RAG-only pipeline)") + + @classmethod + def is_workflow_enabled(cls, workflow_name: str) -> bool: + """ + Check if a specific workflow is enabled. + + Args: + workflow_name: Name of workflow ("service", "context", "rag", "ood") + + Returns: + True if workflow is enabled and classifier is enabled + """ + if not cls.TOOL_CLASSIFIER_ENABLED: + return False + + workflow_flags = { + "service": cls.SERVICE_WORKFLOW_ENABLED, + "context": cls.CONTEXT_WORKFLOW_ENABLED, + "rag": True, # Always enabled + "ood": True, # Always enabled + } + + return workflow_flags.get(workflow_name.lower(), False) diff --git a/src/tool_classifier/__init__.py b/src/tool_classifier/__init__.py new file mode 100644 index 00000000..38b861d5 --- /dev/null +++ b/src/tool_classifier/__init__.py @@ -0,0 +1,20 @@ +""" +Tool Classifier Module - Multi-workflow routing system. + +This module implements a layer-wise workflow routing system that determines +whether a user query should be handled by: +- Layer 1: Service Workflow (external API calls) +- Layer 2: Context Workflow (conversation history/greetings) +- Layer 3: RAG Workflow (knowledge base retrieval) +- Layer 4: OOD Workflow (out-of-domain fallback) +""" + +from .classifier import ToolClassifier +from .enums import WorkflowType +from .models import ClassificationResult + +__all__ = [ + "ToolClassifier", + "WorkflowType", + "ClassificationResult", +] diff --git a/src/tool_classifier/base_workflow.py b/src/tool_classifier/base_workflow.py new file mode 100644 index 00000000..50faf7ad --- /dev/null +++ b/src/tool_classifier/base_workflow.py @@ -0,0 +1,118 @@ +"""Abstract base class for workflow executors.""" + +from abc import ABC, abstractmethod +from typing import Any, AsyncIterator, Dict, Optional + +from models.request_models import OrchestrationRequest, OrchestrationResponse + + +class BaseWorkflow(ABC): + """ + Abstract base class for all workflow executors. + + This class defines the contract that all workflow implementations must follow. + Each workflow must implement both streaming and non-streaming execution methods. + + Design Pattern: Strategy Pattern + - Each workflow is a concrete strategy for handling queries + - ToolClassifier acts as the context that selects the appropriate strategy + + Workflows: + - ServiceWorkflowExecutor: Handles external service/API calls + - ContextWorkflowExecutor: Handles conversation history and greetings + - RAGWorkflowExecutor: Handles knowledge base retrieval (existing) + - OODWorkflowExecutor: Handles out-of-domain queries + + Return None Pattern: + Workflows return None when they cannot handle a query, triggering + fallback to the next layer in the classification chain. + """ + + @abstractmethod + async def execute_async( + self, + request: OrchestrationRequest, + context: Dict[str, Any], + ) -> Optional[OrchestrationResponse]: + """ + Execute workflow in non-streaming mode. + + This method is called for the /orchestrate and /orchestrate/test endpoints + which return complete responses in a single HTTP response. + + Args: + request: The orchestration request containing user query and context + context: Workflow-specific metadata from ClassificationResult.metadata + + Returns: + OrchestrationResponse if workflow can handle this query + None if workflow cannot handle (triggers fallback to next layer) + + Example: + # If Service workflow detects no matching service: + return None # Falls back to Context workflow + + # If Service workflow successfully executes: + return OrchestrationResponse( + chatId=request.chatId, + llmServiceActive=True, + questionOutOfLLMScope=False, + inputGuardFailed=False, + content="EUR/USD rate is 1.0850" + ) + """ + pass + + @abstractmethod + async def execute_streaming( + self, + request: OrchestrationRequest, + context: Dict[str, Any], + ) -> Optional[AsyncIterator[str]]: + """ + Execute workflow in streaming mode (Server-Sent Events). + + This method is called for the /orchestrate/stream endpoint which yields + response chunks progressively to the client. + + Args: + request: The orchestration request containing user query and context + context: Workflow-specific metadata from ClassificationResult.metadata + + Returns: + AsyncIterator[str] yielding SSE-formatted strings if workflow can handle + None if workflow cannot handle (triggers fallback to next layer) + + SSE Format: + Each yielded string should be formatted as: + 'data: {"chatId": "...", "payload": {"content": "..."}, ...}\\n\\n' + + Streaming Types: + - Real streaming (RAG): LLM generates tokens progressively + - Simulated streaming (Service/Context): Complete response chunked for UX + + Example: + # If Context workflow cannot answer from history: + return None # Falls back to RAG workflow + + # If Context workflow can answer: + async def stream_response(): + # Validate complete response first + answer = "The rate I mentioned was 1.08" + is_safe = await validate_with_guardrails(answer) + + if not is_safe: + yield format_sse(chatId, VIOLATION_MESSAGE) + yield format_sse(chatId, "END") + return + + # Stream validated response token-by-token + for chunk in split_into_chunks(answer): + yield format_sse(chatId, chunk) + await asyncio.sleep(0.01) + + yield format_sse(chatId, "END") + + return stream_response() + """ + pass diff --git a/src/tool_classifier/classifier.py b/src/tool_classifier/classifier.py new file mode 100644 index 00000000..71a45922 --- /dev/null +++ b/src/tool_classifier/classifier.py @@ -0,0 +1,338 @@ +"""Main tool classifier for workflow routing.""" + +from typing import Any, AsyncIterator, Dict, List, Literal, Union, overload +from loguru import logger + +from models.request_models import ( + ConversationItem, + OrchestrationRequest, + OrchestrationResponse, +) +from tool_classifier.enums import WorkflowType, WORKFLOW_DISPLAY_NAMES +from tool_classifier.models import ClassificationResult +from tool_classifier.workflows import ( + ServiceWorkflowExecutor, + ContextWorkflowExecutor, + RAGWorkflowExecutor, + OODWorkflowExecutor, +) + + +class ToolClassifier: + """ + Main classifier that determines which workflow should handle user queries. + + Implements a layer-wise filtering approach: + Layer 1: Service Workflow → External API calls + Layer 2: Context Workflow → Conversation history/greetings + Layer 3: RAG Workflow → Knowledge base retrieval + Layer 4: OOD Workflow → Out-of-domain fallback + + Each layer is tried in sequence. If a layer cannot handle the query + (returns None), the classifier falls back to the next layer. + + Architecture: + - Strategy Pattern: Each workflow is a pluggable strategy + - Chain of Responsibility: Layers form a fallback chain + - Dependency Injection: LLM manager and connections injected from main service + """ + + def __init__( + self, + llm_manager: Any, + orchestration_service: Any, + ): + """ + Initialize tool classifier with required dependencies. + + Args: + llm_manager: LLM manager for making LLM calls (intent detection, context check) + orchestration_service: Reference to main orchestration service (for RAG workflow) + """ + self.llm_manager = llm_manager + self.orchestration_service = orchestration_service + + # Initialize workflow executors + self.service_workflow = ServiceWorkflowExecutor( + llm_manager=llm_manager, + ) + self.context_workflow = ContextWorkflowExecutor( + llm_manager=llm_manager, + ) + self.rag_workflow = RAGWorkflowExecutor( + orchestration_service=orchestration_service, + ) + self.ood_workflow = OODWorkflowExecutor() + + logger.info("Tool classifier initialized with all workflow executors") + + async def classify( + self, + query: str, + conversation_history: List[ConversationItem], + language: str, + ) -> ClassificationResult: + """ + Classify a user query to determine which workflow should handle it. + + Implements layer-wise classification logic: + 1. Check if SERVICE workflow can handle (intent detection) + 2. Check if CONTEXT workflow can handle (greeting/history check) + 3. Default to RAG workflow (knowledge retrieval) + + Args: + query: User's query string + conversation_history: List of previous conversation messages + language: Detected language code (e.g., 'en', 'et') + + Returns: + ClassificationResult indicating which workflow to use + + Note: + In this skeleton, always defaults to RAG. Full implementation + will add Layer 1 and Layer 2 logic in separate tasks. + """ + logger.info(f"Classifying query: {query[:100]}...") + + # TODO: LAYER 1 - SERVICE WORKFLOW DETECTION + # Implementation task: Service workflow implementation + # Logic: + # 1. Count active services in database + # 2. If count > 50: Use Qdrant semantic search for top 20 services + # 3. If count <= 50: Use all services + # 4. Call LLM to detect intent and extract entities + # 5. If intent detected and service valid: return SERVICE classification + # Example: + # service_check = await self._check_service_layer(query, language) + # if service_check.can_handle: + # return ClassificationResult( + # workflow=WorkflowType.SERVICE, + # confidence=service_check.confidence, + # metadata=service_check.metadata, + # reasoning="Service intent detected" + # ) + + # TODO: LAYER 2 - CONTEXT WORKFLOW DETECTION + # Implementation task: Context workflow implementation + # Logic: + # 1. Check if query is a greeting using LLM + # 2. If greeting: return CONTEXT classification + # 3. If conversation_history exists: Check if query references history + # 4. Call LLM to determine if history contains answer + # 5. If can answer from history: return CONTEXT classification + # Example: + # context_check = await self._check_context_layer( + # query, conversation_history, language + # ) + # if context_check.can_handle: + # return ClassificationResult( + # workflow=WorkflowType.CONTEXT, + # confidence=context_check.confidence, + # metadata=context_check.metadata, + # reasoning="Greeting or answerable from history" + # ) + + # LAYER 3 - RAG WORKFLOW (DEFAULT) + # Always defaults to RAG for now + # RAG workflow will handle the query or return OOD if no chunks found + logger.info("Defaulting to RAG workflow (Layers 1-2 not implemented)") + return ClassificationResult( + workflow=WorkflowType.RAG, + confidence=1.0, + metadata={}, + reasoning="Default to RAG workflow (service and context layers not implemented)", + ) + + @overload + async def route_to_workflow( + self, + classification: ClassificationResult, + request: OrchestrationRequest, + is_streaming: Literal[False] = False, + ) -> OrchestrationResponse: ... + + @overload + async def route_to_workflow( + self, + classification: ClassificationResult, + request: OrchestrationRequest, + is_streaming: Literal[True], + ) -> AsyncIterator[str]: ... + + async def route_to_workflow( + self, + classification: ClassificationResult, + request: OrchestrationRequest, + is_streaming: bool = False, + ) -> Union[OrchestrationResponse, AsyncIterator[str]]: + """ + Route request to appropriate workflow based on classification. + + Implements fallback chain: If a workflow returns None, tries the next layer. + This ensures queries always get handled, even if primary workflow fails. + + Args: + classification: Classification result from classify() + request: Original orchestration request + is_streaming: Whether to use streaming mode (for /orchestrate/stream) + + Returns: + OrchestrationResponse for non-streaming mode + AsyncIterator[str] for streaming mode + + Fallback Chain: + SERVICE → CONTEXT → RAG → OOD + Each layer returns None if it cannot handle, triggering next layer. + """ + chat_id = request.chatId + workflow_name = WORKFLOW_DISPLAY_NAMES.get( + classification.workflow, classification.workflow.value + ) + + logger.info( + f"[{chat_id}] Routing to {workflow_name} " + f"(streaming: {is_streaming}, confidence: {classification.confidence:.2f})" + ) + + # Get the workflow executor + workflow = self._get_workflow_executor(classification.workflow) + + if is_streaming: + # STREAMING MODE: For /orchestrate/stream endpoint + # Return the async iterator directly + return self._execute_with_fallback_streaming( + workflow=workflow, + request=request, + context=classification.metadata, + start_layer=classification.workflow, + ) + else: + # NON-STREAMING MODE: For /orchestrate and /orchestrate/test endpoints + return await self._execute_with_fallback_async( + workflow=workflow, + request=request, + context=classification.metadata, + start_layer=classification.workflow, + ) + + def _get_workflow_executor(self, workflow_type: WorkflowType) -> Any: + """Get workflow executor instance for given workflow type.""" + workflow_map = { + WorkflowType.SERVICE: self.service_workflow, + WorkflowType.CONTEXT: self.context_workflow, + WorkflowType.RAG: self.rag_workflow, + WorkflowType.OOD: self.ood_workflow, + } + return workflow_map[workflow_type] + + async def _execute_with_fallback_async( + self, + workflow: Any, + request: OrchestrationRequest, + context: Dict[str, Any], + start_layer: WorkflowType, + ) -> OrchestrationResponse: + """ + Execute workflow with fallback to subsequent layers (non-streaming). + + TODO: Implement full fallback chain logic + Currently just executes the primary workflow. + + Full implementation should: + 1. Try primary workflow + 2. If returns None, try next layer in WORKFLOW_LAYER_ORDER + 3. Continue until workflow returns non-None result + 4. OOD workflow always returns result (never None) + """ + chat_id = request.chatId + workflow_name = WORKFLOW_DISPLAY_NAMES.get(start_layer, start_layer.value) + + logger.info(f"[{chat_id}] Executing {workflow_name} (non-streaming)") + + try: + result = await workflow.execute_async(request, context) + + if result is not None: + logger.info(f"[{chat_id}] {workflow_name} handled successfully") + return result + + # TODO: Implement fallback to next layer + # For now, if workflow returns None, call RAG as fallback + logger.warning( + f"[{chat_id}] {workflow_name} returned None, " + f"falling back to RAG workflow" + ) + rag_result = await self.rag_workflow.execute_async(request, {}) + if rag_result is not None: + return rag_result + else: + # This should never happen since RAG always returns a result + # But handle gracefully + raise RuntimeError("RAG workflow returned None unexpectedly") + + except Exception as e: + logger.error(f"[{chat_id}] Error executing {workflow_name}: {e}") + # Fallback to RAG on error + logger.info(f"[{chat_id}] Falling back to RAG due to error") + rag_result = await self.rag_workflow.execute_async(request, {}) + if rag_result is not None: + return rag_result + else: + raise RuntimeError("RAG workflow returned None unexpectedly") + + async def _execute_with_fallback_streaming( + self, + workflow: Any, + request: OrchestrationRequest, + context: Dict[str, Any], + start_layer: WorkflowType, + ) -> AsyncIterator[str]: + """ + Execute workflow with fallback to subsequent layers (streaming). + + TODO: Implement full fallback chain logic + Currently just executes the primary workflow. + + Full implementation should: + 1. Try primary workflow + 2. If returns None, try next layer in WORKFLOW_LAYER_ORDER + 3. Stream from the first workflow that returns non-None + 4. OOD workflow always returns result (never None) + """ + chat_id = request.chatId + workflow_name = WORKFLOW_DISPLAY_NAMES.get(start_layer, start_layer.value) + + logger.info(f"[{chat_id}] Executing {workflow_name} (streaming)") + + try: + result = await workflow.execute_streaming(request, context) + + if result is not None: + logger.info(f"[{chat_id}] {workflow_name} streaming started") + async for chunk in result: + yield chunk + return + + # TODO: Implement fallback to next layer + # For now, if workflow returns None, call RAG as fallback + logger.warning( + f"[{chat_id}] {workflow_name} returned None, " + f"falling back to RAG workflow streaming" + ) + streaming_result = await self.rag_workflow.execute_streaming(request, {}) + if streaming_result is not None: + async for chunk in streaming_result: + yield chunk + else: + raise RuntimeError("RAG workflow returned None unexpectedly") + + except Exception as e: + logger.error(f"[{chat_id}] Error executing {workflow_name} streaming: {e}") + # Fallback to RAG on error + logger.info(f"[{chat_id}] Falling back to RAG streaming due to error") + streaming_result = await self.rag_workflow.execute_streaming(request, {}) + if streaming_result is not None: + async for chunk in streaming_result: + yield chunk + else: + raise RuntimeError("RAG workflow returned None unexpectedly") diff --git a/src/tool_classifier/enums.py b/src/tool_classifier/enums.py new file mode 100644 index 00000000..ce6c7859 --- /dev/null +++ b/src/tool_classifier/enums.py @@ -0,0 +1,39 @@ +"""Enumerations and constants for tool classifier system.""" + +from enum import Enum + + +class WorkflowType(Enum): + """ + Workflow types representing different query handling strategies. + + The tool classifier uses a layer-wise approach to determine which + workflow should handle each user query: + + - SERVICE: External service/API calls (Layer 1) + - CONTEXT: Conversation history or greetings (Layer 2) + - RAG: Knowledge base retrieval (Layer 3) + - OOD: Out-of-domain fallback (Layer 4) + """ + + SERVICE = "service" + CONTEXT = "context" + RAG = "rag" + OOD = "ood" + + +# Layer configuration - defines the order of workflow evaluation +WORKFLOW_LAYER_ORDER = [ + WorkflowType.SERVICE, # Layer 1: Try service first + WorkflowType.CONTEXT, # Layer 2: Then context + WorkflowType.RAG, # Layer 3: Then RAG + WorkflowType.OOD, # Layer 4: Finally OOD (always succeeds) +] + +# Workflow display names for logging +WORKFLOW_DISPLAY_NAMES = { + WorkflowType.SERVICE: "Service Workflow", + WorkflowType.CONTEXT: "Context Workflow", + WorkflowType.RAG: "RAG Workflow", + WorkflowType.OOD: "Out-of-Domain Workflow", +} diff --git a/src/tool_classifier/models.py b/src/tool_classifier/models.py new file mode 100644 index 00000000..9929473b --- /dev/null +++ b/src/tool_classifier/models.py @@ -0,0 +1,81 @@ +"""Data models for tool classifier system.""" + +from typing import Any, Dict, Optional +from pydantic import BaseModel, Field + +from tool_classifier.enums import WorkflowType + + +class ClassificationResult(BaseModel): + """ + Result of query classification by the tool classifier. + + This model encapsulates the decision of which workflow should handle + a user query, along with confidence score and metadata. + + Attributes: + workflow: The workflow type that should handle this query + confidence: Confidence score (0.0-1.0) for this classification + metadata: Workflow-specific data (e.g., service_id, intent, entities) + reasoning: Human-readable explanation of why this workflow was chosen + """ + + workflow: WorkflowType = Field( + ..., description="Which workflow should handle this query" + ) + confidence: float = Field( + default=1.0, + ge=0.0, + le=1.0, + description="Confidence score for this classification", + ) + metadata: Dict[str, Any] = Field( + default_factory=dict, description="Workflow-specific data passed to executor" + ) + reasoning: Optional[str] = Field( + default=None, description="Explanation of classification decision" + ) + + +class ServiceWorkflowMetadata(BaseModel): + """ + Metadata specific to Service Workflow execution. + + TODO: Will be populated by service discovery logic with: + - service_id: Identified service to call + - intent: Detected user intent + - entities: Extracted parameters for service call + - confidence: Intent detection confidence + """ + + service_id: Optional[str] = Field( + default=None, description="ID of the service to execute" + ) + intent: Optional[str] = Field( + default=None, description="Detected user intent/service name" + ) + entities: Optional[Dict[str, Any]] = Field( + default=None, description="Extracted entities/parameters" + ) + + +class ContextWorkflowMetadata(BaseModel): + """ + Metadata specific to Context Workflow execution. + + TODO: Will be populated by context analysis logic with: + - is_greeting: Whether query is a greeting + - greeting_type: Type of greeting (hello, goodbye, thanks, etc.) + - can_answer_from_history: Whether conversation history has answer + - relevant_history_indices: Indices of relevant history items + """ + + is_greeting: bool = Field( + default=False, description="Whether this is a greeting/conversational query" + ) + greeting_type: Optional[str] = Field( + default=None, description="Type of greeting: hello, goodbye, thanks, casual" + ) + can_answer_from_history: bool = Field( + default=False, description="Whether conversation history can answer this" + ) diff --git a/src/tool_classifier/workflows/__init__.py b/src/tool_classifier/workflows/__init__.py new file mode 100644 index 00000000..3d733d54 --- /dev/null +++ b/src/tool_classifier/workflows/__init__.py @@ -0,0 +1,13 @@ +"""Workflow executor implementations.""" + +from tool_classifier.workflows.service_workflow import ServiceWorkflowExecutor +from tool_classifier.workflows.context_workflow import ContextWorkflowExecutor +from tool_classifier.workflows.rag_workflow import RAGWorkflowExecutor +from tool_classifier.workflows.ood_workflow import OODWorkflowExecutor + +__all__ = [ + "ServiceWorkflowExecutor", + "ContextWorkflowExecutor", + "RAGWorkflowExecutor", + "OODWorkflowExecutor", +] diff --git a/src/tool_classifier/workflows/context_workflow.py b/src/tool_classifier/workflows/context_workflow.py new file mode 100644 index 00000000..88212efa --- /dev/null +++ b/src/tool_classifier/workflows/context_workflow.py @@ -0,0 +1,86 @@ +"""Context workflow executor - Layer 2: Conversation history and greetings.""" + +from typing import Any, AsyncIterator, Dict, Optional +from loguru import logger + +from models.request_models import OrchestrationRequest, OrchestrationResponse +from tool_classifier.base_workflow import BaseWorkflow + + +class ContextWorkflowExecutor(BaseWorkflow): + """ + Handles greetings and conversation history queries (Layer 2). + + Detects: + - Greetings: "Hello", "Thanks", "Goodbye" + - History references: "What did you say earlier?", "Can you repeat that?" + + Uses LLM for semantic detection (multilingual), no regex patterns. + + Status: SKELETON - Returns None (fallback to RAG) + TODO: Implement greeting/context detection, answer extraction, guardrails + """ + + def __init__(self, llm_manager: Any): + """ + Initialize context workflow executor. + + Args: + llm_manager: LLM manager for context analysis + """ + self.llm_manager = llm_manager + logger.info("Context workflow executor initialized (skeleton)") + + async def execute_async( + self, + request: OrchestrationRequest, + context: Dict[str, Any], + ) -> Optional[OrchestrationResponse]: + """ + Execute context workflow in non-streaming mode. + + TODO: Check greeting (LLM) → generate response, OR check history (last 10 turns) + → extract answer → validate with guardrails. Return None if cannot answer. + + Args: + request: Orchestration request with user query and history + context: Metadata with is_greeting, can_answer_from_history flags + + Returns: + OrchestrationResponse with context-based answer or None to fallback + """ + logger.debug( + f"[{request.chatId}] Context workflow execute_async called " + f"(not implemented - returning None)" + ) + + # TODO: Implement context workflow logic here + # For now, return None to trigger fallback to next layer (RAG) + return None + + async def execute_streaming( + self, + request: OrchestrationRequest, + context: Dict[str, Any], + ) -> Optional[AsyncIterator[str]]: + """ + Execute context workflow in streaming mode. + + TODO: Get answer (greeting/history) → validate BEFORE streaming → chunk and + yield as SSE. Return None if cannot answer. + + Args: + request: Orchestration request with user query and history + context: Metadata with is_greeting, can_answer_from_history flags + + Returns: + AsyncIterator yielding SSE strings or None to fallback + """ + logger.debug( + f"[{request.chatId}] Context workflow execute_streaming called " + f"(not implemented - returning None)" + ) + + # TODO: Implement context streaming logic here + # For now, return None to trigger fallback to next layer (RAG) + return None diff --git a/src/tool_classifier/workflows/ood_workflow.py b/src/tool_classifier/workflows/ood_workflow.py new file mode 100644 index 00000000..fed467a5 --- /dev/null +++ b/src/tool_classifier/workflows/ood_workflow.py @@ -0,0 +1,131 @@ +"""OOD workflow executor - Layer 4: Out-of-domain fallback.""" + +from typing import Any, AsyncIterator, Dict, Optional +from loguru import logger + +from models.request_models import OrchestrationRequest, OrchestrationResponse +from tool_classifier.base_workflow import BaseWorkflow + + +class OODWorkflowExecutor(BaseWorkflow): + """ + Handles out-of-domain queries that no workflow can answer (Layer 4). + + This is the final fallback in the workflow chain. It returns a polite + "cannot answer" message when: + - No service matches (Layer 1 failed) + - No context match (Layer 2 failed) + - No relevant knowledge chunks (Layer 3 failed) + + Examples of OOD queries: + - "What's the weather today?" (not in scope) + - "Tell me a joke" (not government service) + - Questions with no relevant knowledge + + Implementation Status: SKELETON + Returns None (will implement to return OOD message) + + TODO - Implementation (Simple): + - Return localized OUT_OF_SCOPE_MESSAGE + - Set questionOutOfLLMScope flag to True + - For streaming: chunk message and stream for UX consistency + """ + + def __init__(self): + """Initialize OOD workflow executor.""" + logger.info("OOD workflow executor initialized (skeleton)") + + async def execute_async( + self, + request: OrchestrationRequest, + context: Dict[str, Any], + ) -> Optional[OrchestrationResponse]: + """ + Execute OOD workflow in non-streaming mode. + + TODO: Implement OOD response: + ```python + from src.llm_orchestrator_config.llm_ochestrator_constants import ( + get_localized_message, + OUT_OF_SCOPE_MESSAGES, + ) + + # Get detected language from request + detected_language = getattr(request, "_detected_language", "en") + + # Get localized message + ood_message = get_localized_message(OUT_OF_SCOPE_MESSAGES, detected_language) + + return OrchestrationResponse( + chatId=request.chatId, + llmServiceActive=True, + questionOutOfLLMScope=True, # Flag as out of scope + inputGuardFailed=False, + content=ood_message, + ) + ``` + + Args: + request: Orchestration request with user query + context: Unused (OOD doesn't need metadata) + + Returns: + OrchestrationResponse with OOD message + Never returns None (this is final fallback) + """ + logger.info( + f"[{request.chatId}] OOD workflow execute_async called " + f"(not implemented - returning None for now)" + ) + + # TODO: Implement OOD response logic here + # For now, return None (will be implemented as simple message return) + return None + + async def execute_streaming( + self, + request: OrchestrationRequest, + context: Dict[str, Any], + ) -> Optional[AsyncIterator[str]]: + """ + Execute OOD workflow in streaming mode. + + TODO: Implement OOD streaming: + ```python + from src.llm_orchestrator_config.llm_ochestrator_constants import ( + get_localized_message, + OUT_OF_SCOPE_MESSAGES, + ) + + # Get localized OOD message + detected_language = getattr(request, "_detected_language", "en") + ood_message = get_localized_message(OUT_OF_SCOPE_MESSAGES, detected_language) + + # Stream message for UX consistency (no guardrails needed - fixed message) + async def stream_ood_message(): + for chunk in split_into_tokens(ood_message, chunk_size=5): + yield self._format_sse(request.chatId, chunk) + await asyncio.sleep(0.01) + yield self._format_sse(request.chatId, "END") + + return stream_ood_message() + ``` + + Note: No output guardrails needed since this is a fixed, safe message. + + Args: + request: Orchestration request with user query + context: Unused (OOD doesn't need metadata) + + Returns: + AsyncIterator yielding SSE strings + Never returns None (this is final fallback) + """ + logger.info( + f"[{request.chatId}] OOD workflow execute_streaming called " + f"(not implemented - returning None for now)" + ) + + # TODO: Implement OOD streaming logic here + # For now, return None (will be implemented as simple message streaming) + return None diff --git a/src/tool_classifier/workflows/rag_workflow.py b/src/tool_classifier/workflows/rag_workflow.py new file mode 100644 index 00000000..d83080a7 --- /dev/null +++ b/src/tool_classifier/workflows/rag_workflow.py @@ -0,0 +1,172 @@ +"""RAG workflow executor - Layer 3: Knowledge base retrieval.""" + +from typing import Any, AsyncIterator, Dict, Optional +from loguru import logger + +from models.request_models import OrchestrationRequest, OrchestrationResponse +from tool_classifier.base_workflow import BaseWorkflow + + +class RAGWorkflowExecutor(BaseWorkflow): + """ + Wrapper for existing RAG (Retrieval-Augmented Generation) workflow (Layer 3). + + This workflow handles queries that require searching the knowledge base + and generating responses based on retrieved chunks. It uses the existing + RAG pipeline: + 1. Prompt refinement + 2. Contextual retrieval (Qdrant + BM25) + 3. Rank fusion (RRF) + 4. Response generation + 5. Output guardrails (validation-first streaming) + + Examples of RAG queries: + - "What are digital signatures?" + - "How do I register a company?" + - "Explain tax regulations" + + Implementation Status: COMPLETE + This is a thin wrapper that delegates to existing LLMOrchestrationService methods. + + No TODO - Just wraps existing pipeline: + - Non-streaming: Calls _execute_orchestration_pipeline() + - Streaming: Calls existing streaming logic with NeMo guardrails + + Note: If no relevant chunks found, returns OOD response (not None) + """ + + def __init__(self, orchestration_service: Any): + """ + Initialize RAG workflow executor. + + Args: + orchestration_service: Reference to LLMOrchestrationService + for calling existing RAG pipeline + """ + self.orchestration_service = orchestration_service + logger.info("RAG workflow executor initialized (wrapper)") + + async def execute_async( + self, + request: OrchestrationRequest, + context: Dict[str, Any], + ) -> Optional[OrchestrationResponse]: + """ + Execute RAG workflow in non-streaming mode. + + Delegates to existing LLMOrchestrationService._execute_orchestration_pipeline() + which handles: + - Prompt refinement + - Chunk retrieval (Qdrant + BM25) + - Response generation + - Output guardrails + + Args: + request: Orchestration request with user query + context: Unused (RAG doesn't need classification metadata) + + Returns: + OrchestrationResponse with RAG-generated answer + Never returns None (handles OOD internally) + """ + logger.info(f"[{request.chatId}] Executing RAG workflow (non-streaming)") + + # Initialize components needed for RAG pipeline + costs_dict: Dict[str, Any] = {} + timing_dict: Dict[str, float] = {} + + # Initialize service components + components = self.orchestration_service._initialize_service_components(request) + + # Call existing RAG pipeline + response = await self.orchestration_service._execute_orchestration_pipeline( + request=request, + components=components, + costs_dict=costs_dict, + timing_dict=timing_dict, + ) + + # Log costs and timings + self.orchestration_service._log_costs(costs_dict) + from src.utils.time_tracker import log_step_timings + + log_step_timings(timing_dict, request.chatId) + + return response + + async def execute_streaming( + self, + request: OrchestrationRequest, + context: Dict[str, Any], + ) -> Optional[AsyncIterator[str]]: + """ + Execute RAG workflow in streaming mode. + + Delegates to existing streaming pipeline which handles: + - Prompt refinement (blocking) + - Chunk retrieval (blocking) + - Streaming through NeMo guardrails (validation-first) + - Real-time token validation + + The existing implementation uses NeMo's stream_with_guardrails which: + - Buffers tokens (chunk_size=200) + - Validates each buffer before yielding + - Provides true validation-first streaming + + Args: + request: Orchestration request with user query + context: Unused (RAG doesn't need classification metadata) + + Returns: + AsyncIterator yielding SSE-formatted strings + Never returns None (handles OOD internally) + """ + logger.info(f"[{request.chatId}] Executing RAG workflow (streaming)") + + # Initialize tracking dictionaries + costs_dict: Dict[str, Any] = {} + timing_dict: Dict[str, float] = {} + + # Get components from context if provided, otherwise initialize + components = context.get("components") + if components is None: + components = self.orchestration_service._initialize_service_components( + request + ) + + # Get stream context from context if provided, otherwise create minimal tracking + stream_ctx = context.get("stream_ctx") + if stream_ctx is None: + # Create minimal stream context when called via tool classifier + # In production flow, this is provided by stream_orchestration_response + class MinimalStreamContext: + """Minimal stream context for RAG workflow when called directly.""" + + def __init__(self, chat_id: str) -> None: + self.stream_id = f"rag-{chat_id}" + self.token_count = 0 + self.bot_generator = None + + def mark_completed(self) -> None: + """No-op: Tracking handled by orchestration service.""" + pass + + def mark_cancelled(self) -> None: + """No-op: Tracking handled by orchestration service.""" + pass + + def mark_error(self, error_id: str) -> None: + """No-op: Tracking handled by orchestration service.""" + pass + + stream_ctx = MinimalStreamContext(request.chatId) + + # Delegate to core RAG pipeline (bypasses classifier to avoid recursion) + async for sse_chunk in self.orchestration_service._stream_rag_pipeline( + request=request, + components=components, + stream_ctx=stream_ctx, + costs_dict=costs_dict, + timing_dict=timing_dict, + ): + yield sse_chunk diff --git a/src/tool_classifier/workflows/service_workflow.py b/src/tool_classifier/workflows/service_workflow.py new file mode 100644 index 00000000..8a6889bc --- /dev/null +++ b/src/tool_classifier/workflows/service_workflow.py @@ -0,0 +1,137 @@ +"""Service workflow executor - Layer 1: External service/API calls.""" + +from typing import Any, AsyncIterator, Dict, Optional +from loguru import logger + +from models.request_models import OrchestrationRequest, OrchestrationResponse +from tool_classifier.base_workflow import BaseWorkflow + + +class ServiceWorkflowExecutor(BaseWorkflow): + """ + Executes external service calls via Ruuter endpoints (Layer 1). + + This workflow handles queries that require calling external government + services or APIs. It performs: + 1. Service discovery (semantic search if >50 services) + 2. Intent detection using LLM + 3. Entity extraction from query + 4. Service validation against database + 5. External API call via Ruuter + 6. Output guardrails validation + + Examples of Service queries: + - "What's the EUR to USD exchange rate?" + - "Check my document status" + - "Submit a tax declaration" + + Implementation Status: SKELETON + Returns None (triggers fallback to Context workflow) + + TODO - Full Implementation (Separate Task): + - Service discovery logic (Qdrant semantic search) + - Intent detection (LLM-based) + - Entity extraction and transformation + - Service validation (database lookup) + - Ruuter API integration + - Output guardrails for service responses + """ + + def __init__(self, llm_manager: Any): + """ + Initialize service workflow executor. + + Args: + llm_manager: LLM manager for intent detection + """ + self.llm_manager = llm_manager + logger.info("Service workflow executor initialized (skeleton)") + + async def execute_async( + self, + request: OrchestrationRequest, + context: Dict[str, Any], + ) -> Optional[OrchestrationResponse]: + """ + Execute service workflow in non-streaming mode. + + TODO: Implement service workflow logic: + 1. Extract service metadata from context (service_id, intent, entities) + 2. Validate service exists and is active in database + 3. Transform entities to array format for service call + 4. Call Ruuter endpoint: POST {RUUTER_BASE_URL}/services/active{ServiceName} + 5. Validate response with output guardrails + 6. Return OrchestrationResponse with service result + + Failure scenarios: + - No service_id in context → return None (fallback to Context) + - Service not found/inactive → return None (fallback to Context) + - Service call timeout → return error response + - Output guardrails blocked → return violation response or None + + Args: + request: Orchestration request with user query + context: Metadata with service_id, intent, entities + + Returns: + OrchestrationResponse with service result or None to fallback + """ + logger.debug( + f"[{request.chatId}] Service workflow execute_async called " + f"(not implemented - returning None)" + ) + + # TODO: Implement service workflow logic here + # For now, return None to trigger fallback to next layer + return None + + async def execute_streaming( + self, + request: OrchestrationRequest, + context: Dict[str, Any], + ) -> Optional[AsyncIterator[str]]: + """ + Execute service workflow in streaming mode. + + TODO: Implement service workflow streaming: + 1. Execute service call (same as non-streaming) + 2. Get complete service response + 3. Validate with output guardrails (validation-first) + 4. If blocked: yield violation message + END + 5. If allowed: chunk response and stream token-by-token + 6. Simulate streaming for consistent UX with RAG + + Streaming approach (validation-first): + ```python + # Get complete response + service_response = await call_service(...) + + # Validate BEFORE streaming + is_safe = await guardrails.check_output_async(service_response) + if not is_safe: + yield format_sse(chatId, VIOLATION_MESSAGE) + yield format_sse(chatId, "END") + return + + # Stream validated response + for chunk in split_into_tokens(service_response, chunk_size=5): + yield format_sse(chatId, chunk) + await asyncio.sleep(0.01) + yield format_sse(chatId, "END") + ``` + + Args: + request: Orchestration request with user query + context: Metadata with service_id, intent, entities + + Returns: + AsyncIterator yielding SSE strings or None to fallback + """ + logger.debug( + f"[{request.chatId}] Service workflow execute_streaming called " + f"(not implemented - returning None)" + ) + + # TODO: Implement service streaming logic here + # For now, return None to trigger fallback to next layer + return None