Skip to content

Commit 292e794

Browse files
author
erangi-ar
committed
Merge branch 'wip' of https://github.com/rootcodelabs/RAG-Module into streaming-response-formatting
2 parents 800db34 + f6c6dc7 commit 292e794

File tree

7 files changed

+549
-12
lines changed

7 files changed

+549
-12
lines changed

src/llm_orchestration_service.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
INPUT_GUARDRAIL_VIOLATION_MESSAGES,
3535
OUTPUT_GUARDRAIL_VIOLATION_MESSAGE,
3636
OUTPUT_GUARDRAIL_VIOLATION_MESSAGES,
37+
QUERY_VALIDATION_FAILED_MESSAGES,
3738
get_localized_message,
3839
GUARDRAILS_BLOCKED_PHRASES,
3940
TEST_DEPLOYMENT_ENVIRONMENT,
@@ -52,6 +53,7 @@
5253
from src.utils.production_store import get_production_store
5354
from src.utils.language_detector import detect_language, get_language_name
5455
from src.utils.prompt_config_loader import PromptConfigurationLoader
56+
from src.utils.query_validator import validate_query_basic
5557
from src.guardrails import NeMoRailsAdapter, GuardrailCheckResult
5658
from src.contextual_retrieval import ContextualRetriever
5759
from src.llm_orchestrator_config.exceptions import (
@@ -170,7 +172,36 @@ def process_orchestration_request(
170172
# Using setattr for type safety - adds dynamic attribute to Pydantic model instance
171173
setattr(request, "_detected_language", detected_language)
172174

173-
# Initialize all service components
175+
# STEP 0.5: Basic Query Validation (before expensive component initialization)
176+
validation_result = validate_query_basic(request.message)
177+
if not validation_result.is_valid:
178+
logger.info(
179+
f"[{request.chatId}] Query validation failed: {validation_result.rejection_reason}"
180+
)
181+
# Get localized message
182+
validation_msg = get_localized_message(
183+
QUERY_VALIDATION_FAILED_MESSAGES, detected_language
184+
)
185+
186+
# Return appropriate response type without initializing components
187+
if request.environment == TEST_DEPLOYMENT_ENVIRONMENT:
188+
return TestOrchestrationResponse(
189+
llmServiceActive=True,
190+
questionOutOfLLMScope=False,
191+
inputGuardFailed=False,
192+
content=validation_msg,
193+
chunks=None,
194+
)
195+
else:
196+
return OrchestrationResponse(
197+
chatId=request.chatId,
198+
llmServiceActive=True,
199+
questionOutOfLLMScope=False,
200+
inputGuardFailed=False,
201+
content=validation_msg,
202+
)
203+
204+
# Initialize all service components (only for valid queries)
174205
components = self._initialize_service_components(request)
175206

176207
# Execute the orchestration pipeline
@@ -299,6 +330,22 @@ async def stream_orchestration_response(
299330
# Using setattr for type safety - adds dynamic attribute to Pydantic model instance
300331
setattr(request, "_detected_language", detected_language)
301332

333+
# Step 0.5: Basic Query Validation (before guardrails)
334+
validation_result = validate_query_basic(request.message)
335+
if not validation_result.is_valid:
336+
logger.info(
337+
f"[{request.chatId}] Streaming - Query validation failed: {validation_result.rejection_reason}"
338+
)
339+
# Get localized message
340+
validation_msg = get_localized_message(
341+
QUERY_VALIDATION_FAILED_MESSAGES, detected_language
342+
)
343+
344+
# Yield SSE format error + END marker
345+
yield self._format_sse(request.chatId, validation_msg)
346+
yield self._format_sse(request.chatId, "END")
347+
return # Stop processing
348+
302349
# Use StreamManager for centralized tracking and guaranteed cleanup
303350
async with stream_manager.managed_stream(
304351
chat_id=request.chatId, author_id=request.authorId
@@ -953,6 +1000,9 @@ def _execute_orchestration_pipeline(
9531000
timing_dict: Dict[str, float],
9541001
) -> Union[OrchestrationResponse, TestOrchestrationResponse]:
9551002
"""Execute the main orchestration pipeline with all components."""
1003+
# Note: Query validation now happens in process_orchestration_request()
1004+
# before component initialization for true early rejection
1005+
9561006
# Step 1: Input Guardrails Check
9571007
if components["guardrails_adapter"]:
9581008
start_time = time.time()

src/llm_orchestration_service_api.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,60 @@ def refresh_prompt_config(http_request: Request) -> Dict[str, Any]:
839839
},
840840
) from e
841841

842+
try:
843+
success = orchestration_service.prompt_config_loader.force_refresh()
844+
845+
if success:
846+
# Get prompt metadata without exposing content (security)
847+
custom_instructions = (
848+
orchestration_service.prompt_config_loader.get_custom_instructions()
849+
)
850+
prompt_length = len(custom_instructions)
851+
852+
# Generate hash for verification purposes (without exposing content)
853+
import hashlib
854+
855+
prompt_hash = hashlib.sha256(custom_instructions.encode()).hexdigest()[:16]
856+
857+
logger.info(
858+
f"Prompt configuration cache refreshed successfully ({prompt_length} chars)"
859+
)
860+
861+
return {
862+
"refreshed": True,
863+
"message": "Prompt configuration refreshed successfully",
864+
"prompt_length": prompt_length,
865+
"content_hash": prompt_hash, # Safe: hash instead of preview
866+
}
867+
else:
868+
# No fresh data loaded - could be fetch failure or truly not found
869+
error_id = generate_error_id()
870+
logger.warning(
871+
f"[{error_id}] Prompt configuration refresh returned empty result"
872+
)
873+
raise HTTPException(
874+
status_code=status.HTTP_404_NOT_FOUND,
875+
detail={
876+
"error": "No prompt configuration found in database",
877+
"error_id": error_id,
878+
},
879+
)
880+
881+
except HTTPException:
882+
# Re-raise HTTP exceptions as-is
883+
raise
884+
except Exception as e:
885+
# Unexpected errors during refresh
886+
error_id = generate_error_id()
887+
logger.error(f"[{error_id}] Failed to refresh prompt configuration: {e}")
888+
raise HTTPException(
889+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
890+
detail={
891+
"error": "Failed to refresh prompt configuration",
892+
"error_id": error_id,
893+
},
894+
) from e
895+
842896

843897
if __name__ == "__main__":
844898
logger.info("Starting LLM Orchestration Service API server on port 8100")

src/llm_orchestrator_config/llm_ochestrator_constants.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
"en": "I apologize, but I'm unable to provide a response as it may violate our usage policies.",
2424
}
2525

26+
# Query validation messages - single generic message for all rejection types
27+
# (empty queries, special characters only, too short, repetitive characters)
28+
QUERY_VALIDATION_FAILED_MESSAGES = {
29+
"et": "Palun esitage kehtiv küsimus või sõnum, et ma saaksin teid aidata."
30+
}
31+
2632
# Legacy constants for backward compatibility (English defaults)
2733
OUT_OF_SCOPE_MESSAGE = OUT_OF_SCOPE_MESSAGES["en"]
2834
TECHNICAL_ISSUE_MESSAGE = TECHNICAL_ISSUE_MESSAGES["en"]
@@ -106,9 +112,9 @@
106112

107113

108114
# Helper function to get localized messages
109-
def get_localized_message(message_dict: dict, language_code: str = "en") -> str:
115+
def get_localized_message(message_dict: dict, language_code: str = "et") -> str:
110116
"""
111-
Get message in the specified language, fallback to English.
117+
Get message in the specified language, fallback to Estonian.
112118
113119
Args:
114120
message_dict: Dictionary with language codes as keys
@@ -117,7 +123,7 @@ def get_localized_message(message_dict: dict, language_code: str = "en") -> str:
117123
Returns:
118124
Localized message string
119125
"""
120-
return message_dict.get(language_code, message_dict.get("en", ""))
126+
return message_dict.get(language_code, message_dict.get("et", ""))
121127

122128

123129
# Service endpoints

src/models/request_models.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,19 +66,20 @@ class OrchestrationRequest(BaseModel):
6666
def validate_and_sanitize_message(cls, v: str) -> str:
6767
"""Sanitize and validate user message.
6868
69-
Note: Content safety checks (prompt injection, PII, harmful content)
69+
Note: This validator only handles security/format concerns:
70+
- XSS/HTML sanitization
71+
- Maximum length enforcement
72+
73+
Query quality validation (empty messages, special chars, etc.) is handled
74+
by the business logic layer (query_validator) with localized error messages.
75+
76+
Content safety checks (prompt injection, PII, harmful content)
7077
are handled by NeMo Guardrails after this validation layer.
7178
"""
7279
# Sanitize HTML/XSS and normalize whitespace
7380
v = InputSanitizer.sanitize_message(v)
7481

75-
# Check if message is empty after sanitization
76-
if not v or len(v.strip()) < 3:
77-
raise ValueError(
78-
"Message must contain at least 3 characters after sanitization"
79-
)
80-
81-
# Check length after sanitization
82+
# Check length after sanitization (resource protection)
8283
if len(v) > StreamConfig.MAX_MESSAGE_LENGTH:
8384
raise ValueError(
8485
f"Message exceeds maximum length of {StreamConfig.MAX_MESSAGE_LENGTH} characters"

src/utils/query_validator.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""Basic query validation for empty/meaningless inputs.
2+
3+
This module provides lightweight, rule-based validation to reject syntactically
4+
invalid queries before they reach expensive LLM-based processing stages.
5+
6+
Validation checks (all syntactic, NO semantic):
7+
- Empty or whitespace-only messages
8+
- Messages containing only special characters/punctuation (including unicode)
9+
- Messages with too few meaningful characters (< 2)
10+
- Messages with only repetitive characters (e.g., "aaaa", "????")
11+
- Emoji-only messages
12+
13+
Out of scope for this module:
14+
- Semantic validation (greetings, chitchat, intent detection)
15+
- Language quality checks
16+
- Content policy checks (handled by guardrails)
17+
18+
Design decisions:
19+
- Numbers are considered valid (e.g., "123" passes validation)
20+
- Mixed alphanumeric with punctuation is valid (e.g., "ab!" passes)
21+
- Unicode punctuation is treated same as ASCII punctuation
22+
- Emojis are not considered meaningful characters
23+
"""
24+
25+
import re
26+
from typing import Optional
27+
from pydantic import BaseModel
28+
29+
30+
class QueryValidationResult(BaseModel):
31+
"""Result of basic query validation.
32+
33+
Attributes:
34+
is_valid: True if query passes all validation checks
35+
rejection_reason: Optional reason code if validation fails
36+
(empty, special_chars_only, too_short, repetitive)
37+
"""
38+
39+
is_valid: bool
40+
rejection_reason: Optional[str] = None
41+
42+
43+
def validate_query_basic(query: str) -> QueryValidationResult:
44+
"""
45+
Validate query for basic syntactic issues (NOT semantic).
46+
47+
This is a fast, rule-based check that runs before expensive operations
48+
like guardrails or prompt refinement. It only catches obvious syntactic
49+
issues, not semantic problems.
50+
51+
Args:
52+
query: User's input message to validate
53+
54+
Returns:
55+
QueryValidationResult with is_valid flag and optional rejection_reason
56+
57+
Examples:
58+
Valid queries:
59+
>>> validate_query_basic("How to apply for benefits?")
60+
QueryValidationResult(is_valid=True, rejection_reason=None)
61+
>>> validate_query_basic("hi")
62+
QueryValidationResult(is_valid=True, rejection_reason=None)
63+
>>> validate_query_basic("123")
64+
QueryValidationResult(is_valid=True, rejection_reason=None)
65+
>>> validate_query_basic("ab!")
66+
QueryValidationResult(is_valid=True, rejection_reason=None)
67+
68+
Invalid queries:
69+
>>> validate_query_basic("...")
70+
QueryValidationResult(is_valid=False, rejection_reason='special_chars_only')
71+
>>> validate_query_basic("")
72+
QueryValidationResult(is_valid=False, rejection_reason='empty')
73+
>>> validate_query_basic("????")
74+
QueryValidationResult(is_valid=False, rejection_reason='repetitive')
75+
>>> validate_query_basic("a")
76+
QueryValidationResult(is_valid=False, rejection_reason='too_short')
77+
>>> validate_query_basic("😀😀😀")
78+
QueryValidationResult(is_valid=False, rejection_reason='special_chars_only')
79+
"""
80+
# Trim whitespace
81+
query = query.strip()
82+
83+
# Check 1: Empty query
84+
if not query:
85+
return QueryValidationResult(is_valid=False, rejection_reason="empty")
86+
87+
# Check 2: Only special characters/punctuation (including unicode and emojis)
88+
# Remove all alphanumeric characters (letters and numbers in any language)
89+
# If nothing remains or only punctuation/symbols/emojis, reject
90+
alphanumeric_pattern = re.compile(r"[\w]", re.UNICODE)
91+
has_alphanumeric = bool(alphanumeric_pattern.search(query))
92+
93+
if not has_alphanumeric:
94+
# No letters or numbers found - only punctuation/symbols/emojis
95+
return QueryValidationResult(
96+
is_valid=False, rejection_reason="special_chars_only"
97+
)
98+
99+
# Check 3: Too short (< 2 meaningful characters)
100+
# Extract alphanumeric characters (letters + numbers, unicode-aware)
101+
meaningful_chars = alphanumeric_pattern.findall(query)
102+
if len(meaningful_chars) < 2:
103+
return QueryValidationResult(is_valid=False, rejection_reason="too_short")
104+
105+
# Check 4: Only repetitive characters (e.g., "aaaa", "????", "111")
106+
# If all meaningful characters are the same (case-insensitive), likely spam
107+
unique_chars = {c.lower() for c in meaningful_chars}
108+
if len(unique_chars) == 1:
109+
return QueryValidationResult(is_valid=False, rejection_reason="repetitive")
110+
111+
# Passed all checks - query is syntactically valid
112+
return QueryValidationResult(is_valid=True)

tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Pytest configuration for test discovery and imports."""
2+
3+
import sys
4+
from pathlib import Path
5+
6+
# Add the project root to Python path so tests can import from src
7+
project_root = Path(__file__).parent.parent
8+
sys.path.insert(0, str(project_root))

0 commit comments

Comments
 (0)