diff --git a/examples/chat/chat_history_analytics.py b/examples/chat/chat_history_analytics.py new file mode 100644 index 000000000..cc9cb57be --- /dev/null +++ b/examples/chat/chat_history_analytics.py @@ -0,0 +1,617 @@ +""" +Ragbits Chat Example: Advanced History Persistence with Analytics + +This example demonstrates advanced usage of SQLHistoryPersistence including: + +- Querying conversation history with custom filters +- Analyzing conversation patterns and metrics +- Exporting conversation data +- Managing conversation lifecycle (archiving, deletion) +- Working with PostgreSQL for production use cases + +To run the script, execute the following command: + + ```bash + uv run python examples/chat/chat_history_analytics.py + ``` + +Requirements: + - aiosqlite (for async SQLite support) + - ragbits-chat +""" + +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "ragbits-chat", +# "aiosqlite>=0.21.0", +# "greenlet>=3.0.0", +# ] +# /// + +import asyncio +import json +from collections.abc import AsyncGenerator +from datetime import datetime, timedelta +from typing import Any + +import sqlalchemy +from sqlalchemy import and_, desc, func +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine + +from ragbits.chat.interface import ChatInterface +from ragbits.chat.interface.types import ChatContext, ChatResponse +from ragbits.chat.persistence.sql import SQLHistoryPersistence, SQLHistoryPersistenceOptions +from ragbits.core.prompt import ChatFormat + + +async def mock_llm_response(message: str) -> str: + """Mock LLM response for demonstration purposes.""" + responses = { + "What is Python?": ( + "Python is a high-level, interpreted programming language known for its simplicity and " + "readability. It's widely used in web development, data science, AI, and automation." + ), + "How do I install packages?": ( + "You can install Python packages using pip, the package installer for Python. Simply run " + "'pip install package-name' in your terminal." + ), + "Tell me about virtual environments": ( + "Virtual environments are isolated Python environments that allow you to manage project " + "dependencies separately. Use 'python -m venv myenv' to create one." + ), + "What is machine learning?": ( + "Machine learning is a subset of AI that enables systems to learn and improve from " + "experience without being explicitly programmed. It uses algorithms to find patterns in data." + ), + "Explain neural networks": ( + "Neural networks are computing systems inspired by biological neural networks. They consist " + "of layers of interconnected nodes that process information to learn patterns and make " + "predictions." + ), + "What is the capital of France?": ( + "The capital of France is Paris, known for its art, culture, and iconic landmarks like the " "Eiffel Tower." + ), + "Tell me about Paris": ( + "Paris is the capital of France, famous for its museums, architecture, cuisine, and " + "landmarks like the Louvre, Notre-Dame, and the Eiffel Tower." + ), + "How does async/await work in Python?": ( + "Async/await in Python allows you to write asynchronous code that can handle multiple tasks " + "concurrently. 'async def' defines a coroutine, and 'await' pauses execution until the " + "awaited task completes." + ), + "What are coroutines?": ( + "Coroutines are special functions in Python that can pause and resume their execution. " + "They're defined with 'async def' and are the building blocks of asynchronous programming." + ), + "Explain REST APIs": ( + "REST (Representational State Transfer) APIs are web services that use HTTP methods " + "(GET, POST, PUT, DELETE) to perform operations on resources. They're stateless and use " + "standard HTTP protocols." + ), + "What is GraphQL?": ( + "GraphQL is a query language for APIs that allows clients to request exactly the data they " + "need. Unlike REST, it uses a single endpoint and a flexible query syntax." + ), + "Compare REST and GraphQL": ( + "REST uses multiple endpoints and fixed data structures, while GraphQL uses a single " + "endpoint and allows clients to specify exactly what data they need. GraphQL reduces " + "over-fetching but adds complexity." + ), + } + return responses.get(message, f"This is a mock response to: {message}") + + +class AnalyticsHistoryPersistence(SQLHistoryPersistence): + """ + Extended SQLHistoryPersistence with analytics and management features. + + This class adds advanced querying, analytics, and management capabilities + on top of the base SQLHistoryPersistence. + """ + + async def get_conversation_count(self) -> int: + """Get the total number of conversations.""" + await self._init_db() + + async with AsyncSession(self.sqlalchemy_engine) as session: + result = await session.execute(sqlalchemy.select(func.count()).select_from(self.Conversation)) + return result.scalar() or 0 + + async def get_total_interactions_count(self) -> int: + """Get the total number of chat interactions across all conversations.""" + await self._init_db() + + async with AsyncSession(self.sqlalchemy_engine) as session: + result = await session.execute(sqlalchemy.select(func.count()).select_from(self.ChatInteraction)) + return result.scalar() or 0 + + async def get_recent_conversations(self, limit: int = 10) -> list[dict[str, Any]]: + """ + Get the most recent conversations. + + Args: + limit: Maximum number of conversations to retrieve + + Returns: + List of conversation dictionaries with metadata + """ + await self._init_db() + + async with AsyncSession(self.sqlalchemy_engine) as session: + result = await session.execute( + sqlalchemy.select(self.Conversation).order_by(desc(self.Conversation.created_at)).limit(limit) + ) + conversations = result.scalars().all() + + conversation_data = [] + for conv in conversations: + # Get interaction count for this conversation + interaction_result = await session.execute( + sqlalchemy.select(func.count()) + .select_from(self.ChatInteraction) + .where(self.ChatInteraction.conversation_id == conv.id) + ) + interaction_count = interaction_result.scalar() or 0 + + conversation_data.append( + { + "id": conv.id, + "created_at": conv.created_at, + "interaction_count": interaction_count, + } + ) + + return conversation_data + + async def search_interactions( + self, + query: str, + search_in_messages: bool = True, + search_in_responses: bool = True, + limit: int = 20, + ) -> list[dict[str, Any]]: + """ + Search for interactions containing specific text. + + Args: + query: Text to search for + search_in_messages: Whether to search in user messages + search_in_responses: Whether to search in assistant responses + limit: Maximum number of results + + Returns: + List of matching interactions + """ + await self._init_db() + + async with AsyncSession(self.sqlalchemy_engine) as session: + filters = [] + if search_in_messages: + filters.append(self.ChatInteraction.message.contains(query)) + if search_in_responses: + filters.append(self.ChatInteraction.response.contains(query)) + + if not filters: + return [] + + result = await session.execute( + sqlalchemy.select(self.ChatInteraction) + .where(sqlalchemy.or_(*filters)) + .order_by(desc(self.ChatInteraction.timestamp)) + .limit(limit) + ) + interactions = result.scalars().all() + + return [ + { + "id": interaction.id, + "conversation_id": interaction.conversation_id, + "message_id": interaction.message_id, + "message": interaction.message, + "response": interaction.response, + "timestamp": interaction.timestamp, + } + for interaction in interactions + ] + + async def get_interactions_by_date_range( + self, + start_timestamp: float, + end_timestamp: float, + conversation_id: str | None = None, + ) -> list[dict[str, Any]]: + """ + Get interactions within a specific time range. + + Args: + start_timestamp: Start of the time range (Unix timestamp) + end_timestamp: End of the time range (Unix timestamp) + conversation_id: Optional conversation ID to filter by + + Returns: + List of interactions in the time range + """ + await self._init_db() + + async with AsyncSession(self.sqlalchemy_engine) as session: + query = sqlalchemy.select(self.ChatInteraction).where( + and_( + self.ChatInteraction.timestamp >= start_timestamp, + self.ChatInteraction.timestamp <= end_timestamp, + ) + ) + + if conversation_id: + query = query.where(self.ChatInteraction.conversation_id == conversation_id) + + query = query.order_by(self.ChatInteraction.timestamp) + + result = await session.execute(query) + interactions = result.scalars().all() + + return [ + { + "id": interaction.id, + "conversation_id": interaction.conversation_id, + "message_id": interaction.message_id, + "message": interaction.message, + "response": interaction.response, + "timestamp": interaction.timestamp, + "created_at": interaction.created_at, + } + for interaction in interactions + ] + + async def export_conversation( + self, + conversation_id: str, + include_metadata: bool = True, + ) -> dict[str, Any]: + """ + Export a complete conversation with all metadata. + + Args: + conversation_id: The conversation to export + include_metadata: Whether to include extra metadata + + Returns: + Dictionary containing the complete conversation data + """ + interactions = await self.get_conversation_interactions(conversation_id) + + export_data = { + "conversation_id": conversation_id, + "export_timestamp": datetime.now().isoformat(), + "interaction_count": len(interactions), + "interactions": interactions + if include_metadata + else [ + { + "message": i["message"], + "response": i["response"], + "timestamp": i["timestamp"], + } + for i in interactions + ], + } + + return export_data + + async def delete_conversation(self, conversation_id: str) -> bool: + """ + Delete a conversation and all its interactions. + + Args: + conversation_id: The conversation to delete + + Returns: + True if the conversation was deleted, False if it didn't exist + """ + await self._init_db() + + async with AsyncSession(self.sqlalchemy_engine) as session, session.begin(): + # Check if conversation exists + result = await session.execute(sqlalchemy.select(self.Conversation).filter_by(id=conversation_id)) + conversation = result.scalar_one_or_none() + + if not conversation: + return False + + # Delete the conversation (interactions will be cascade deleted) + await session.delete(conversation) + await session.commit() + return True + + async def get_conversation_statistics(self) -> dict[str, Any]: + """ + Get overall statistics about stored conversations. + + Returns: + Dictionary containing various statistics + """ + await self._init_db() + + async with AsyncSession(self.sqlalchemy_engine) as session: + # Total counts + conversation_count = await self.get_conversation_count() + interaction_count = await self.get_total_interactions_count() + + # Average interactions per conversation + avg_interactions = interaction_count / conversation_count if conversation_count > 0 else 0 + + # Get timestamp range + timestamp_result = await session.execute( + sqlalchemy.select( + func.min(self.ChatInteraction.timestamp), + func.max(self.ChatInteraction.timestamp), + ) + ) + min_ts, max_ts = timestamp_result.one() + + # Calculate message length statistics + message_lengths_result = await session.execute( + sqlalchemy.select( + func.avg(func.length(self.ChatInteraction.message)), + func.avg(func.length(self.ChatInteraction.response)), + ) + ) + avg_message_length, avg_response_length = message_lengths_result.one() + + return { + "total_conversations": conversation_count, + "total_interactions": interaction_count, + "avg_interactions_per_conversation": round(avg_interactions, 2), + "first_interaction": datetime.fromtimestamp(min_ts).isoformat() if min_ts else None, + "last_interaction": datetime.fromtimestamp(max_ts).isoformat() if max_ts else None, + "avg_message_length": round(avg_message_length or 0, 2), + "avg_response_length": round(avg_response_length or 0, 2), + } + + +class ChatWithAnalytics(ChatInterface): + """Simple chat interface for demonstrating analytics.""" + + conversation_history = True + + def __init__(self, history_persistence: AnalyticsHistoryPersistence) -> None: + self.history_persistence = history_persistence + + async def chat( + self, + message: str, + history: ChatFormat, + context: ChatContext, + ) -> AsyncGenerator[ChatResponse, None]: + """Generate responses using the mock LLM.""" + # Generate mock response + response = await mock_llm_response(message) + + # Simulate streaming by yielding the response in chunks + chunk_size = 15 + for i in range(0, len(response), chunk_size): + chunk = response[i : i + chunk_size] + yield self.create_text_response(chunk) + await asyncio.sleep(0.03) # Simulate streaming delay + + +async def create_sample_conversations( + chat: ChatWithAnalytics, + num_conversations: int = 5, +) -> list[str]: + """ + Create sample conversations for demonstration. + + Args: + chat: The chat interface to use + num_conversations: Number of conversations to create + + Returns: + List of conversation IDs + """ + print("Creating sample conversations...") + print("-" * 80) + + conversation_ids = [] + sample_questions = [ + ["What is Python?", "How do I install packages?", "Tell me about virtual environments"], + ["What is machine learning?", "Explain neural networks"], + ["What is the capital of France?", "Tell me about Paris"], + ["How does async/await work in Python?", "What are coroutines?"], + ["Explain REST APIs", "What is GraphQL?", "Compare REST and GraphQL"], + ] + + for i in range(num_conversations): + context = ChatContext() + history: ChatFormat = [] + + questions = sample_questions[i % len(sample_questions)] + + for question in questions: + response_text = "" + async for response in chat.chat(question, history=history, context=context): + if text := response.as_text(): + response_text += text + elif (conv_id := response.as_conversation_id()) and conv_id not in conversation_ids: + conversation_ids.append(conv_id) + + history.append({"role": "user", "content": question}) + history.append({"role": "assistant", "content": response_text}) + + # Small delay between messages + await asyncio.sleep(0.1) + + print(f" Created conversation {i + 1}/{num_conversations}") + + print(f"✓ Created {len(conversation_ids)} conversations") + print() + return conversation_ids + + +async def _setup_database() -> tuple[Any, AnalyticsHistoryPersistence, ChatWithAnalytics, str]: + """Setup database and chat components.""" + database_url = "sqlite+aiosqlite:///./chat_analytics.db" + engine = create_async_engine(database_url, echo=False) + + persistence = AnalyticsHistoryPersistence( + sqlalchemy_engine=engine, + options=SQLHistoryPersistenceOptions( + conversations_table="analytics_conversations", + interactions_table="analytics_interactions", + ), + ) + + chat = ChatWithAnalytics(history_persistence=persistence) + return engine, persistence, chat, database_url + + +async def _demonstrate_overall_statistics(persistence: AnalyticsHistoryPersistence) -> None: + """Display overall statistics.""" + print("Overall Statistics") + print("-" * 80) + stats = await persistence.get_conversation_statistics() + + print(f"Total Conversations: {stats['total_conversations']}") + print(f"Total Interactions: {stats['total_interactions']}") + print(f"Avg Interactions per Conversation: {stats['avg_interactions_per_conversation']}") + print(f"First Interaction: {stats['first_interaction']}") + print(f"Last Interaction: {stats['last_interaction']}") + print(f"Avg Message Length: {stats['avg_message_length']} characters") + print(f"Avg Response Length: {stats['avg_response_length']} characters") + print() + + +async def _demonstrate_recent_conversations(persistence: AnalyticsHistoryPersistence) -> None: + """Display recent conversations.""" + print("Recent Conversations") + print("-" * 80) + recent = await persistence.get_recent_conversations(limit=3) + + for i, conv in enumerate(recent, 1): + print(f"{i}. Conversation ID: {conv['id'][:8]}...") + print(f" Created: {conv['created_at']}") + print(f" Interactions: {conv['interaction_count']}") + print() + + +async def _demonstrate_search(persistence: AnalyticsHistoryPersistence) -> None: + """Demonstrate search functionality.""" + print("Search Example: Finding interactions about 'Python'") + print("-" * 80) + search_results = await persistence.search_interactions( + query="Python", + search_in_messages=True, + search_in_responses=True, + limit=3, + ) + + for i, result in enumerate(search_results, 1): + print(f"{i}. Message: {result['message'][:60]}...") + print(f" Response: {result['response'][:60]}...") + print(f" Conversation: {result['conversation_id'][:8]}...") + print() + + +async def _demonstrate_date_range(persistence: AnalyticsHistoryPersistence) -> None: + """Demonstrate date range query.""" + print("Date Range Example: Interactions from last hour") + print("-" * 80) + now = datetime.now().timestamp() + one_hour_ago = (datetime.now() - timedelta(hours=1)).timestamp() + + recent_interactions = await persistence.get_interactions_by_date_range( + start_timestamp=one_hour_ago, + end_timestamp=now, + ) + print(f"Found {len(recent_interactions)} interactions in the last hour") + print() + + +async def _demonstrate_export(persistence: AnalyticsHistoryPersistence, conversation_ids: list[str]) -> None: + """Demonstrate conversation export.""" + print("Export Example: Exporting a conversation") + print("-" * 80) + if conversation_ids: + export_data = await persistence.export_conversation( + conversation_ids[0], + include_metadata=True, + ) + + # Save to file + export_filename = "conversation_export.json" + with open(export_filename, "w") as f: + json.dump(export_data, f, indent=2, default=str) + + print(f"✓ Exported conversation to {export_filename}") + print(f" Conversation ID: {export_data['conversation_id'][:8]}...") + print(f" Interactions: {export_data['interaction_count']}") + print() + + +async def _demonstrate_deletion(persistence: AnalyticsHistoryPersistence, conversation_ids: list[str]) -> None: + """Demonstrate conversation deletion.""" + print("Management Example: Deleting a conversation") + print("-" * 80) + if len(conversation_ids) > 1: + conversation_to_delete = conversation_ids[-1] + print(f"Deleting conversation: {conversation_to_delete[:8]}...") + + deleted = await persistence.delete_conversation(conversation_to_delete) + if deleted: + print("✓ Conversation deleted successfully") + + # Verify deletion + new_count = await persistence.get_conversation_count() + print(f" Remaining conversations: {new_count}") + else: + print("✗ Conversation not found") + print() + + +def _print_summary(database_url: str) -> None: + """Print final summary.""" + print("=" * 80) + print("Analytics Features Demonstrated:") + print("=" * 80) + print("✓ Overall statistics and metrics") + print("✓ Recent conversations listing") + print("✓ Full-text search across interactions") + print("✓ Date range queries") + print("✓ Conversation export to JSON") + print("✓ Conversation deletion and management") + print() + print(f"Database: {database_url}") + print() + + +async def demonstrate_analytics() -> None: + """Demonstrate the analytics features.""" + print("=" * 80) + print("Chat History Analytics Example") + print("=" * 80) + print() + + # Setup database + engine, persistence, chat, database_url = await _setup_database() + + # Create sample data + conversation_ids = await create_sample_conversations(chat, num_conversations=5) + + # Demonstrate various features + await _demonstrate_overall_statistics(persistence) + await _demonstrate_recent_conversations(persistence) + await _demonstrate_search(persistence) + await _demonstrate_date_range(persistence) + await _demonstrate_export(persistence, conversation_ids) + await _demonstrate_deletion(persistence, conversation_ids) + + # Print summary + _print_summary(database_url) + + # Cleanup + await engine.dispose() + + +if __name__ == "__main__": + asyncio.run(demonstrate_analytics()) diff --git a/examples/chat/chat_with_history_persistence.py b/examples/chat/chat_with_history_persistence.py new file mode 100644 index 000000000..c5a96e62f --- /dev/null +++ b/examples/chat/chat_with_history_persistence.py @@ -0,0 +1,394 @@ +""" +Ragbits Chat Example: Chat Interface with SQL History Persistence + +This example demonstrates how to use the `ChatInterface` with `SQLHistoryPersistence` +to persist chat interactions to a SQL database. It showcases: + +- Setting up SQLHistoryPersistence with SQLite (aiosqlite) +- Saving chat interactions automatically through the ChatInterface +- Retrieving conversation history from the database +- Resuming conversations using persisted history + +To run the script, execute the following command: + + ```bash + uv run python examples/chat/chat_with_history_persistence.py + ``` + +Requirements: + - aiosqlite (for async SQLite support) + - ragbits-chat +""" + +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "ragbits-chat", +# "aiosqlite>=0.21.0", +# "greenlet>=3.0.0", +# ] +# /// + +import asyncio +from collections.abc import AsyncGenerator + +from sqlalchemy.ext.asyncio import create_async_engine + +from ragbits.chat.interface import ChatInterface +from ragbits.chat.interface.forms import FeedbackConfig +from ragbits.chat.interface.types import ChatContext, ChatResponse, ChatResponseType +from ragbits.chat.persistence.sql import SQLHistoryPersistence, SQLHistoryPersistenceOptions +from ragbits.core.prompt import ChatFormat + + +async def mock_llm_response(message: str) -> str: + """ + Mock LLM response for demonstration purposes (avoids needing API keys). + + In a real application, replace this with actual LLM calls. + """ + responses = { + "What is the capital of France?": ( + "The capital of France is Paris. Paris is known for its iconic landmarks like the Eiffel " + "Tower, the Louvre Museum, and Notre-Dame Cathedral." + ), + "What about Germany?": ( + "The capital of Germany is Berlin. Berlin is a vibrant city known for its history, " + "culture, and nightlife. It was divided during the Cold War but is now reunited." + ), + "And what about Italy?": ( + "The capital of Italy is Rome. Rome is an ancient city with a rich history spanning over " + "2,500 years. It's home to the Colosseum, Vatican City, and the Trevi Fountain." + ), + "Tell me a fact about number 1": ( + "The number 1 is the first and smallest positive integer. It is the multiplicative " + "identity, meaning any number multiplied by 1 equals itself." + ), + "Tell me a fact about number 2": ( + "The number 2 is the smallest and only even prime number. It is also the base of the " + "binary numeral system used in computing." + ), + } + return responses.get(message, f"This is a mock response to: {message}") + + +class SimpleChatWithPersistence(ChatInterface): + """ + A simple chat interface that demonstrates SQL history persistence. + + This interface automatically saves all chat interactions to a SQLite database, + including messages, responses, and metadata like conversation IDs. + """ + + feedback_config = FeedbackConfig( + like_enabled=True, + dislike_enabled=True, + ) + + # Enable conversation history to show previous messages + conversation_history = True + show_usage = True + + def __init__(self, history_persistence: SQLHistoryPersistence) -> None: + """ + Initialize the chat interface with history persistence. + + Args: + history_persistence: The SQLHistoryPersistence instance to use for storing interactions + """ + self.history_persistence = history_persistence + + async def chat( + self, + message: str, + history: ChatFormat, + context: ChatContext, + ) -> AsyncGenerator[ChatResponse, None]: + """ + Process a chat message and yield responses. + + All interactions are automatically saved to the database via the + @with_chat_metadata decorator. + + Args: + message: The current user message + history: List of previous messages in the conversation + context: Context containing conversation metadata + + Yields: + ChatResponse objects containing text chunks and usage information + """ + # Add a reference to show the conversation ID being used + yield self.create_reference( + title="Conversation Info", + content=f"Conversation ID: {context.conversation_id}\nMessage ID: {context.message_id}", + url=None, + ) + + # Generate mock response (in real usage, use an actual LLM) + response = await mock_llm_response(message) + + # Simulate streaming by yielding the response in chunks + chunk_size = 10 + for i in range(0, len(response), chunk_size): + chunk = response[i : i + chunk_size] + yield self.create_text_response(chunk) + await asyncio.sleep(0.05) # Simulate streaming delay + + +async def _setup_persistence() -> tuple: + """Setup database and persistence components.""" + database_url = "sqlite+aiosqlite:///./chat_history.db" + engine = create_async_engine(database_url, echo=False) + + persistence_options = SQLHistoryPersistenceOptions( + conversations_table="my_conversations", + interactions_table="my_chat_interactions", + ) + + persistence = SQLHistoryPersistence( + sqlalchemy_engine=engine, + options=persistence_options, + ) + + chat = SimpleChatWithPersistence(history_persistence=persistence) + return engine, persistence, persistence_options, chat, database_url + + +async def _run_first_message(chat: SimpleChatWithPersistence) -> tuple: + """Execute first message in a new conversation.""" + conversation_id = None + message_ids = [] + history: ChatFormat = [] + context = ChatContext() + + user_message_1 = "What is the capital of France?" + print(f"User: {user_message_1}") + print("Assistant: ", end="", flush=True) + + response_text_1 = "" + async for response in chat.chat(user_message_1, history=history, context=context): + if text := response.as_text(): + print(text, end="", flush=True) + response_text_1 += text + elif conv_id := response.as_conversation_id(): + conversation_id = conv_id + elif response.type == ChatResponseType.MESSAGE_ID: + message_ids.append(str(response.content)) + + print() + print() + + history.append({"role": "user", "content": user_message_1}) + history.append({"role": "assistant", "content": response_text_1}) + return conversation_id, message_ids, history, context + + +async def _run_second_message( + chat: SimpleChatWithPersistence, + history: ChatFormat, + context: ChatContext, + message_ids: list, +) -> ChatFormat: + """Execute second message in the conversation.""" + user_message_2 = "What about Germany?" + print(f"User: {user_message_2}") + print("Assistant: ", end="", flush=True) + + response_text_2 = "" + async for response in chat.chat(user_message_2, history=history, context=context): + if text := response.as_text(): + print(text, end="", flush=True) + response_text_2 += text + elif response.type == ChatResponseType.MESSAGE_ID: + message_ids.append(str(response.content)) + + print() + print() + + history.append({"role": "user", "content": user_message_2}) + history.append({"role": "assistant", "content": response_text_2}) + return history + + +async def _retrieve_and_display_history(persistence: SQLHistoryPersistence, conversation_id: str | None) -> None: + """Retrieve and display conversation history from database.""" + print("Part 2: Retrieving conversation history from database...") + print("-" * 80) + + if conversation_id: + interactions = await persistence.get_conversation_interactions(conversation_id) + print(f"Found {len(interactions)} interactions in the database:") + print() + + for i, interaction in enumerate(interactions, 1): + print(f"Interaction #{i}") + print(f" Message ID: {interaction['message_id']}") + print(f" Timestamp: {interaction['timestamp']}") + print(f" User: {interaction['message'][:80]}...") + print(f" Assistant: {interaction['response'][:80]}...") + print(f" Extra responses: {len(interaction['extra_responses'])} items") + print() + + +async def _resume_conversation( + chat: SimpleChatWithPersistence, + persistence: SQLHistoryPersistence, + conversation_id: str | None, +) -> None: + """Resume conversation with loaded history.""" + print("Part 3: Resuming conversation with loaded history...") + print("-" * 80) + + if conversation_id: + loaded_interactions = await persistence.get_conversation_interactions(conversation_id) + + reconstructed_history: ChatFormat = [] + for interaction in loaded_interactions: + reconstructed_history.append({"role": "user", "content": interaction["message"]}) + reconstructed_history.append({"role": "assistant", "content": interaction["response"]}) + + resume_context = ChatContext(conversation_id=conversation_id) + + user_message_3 = "And what about Italy?" + print(f"User: {user_message_3}") + print("Assistant: ", end="", flush=True) + + async for response in chat.chat(user_message_3, history=reconstructed_history, context=resume_context): + if text := response.as_text(): + print(text, end="", flush=True) + + print() + print() + + updated_interactions = await persistence.get_conversation_interactions(conversation_id) + print(f"Total interactions after resume: {len(updated_interactions)}") + print() + + +def _print_summary( + conversation_id: str | None, + message_ids: list, + database_url: str, + persistence_options: SQLHistoryPersistenceOptions, +) -> None: + """Print summary of the example.""" + print("=" * 80) + print("Summary") + print("=" * 80) + print(f"✓ Created conversation with ID: {conversation_id}") + print(f"✓ Saved {len(message_ids)} messages to the database") + print("✓ Successfully retrieved conversation history") + print("✓ Resumed conversation and added new message") + print() + print(f"Database location: {database_url}") + print("Tables created:") + print(f" - {persistence_options.conversations_table}") + print(f" - {persistence_options.interactions_table}") + print() + + +async def run_conversation_example() -> None: + """ + Demonstrates a complete conversation lifecycle with SQL history persistence. + + This example shows: + 1. Creating a new conversation + 2. Saving multiple interactions + 3. Retrieving conversation history from the database + 4. Resuming a conversation with loaded history + """ + engine, persistence, persistence_options, chat, database_url = await _setup_persistence() + + print("=" * 80) + print("Chat Example with SQL History Persistence") + print("=" * 80) + print() + + print("Part 1: Starting a new conversation...") + print("-" * 80) + + conversation_id, message_ids, history, context = await _run_first_message(chat) + history = await _run_second_message(chat, history, context, message_ids) + + print(f"Conversation ID: {conversation_id}") + print(f"Messages saved: {len(message_ids)}") + print() + + await _retrieve_and_display_history(persistence, conversation_id) + await _resume_conversation(chat, persistence, conversation_id) + + _print_summary(conversation_id, message_ids, database_url, persistence_options) + + await engine.dispose() + + +async def run_multi_conversation_example() -> None: + """ + Demonstrates managing multiple conversations in the same database. + + This example shows how different conversations are isolated from each other. + """ + print("\n") + print("=" * 80) + print("Multi-Conversation Example") + print("=" * 80) + print() + + # Setup database + database_url = "sqlite+aiosqlite:///./chat_history.db" + engine = create_async_engine(database_url, echo=False) + + persistence = SQLHistoryPersistence( + sqlalchemy_engine=engine, + options=SQLHistoryPersistenceOptions( + conversations_table="my_conversations", + interactions_table="my_chat_interactions", + ), + ) + + chat = SimpleChatWithPersistence(history_persistence=persistence) + + # Start two different conversations + conversations = [] + + for i in range(1, 3): + print(f"Starting conversation #{i}...") + context = ChatContext() + history: ChatFormat = [] + + message = f"Tell me a fact about number {i}" + print(f"User: {message}") + print("Assistant: ", end="", flush=True) + + conversation_id = None + async for response in chat.chat(message, history=history, context=context): + if text := response.as_text(): + print(text, end="", flush=True) + elif conv_id := response.as_conversation_id(): + conversation_id = conv_id + + print() + print(f"Conversation ID: {conversation_id}") + print() + + conversations.append(conversation_id) + + # Verify each conversation has its own history + print("Verifying isolated conversation histories...") + print("-" * 80) + for i, conv_id in enumerate(conversations, 1): + if conv_id: + interactions = await persistence.get_conversation_interactions(conv_id) + print(f"Conversation {i} ({conv_id}): {len(interactions)} interactions") + + print() + await engine.dispose() + + +if __name__ == "__main__": + # Run the main example + asyncio.run(run_conversation_example()) + + # Run the multi-conversation example + asyncio.run(run_multi_conversation_example()) diff --git a/examples/chat/chat_with_postgresql.py b/examples/chat/chat_with_postgresql.py new file mode 100644 index 000000000..c629b72e1 --- /dev/null +++ b/examples/chat/chat_with_postgresql.py @@ -0,0 +1,477 @@ +""" +Ragbits Chat Example: Production PostgreSQL Setup + +This example demonstrates production-ready setup of SQLHistoryPersistence with PostgreSQL, +including: + +- Environment-based configuration +- Connection pooling and optimization +- Error handling and retry logic +- Schema migration patterns +- Production best practices + +Prerequisites: + - PostgreSQL server running locally or remotely + - asyncpg driver installed + +Setup PostgreSQL (using Docker): + ```bash + docker run --name ragbits-postgres \ + -e POSTGRES_PASSWORD=ragbits2024 \ + -e POSTGRES_USER=ragbits \ + -e POSTGRES_DB=chatdb \ + -p 5432:5432 \ + -d postgres:16 + ``` + +Run the example: + ```bash + # Set the database URL + export DATABASE_URL="postgresql+asyncpg://ragbits:ragbits2024@localhost:5432/chatdb" + + # Run the script + uv run python examples/chat/chat_with_postgresql.py + ``` + +Requirements: + - asyncpg (for PostgreSQL async support) + - ragbits-chat +""" + +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "ragbits-chat", +# "asyncpg>=0.30.0", +# "greenlet>=3.0.0", +# ] +# /// + +import asyncio +import os +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager + +from sqlalchemy.exc import OperationalError +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine + +from ragbits.chat.interface import ChatInterface +from ragbits.chat.interface.types import ChatContext, ChatResponse +from ragbits.chat.persistence.sql import SQLHistoryPersistence, SQLHistoryPersistenceOptions +from ragbits.core.prompt import ChatFormat + + +async def mock_llm_response(message: str) -> str: + """Mock LLM response for demonstration purposes.""" + responses = { + "What are the best practices for database connection pooling?": ( + "Database connection pooling best practices include: 1) Set appropriate pool size based " + "on your application's concurrency needs, 2) Configure timeout values to prevent resource " + "exhaustion, 3) Use connection validation (pool_pre_ping), 4) Recycle connections " + "periodically, and 5) Monitor pool statistics to optimize settings." + ), + "How does PostgreSQL handle concurrent connections?": ( + "PostgreSQL handles concurrent connections through a process-based architecture. Each " + "connection spawns a new backend process, which provides strong isolation. PostgreSQL uses " + "MVCC (Multi-Version Concurrency Control) to manage concurrent access to data without " + "locking, allowing high throughput for read operations." + ), + "Tell me about ACID properties in PostgreSQL": ( + "PostgreSQL fully supports ACID properties: Atomicity ensures transactions are " + "all-or-nothing, Consistency maintains database rules, Isolation prevents transaction " + "interference, and Durability guarantees committed data persists. PostgreSQL uses WAL " + "(Write-Ahead Logging) for durability and offers multiple isolation levels." + ), + } + return responses.get(message, f"This is a mock response about: {message}") + + +class ProductionConfig: + """Configuration for production database setup.""" + + def __init__(self): + """Initialize configuration from environment variables.""" + self.database_url = os.getenv("DATABASE_URL", "postgresql+asyncpg://ragbits:ragbits2024@localhost:5432/chatdb") + + # Connection pool settings + self.pool_size = int(os.getenv("DB_POOL_SIZE", "20")) + self.max_overflow = int(os.getenv("DB_MAX_OVERFLOW", "10")) + self.pool_timeout = int(os.getenv("DB_POOL_TIMEOUT", "30")) + self.pool_recycle = int(os.getenv("DB_POOL_RECYCLE", "3600")) + + # Retry settings + self.max_retries = int(os.getenv("DB_MAX_RETRIES", "3")) + self.retry_delay = float(os.getenv("DB_RETRY_DELAY", "1.0")) + + # Table names + self.conversations_table = os.getenv("CONVERSATIONS_TABLE", "production_conversations") + self.interactions_table = os.getenv("INTERACTIONS_TABLE", "production_interactions") + + +class ProductionChatInterface(ChatInterface): + """Production-ready chat interface with PostgreSQL persistence.""" + + conversation_history = True + show_usage = True + + def __init__(self, history_persistence: SQLHistoryPersistence) -> None: + """ + Initialize the production chat interface. + + Args: + history_persistence: The SQLHistoryPersistence instance + """ + self.history_persistence = history_persistence + + async def chat( + self, + message: str, + history: ChatFormat, + context: ChatContext, + ) -> AsyncGenerator[ChatResponse, None]: + """ + Process a chat message with production-grade error handling. + + Args: + message: The current user message + history: List of previous messages in the conversation + context: Context containing conversation metadata + + Yields: + ChatResponse objects + """ + try: + # Generate mock response + response = await mock_llm_response(message) + + # Simulate streaming by yielding the response in chunks + chunk_size = 20 + for i in range(0, len(response), chunk_size): + chunk = response[i : i + chunk_size] + yield self.create_text_response(chunk) + await asyncio.sleep(0.02) # Simulate streaming delay + + except Exception as e: + # In production, log the error and provide a graceful fallback + print(f"Error generating response: {e}") + yield self.create_text_response( + "I apologize, but I encountered an error processing your request. Please try again." + ) + + +@asynccontextmanager +async def create_database_engine(config: ProductionConfig) -> AsyncGenerator[AsyncEngine, None]: + """ + Create and configure a production database engine with connection pooling. + + Args: + config: Production configuration + + Yields: + Configured AsyncEngine instance + """ + engine = create_async_engine( + config.database_url, + # Connection pool settings + pool_size=config.pool_size, + max_overflow=config.max_overflow, + pool_timeout=config.pool_timeout, + pool_recycle=config.pool_recycle, + # Pre-ping ensures connections are valid before use + pool_pre_ping=True, + # Echo SQL queries (disable in production) + echo=False, + # PostgreSQL-specific optimizations + connect_args={ + "server_settings": { + "application_name": "ragbits_chat", + }, + }, + ) + + try: + yield engine + finally: + # Ensure proper cleanup + await engine.dispose() + + +async def create_persistence( + engine: AsyncEngine, + config: ProductionConfig, +) -> SQLHistoryPersistence: + """ + Create and initialize the persistence layer with retry logic. + + Args: + engine: SQLAlchemy async engine + config: Production configuration + + Returns: + Initialized SQLHistoryPersistence instance + """ + options = SQLHistoryPersistenceOptions( + conversations_table=config.conversations_table, + interactions_table=config.interactions_table, + ) + + persistence = SQLHistoryPersistence( + sqlalchemy_engine=engine, + options=options, + ) + + # Initialize database with retry logic + for attempt in range(config.max_retries): + try: + await persistence._init_db() + print("✓ Database initialized successfully") + return persistence + except OperationalError as e: + if attempt < config.max_retries - 1: + print(f"Database connection failed (attempt {attempt + 1}/{config.max_retries}), retrying...") + await asyncio.sleep(config.retry_delay * (attempt + 1)) + else: + print(f"✗ Failed to connect to database after {config.max_retries} attempts") + raise e + + return persistence + + +def _print_config(config: ProductionConfig) -> None: + """Print production configuration.""" + print("Configuration:") + print("-" * 80) + print(f"Database URL: {config.database_url.split('@')[1] if '@' in config.database_url else 'localhost'}") + print(f"Pool Size: {config.pool_size}") + print(f"Max Overflow: {config.max_overflow}") + print(f"Pool Timeout: {config.pool_timeout}s") + print(f"Pool Recycle: {config.pool_recycle}s") + print(f"Conversations Table: {config.conversations_table}") + print(f"Interactions Table: {config.interactions_table}") + print() + + +async def _run_example_conversation( + chat: ProductionChatInterface, + persistence: SQLHistoryPersistence, +) -> str | None: + """Run example conversation and return conversation ID.""" + print("Running Example Conversation:") + print("-" * 80) + + context = ChatContext() + history: ChatFormat = [] + + messages = [ + "What are the best practices for database connection pooling?", + "How does PostgreSQL handle concurrent connections?", + "Tell me about ACID properties in PostgreSQL", + ] + + conversation_id = None + + for i, user_message in enumerate(messages, 1): + print(f"\n[{i}/{len(messages)}] User: {user_message}") + print("Assistant: ", end="", flush=True) + + response_text = "" + try: + async for response in chat.chat(user_message, history=history, context=context): + if text := response.as_text(): + print(text, end="", flush=True) + response_text += text + elif conv_id := response.as_conversation_id(): + conversation_id = conv_id + + history.append({"role": "user", "content": user_message}) + history.append({"role": "assistant", "content": response_text}) + print() + + except Exception as e: + print(f"\n✗ Error: {e}") + break + + print() + print("=" * 80) + return conversation_id + + +async def _verify_persistence(persistence: SQLHistoryPersistence, conversation_id: str | None) -> None: + """Verify data persistence in PostgreSQL.""" + if conversation_id: + print("\nVerifying Persistence:") + print("-" * 80) + + interactions = await persistence.get_conversation_interactions(conversation_id) + print(f"✓ Successfully retrieved {len(interactions)} interactions from PostgreSQL") + + if interactions: + first = interactions[0] + print("\nFirst Interaction:") + print(f" Message ID: {first['message_id']}") + print(f" Message: {first['message'][:50]}...") + print(f" Response: {first['response'][:50]}...") + print(f" Timestamp: {first['timestamp']}") + + print() + + +def _print_features() -> None: + """Print demonstrated production features.""" + print("=" * 80) + print("Production Features Demonstrated:") + print("=" * 80) + print("✓ Environment-based configuration") + print("✓ Connection pooling with optimized settings") + print("✓ Retry logic for database initialization") + print("✓ Proper resource cleanup (connection disposal)") + print("✓ Error handling and graceful degradation") + print("✓ Production-grade database settings") + print() + + +async def demonstrate_production_setup() -> None: + """ + Demonstrate production-ready PostgreSQL setup with best practices. + """ + print("=" * 80) + print("Production PostgreSQL Setup Example") + print("=" * 80) + print() + + config = ProductionConfig() + _print_config(config) + + async with create_database_engine(config) as engine: + persistence = await create_persistence(engine, config) + chat = ProductionChatInterface(history_persistence=persistence) + + conversation_id = await _run_example_conversation(chat, persistence) + await _verify_persistence(persistence, conversation_id) + + _print_features() + + +async def demonstrate_migration_pattern() -> None: + """ + Demonstrate schema migration patterns for production. + + Note: This is a simplified example. In production, use a proper migration + tool like Alembic for schema changes. + """ + print("=" * 80) + print("Schema Migration Pattern Example") + print("=" * 80) + print() + + config = ProductionConfig() + + async with create_database_engine(config) as engine: + from sqlalchemy import inspect + + await create_persistence(engine, config) + + print("Schema Information:") + print("-" * 80) + + # Inspect schema + async with engine.begin() as conn: + + def inspect_tables(connection: object) -> list[str]: + inspector = inspect(connection) + tables = inspector.get_table_names() + return tables + + tables = await conn.run_sync(inspect_tables) + print(f"Tables in database: {', '.join(tables)}") + + # Example: Check if tables exist + if config.conversations_table in tables: + print(f"✓ {config.conversations_table} exists") + if config.interactions_table in tables: + print(f"✓ {config.interactions_table} exists") + + print() + + # Note: For real migrations, use Alembic + print("For production schema migrations, use Alembic:") + print(" 1. pip install alembic") + print(" 2. alembic init migrations") + print(" 3. alembic revision --autogenerate -m 'description'") + print(" 4. alembic upgrade head") + print() + + +async def demonstrate_performance_monitoring() -> None: + """ + Demonstrate connection pool monitoring and performance tracking. + """ + print("=" * 80) + print("Performance Monitoring Example") + print("=" * 80) + print() + + config = ProductionConfig() + + async with create_database_engine(config) as engine: + print("Connection Pool Status:") + print("-" * 80) + + # Get pool statistics + pool = engine.pool + print(f"Pool Size: {pool.size()}") + print(f"Checked Out Connections: {pool.checkedout()}") + print(f"Overflow: {pool.overflow()}") + print(f"Checked In: {pool.checkedin()}") + print() + + # Create persistence and run some operations + persistence = await create_persistence(engine, config) + chat = ProductionChatInterface(history_persistence=persistence) + + # Simulate concurrent operations + print("Simulating concurrent operations...") + print("-" * 80) + + async def create_conversation(msg: str) -> str: + """Create a single conversation.""" + context = ChatContext() + async for response in chat.chat(msg, history=[], context=context): + if conv_id := response.as_conversation_id(): + return conv_id + return "" + + # Create multiple conversations concurrently + tasks = [create_conversation(f"Question {i}") for i in range(5)] + + conversation_ids = await asyncio.gather(*tasks) + print(f"✓ Created {len(conversation_ids)} conversations concurrently") + + # Check pool status after operations + print("\nPool Status After Operations:") + print(f" Checked Out: {pool.checkedout()}") + print(f" Overflow: {pool.overflow()}") + print() + + +if __name__ == "__main__": + print("Make sure PostgreSQL is running and DATABASE_URL is set correctly!") + print() + + try: + # Run main demonstration + asyncio.run(demonstrate_production_setup()) + + # Run migration pattern demonstration + asyncio.run(demonstrate_migration_pattern()) + + # Run performance monitoring demonstration + asyncio.run(demonstrate_performance_monitoring()) + + except Exception as e: + print(f"\n✗ Error: {e}") + print("\nTroubleshooting:") + print("1. Ensure PostgreSQL is running") + print("2. Check DATABASE_URL environment variable") + print("3. Verify database credentials") + print("4. Ensure asyncpg is installed: pip install asyncpg")