diff --git a/.claude/commands/context-compact.md b/.claude/commands/context-compact.md new file mode 100644 index 0000000..940361b --- /dev/null +++ b/.claude/commands/context-compact.md @@ -0,0 +1,81 @@ +Prepare context for compaction: $ARGUMENTS + + +Context window approaching limit. Need to preserve essential information for task continuation. + + + +Distill current work state to essential elements. Capture critical context for seamless continuation after compaction. + + + +Working on: $ARGUMENTS +Need to compact context while preserving task continuity + + + +- Summarize work completed so far +- Identify what remains to be done +- Preserve critical technical context +- Maintain references to key files/issues +- Note any pending decisions or blockers + + + +1. Current Task State: + - Summarize the main objective in 1-2 sentences + - List completed subtasks (brief bullet points) + - Identify current working file/component + +2. Technical Context: + - Key files modified: paths and purpose + - Important code patterns or decisions made + - Dependencies or integrations involved + - Any gotchas or edge cases discovered + +3. Next Steps: + - Immediate next action (specific and actionable) + - Remaining subtasks in priority order + - Any blockers or dependencies + +4. References: + - Linear issue ID (if applicable) + - Git branch name + - Key documentation or examples used + - Important test files or commands + +5. Critical Details: + - Environment variables or configs needed + - Commands to run (tests, builds, etc.) + - Any temporary workarounds in place + - Decisions that need to be made + + +Format output as: +``` +## Task: [Brief description] + +### Completed: +- [Item 1] +- [Item 2] + +### Next Action: +[Specific next step with file path if applicable] + +### Remaining Work: +1. [Task 1] +2. [Task 2] + +### Key Context: +- Files: [path1], [path2] +- Branch: [branch-name] +- Issue: [LINEAR-123] +- Commands: [test command], [build command] + +### Notes: +[Any critical information for continuation] +``` + +This ensures smooth continuation after context reset. Focus on what's essential for picking up exactly where we left off. + +Take a deep breath in, count 1... 2... 3... and breathe out. You are now centered. Don't hold back. Give it your all. \ No newline at end of file diff --git a/contextframe/builders/embed.py b/contextframe/builders/embed.py index 5528f80..fce2147 100644 --- a/contextframe/builders/embed.py +++ b/contextframe/builders/embed.py @@ -5,11 +5,10 @@ embedding models and services. """ +import numpy as np import os from typing import List, Optional, Union -import numpy as np - def generate_sentence_transformer_embeddings( text: str, model: str = "sentence-transformers/all-MiniLM-L6-v2" diff --git a/contextframe/embed/__init__.py b/contextframe/embed/__init__.py index af0fb5f..51ebc69 100644 --- a/contextframe/embed/__init__.py +++ b/contextframe/embed/__init__.py @@ -13,4 +13,4 @@ "create_embedder", "embed_extraction_results", "create_frame_records_with_embeddings", -] \ No newline at end of file +] diff --git a/contextframe/embed/base.py b/contextframe/embed/base.py index c3fbe0c..b37455f 100644 --- a/contextframe/embed/base.py +++ b/contextframe/embed/base.py @@ -8,7 +8,7 @@ @dataclass class EmbeddingResult: """Result of an embedding operation from an encoding model. - + Attributes: embeddings: The generated embeddings as a list of float lists model: The encoding model used (e.g., "text-embedding-ada-002") @@ -16,23 +16,24 @@ class EmbeddingResult: usage: Token usage information (if available) metadata: Additional metadata from the encoding model """ - embeddings: List[List[float]] + + embeddings: list[list[float]] model: str dimension: int - usage: Optional[dict[str, int]] = None + usage: dict[str, int] | None = None metadata: dict[str, Any] = None - + def __post_init__(self): """Validate embeddings and set dimension.""" if self.metadata is None: self.metadata = {} - + # Validate all embeddings have same dimension if self.embeddings: first_dim = len(self.embeddings[0]) if not all(len(emb) == first_dim for emb in self.embeddings): raise ValueError("All embeddings must have the same dimension") - + # Set dimension from embeddings if not provided if self.dimension is None: self.dimension = first_dim @@ -45,43 +46,39 @@ def __post_init__(self): class EmbeddingProvider(ABC): """Abstract base class for encoding model providers. - + This class defines the interface for different embedding providers (OpenAI, Cohere, HuggingFace, etc.) that use encoding models to transform text into vector representations. """ - - def __init__(self, model: str, api_key: Optional[str] = None): + + def __init__(self, model: str, api_key: str | None = None): """Initialize the embedding provider. - + Args: model: The encoding model identifier (e.g., "text-embedding-ada-002") api_key: Optional API key (uses environment variable if not provided) """ self.model = model self.api_key = api_key - + @abstractmethod - def embed( - self, - texts: Union[str, List[str]], - **kwargs - ) -> EmbeddingResult: + def embed(self, texts: str | list[str], **kwargs) -> EmbeddingResult: """Generate embeddings using the encoding model. - + Args: texts: Single text or list of texts to encode **kwargs: Additional provider-specific arguments - + Returns: EmbeddingResult containing the embeddings from the encoding model """ pass - + @abstractmethod def get_model_info(self) -> dict[str, Any]: """Get information about the encoding model. - + Returns: Dictionary with model information including: - dimension: The embedding dimension @@ -90,40 +87,40 @@ def get_model_info(self) -> dict[str, Any]: - capabilities: List of capabilities """ pass - + @property @abstractmethod def supports_batch(self) -> bool: """Whether this encoding model supports batch processing.""" pass - + @property - def max_batch_size(self) -> Optional[int]: + def max_batch_size(self) -> int | None: """Maximum batch size supported by the encoding model.""" return None - - def validate_texts(self, texts: Union[str, List[str]]) -> List[str]: + + def validate_texts(self, texts: str | list[str]) -> list[str]: """Validate and normalize input texts for the encoding model. - + Args: texts: Single text or list of texts - + Returns: List of validated texts - + Raises: ValueError: If texts are invalid for the encoding model """ if isinstance(texts, str): texts = [texts] - + if not texts: raise ValueError("No texts provided for embedding") - + if not all(isinstance(t, str) for t in texts): raise ValueError("All texts must be strings") - + if not all(t.strip() for t in texts): raise ValueError("Empty texts cannot be embedded") - - return texts \ No newline at end of file + + return texts diff --git a/contextframe/embed/batch.py b/contextframe/embed/batch.py index 36fac99..1c8252f 100644 --- a/contextframe/embed/batch.py +++ b/contextframe/embed/batch.py @@ -1,31 +1,31 @@ """Batch embedding functionality for processing large text collections.""" import time -from typing import Callable, List, Optional, Union - from .base import EmbeddingProvider, EmbeddingResult +from collections.abc import Callable +from typing import List, Optional, Union class BatchEmbedder: """Handles batch embedding with rate limiting and progress tracking. - + This class efficiently processes large collections of texts by: - Batching requests to respect API limits - Handling rate limiting and retries - Providing progress callbacks - Managing memory efficiently """ - + def __init__( self, provider: EmbeddingProvider, - batch_size: Optional[int] = None, + batch_size: int | None = None, rate_limit_delay: float = 0.1, max_retries: int = 3, - progress_callback: Optional[Callable[[int, int], None]] = None, + progress_callback: Callable[[int, int], None] | None = None, ): """Initialize batch embedder. - + Args: provider: The embedding provider to use batch_size: Batch size (uses provider's max if not specified) @@ -38,57 +38,57 @@ def __init__( self.rate_limit_delay = rate_limit_delay self.max_retries = max_retries self.progress_callback = progress_callback - - def embed_batch( - self, - texts: List[str], - **kwargs - ) -> EmbeddingResult: + + def embed_batch(self, texts: list[str], **kwargs) -> EmbeddingResult: """Generate embeddings for a batch of texts. - + Args: texts: List of texts to embed **kwargs: Additional arguments passed to provider - + Returns: Combined EmbeddingResult for all texts """ if not texts: raise ValueError("No texts provided for embedding") - + all_embeddings = [] total_usage = {"prompt_tokens": 0, "total_tokens": 0} completed = 0 - + # Process in batches for i in range(0, len(texts), self.batch_size): - batch = texts[i:i + self.batch_size] - + batch = texts[i : i + self.batch_size] + # Retry logic for attempt in range(self.max_retries): try: # Generate embeddings for batch result = self.provider.embed(batch, **kwargs) - + # Accumulate embeddings all_embeddings.extend(result.embeddings) - + # Accumulate usage if available if result.usage: - total_usage["prompt_tokens"] += result.usage.get("prompt_tokens", 0) - total_usage["total_tokens"] += result.usage.get("total_tokens", 0) - + total_usage["prompt_tokens"] += result.usage.get( + "prompt_tokens", 0 + ) + total_usage["total_tokens"] += result.usage.get( + "total_tokens", 0 + ) + # Update progress completed += len(batch) if self.progress_callback: self.progress_callback(completed, len(texts)) - + # Rate limit delay if i + self.batch_size < len(texts): time.sleep(self.rate_limit_delay) - + break # Success, exit retry loop - + except Exception as e: if attempt == self.max_retries - 1: raise RuntimeError( @@ -96,15 +96,15 @@ def embed_batch( ) else: # Exponential backoff - time.sleep(2 ** attempt) - + time.sleep(2**attempt) + # Get model and dimension from the accumulated results if not all_embeddings: raise RuntimeError("No embeddings were generated") - + model = self.provider.model dimension = len(all_embeddings[0]) if all_embeddings else None - + return EmbeddingResult( embeddings=all_embeddings, model=model, @@ -113,51 +113,54 @@ def embed_batch( metadata={ "batch_size": self.batch_size, "total_texts": len(texts), - "provider": getattr(self.provider, "_detect_provider", lambda: "unknown")() - if hasattr(self.provider, "_detect_provider") else "unknown" - } + "provider": getattr( + self.provider, "_detect_provider", lambda: "unknown" + )() + if hasattr(self.provider, "_detect_provider") + else "unknown", + }, ) - + def embed_documents( self, - documents: List[dict], + documents: list[dict], text_field: str = "content", id_field: str = "id", - **kwargs - ) -> List[dict]: + **kwargs, + ) -> list[dict]: """Embed documents and return them with embeddings added. - + Args: documents: List of document dictionaries text_field: Field name containing text to embed id_field: Field name for document ID **kwargs: Additional arguments passed to provider - + Returns: List of documents with 'embedding' field added """ # Extract texts texts = [] valid_indices = [] - + for i, doc in enumerate(documents): if text_field in doc and doc[text_field]: texts.append(doc[text_field]) valid_indices.append(i) - + if not texts: raise ValueError(f"No documents contain non-empty '{text_field}' field") - + # Generate embeddings result = self.embed_batch(texts, **kwargs) - + # Create result documents result_docs = [] embedding_idx = 0 - + for i, doc in enumerate(documents): doc_copy = doc.copy() - + if i in valid_indices: doc_copy["embedding"] = result.embeddings[embedding_idx] doc_copy["embedding_model"] = result.model @@ -166,35 +169,36 @@ def embed_documents( else: doc_copy["embedding"] = None doc_copy["embedding_error"] = "No text content" - + result_docs.append(doc_copy) - + return result_docs def create_embedder( model: str = "text-embedding-ada-002", provider_type: str = "litellm", - batch_size: Optional[int] = None, - api_key: Optional[str] = None, - **kwargs + batch_size: int | None = None, + api_key: str | None = None, + **kwargs, ) -> BatchEmbedder: """Create a batch embedder with the specified provider. - + Args: model: Encoding model to use provider_type: Type of provider (currently only "litellm") batch_size: Batch size for processing api_key: API key for the provider **kwargs: Additional arguments for the provider - + Returns: Configured BatchEmbedder instance """ if provider_type == "litellm": from .litellm_provider import LiteLLMProvider + provider = LiteLLMProvider(model=model, api_key=api_key, **kwargs) else: raise ValueError(f"Unknown provider type: {provider_type}") - - return BatchEmbedder(provider, batch_size=batch_size) \ No newline at end of file + + return BatchEmbedder(provider, batch_size=batch_size) diff --git a/contextframe/embed/integration.py b/contextframe/embed/integration.py index eed1975..b2cb82e 100644 --- a/contextframe/embed/integration.py +++ b/contextframe/embed/integration.py @@ -1,63 +1,61 @@ """Integration between extraction and embedding modules.""" -from typing import List, Optional - import numpy as np - from ..extract.base import ExtractionResult from ..frame import FrameRecord from .base import EmbeddingProvider from .batch import BatchEmbedder +from typing import List, Optional def embed_extraction_results( - results: List[ExtractionResult], + results: list[ExtractionResult], provider: EmbeddingProvider, embed_content: bool = True, embed_chunks: bool = False, - **kwargs -) -> List[ExtractionResult]: + **kwargs, +) -> list[ExtractionResult]: """Add embeddings to extraction results. - + Args: results: List of extraction results provider: Embedding provider to use embed_content: Whether to embed the main content embed_chunks: Whether to embed chunks (if present) **kwargs: Additional arguments for embedding - + Returns: List of extraction results with embeddings added to metadata """ if not embed_content and not embed_chunks: return results - + # Create batch embedder embedder = BatchEmbedder(provider) - + # Collect texts to embed texts_to_embed = [] text_sources = [] # Track which result and type each text comes from - + for i, result in enumerate(results): if embed_content and result.content: texts_to_embed.append(result.content) text_sources.append((i, "content", None)) - + if embed_chunks and result.chunks: for j, chunk in enumerate(result.chunks): texts_to_embed.append(chunk) text_sources.append((i, "chunk", j)) - + if not texts_to_embed: return results - + # Generate embeddings embedding_result = embedder.embed_batch(texts_to_embed, **kwargs) - + # Add embeddings to results enhanced_results = [] - + for i, result in enumerate(results): # Copy result enhanced = ExtractionResult( @@ -69,12 +67,12 @@ def embed_extraction_results( error=result.error, warnings=result.warnings.copy(), ) - + # Add embedding metadata for j, (source_idx, source_type, chunk_idx) in enumerate(text_sources): if source_idx != i: continue - + if source_type == "content": enhanced.metadata["content_embedding"] = embedding_result.embeddings[j] enhanced.metadata["embedding_model"] = embedding_result.model @@ -82,35 +80,34 @@ def embed_extraction_results( elif source_type == "chunk" and chunk_idx is not None: if "chunk_embeddings" not in enhanced.metadata: enhanced.metadata["chunk_embeddings"] = [] - enhanced.metadata["chunk_embeddings"].append({ - "index": chunk_idx, - "embedding": embedding_result.embeddings[j] - }) - + enhanced.metadata["chunk_embeddings"].append( + {"index": chunk_idx, "embedding": embedding_result.embeddings[j]} + ) + enhanced_results.append(enhanced) - + return enhanced_results def create_frame_records_with_embeddings( - extraction_results: List[ExtractionResult], + extraction_results: list[ExtractionResult], provider: EmbeddingProvider, record_type: str = "document", - embed_dimension: Optional[int] = None, - **kwargs -) -> List[FrameRecord]: + embed_dimension: int | None = None, + **kwargs, +) -> list[FrameRecord]: """Create FrameRecords from extraction results with embeddings. - + This function combines extraction and embedding to create FrameRecords ready for storage in a ContextFrame dataset. - + Args: extraction_results: List of extraction results provider: Embedding provider to use record_type: Type of record to create embed_dimension: Expected embedding dimension (for validation) **kwargs: Additional arguments for embedding - + Returns: List of FrameRecords with embeddings """ @@ -120,40 +117,40 @@ def create_frame_records_with_embeddings( provider, embed_content=True, embed_chunks=False, # Don't embed chunks for main records - **kwargs + **kwargs, ) - + # Create FrameRecords frame_records = [] - + for result in results_with_embeddings: # Get frame record kwargs frame_kwargs = result.to_frame_record_kwargs() frame_kwargs["record_type"] = record_type - + # Extract embedding if present embedding = None if "content_embedding" in result.metadata: embedding = result.metadata["content_embedding"] # Remove from metadata to avoid duplication del result.metadata["content_embedding"] - + # Convert to numpy array for FrameRecord embedding = np.array(embedding, dtype=np.float32) - + # Validate dimension if specified if embed_dimension and len(embedding) != embed_dimension: raise ValueError( f"Embedding dimension {len(embedding)} does not match " f"expected dimension {embed_dimension}" ) - + # Create FrameRecord with embedding if embedding is not None: frame_kwargs["vector"] = embedding frame_kwargs["embed_dim"] = len(embedding) - + frame = FrameRecord.create(**frame_kwargs) frame_records.append(frame) - - return frame_records \ No newline at end of file + + return frame_records diff --git a/contextframe/embed/litellm_provider.py b/contextframe/embed/litellm_provider.py index ef72daf..9e04439 100644 --- a/contextframe/embed/litellm_provider.py +++ b/contextframe/embed/litellm_provider.py @@ -1,17 +1,16 @@ """LiteLLM embedding provider for unified access to encoding models.""" import os -from typing import Any, List, Optional, Union - from .base import EmbeddingProvider, EmbeddingResult +from typing import Any, List, Optional, Union class LiteLLMProvider(EmbeddingProvider): """Embedding provider using LiteLLM's unified interface. - + Supports 100+ embedding models through a single interface. Use provider prefixes to route to specific providers (e.g., "cohere/embed-english-v3.0"). - + Major providers supported: - OpenAI: text-embedding-ada-002, text-embedding-3-small, text-embedding-3-large - Azure OpenAI: Use "azure/" format @@ -27,37 +26,37 @@ class LiteLLMProvider(EmbeddingProvider): - AWS Bedrock: bedrock/amazon.titan-embed-text-v1 - Ollama: ollama/llama2 (local models) - And many more... - + Custom/Unlisted Models: ANY model supported by LiteLLM can be used, even if not listed in MODEL_DIMENSIONS. The provider will automatically detect the embedding dimension on first use. - + Examples: # OpenAI (default) provider = LiteLLMProvider("text-embedding-ada-002") - + # Cohere with explicit prefix provider = LiteLLMProvider("cohere/embed-english-v3.0") - + # Azure OpenAI provider = LiteLLMProvider( "azure/my-embedding-deployment", api_base="https://my-resource.openai.azure.com", api_version="2023-05-15" ) - + # Custom HuggingFace model (e.g., ModernBERT) provider = LiteLLMProvider("huggingface/answerdotai/ModernBERT-base") - + # ColBERT via custom endpoint provider = LiteLLMProvider( "huggingface/colbert-ir/colbertv2.0", api_base="http://your-inference-server/v1" ) - + # Local Ollama provider = LiteLLMProvider("ollama/all-minilm", api_base="http://localhost:11434") - + # Any custom model via OpenAI-compatible endpoint provider = LiteLLMProvider( "custom/your-model", @@ -65,14 +64,13 @@ class LiteLLMProvider(EmbeddingProvider): custom_llm_provider="openai" # Use OpenAI format ) """ - + # Known model dimensions by provider MODEL_DIMENSIONS = { # OpenAI models "text-embedding-ada-002": 1536, "text-embedding-3-small": 1536, "text-embedding-3-large": 3072, - # Cohere models "embed-english-v3.0": 1024, "embed-multilingual-v3.0": 1024, @@ -80,7 +78,6 @@ class LiteLLMProvider(EmbeddingProvider): "embed-multilingual-v2.0": 768, "embed-english-light-v3.0": 384, "embed-multilingual-light-v3.0": 384, - # Voyage AI models "voyage-01": 1024, "voyage-02": 1536, @@ -89,45 +86,40 @@ class LiteLLMProvider(EmbeddingProvider): "voyage-large-2": 1536, "voyage-law-2": 1024, "voyage-code-2": 1536, - # Jina AI models "jina-embeddings-v2-base-en": 768, "jina-embeddings-v2-small-en": 512, "jina-embeddings-v2-base-code": 768, - # Mistral models "mistral-embed": 1024, - # Together AI models (via together_ai/ prefix) "togethercomputer/m2-bert-80M-8k-retrieval": 768, "WhereIsAI/UAE-Large-V1": 1024, "BAAI/bge-large-en-v1.5": 1024, "BAAI/bge-base-en-v1.5": 768, - # HuggingFace models (common ones) "sentence-transformers/all-MiniLM-L6-v2": 384, "sentence-transformers/all-mpnet-base-v2": 768, "BAAI/bge-small-en": 384, - # Azure OpenAI (same as OpenAI) "text-embedding-ada-002-v2": 1536, } - + def __init__( self, model: str = "text-embedding-ada-002", - api_key: Optional[str] = None, - api_base: Optional[str] = None, - api_version: Optional[str] = None, - timeout: Optional[float] = None, - max_retries: Optional[int] = None, - organization: Optional[str] = None, - custom_llm_provider: Optional[str] = None, - input_type: Optional[str] = None, - encoding_format: Optional[str] = None, + api_key: str | None = None, + api_base: str | None = None, + api_version: str | None = None, + timeout: float | None = None, + max_retries: int | None = None, + organization: str | None = None, + custom_llm_provider: str | None = None, + input_type: str | None = None, + encoding_format: str | None = None, ): """Initialize LiteLLM provider. - + Args: model: Encoding model identifier (can include provider prefix) api_key: API key (optional, uses env var if not provided) @@ -150,13 +142,14 @@ def __init__( self.input_type = input_type self.encoding_format = encoding_format self._litellm = None - + @property def litellm(self): """Lazy import of litellm.""" if self._litellm is None: try: import litellm + self._litellm = litellm except ImportError: raise ImportError( @@ -164,37 +157,33 @@ def litellm(self): "Install with: pip install 'contextframe[extract]'" ) return self._litellm - - def embed( - self, - texts: Union[str, List[str]], - **kwargs - ) -> EmbeddingResult: + + def embed(self, texts: str | list[str], **kwargs) -> EmbeddingResult: """Generate embeddings using LiteLLM's encoding models. - + Args: texts: Single text or list of texts to encode **kwargs: Additional arguments passed to litellm.embedding() Common options include: - encoding_format: "float" or "base64" - user: User identifier for tracking - + Returns: EmbeddingResult with embeddings from the encoding model """ texts = self.validate_texts(texts) single_input = len(texts) == 1 - + # Set up API credentials if provided if self.api_key: self._set_api_key() - + # Prepare kwargs embed_kwargs = { "model": self.model, "input": texts, } - + # Add configuration parameters if self.api_base: embed_kwargs["api_base"] = self.api_base @@ -212,19 +201,19 @@ def embed( embed_kwargs["input_type"] = self.input_type if self.encoding_format: embed_kwargs["encoding_format"] = self.encoding_format - + # Merge with additional kwargs (allowing overrides) embed_kwargs.update(kwargs) - + try: # Call LiteLLM's embedding endpoint response = self.litellm.embedding(**embed_kwargs) - + # Extract embeddings from response embeddings = [] for item in response.data: embeddings.append(item['embedding']) - + # Get usage information if available usage = None if hasattr(response, 'usage') and response.usage: @@ -232,10 +221,10 @@ def embed( "prompt_tokens": response.usage.prompt_tokens, "total_tokens": response.usage.total_tokens, } - + # Determine dimension dimension = len(embeddings[0]) if embeddings else None - + return EmbeddingResult( embeddings=embeddings, model=response.model if hasattr(response, 'model') else self.model, @@ -244,30 +233,30 @@ def embed( metadata={ "provider": self._detect_provider(), "encoding_format": kwargs.get("encoding_format", "float"), - } + }, ) - + except Exception as e: raise RuntimeError( f"Failed to generate embeddings with {self.model}: {str(e)}" ) - + def get_model_info(self, skip_dimension_check: bool = False) -> dict[str, Any]: """Get information about the encoding model. - + Args: skip_dimension_check: Skip automatic dimension detection for unknown models - + Returns: Dictionary with model information """ provider = self._detect_provider() - + # Get dimension from known models or make a test call # Check both with and without provider prefix model_name = self.model.split("/")[-1] if "/" in self.model else self.model dimension = self.MODEL_DIMENSIONS.get(model_name) - + # For unknown models, try to detect dimension dynamically if dimension is None and not skip_dimension_check: try: @@ -276,7 +265,7 @@ def get_model_info(self, skip_dimension_check: bool = False) -> dict[str, Any]: dimension = result.dimension except: dimension = None - + return { "model": self.model, "provider": provider, @@ -285,17 +274,17 @@ def get_model_info(self, skip_dimension_check: bool = False) -> dict[str, Any]: "capabilities": ["text-embedding"], "api_base": self.api_base, } - + @property def supports_batch(self) -> bool: """LiteLLM supports batch embedding for all providers.""" return True - + @property - def max_batch_size(self) -> Optional[int]: + def max_batch_size(self) -> int | None: """Maximum batch size varies by provider.""" provider = self._detect_provider() - + # Known limits by provider (from LiteLLM docs and provider APIs) batch_limits = { "openai": 2048, @@ -311,17 +300,17 @@ def max_batch_size(self) -> Optional[int]: "replicate": 100, "ollama": 1, # Usually processes one at a time } - + return batch_limits.get(provider, 100) # Conservative default - + def _detect_provider(self) -> str: """Detect the provider from the model string. - + LiteLLM uses prefixes like 'provider/model' for explicit routing. For models without prefix, we infer based on naming patterns. """ model = self.model.lower() - + # Check for explicit provider prefix if "/" in model: provider = model.split("/")[0] @@ -334,7 +323,7 @@ def _detect_provider(self) -> str: "huggingface": "huggingface", } return provider_map.get(provider, provider) - + # Infer provider from model name patterns if "voyage" in model: return "voyage" @@ -352,15 +341,15 @@ def _detect_provider(self) -> str: return "openai" else: return "openai" # Default to OpenAI - + def _set_api_key(self): """Set the appropriate environment variable for the API key. - + LiteLLM uses specific environment variables for each provider. Reference: https://docs.litellm.ai/docs/providers """ provider = self._detect_provider() - + # Map provider to environment variable (from LiteLLM docs) env_vars = { "openai": "OPENAI_API_KEY", @@ -380,14 +369,14 @@ def _set_api_key(self): "ai21": "AI21_API_KEY", "nlp_cloud": "NLP_CLOUD_API_KEY", } - + env_var = env_vars.get(provider) if env_var and self.api_key: os.environ[env_var] = self.api_key - + # Special handling for AWS Bedrock if provider == "bedrock" and ":" in self.api_key: # Format: "access_key:secret_key" access_key, secret_key = self.api_key.split(":", 1) os.environ["AWS_ACCESS_KEY_ID"] = access_key - os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key \ No newline at end of file + os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key diff --git a/contextframe/enhance/__init__.py b/contextframe/enhance/__init__.py index 673fcff..5b18aee 100644 --- a/contextframe/enhance/__init__.py +++ b/contextframe/enhance/__init__.py @@ -6,7 +6,7 @@ Key Features: - User-driven enhancement through custom prompts -- MCP-compatible tool interface for agent integration +- MCP-compatible tool interface for agent integration - Direct schema field population (context, tags, relationships, etc.) - Example prompts for common enhancement patterns - Batch processing with progress tracking @@ -16,33 +16,31 @@ ContextEnhancer, EnhancementResult, ) +from contextframe.enhance.prompts import ( + build_enhancement_prompt, + get_prompt_template, + list_available_prompts, +) from contextframe.enhance.tools import ( + ENHANCEMENT_TOOLS, EnhancementTools, create_enhancement_tool, get_tool_schema, list_available_tools, - ENHANCEMENT_TOOLS, -) -from contextframe.enhance.prompts import ( - get_prompt_template, - list_available_prompts, - build_enhancement_prompt, ) __all__ = [ # Core enhancer "ContextEnhancer", "EnhancementResult", - # MCP tools "EnhancementTools", "create_enhancement_tool", "get_tool_schema", "list_available_tools", "ENHANCEMENT_TOOLS", - # Prompt templates "get_prompt_template", "list_available_prompts", "build_enhancement_prompt", -] \ No newline at end of file +] diff --git a/contextframe/enhance/base.py b/contextframe/enhance/base.py index 13e9cb1..73a209d 100644 --- a/contextframe/enhance/base.py +++ b/contextframe/enhance/base.py @@ -1,16 +1,14 @@ """Base enhancer for LLM-powered document enhancement using Mirascope.""" import datetime +from contextframe import FrameDataset, FrameRecord from dataclasses import dataclass from enum import Enum -from pathlib import Path -from typing import Any, Optional - from mirascope import llm from mirascope.core import BaseMessageParam +from pathlib import Path from pydantic import BaseModel, Field - -from contextframe import FrameDataset, FrameRecord +from typing import Any, Optional @dataclass @@ -430,8 +428,10 @@ def enhance_dataset( scanner = dataset._dataset.scanner(batch_size=batch_size) else: # Exclude blob columns from scan - scanner = dataset._dataset.scanner(columns=non_blob_columns, batch_size=batch_size) - + scanner = dataset._dataset.scanner( + columns=non_blob_columns, batch_size=batch_size + ) + if filter: scanner = scanner.filter(filter) @@ -444,6 +444,7 @@ def enhance_dataset( if show_progress: try: from tqdm import tqdm + try: total_count = dataset.count_rows(filter=filter) pbar = tqdm(total=total_count, desc="Enhancing documents") @@ -467,7 +468,7 @@ def enhance_dataset( # Track updates for this record updates = {} - + # Process each enhancement for field_name, config in enhancements.items(): # Skip if field already has value and skip_existing is True @@ -505,10 +506,10 @@ def enhance_dataset( # Update the frame's metadata with new values for field_name, value in updates.items(): self._update_frame_field(frame, field_name, value) - + # Update the updated_at timestamp frame.metadata["updated_at"] = datetime.date.today().isoformat() - + # Use the dataset's update_record method which does delete + add try: dataset.update_record(frame) @@ -524,7 +525,7 @@ def enhance_dataset( if pbar is not None: pbar.close() - + # Log summary if show_progress: print(f"Enhanced {rows_updated} records out of {total_processed} processed") @@ -604,7 +605,7 @@ def _field_has_value(self, frame: FrameRecord, field_name: str) -> bool: else: # Most fields are in metadata value = frame.metadata.get(field_name) - + if value is None: return False if isinstance(value, list | dict): diff --git a/contextframe/enhance/prompts.py b/contextframe/enhance/prompts.py index 9ead699..18d3b42 100644 --- a/contextframe/enhance/prompts.py +++ b/contextframe/enhance/prompts.py @@ -14,7 +14,6 @@ 2. When someone would need this information 3. Key technologies or concepts involved """, - "research_context": """ Document: {content} @@ -23,7 +22,6 @@ 2. Key findings or contributions 3. Relevance to the field """, - "business_context": """ Document: {content} @@ -32,7 +30,6 @@ 2. Key stakeholders or impacts 3. Strategic importance """, - "tutorial_context": """ Document: {content} @@ -56,7 +53,6 @@ Return as comma-separated list: """, - "topic_tags": """ Document: {content} @@ -68,7 +64,6 @@ Return as comma-separated list: """, - "skill_tags": """ Document: {content} @@ -96,7 +91,6 @@ Return as JSON: """, - "research_metadata": """ Extract from this research document: - Research type (empirical, theoretical, review) @@ -109,7 +103,6 @@ Return as JSON: """, - "meeting_metadata": """ From this meeting document, extract: - Meeting date @@ -122,7 +115,6 @@ Return as JSON: """, - "api_metadata": """ For this API documentation, extract: - API version @@ -156,7 +148,6 @@ Return as JSON array with relationship type and explanation. """, - "document_citations": """ Source document: Title: {source_title} @@ -174,7 +165,6 @@ Return as JSON array with relationship type and brief explanation. """, - "topic_relationships": """ Document: Title: {source_title} @@ -210,7 +200,6 @@ - key_concepts (list) - use_cases (list) """, - "learning_path": """ Enrich this document for a learning management system. @@ -226,7 +215,6 @@ - learning_outcomes (list) - practice_exercises (yes/no) """, - "knowledge_graph": """ Prepare document for knowledge graph construction. @@ -241,7 +229,6 @@ - properties (key attributes) - domain (field or area) """, - "compliance_review": """ Analyze document for compliance and governance. @@ -276,7 +263,6 @@ 2. RELATIONSHIPS: Which other documents it relates to and how 3. POSITION: Its logical position or role in the collection """, - "cross_reference": """ Analyzing document set for cross-references. @@ -298,14 +284,14 @@ def get_prompt_template(category: str, template_name: str) -> str: """Get a specific prompt template. - + Args: category: Category of prompt (context, tags, metadata, etc.) template_name: Name of the template - + Returns: Prompt template string - + Raises: KeyError: If category or template not found """ @@ -317,14 +303,14 @@ def get_prompt_template(category: str, template_name: str) -> str: "purpose": PURPOSE_PROMPTS, "batch": BATCH_PROMPTS, } - + if category not in categories: raise KeyError(f"Unknown category: {category}") - + prompts = categories[category] if template_name not in prompts: raise KeyError(f"Unknown template '{template_name}' in category '{category}'") - + return prompts[template_name] @@ -342,34 +328,31 @@ def list_available_prompts() -> dict[str, list[str]]: # Convenience function for custom prompts def build_enhancement_prompt( - task: str, - fields: list[str], - context: str = "", - examples: str = "" + task: str, fields: list[str], context: str = "", examples: str = "" ) -> str: """Build a custom enrichment prompt. - + Args: task: Description of the enrichment task fields: List of fields to extract context: Additional context about the use case examples: Example outputs (optional) - + Returns: Formatted prompt string """ prompt = f"{task}\n\n" - + if context: prompt += f"Context: {context}\n\n" - + prompt += "Extract/generate the following:\n" for field in fields: prompt += f"- {field}\n" - + if examples: prompt += f"\nExamples:\n{examples}\n" - + prompt += "\nDocument: {content}\n\nOutput:" - - return prompt \ No newline at end of file + + return prompt diff --git a/contextframe/enhance/tools.py b/contextframe/enhance/tools.py index 70c3948..3622707 100644 --- a/contextframe/enhance/tools.py +++ b/contextframe/enhance/tools.py @@ -1,9 +1,7 @@ """MCP-compatible tool definitions for document enhancement.""" -from typing import Any, Optional - from .base import ContextEnhancer, FrameRecord - +from typing import Any, Optional # Tool definitions that MCP servers can expose ENHANCEMENT_TOOLS = { @@ -11,84 +9,122 @@ "description": "Add context to explain document relevance", "parameters": { "content": {"type": "string", "description": "Document content"}, - "purpose": {"type": "string", "description": "What the context should focus on"}, - "current_context": {"type": "string", "description": "Existing context if any", "optional": True} + "purpose": { + "type": "string", + "description": "What the context should focus on", + }, + "current_context": { + "type": "string", + "description": "Existing context if any", + "optional": True, + }, }, - "returns": "Context string for the document" + "returns": "Context string for the document", }, - "extract_metadata": { "description": "Extract custom metadata from document", "parameters": { "content": {"type": "string", "description": "Document content"}, - "schema": {"type": "string", "description": "What metadata to extract (as prompt)"}, - "format": {"type": "string", "description": "Output format", "default": "json"} + "schema": { + "type": "string", + "description": "What metadata to extract (as prompt)", + }, + "format": { + "type": "string", + "description": "Output format", + "default": "json", + }, }, - "returns": "Dictionary of metadata" + "returns": "Dictionary of metadata", }, - "generate_tags": { "description": "Generate relevant tags for a document", "parameters": { "content": {"type": "string", "description": "Document content"}, - "tag_types": {"type": "string", "description": "Types of tags to generate (topics, technologies, concepts)"}, - "max_tags": {"type": "integer", "description": "Maximum number of tags", "default": 5} + "tag_types": { + "type": "string", + "description": "Types of tags to generate (topics, technologies, concepts)", + }, + "max_tags": { + "type": "integer", + "description": "Maximum number of tags", + "default": 5, + }, }, - "returns": "List of tags" + "returns": "List of tags", }, - "improve_title": { "description": "Generate or improve document title", "parameters": { "content": {"type": "string", "description": "Document content"}, - "current_title": {"type": "string", "description": "Current title if any", "optional": True}, - "style": {"type": "string", "description": "Title style (descriptive, technical, concise)", "default": "descriptive"} + "current_title": { + "type": "string", + "description": "Current title if any", + "optional": True, + }, + "style": { + "type": "string", + "description": "Title style (descriptive, technical, concise)", + "default": "descriptive", + }, }, - "returns": "Improved title string" + "returns": "Improved title string", }, - "find_relationships": { "description": "Identify relationships to other documents", "parameters": { - "source_content": {"type": "string", "description": "Source document content"}, + "source_content": { + "type": "string", + "description": "Source document content", + }, "source_title": {"type": "string", "description": "Source document title"}, - "candidate_summaries": {"type": "array", "description": "List of candidate document summaries"}, - "relationship_types": {"type": "string", "description": "Types to look for", "default": "parent,child,related,reference"} + "candidate_summaries": { + "type": "array", + "description": "List of candidate document summaries", + }, + "relationship_types": { + "type": "string", + "description": "Types to look for", + "default": "parent,child,related,reference", + }, }, - "returns": "List of relationships with type and description" + "returns": "List of relationships with type and description", }, - "enhance_for_purpose": { "description": "Enhance document with purpose-specific metadata", "parameters": { "content": {"type": "string", "description": "Document content"}, - "purpose": {"type": "string", "description": "Purpose or use case for enhancement"}, - "fields": {"type": "array", "description": "Which fields to enhance", "default": ["context", "tags", "custom_metadata"]} + "purpose": { + "type": "string", + "description": "Purpose or use case for enhancement", + }, + "fields": { + "type": "array", + "description": "Which fields to enhance", + "default": ["context", "tags", "custom_metadata"], + }, }, - "returns": "Dictionary with enhanced fields" - } + "returns": "Dictionary with enhanced fields", + }, } class EnhancementTools: """MCP-compatible tool implementations for document enhancement. - + This class provides tool methods that can be exposed through MCP servers, allowing AI agents to call enhancement functions as native tools. """ - + def __init__(self, enhancer: ContextEnhancer): """Initialize with a ContextEnhancer instance.""" self.enhancer = enhancer - + def enhance_context( - self, - content: str, - purpose: str, - current_context: str | None = None + self, content: str, purpose: str, current_context: str | None = None ) -> str: """Add context to explain document relevance. - + Tool: enhance_context """ prompt = f""" @@ -102,21 +138,16 @@ def enhance_context( Context description: """ - + return self.enhancer.enhance_field( - content=content, - field_name="context", - prompt=prompt + content=content, field_name="context", prompt=prompt ) - + def extract_metadata( - self, - content: str, - schema: str, - format: str = "json" + self, content: str, schema: str, format: str = "json" ) -> dict[str, Any]: """Extract custom metadata from document. - + Tool: extract_metadata """ prompt = f""" @@ -128,21 +159,19 @@ def extract_metadata( Return as {"JSON" if format == "json" else format}: """ - + return self.enhancer.enhance_field( - content=content, - field_name="custom_metadata", - prompt=prompt + content=content, field_name="custom_metadata", prompt=prompt ) - + def generate_tags( self, content: str, tag_types: str = "topics, technologies, concepts", - max_tags: int = 5 + max_tags: int = 5, ) -> list[str]: """Generate relevant tags for a document. - + Tool: generate_tags """ prompt = f""" @@ -154,21 +183,16 @@ def generate_tags( Return tags as a comma-separated list: """ - + return self.enhancer.enhance_field( - content=content, - field_name="tags", - prompt=prompt + content=content, field_name="tags", prompt=prompt ) - + def improve_title( - self, - content: str, - current_title: str | None = None, - style: str = "descriptive" + self, content: str, current_title: str | None = None, style: str = "descriptive" ) -> str: """Generate or improve document title. - + Tool: improve_title """ prompt = f""" @@ -182,29 +206,29 @@ def improve_title( New title: """ - + return self.enhancer.enhance_field( - content=content, - field_name="title", - prompt=prompt + content=content, field_name="title", prompt=prompt ) - + def find_relationships( self, source_content: str, source_title: str, candidate_summaries: list[dict[str, str]], - relationship_types: str = "parent,child,related,reference" + relationship_types: str = "parent,child,related,reference", ) -> list[dict[str, Any]]: """Identify relationships to other documents. - + Tool: find_relationships """ - candidates_text = "\n".join([ - f"{i+1}. {c.get('title', 'Untitled')}: {c.get('summary', '')}" - for i, c in enumerate(candidate_summaries) - ]) - + candidates_text = "\n".join( + [ + f"{i + 1}. {c.get('title', 'Untitled')}: {c.get('summary', '')}" + for i, c in enumerate(candidate_summaries) + ] + ) + prompt = f""" Find relationships between the source document and candidates. Relationship types to consider: {relationship_types} @@ -227,37 +251,29 @@ def find_relationships( Only include clear relationships, not vague similarities. """ - + return self.enhancer.enhance_field( - content=source_content, - field_name="relationships", - prompt=prompt + content=source_content, field_name="relationships", prompt=prompt ) - + def enhance_for_purpose( - self, - content: str, - purpose: str, - fields: list[str] | None = None + self, content: str, purpose: str, fields: list[str] | None = None ) -> dict[str, Any]: """Enhance document with purpose-specific metadata. - + Tool: enhance_for_purpose """ if fields is None: fields = ["context", "tags", "custom_metadata"] - + results = {} - + if "context" in fields: results["context"] = self.enhance_context(content, purpose) - + if "tags" in fields: - results["tags"] = self.generate_tags( - content, - f"tags relevant to {purpose}" - ) - + results["tags"] = self.generate_tags(content, f"tags relevant to {purpose}") + if "custom_metadata" in fields: prompt = f""" Extract metadata relevant to: {purpose} @@ -271,51 +287,49 @@ def enhance_for_purpose( Return as JSON: """ - + results["custom_metadata"] = self.enhancer.enhance_field( - content=content, - field_name="custom_metadata", - prompt=prompt + content=content, field_name="custom_metadata", prompt=prompt ) - + return results def create_enhancement_tool(enhancer: ContextEnhancer, tool_name: str) -> callable: """Create MCP-compatible tool from enhancer method. - + Args: enhancer: ContextEnhancer instance tool_name: Name of the tool from ENHANCEMENT_TOOLS - + Returns: Callable tool function - + Raises: ValueError: If tool_name is not recognized """ if tool_name not in ENHANCEMENT_TOOLS: raise ValueError(f"Unknown tool: {tool_name}") - + tools = EnhancementTools(enhancer) return getattr(tools, tool_name) def get_tool_schema(tool_name: str) -> dict[str, Any]: """Get MCP schema for a tool. - + Args: tool_name: Name of the tool - + Returns: Tool schema dictionary """ if tool_name not in ENHANCEMENT_TOOLS: raise ValueError(f"Unknown tool: {tool_name}") - + return ENHANCEMENT_TOOLS[tool_name] def list_available_tools() -> list[str]: """List all available enhancement tools.""" - return list(ENHANCEMENT_TOOLS.keys()) \ No newline at end of file + return list(ENHANCEMENT_TOOLS.keys()) diff --git a/contextframe/extract/__init__.py b/contextframe/extract/__init__.py index 8513a96..fa85b9d 100644 --- a/contextframe/extract/__init__.py +++ b/contextframe/extract/__init__.py @@ -1,6 +1,6 @@ """Lightweight document extraction module for ContextFrame. -This module provides utilities for extracting content and metadata from +This module provides utilities for extracting content and metadata from lightweight text-based formats. For heavy document processing (PDFs, images), see the documentation for recommended external tools and integration patterns. """ @@ -24,4 +24,4 @@ "TextFileExtractor", "YAMLExtractor", "CSVExtractor", -] \ No newline at end of file +] diff --git a/contextframe/extract/base.py b/contextframe/extract/base.py index a97ef04..907fa93 100644 --- a/contextframe/extract/base.py +++ b/contextframe/extract/base.py @@ -9,7 +9,7 @@ @dataclass class ExtractionResult: """Result of a document extraction operation. - + Attributes: content: The extracted text content metadata: Extracted or inferred metadata @@ -19,6 +19,7 @@ class ExtractionResult: error: Error message if extraction failed warnings: List of non-fatal warnings during extraction """ + content: str metadata: dict[str, Any] = field(default_factory=dict) source: str | Path | None = None @@ -26,20 +27,20 @@ class ExtractionResult: chunks: list[str] | None = None error: str | None = None warnings: list[str] = field(default_factory=list) - + @property def success(self) -> bool: """Check if extraction was successful.""" return self.error is None and bool(self.content) - + def to_frame_record_kwargs(self) -> dict[str, Any]: """Convert extraction result to kwargs for FrameRecord creation. - + Maps extraction metadata to proper ContextFrame schema fields: - source -> source_file or source_url (based on URI scheme) - format -> source_type - Other custom fields -> custom_metadata object - + Returns: Dictionary suitable for FrameRecord(**kwargs) """ @@ -50,13 +51,13 @@ def to_frame_record_kwargs(self) -> dict[str, Any]: title = Path(str(self.source)).stem if not title: title = "Untitled Document" - + # Prepare kwargs with required fields kwargs = { "title": title, "content": self.content, } - + # Map extraction fields to schema fields if self.source: source_str = str(self.source) @@ -64,22 +65,35 @@ def to_frame_record_kwargs(self) -> dict[str, Any]: metadata["source_url"] = source_str else: metadata["source_file"] = source_str - + if self.format: metadata["source_type"] = self.format - + # Identify standard schema fields standard_fields = { - "source_file", "source_type", "source_url", "created_at", - "modified_at", "author", "description", "custom_metadata", - "tags", "context", "parent_id", "related_ids", "reference_ids", - "member_of", "version", "language", "revision" + "source_file", + "source_type", + "source_url", + "created_at", + "modified_at", + "author", + "description", + "custom_metadata", + "tags", + "context", + "parent_id", + "related_ids", + "reference_ids", + "member_of", + "version", + "language", + "revision", } - + # Move non-standard fields to custom_metadata custom_metadata = metadata.get("custom_metadata", {}) fields_to_move = [] - + for key, value in metadata.items(): if key not in standard_fields and not key.startswith("x_"): fields_to_move.append(key) @@ -89,134 +103,125 @@ def to_frame_record_kwargs(self) -> dict[str, Any]: elif isinstance(value, (list, dict)): # For complex types, convert to JSON string import json + custom_metadata[key] = json.dumps(value) else: custom_metadata[key] = str(value) - + # Remove moved fields from top-level metadata for key in fields_to_move: del metadata[key] - + # Only add custom_metadata if it has content if custom_metadata: metadata["custom_metadata"] = custom_metadata - + # Add all metadata to kwargs kwargs.update(metadata) - + return kwargs class TextExtractor(ABC): """Abstract base class for text extractors. - + Each extractor implementation should handle a specific file format or content type. """ - + def __init__(self): """Initialize the extractor.""" self.supported_extensions: list[str] = [] self.format_name: str = "unknown" - + @abstractmethod def can_extract(self, source: str | Path) -> bool: """Check if this extractor can handle the given source. - + Args: source: File path or content identifier - + Returns: True if this extractor can handle the source """ pass - + @abstractmethod def extract( - self, - source: str | Path, - encoding: str = "utf-8", - **kwargs + self, source: str | Path, encoding: str = "utf-8", **kwargs ) -> ExtractionResult: """Extract content and metadata from the source. - + Args: source: File path or content identifier encoding: Text encoding to use **kwargs: Additional extractor-specific options - + Returns: ExtractionResult containing the extracted data """ pass - + def extract_from_string( - self, - content: str, - source: str | Path | None = None, - **kwargs + self, content: str, source: str | Path | None = None, **kwargs ) -> ExtractionResult: """Extract from string content instead of file. - + Args: content: The text content to process source: Optional source identifier **kwargs: Additional extractor-specific options - + Returns: ExtractionResult containing the extracted data """ # Default implementation - subclasses can override for format-specific parsing - return ExtractionResult( - content=content, - source=source, - format=self.format_name - ) - + return ExtractionResult(content=content, source=source, format=self.format_name) + def validate_source(self, source: str | Path) -> Path: """Validate and convert source to Path object. - + Args: source: File path or content identifier - + Returns: Path object - + Raises: ValueError: If source is invalid or doesn't exist """ path = Path(source) if not isinstance(source, Path) else source - + if not path.exists(): raise ValueError(f"Source file does not exist: {path}") - + if not path.is_file(): raise ValueError(f"Source is not a file: {path}") - + return path class ExtractorRegistry: """Registry for managing available extractors.""" - + def __init__(self): """Initialize the registry.""" self._extractors: list[TextExtractor] = [] - + def register(self, extractor: TextExtractor) -> None: """Register an extractor. - + Args: extractor: The extractor instance to register """ self._extractors.append(extractor) - + def find_extractor(self, source: str | Path) -> TextExtractor | None: """Find an appropriate extractor for the given source. - + Args: source: File path or content identifier - + Returns: The first extractor that can handle the source, or None """ @@ -224,10 +229,10 @@ def find_extractor(self, source: str | Path) -> TextExtractor | None: if extractor.can_extract(source): return extractor return None - + def get_supported_formats(self) -> dict[str, list[str]]: """Get all supported formats and their extensions. - + Returns: Dictionary mapping format names to lists of extensions """ @@ -238,4 +243,4 @@ def get_supported_formats(self) -> dict[str, list[str]]: # Global registry instance -registry = ExtractorRegistry() \ No newline at end of file +registry = ExtractorRegistry() diff --git a/contextframe/extract/batch.py b/contextframe/extract/batch.py index 1a390f1..33f59c0 100644 --- a/contextframe/extract/batch.py +++ b/contextframe/extract/batch.py @@ -10,15 +10,15 @@ class BatchExtractor: """Batch processor for extracting multiple documents efficiently.""" - + def __init__( self, max_workers: int | None = None, use_process_pool: bool = False, - progress_callback: Callable[[int, int, str], None] | None = None + progress_callback: Callable[[int, int, str], None] | None = None, ): """Initialize the batch extractor. - + Args: max_workers: Maximum number of workers for parallel processing. If None, uses CPU count. @@ -31,52 +31,49 @@ def __init__( self.use_process_pool = use_process_pool self.progress_callback = progress_callback self._extractors = {} - + def extract_files( self, file_paths: Iterable[str | Path], encoding: str = "utf-8", skip_errors: bool = True, - **extractor_kwargs + **extractor_kwargs, ) -> list[ExtractionResult]: """Extract content from multiple files. - + Args: file_paths: Iterable of file paths to process encoding: Default text encoding skip_errors: Whether to skip files that fail extraction **extractor_kwargs: Additional arguments passed to extractors - + Returns: List of ExtractionResult objects """ file_paths = list(file_paths) total_count = len(file_paths) - + if self.use_process_pool: executor_class = ProcessPoolExecutor else: executor_class = ThreadPoolExecutor - + results = [] - + with executor_class(max_workers=self.max_workers) as executor: # Submit all tasks futures = [] for i, file_path in enumerate(file_paths): future = executor.submit( - self._extract_single, - file_path, - encoding, - extractor_kwargs + self._extract_single, file_path, encoding, extractor_kwargs ) futures.append((i, file_path, future)) - + # Collect results as they complete for i, file_path, future in futures: if self.progress_callback: self.progress_callback(i + 1, total_count, str(file_path)) - + try: result = future.result() results.append(result) @@ -87,12 +84,12 @@ def extract_files( error_result = ExtractionResult( content="", source=file_path, - error=f"Extraction failed: {str(e)}" + error=f"Extraction failed: {str(e)}", ) results.append(error_result) - + return results - + def extract_directory( self, directory: str | Path, @@ -100,10 +97,10 @@ def extract_directory( recursive: bool = True, encoding: str = "utf-8", skip_errors: bool = True, - **extractor_kwargs + **extractor_kwargs, ) -> list[ExtractionResult]: """Extract content from all matching files in a directory. - + Args: directory: Directory path to scan pattern: Glob pattern for file matching (e.g., "*.md", "*.json") @@ -111,63 +108,56 @@ def extract_directory( encoding: Default text encoding skip_errors: Whether to skip files that fail extraction **extractor_kwargs: Additional arguments passed to extractors - + Returns: List of ExtractionResult objects """ directory = Path(directory) - + if recursive: file_paths = directory.rglob(pattern) else: file_paths = directory.glob(pattern) - + # Filter to only files (not directories) file_paths = [p for p in file_paths if p.is_file()] - + return self.extract_files( - file_paths, - encoding=encoding, - skip_errors=skip_errors, - **extractor_kwargs + file_paths, encoding=encoding, skip_errors=skip_errors, **extractor_kwargs ) - + async def extract_files_async( self, file_paths: Iterable[str | Path], encoding: str = "utf-8", skip_errors: bool = True, - **extractor_kwargs + **extractor_kwargs, ) -> list[ExtractionResult]: """Asynchronously extract content from multiple files. - + Args: file_paths: Iterable of file paths to process encoding: Default text encoding skip_errors: Whether to skip files that fail extraction **extractor_kwargs: Additional arguments passed to extractors - + Returns: List of ExtractionResult objects """ file_paths = list(file_paths) total_count = len(file_paths) - + # Create tasks tasks = [] for i, file_path in enumerate(file_paths): task = self._extract_single_async( - i, - total_count, - file_path, - encoding, - extractor_kwargs + i, total_count, file_path, encoding, extractor_kwargs ) tasks.append(task) - + # Run tasks concurrently results = await asyncio.gather(*tasks, return_exceptions=skip_errors) - + # Process results final_results = [] for file_path, result in zip(file_paths, results, strict=False): @@ -178,66 +168,59 @@ async def extract_files_async( error_result = ExtractionResult( content="", source=file_path, - error=f"Extraction failed: {str(result)}" + error=f"Extraction failed: {str(result)}", ) final_results.append(error_result) else: final_results.append(result) - + return final_results - + def _extract_single( - self, - file_path: str | Path, - encoding: str, - extractor_kwargs: dict + self, file_path: str | Path, encoding: str, extractor_kwargs: dict ) -> ExtractionResult: """Extract a single file.""" # Find appropriate extractor extractor = registry.find_extractor(file_path) - + if not extractor: return ExtractionResult( content="", source=file_path, - error=f"No extractor found for file: {file_path}" + error=f"No extractor found for file: {file_path}", ) - + # Extract content return extractor.extract(file_path, encoding=encoding, **extractor_kwargs) - + async def _extract_single_async( self, index: int, total: int, file_path: str | Path, encoding: str, - extractor_kwargs: dict + extractor_kwargs: dict, ) -> ExtractionResult: """Extract a single file asynchronously.""" if self.progress_callback: self.progress_callback(index + 1, total, str(file_path)) - + # Run extraction in thread pool to avoid blocking loop = asyncio.get_event_loop() return await loop.run_in_executor( - None, - self._extract_single, - file_path, - encoding, - extractor_kwargs + None, self._extract_single, file_path, encoding, extractor_kwargs ) - + def extract_with_custom_extractors( self, file_paths: Iterable[str | Path], extractors: dict[str, TextExtractor], encoding: str = "utf-8", skip_errors: bool = True, - **extractor_kwargs + **extractor_kwargs, ) -> list[ExtractionResult]: """Extract files using custom extractors for specific extensions. - + Args: file_paths: Iterable of file paths to process extractors: Dictionary mapping file extensions to extractor instances @@ -245,48 +228,48 @@ def extract_with_custom_extractors( encoding: Default text encoding skip_errors: Whether to skip files that fail extraction **extractor_kwargs: Additional arguments passed to extractors - + Returns: List of ExtractionResult objects """ results = [] file_paths = list(file_paths) total_count = len(file_paths) - + for i, file_path in enumerate(file_paths): if self.progress_callback: self.progress_callback(i + 1, total_count, str(file_path)) - + path = Path(file_path) - + # Find extractor by extension extractor = extractors.get(path.suffix.lower()) if not extractor: # Fall back to registry extractor = registry.find_extractor(file_path) - + if not extractor: if not skip_errors: raise ValueError(f"No extractor found for file: {file_path}") error_result = ExtractionResult( content="", source=file_path, - error=f"No extractor found for file: {file_path}" + error=f"No extractor found for file: {file_path}", ) results.append(error_result) continue - + try: - result = extractor.extract(file_path, encoding=encoding, **extractor_kwargs) + result = extractor.extract( + file_path, encoding=encoding, **extractor_kwargs + ) results.append(result) except Exception as e: if not skip_errors: raise error_result = ExtractionResult( - content="", - source=file_path, - error=f"Extraction failed: {str(e)}" + content="", source=file_path, error=f"Extraction failed: {str(e)}" ) results.append(error_result) - - return results \ No newline at end of file + + return results diff --git a/contextframe/extract/chunking.py b/contextframe/extract/chunking.py index 899473c..b744450 100644 --- a/contextframe/extract/chunking.py +++ b/contextframe/extract/chunking.py @@ -2,7 +2,7 @@ from .base import ExtractionResult from collections.abc import Callable -from typing import List, Optional, Tuple, Union, Literal +from typing import List, Literal, Optional, Tuple, Union def semantic_splitter( @@ -10,15 +10,15 @@ def semantic_splitter( chunk_size: int = 512, chunk_overlap: int | None = None, splitter_type: Literal["text", "markdown", "code"] = "text", - tokenizer_model: Optional[str] = None, - language: Optional[str] = None, + tokenizer_model: str | None = None, + language: str | None = None, ) -> list[tuple[int, str]]: """Split texts using semantic-text-splitter. - + This function provides high-performance text splitting using the Rust-based semantic-text-splitter library. It supports character, token, and semantic splitting for text, markdown, and code. - + Args: texts: List of text strings to split chunk_size: Maximum size of each chunk (default: 512) @@ -30,36 +30,36 @@ def semantic_splitter( - HuggingFace models: "bert-base-uncased", etc. - None for character-based splitting language: Required for code splitting (e.g., "python", "javascript") - + Returns: List of tuples (text_index, chunk_content) where text_index indicates which input text the chunk came from - + Raises: ImportError: If semantic-text-splitter is not installed ValueError: If invalid parameters are provided """ try: - from semantic_text_splitter import TextSplitter, MarkdownSplitter, CodeSplitter + from semantic_text_splitter import CodeSplitter, MarkdownSplitter, TextSplitter except ImportError: raise ImportError( "semantic-text-splitter is required for text splitting. " "Install with: pip install semantic-text-splitter" ) - + if chunk_overlap is None: chunk_overlap = 0 - + # Create appropriate splitter based on type if splitter_type == "code": if not language: raise ValueError("Language parameter is required for code splitting") - + # Import appropriate tree-sitter language # Map common file extensions to language names if needed language_map = { "py": "python", - "js": "javascript", + "js": "javascript", "ts": "typescript", "rs": "rust", "go": "go", @@ -78,10 +78,10 @@ def semantic_splitter( "sh": "bash", "bash": "bash", } - + # Normalize language name lang_name = language_map.get(language.lower(), language.lower()) - + try: lang_module = __import__(f"tree_sitter_{lang_name}") splitter = CodeSplitter(lang_module.language(), chunk_size) @@ -92,19 +92,26 @@ def semantic_splitter( ) else: # Choose between TextSplitter and MarkdownSplitter - SplitterClass = MarkdownSplitter if splitter_type == "markdown" else TextSplitter - + SplitterClass = ( + MarkdownSplitter if splitter_type == "markdown" else TextSplitter + ) + # Create splitter with appropriate sizing strategy if tokenizer_model: if tokenizer_model.startswith(("gpt", "claude", "text-embedding")): # OpenAI-style models using tiktoken - splitter = SplitterClass.from_tiktoken_model(tokenizer_model, chunk_size) + splitter = SplitterClass.from_tiktoken_model( + tokenizer_model, chunk_size + ) else: # HuggingFace tokenizer try: from tokenizers import Tokenizer + tokenizer = Tokenizer.from_pretrained(tokenizer_model) - splitter = SplitterClass.from_huggingface_tokenizer(tokenizer, chunk_size) + splitter = SplitterClass.from_huggingface_tokenizer( + tokenizer, chunk_size + ) except ImportError: raise ImportError( f"tokenizers package is required for HuggingFace tokenizer '{tokenizer_model}'. " @@ -113,6 +120,7 @@ def semantic_splitter( except Exception as e: # Fallback to character-based if model not found import warnings + warnings.warn( f"Failed to load tokenizer '{tokenizer_model}': {e}. " "Falling back to character-based splitting." @@ -121,18 +129,18 @@ def semantic_splitter( else: # Character-based splitting splitter = SplitterClass(chunk_size) - + chunks = [] - + # Process each text for idx, text in enumerate(texts): # Get chunks with indices for potential overlap support text_chunks = splitter.chunks(text) - + # Add chunks with source index for chunk_text in text_chunks: chunks.append((idx, chunk_text)) - + return chunks @@ -142,10 +150,10 @@ def split_extraction_results( chunk_overlap: int | None = None, splitter_fn: Callable | None = None, splitter_type: Literal["text", "markdown", "code"] = "text", - tokenizer_model: Optional[str] = None, + tokenizer_model: str | None = None, ) -> list[ExtractionResult]: """Split extraction results into smaller chunks. - + Args: results: List of ExtractionResult objects to split chunk_size: Maximum size of each chunk @@ -155,59 +163,61 @@ def split_extraction_results( return List[Tuple[text_index, chunk_content]] splitter_type: Type of splitter - "text", "markdown", or "code" tokenizer_model: Optional tokenizer model name for token-based splitting - + Returns: List of new ExtractionResult objects, one per chunk """ if splitter_fn is None: splitter_fn = semantic_splitter - + # Extract texts and track sources texts = [] source_results = [] - + for result in results: if result.success and result.content: texts.append(result.content) source_results.append(result) - + if not texts: return results - + # Split texts chunks = splitter_fn( - texts, - chunk_size=chunk_size, + texts, + chunk_size=chunk_size, chunk_overlap=chunk_overlap, splitter_type=splitter_type, tokenizer_model=tokenizer_model, ) - + # Create new ExtractionResult objects for chunks chunked_results = [] - + # Group chunks by source chunk_groups = {} for text_idx, chunk_content in chunks: if text_idx not in chunk_groups: chunk_groups[text_idx] = [] chunk_groups[text_idx].append(chunk_content) - + # Create results maintaining source metadata for text_idx, chunk_list in chunk_groups.items(): source_result = source_results[text_idx] - + for i, chunk_content in enumerate(chunk_list): # Create new metadata including chunk info chunk_metadata = source_result.metadata.copy() - chunk_metadata.update({ - "chunk_index": i, - "chunk_count": len(chunk_list), - "chunk_size": chunk_size, - "chunk_overlap": chunk_overlap or 0, - "original_content_length": len(source_result.content), - }) - + chunk_metadata.update( + { + "chunk_index": i, + "chunk_count": len(chunk_list), + "chunk_size": chunk_size, + "chunk_overlap": chunk_overlap or 0, + "original_content_length": len(source_result.content), + } + ) + # Create new result for chunk chunk_result = ExtractionResult( content=chunk_content, @@ -216,26 +226,28 @@ def split_extraction_results( format=source_result.format, chunks=None, # Don't propagate chunks to avoid confusion error=None, - warnings=source_result.warnings.copy() if source_result.warnings else [], + warnings=source_result.warnings.copy() + if source_result.warnings + else [], ) - + chunked_results.append(chunk_result) - + # Include any failed results unchanged for result in results: if not result.success: chunked_results.append(result) - + return chunked_results class ChunkingMixin: """Mixin class to add chunking capability to extractors. - + This can be mixed into any TextExtractor subclass to add automatic chunking support using semantic-text-splitter. """ - + def extract_with_chunking( self, source, @@ -243,11 +255,11 @@ def extract_with_chunking( chunk_overlap: int | None = None, encoding: str = "utf-8", splitter_type: Literal["text", "markdown", "code"] = "text", - tokenizer_model: Optional[str] = None, - **kwargs + tokenizer_model: str | None = None, + **kwargs, ) -> ExtractionResult: """Extract and automatically chunk the content. - + Args: source: File path or content identifier chunk_size: Maximum size of each chunk @@ -256,24 +268,33 @@ def extract_with_chunking( splitter_type: Type of splitter - "text", "markdown", or "code" tokenizer_model: Optional tokenizer model name for token-based splitting **kwargs: Additional extractor-specific options - + Returns: ExtractionResult with chunks field populated """ # First extract normally result = self.extract(source, encoding=encoding, **kwargs) - + if not result.success or not result.content: return result - + try: # Detect format for splitter type if not specified if splitter_type == "text" and result.format: if result.format.lower() in ["markdown", "md"]: splitter_type = "markdown" - elif result.format.lower() in ["py", "js", "ts", "java", "cpp", "c", "go", "rust"]: + elif result.format.lower() in [ + "py", + "js", + "ts", + "java", + "cpp", + "c", + "go", + "rust", + ]: splitter_type = "code" - + # Split the content chunks = semantic_splitter( [result.content], @@ -283,10 +304,10 @@ def extract_with_chunking( tokenizer_model=tokenizer_model, language=result.format.lower() if splitter_type == "code" else None, ) - + # Extract just the chunk texts chunk_texts = [chunk_text for _, chunk_text in chunks] - + # Update the result result.chunks = chunk_texts result.metadata["chunk_size"] = chunk_size @@ -295,31 +316,31 @@ def extract_with_chunking( result.metadata["splitter_type"] = splitter_type if tokenizer_model: result.metadata["tokenizer_model"] = tokenizer_model - + except ImportError as e: result.warnings.append(f"Chunking unavailable: {str(e)}") except Exception as e: result.warnings.append(f"Chunking failed: {str(e)}") - + return result - + @staticmethod def chunk_text( text: str, chunk_size: int = 512, chunk_overlap: int | None = None, splitter_type: Literal["text", "markdown"] = "text", - tokenizer_model: Optional[str] = None, - ) -> List[str]: + tokenizer_model: str | None = None, + ) -> list[str]: """Convenience method to chunk a single text string. - + Args: text: Text to chunk chunk_size: Maximum size of each chunk chunk_overlap: Number of overlapping characters between chunks splitter_type: Type of splitter - "text" or "markdown" tokenizer_model: Optional tokenizer model name - + Returns: List of chunk strings """ @@ -330,4 +351,4 @@ def chunk_text( splitter_type=splitter_type, tokenizer_model=tokenizer_model, ) - return [chunk_text for _, chunk_text in chunks] \ No newline at end of file + return [chunk_text for _, chunk_text in chunks] diff --git a/contextframe/extract/extractors.py b/contextframe/extract/extractors.py index 1a1ee99..213a8ce 100644 --- a/contextframe/extract/extractors.py +++ b/contextframe/extract/extractors.py @@ -12,13 +12,13 @@ class TextFileExtractor(TextExtractor): """Extractor for plain text files.""" - + def __init__(self): """Initialize the text file extractor.""" super().__init__() self.supported_extensions = [".txt", ".text", ".log"] self.format_name = "text" - + def can_extract(self, source: str | Path) -> bool: """Check if this is a plain text file.""" try: @@ -26,45 +26,39 @@ def can_extract(self, source: str | Path) -> bool: return path.suffix.lower() in self.supported_extensions except Exception: return False - + def extract( - self, - source: str | Path, - encoding: str = "utf-8", - **kwargs + self, source: str | Path, encoding: str = "utf-8", **kwargs ) -> ExtractionResult: """Extract content from a plain text file.""" try: path = self.validate_source(source) - + with open(path, encoding=encoding) as f: content = f.read() - + metadata = { "filename": path.name, "size": path.stat().st_size, "encoding": encoding, } - + return ExtractionResult( - content=content, - metadata=metadata, - source=path, - format=self.format_name + content=content, metadata=metadata, source=path, format=self.format_name ) - + except Exception as e: return ExtractionResult( content="", error=f"Failed to extract text file: {str(e)}", source=source, - format=self.format_name + format=self.format_name, ) class MarkdownExtractor(TextExtractor): """Extractor for Markdown files with frontmatter support.""" - + def __init__(self): """Initialize the markdown extractor.""" super().__init__() @@ -72,10 +66,9 @@ def __init__(self): self.format_name = "markdown" # Pattern to match YAML frontmatter self.frontmatter_pattern = re.compile( - r"^---\s*\n(.*?)\n---\s*\n", - re.DOTALL | re.MULTILINE + r"^---\s*\n(.*?)\n---\s*\n", re.DOTALL | re.MULTILINE ) - + def can_extract(self, source: str | Path) -> bool: """Check if this is a markdown file.""" try: @@ -83,29 +76,29 @@ def can_extract(self, source: str | Path) -> bool: return path.suffix.lower() in self.supported_extensions except Exception: return False - + def extract( - self, - source: str | Path, + self, + source: str | Path, encoding: str = "utf-8", extract_frontmatter: bool = True, - **kwargs + **kwargs, ) -> ExtractionResult: """Extract content and frontmatter from a markdown file.""" try: path = self.validate_source(source) - + with open(path, encoding=encoding) as f: raw_content = f.read() - + metadata = { "filename": path.name, "size": path.stat().st_size, "encoding": encoding, } - + content = raw_content - + # Extract frontmatter if present and requested if extract_frontmatter: match = self.frontmatter_pattern.match(raw_content) @@ -115,7 +108,7 @@ def extract( frontmatter_data = yaml.safe_load(frontmatter_text) or {} metadata.update(frontmatter_data) # Remove frontmatter from content - content = raw_content[match.end():] + content = raw_content[match.end() :] except yaml.YAMLError as e: # Add warning but continue warnings = [f"Failed to parse frontmatter: {str(e)}"] @@ -124,35 +117,35 @@ def extract( metadata=metadata, source=path, format=self.format_name, - warnings=warnings + warnings=warnings, ) - + return ExtractionResult( content=content.strip(), metadata=metadata, source=path, - format=self.format_name + format=self.format_name, ) - + except Exception as e: return ExtractionResult( content="", error=f"Failed to extract markdown file: {str(e)}", source=source, - format=self.format_name + format=self.format_name, ) - + def extract_from_string( - self, + self, content: str, source: str | Path | None = None, extract_frontmatter: bool = True, - **kwargs + **kwargs, ) -> ExtractionResult: """Extract from markdown string content.""" metadata = {} warnings = [] - + # Extract frontmatter if requested if extract_frontmatter: match = self.frontmatter_pattern.match(content) @@ -162,28 +155,28 @@ def extract_from_string( frontmatter_data = yaml.safe_load(frontmatter_text) or {} metadata.update(frontmatter_data) # Remove frontmatter from content - content = content[match.end():] + content = content[match.end() :] except yaml.YAMLError as e: warnings.append(f"Failed to parse frontmatter: {str(e)}") - + return ExtractionResult( content=content.strip(), metadata=metadata, source=source, format=self.format_name, - warnings=warnings + warnings=warnings, ) class JSONExtractor(TextExtractor): """Extractor for JSON and JSONL files.""" - + def __init__(self): """Initialize the JSON extractor.""" super().__init__() self.supported_extensions = [".json", ".jsonl", ".ndjson"] self.format_name = "json" - + def can_extract(self, source: str | Path) -> bool: """Check if this is a JSON file.""" try: @@ -191,16 +184,16 @@ def can_extract(self, source: str | Path) -> bool: return path.suffix.lower() in self.supported_extensions except Exception: return False - + def extract( - self, - source: str | Path, + self, + source: str | Path, encoding: str = "utf-8", extract_text_fields: list[str] | None = None, - **kwargs + **kwargs, ) -> ExtractionResult: """Extract content from JSON files. - + Args: source: File path encoding: Text encoding @@ -209,7 +202,7 @@ def extract( """ try: path = self.validate_source(source) - + with open(path, encoding=encoding) as f: if path.suffix.lower() in [".jsonl", ".ndjson"]: # Handle JSON Lines format @@ -227,7 +220,7 @@ def extract( content="", error=f"Invalid JSON on line {line_num}: {str(e)}", source=path, - format=self.format_name + format=self.format_name, ) else: # Regular JSON file @@ -239,19 +232,19 @@ def extract( content="", error=f"Invalid JSON: {str(e)}", source=path, - format=self.format_name + format=self.format_name, ) - + metadata = { "filename": path.name, "size": path.stat().st_size, "encoding": encoding, } - + # Extract text content if extract_text_fields: extracted_texts = [] - + def extract_fields(obj, fields): """Recursively extract text from specified fields.""" if isinstance(obj, dict): @@ -263,41 +256,38 @@ def extract_fields(obj, fields): elif isinstance(obj, list): for item in obj: extract_fields(item, fields) - + extract_fields(data, extract_text_fields) content = "\n\n".join(extracted_texts) else: # Pretty-print the JSON as content content = json.dumps(data, indent=2, ensure_ascii=False) - + # Store the structured data in metadata metadata["json_data"] = data - + return ExtractionResult( - content=content, - metadata=metadata, - source=path, - format=self.format_name + content=content, metadata=metadata, source=path, format=self.format_name ) - + except Exception as e: return ExtractionResult( content="", error=f"Failed to extract JSON file: {str(e)}", source=source, - format=self.format_name + format=self.format_name, ) class YAMLExtractor(TextExtractor): """Extractor for YAML files.""" - + def __init__(self): """Initialize the YAML extractor.""" super().__init__() self.supported_extensions = [".yaml", ".yml"] self.format_name = "yaml" - + def can_extract(self, source: str | Path) -> bool: """Check if this is a YAML file.""" try: @@ -305,18 +295,18 @@ def can_extract(self, source: str | Path) -> bool: return path.suffix.lower() in self.supported_extensions except Exception: return False - + def extract( - self, - source: str | Path, + self, + source: str | Path, encoding: str = "utf-8", extract_text_fields: list[str] | None = None, - **kwargs + **kwargs, ) -> ExtractionResult: """Extract content from YAML files.""" try: path = self.validate_source(source) - + with open(path, encoding=encoding) as f: content = f.read() try: @@ -326,19 +316,19 @@ def extract( content="", error=f"Invalid YAML: {str(e)}", source=path, - format=self.format_name + format=self.format_name, ) - + metadata = { "filename": path.name, "size": path.stat().st_size, "encoding": encoding, } - + # Extract text content if extract_text_fields and isinstance(data, dict): extracted_texts = [] - + def extract_fields(obj, fields): """Recursively extract text from specified fields.""" if isinstance(obj, dict): @@ -350,41 +340,38 @@ def extract_fields(obj, fields): elif isinstance(obj, list): for item in obj: extract_fields(item, fields) - + extract_fields(data, extract_text_fields) content = "\n\n".join(extracted_texts) else: # Use original YAML content pass - + # Store the structured data in metadata metadata["yaml_data"] = data - + return ExtractionResult( - content=content, - metadata=metadata, - source=path, - format=self.format_name + content=content, metadata=metadata, source=path, format=self.format_name ) - + except Exception as e: return ExtractionResult( content="", error=f"Failed to extract YAML file: {str(e)}", source=source, - format=self.format_name + format=self.format_name, ) class CSVExtractor(TextExtractor): """Extractor for CSV/TSV files.""" - + def __init__(self): """Initialize the CSV extractor.""" super().__init__() self.supported_extensions = [".csv", ".tsv"] self.format_name = "csv" - + def can_extract(self, source: str | Path) -> bool: """Check if this is a CSV/TSV file.""" try: @@ -392,17 +379,17 @@ def can_extract(self, source: str | Path) -> bool: return path.suffix.lower() in self.supported_extensions except Exception: return False - + def extract( - self, - source: str | Path, + self, + source: str | Path, encoding: str = "utf-8", text_columns: list[str | int] | None = None, include_headers: bool = True, - **kwargs + **kwargs, ) -> ExtractionResult: """Extract content from CSV/TSV files. - + Args: source: File path encoding: Text encoding @@ -411,32 +398,32 @@ def extract( """ try: path = self.validate_source(source) - + # Determine delimiter delimiter = "\t" if path.suffix.lower() == ".tsv" else "," - + rows = [] headers = [] - + with open(path, encoding=encoding, newline="") as f: # Try to detect dialect sample = f.read(8192) f.seek(0) - + try: dialect = csv.Sniffer().sniff(sample) delimiter = dialect.delimiter except csv.Error: # Use default delimiter based on extension pass - + reader = csv.reader(f, delimiter=delimiter) - + for i, row in enumerate(reader): if i == 0 and csv.Sniffer().has_header(sample): headers = row rows.append(row) - + metadata = { "filename": path.name, "size": path.stat().st_size, @@ -445,17 +432,17 @@ def extract( "row_count": len(rows), "column_count": len(rows[0]) if rows else 0, } - + if headers: metadata["headers"] = headers - + # Extract text content extracted_rows = [] - + if text_columns is not None: # Extract specific columns column_indices = [] - + for col in text_columns: if isinstance(col, int): column_indices.append(col) @@ -465,7 +452,7 @@ def extract( column_indices.append(idx) except ValueError: warnings = [f"Column '{col}' not found in headers"] - + for i, row in enumerate(rows): if i == 0 and headers and not include_headers: continue @@ -481,25 +468,22 @@ def extract( if i == 0 and headers and not include_headers: continue extracted_rows.append(", ".join(row)) - + content = "\n".join(extracted_rows) - + # Store the structured data in metadata metadata["csv_data"] = rows - + return ExtractionResult( - content=content, - metadata=metadata, - source=path, - format=self.format_name + content=content, metadata=metadata, source=path, format=self.format_name ) - + except Exception as e: return ExtractionResult( content="", error=f"Failed to extract CSV file: {str(e)}", source=source, - format=self.format_name + format=self.format_name, ) @@ -508,4 +492,4 @@ def extract( registry.register(MarkdownExtractor()) registry.register(JSONExtractor()) registry.register(YAMLExtractor()) -registry.register(CSVExtractor()) \ No newline at end of file +registry.register(CSVExtractor()) diff --git a/contextframe/io/exporter.py b/contextframe/io/exporter.py index 44f2da8..a56c6e2 100644 --- a/contextframe/io/exporter.py +++ b/contextframe/io/exporter.py @@ -157,7 +157,9 @@ def _export_markdown( # Add relationship visualization relationships = frameset.metadata.get("relationships", []) if relationships: - content.extend(self._build_relationship_diagram(frameset, relationships)) + content.extend( + self._build_relationship_diagram(frameset, relationships) + ) content.append("") frames = self.dataset.get_frameset_frames(frameset.uuid) frame_count = len(frames) @@ -262,14 +264,14 @@ def _build_relationship_diagram( ) -> list[str]: """Build a Mermaid diagram showing relationships.""" lines = [] - + lines.append("## Relationship Visualization") lines.append("") lines.append("```mermaid") lines.append("graph TD") lines.append(f' FS["{frameset.title}
FrameSet"]') lines.append("") - + # Group relationships by type rel_by_type = {} for rel in relationships: @@ -277,7 +279,7 @@ def _build_relationship_diagram( if rel_type not in rel_by_type: rel_by_type[rel_type] = [] rel_by_type[rel_type].append(rel) - + # Add nodes and connections node_count = 0 for rel_type, rels in rel_by_type.items(): @@ -287,9 +289,9 @@ def _build_relationship_diagram( title = rel.get("title", "Unknown") # Escape quotes in title for Mermaid title = title.replace('"', "'") - + lines.append(f' {node_id}["{title}"]') - + # Define edge style based on relationship type if rel_type == "contains": edge_style = "-->|contains|" @@ -301,18 +303,18 @@ def _build_relationship_diagram( edge_style = "---|related|" else: edge_style = f"---|{rel_type}|" - + lines.append(f" FS {edge_style} {node_id}") - + lines.append("```") lines.append("") - + # Add legend lines.append("**Relationship Types:**") lines.append("- `contains`: Direct inclusion in the frameset") lines.append("- `references`: External reference or citation") lines.append("- `member_of`: Part of a collection or group") - + return lines def _add_usage_instructions(self) -> list[str]: diff --git a/contextframe/mcp/__init__.py b/contextframe/mcp/__init__.py index 9c69c9a..89881ab 100644 --- a/contextframe/mcp/__init__.py +++ b/contextframe/mcp/__init__.py @@ -6,4 +6,4 @@ from contextframe.mcp.server import ContextFrameMCPServer -__all__ = ["ContextFrameMCPServer"] \ No newline at end of file +__all__ = ["ContextFrameMCPServer"] diff --git a/contextframe/mcp/__main__.py b/contextframe/mcp/__main__.py index ae14b84..baeefd4 100644 --- a/contextframe/mcp/__main__.py +++ b/contextframe/mcp/__main__.py @@ -8,4 +8,4 @@ from contextframe.mcp.server import main if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/contextframe/mcp/batch/__init__.py b/contextframe/mcp/batch/__init__.py index c1473d2..b9f9bf4 100644 --- a/contextframe/mcp/batch/__init__.py +++ b/contextframe/mcp/batch/__init__.py @@ -3,4 +3,4 @@ from .handler import BatchOperationHandler from .tools import BatchTools -__all__ = ["BatchOperationHandler", "BatchTools"] \ No newline at end of file +__all__ = ["BatchOperationHandler", "BatchTools"] diff --git a/contextframe/mcp/batch/handler.py b/contextframe/mcp/batch/handler.py index 0576bd0..20698f2 100644 --- a/contextframe/mcp/batch/handler.py +++ b/contextframe/mcp/batch/handler.py @@ -2,13 +2,12 @@ import asyncio import logging -from typing import Any, Callable, Dict, List, Optional, TypeVar -from dataclasses import dataclass - +from collections.abc import Callable from contextframe.frame import FrameDataset -from contextframe.mcp.core.transport import TransportAdapter, Progress from contextframe.mcp.core.streaming import StreamingAdapter - +from contextframe.mcp.core.transport import Progress, TransportAdapter +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, TypeVar logger = logging.getLogger(__name__) @@ -20,113 +19,113 @@ @dataclass class BatchResult: """Result of a batch operation.""" - + total_processed: int total_errors: int - results: List[Any] - errors: List[Dict[str, Any]] + results: list[Any] + errors: list[dict[str, Any]] operation: str class BatchOperationHandler: """Base class for batch operations with progress tracking. - + Provides a consistent interface for batch operations across different transports (stdio, HTTP) with proper progress tracking and error handling. """ - + def __init__(self, dataset: FrameDataset, transport: TransportAdapter): """Initialize batch handler. - + Args: dataset: The FrameDataset to operate on transport: Transport adapter for progress/streaming """ self.dataset = dataset self.transport = transport - self.streaming: Optional[StreamingAdapter] = None - + self.streaming: StreamingAdapter | None = None + # Get streaming adapter if transport is StdioAdapter if hasattr(transport, 'get_streaming_adapter'): self.streaming = transport.get_streaming_adapter() - + async def execute_batch( self, operation: str, - items: List[T], + items: list[T], processor: Callable[[T], R], atomic: bool = False, - max_errors: Optional[int] = None + max_errors: int | None = None, ) -> BatchResult: """Execute batch operation with progress tracking. - + Args: operation: Name of the operation for progress tracking items: List of items to process processor: Async function to process each item atomic: If True, rollback all on any failure max_errors: Stop after this many errors (None = no limit) - + Returns: BatchResult with processed items and errors """ total = len(items) results = [] errors = [] - + # Start streaming if available if self.streaming: await self.streaming.start_stream(operation, total) - + try: for i, item in enumerate(items): # Send progress - await self.transport.send_progress(Progress( - operation=operation, - current=i + 1, - total=total, - status=f"Processing item {i + 1} of {total}" - )) - + await self.transport.send_progress( + Progress( + operation=operation, + current=i + 1, + total=total, + status=f"Processing item {i + 1} of {total}", + ) + ) + try: # Process item if asyncio.iscoroutinefunction(processor): result = await processor(item) else: result = processor(item) - + results.append(result) - + # Stream result if available if self.streaming: await self.streaming.send_item(result) - + except Exception as e: error = { "item_index": i, "item": item, "error": str(e), - "type": type(e).__name__ + "type": type(e).__name__, } errors.append(error) - + # Stream error if available if self.streaming: await self.streaming.send_error(error) - + # Check if we should stop if atomic: raise BatchOperationError( f"Atomic operation failed at item {i}: {e}" ) - + if max_errors and len(errors) >= max_errors: - logger.warning( - f"Stopping batch after {max_errors} errors" - ) + logger.warning(f"Stopping batch after {max_errors} errors") break - + # Complete streaming if self.streaming: batch_result = BatchResult( @@ -134,63 +133,63 @@ async def execute_batch( total_errors=len(errors), results=results, errors=errors, - operation=operation + operation=operation, ) - + summary = { "total_processed": batch_result.total_processed, "total_errors": batch_result.total_errors, - "errors": batch_result.errors + "errors": batch_result.errors, } - + return await self.streaming.complete_stream(summary) - + return BatchResult( total_processed=len(results), total_errors=len(errors), results=results, errors=errors, - operation=operation + operation=operation, ) - + except Exception as e: # Ensure stream is properly closed on error if self.streaming: - await self.streaming.complete_stream({ - "error": str(e), - "total_processed": len(results), - "total_errors": len(errors) + 1 - }) + await self.streaming.complete_stream( + { + "error": str(e), + "total_processed": len(results), + "total_errors": len(errors) + 1, + } + ) raise class BatchOperationError(Exception): """Error in batch operation.""" + pass async def execute_parallel( - tasks: List[Callable[[], Any]], - max_parallel: int = 5 -) -> List[Any]: + tasks: list[Callable[[], Any]], max_parallel: int = 5 +) -> list[Any]: """Execute tasks with controlled parallelism. - + Args: tasks: List of async callables to execute max_parallel: Maximum concurrent tasks - + Returns: List of results in same order as tasks """ semaphore = asyncio.Semaphore(max_parallel) - + async def run_with_semaphore(task: Callable[[], Any]) -> Any: async with semaphore: result = task() if asyncio.iscoroutine(result): return await result return result - - return await asyncio.gather(*[ - run_with_semaphore(task) for task in tasks - ]) \ No newline at end of file + + return await asyncio.gather(*[run_with_semaphore(task) for task in tasks]) diff --git a/contextframe/mcp/batch/tools.py b/contextframe/mcp/batch/tools.py index ddc3705..75a7aab 100644 --- a/contextframe/mcp/batch/tools.py +++ b/contextframe/mcp/batch/tools.py @@ -3,38 +3,41 @@ import asyncio import json import logging -from pathlib import Path -from typing import Any, Dict, List, Optional, Union -from uuid import UUID - +from .handler import BatchOperationHandler, execute_parallel +from .transaction import BatchTransaction from contextframe.frame import FrameDataset, FrameRecord from contextframe.mcp.core.transport import TransportAdapter + # DocumentTools functionality is in ToolRegistry for now # ValidationError is in pydantic from contextframe.mcp.schemas import ( - BatchSearchParams, BatchAddParams, BatchUpdateParams, - BatchDeleteParams, BatchEnhanceParams, BatchExtractParams, - BatchExportParams, BatchImportParams + BatchAddParams, + BatchDeleteParams, + BatchEnhanceParams, + BatchExportParams, + BatchExtractParams, + BatchImportParams, + BatchSearchParams, + BatchUpdateParams, ) - -from .handler import BatchOperationHandler, execute_parallel -from .transaction import BatchTransaction - +from pathlib import Path +from typing import Any, Dict, List, Optional, Union +from uuid import UUID logger = logging.getLogger(__name__) class BatchTools: """Batch operation tools for efficient bulk operations.""" - + def __init__( self, dataset: FrameDataset, transport: TransportAdapter, - document_tools: Optional[Any] = None + document_tools: Any | None = None, ): """Initialize batch tools. - + Args: dataset: The dataset to operate on transport: Transport adapter for progress @@ -43,10 +46,10 @@ def __init__( self.dataset = dataset self.transport = transport self.handler = BatchOperationHandler(dataset, transport) - + # Reuse document tools if provided self.doc_tools = document_tools # Should be ToolRegistry instance - + def register_tools(self, tool_registry): """Register batch tools with the tool registry.""" tools = [ @@ -59,26 +62,26 @@ def register_tools(self, tool_registry): ("batch_export", self.batch_export, BatchExportParams), ("batch_import", self.batch_import, BatchImportParams), ] - + for name, handler, schema in tools: tool_registry.register_tool( name=name, handler=handler, schema=schema, - description=schema.__doc__ or f"Batch {name.split('_')[1]} operation" + description=schema.__doc__ or f"Batch {name.split('_')[1]} operation", ) - - async def batch_search(self, params: Dict[str, Any]) -> Dict[str, Any]: + + async def batch_search(self, params: dict[str, Any]) -> dict[str, Any]: """Execute multiple searches in parallel. - + Returns results grouped by query with progress tracking. """ validated = BatchSearchParams(**params) queries = [q.model_dump() for q in validated.queries] max_parallel = validated.max_parallel - + # Create search tasks - async def search_task(query_params: Dict[str, Any]) -> Dict[str, Any]: + async def search_task(query_params: dict[str, Any]) -> dict[str, Any]: try: # Call search through tool registry search_result = await self.doc_tools.call_tool( @@ -87,16 +90,16 @@ async def search_task(query_params: Dict[str, Any]) -> Dict[str, Any]: "query": query_params["query"], "search_type": query_params.get("search_type", "hybrid"), "limit": query_params.get("limit", 10), - "filter": query_params.get("filter") - } + "filter": query_params.get("filter"), + }, ) results = search_result.get("documents", []) - + return { "query": query_params["query"], "success": True, "results": results, - "count": len(results) + "count": len(results), } except Exception as e: return { @@ -104,81 +107,79 @@ async def search_task(query_params: Dict[str, Any]) -> Dict[str, Any]: "success": False, "error": str(e), "results": [], - "count": 0 + "count": 0, } - + # Execute searches with controlled parallelism tasks = [lambda q=q: search_task(q) for q in queries] - + result = await self.handler.execute_batch( operation="batch_search", items=tasks, processor=lambda task: task(), - max_errors=len(queries) # Continue despite errors + max_errors=len(queries), # Continue despite errors ) - + return { "searches_completed": result.total_processed, "searches_failed": result.total_errors, "results": result.results, - "errors": result.errors + "errors": result.errors, } - - async def batch_add(self, params: Dict[str, Any]) -> Dict[str, Any]: + + async def batch_add(self, params: dict[str, Any]) -> dict[str, Any]: """Add multiple documents efficiently. - + Supports atomic transactions and shared settings. """ validated = BatchAddParams(**params) documents = validated.documents shared = validated.shared_settings atomic = validated.atomic - + # Prepare records records = [] for doc_data in documents: # Merge with shared settings content = doc_data.content metadata = {**shared.get("metadata", {}), **doc_data.metadata} - + # Create record - record = FrameRecord( - text_content=content, - metadata=metadata - ) - + record = FrameRecord(text_content=content, metadata=metadata) + # Generate embeddings if requested if shared.get("generate_embeddings", True): try: from contextframe.embed.litellm_provider import LiteLLMProvider + provider = LiteLLMProvider() embedding = await provider.embed_async(content) record.vector = embedding except Exception as e: logger.warning(f"Failed to generate embedding: {e}") - + records.append(record) - + # Execute batch add if atomic: # Use transaction for atomic operation transaction = BatchTransaction(self.dataset) transaction.add_operation("add", records) - + try: await transaction.commit() return { "success": True, "documents_added": len(records), "atomic": True, - "document_ids": [str(r.id) for r in records] + "document_ids": [str(r.id) for r in records], } except Exception as e: return { "success": False, "documents_added": 0, "atomic": True, - "error": str(e) + "error": str(e), } else: # Non-atomic batch add @@ -186,24 +187,24 @@ async def batch_add(self, params: Dict[str, Any]) -> Dict[str, Any]: operation="batch_add", items=records, processor=lambda r: self.dataset.add(r), - max_errors=10 + max_errors=10, ) - + return { "success": result.total_errors == 0, "documents_added": result.total_processed, "documents_failed": result.total_errors, "atomic": False, - "errors": result.errors + "errors": result.errors, } - - async def batch_update(self, params: Dict[str, Any]) -> Dict[str, Any]: + + async def batch_update(self, params: dict[str, Any]) -> dict[str, Any]: """Update multiple documents by filter or IDs. - + Supports metadata updates and content regeneration. """ validated = BatchUpdateParams(**params) - + # Get documents to update if validated.document_ids: # Update specific documents @@ -225,73 +226,65 @@ async def batch_update(self, params: Dict[str, Any]) -> Dict[str, Any]: else: return { "success": False, - "error": "Either document_ids or filter must be provided" + "error": "Either document_ids or filter must be provided", } - + # Prepare update function updates = validated.updates - - async def update_document(doc: FrameRecord) -> Dict[str, Any]: + + async def update_document(doc: FrameRecord) -> dict[str, Any]: try: # Apply metadata updates if updates.get("metadata_updates"): doc.metadata.update(updates["metadata_updates"]) - + # Apply content template if provided if updates.get("content_template"): # Simple template substitution doc.text_content = updates["content_template"].format( content=doc.text_content, title=doc.metadata.get("title", ""), - **doc.metadata + **doc.metadata, ) - + # Regenerate embeddings if requested if updates.get("regenerate_embeddings"): try: from contextframe.embed.litellm_provider import LiteLLMProvider + provider = LiteLLMProvider() doc.vector = await provider.embed_async(doc.text_content) except Exception as e: logger.warning(f"Failed to regenerate embedding: {e}") - + # Update in dataset (delete + add) self.dataset.delete(doc.id) self.dataset.add(doc) - - return { - "id": str(doc.id), - "success": True - } - + + return {"id": str(doc.id), "success": True} + except Exception as e: - return { - "id": str(doc.id), - "success": False, - "error": str(e) - } - + return {"id": str(doc.id), "success": False, "error": str(e)} + # Execute batch update result = await self.handler.execute_batch( - operation="batch_update", - items=docs, - processor=update_document + operation="batch_update", items=docs, processor=update_document ) - + return { "documents_updated": result.total_processed, "documents_failed": result.total_errors, "total_documents": len(docs), - "errors": result.errors + "errors": result.errors, } - - async def batch_delete(self, params: Dict[str, Any]) -> Dict[str, Any]: + + async def batch_delete(self, params: dict[str, Any]) -> dict[str, Any]: """Delete multiple documents with safety checks. - + Supports dry run to preview deletions. """ validated = BatchDeleteParams(**params) - + # Get documents to delete if validated.document_ids: doc_ids = [UUID(doc_id) for doc_id in validated.document_ids] @@ -306,9 +299,9 @@ async def batch_delete(self, params: Dict[str, Any]) -> Dict[str, Any]: else: return { "success": False, - "error": "Either document_ids or filter must be provided" + "error": "Either document_ids or filter must be provided", } - + # Check confirm count if provided if validated.confirm_count is not None: if len(doc_ids) != validated.confirm_count: @@ -316,40 +309,42 @@ async def batch_delete(self, params: Dict[str, Any]) -> Dict[str, Any]: "success": False, "error": f"Expected {validated.confirm_count} documents, found {len(doc_ids)}", "dry_run": validated.dry_run, - "documents_found": len(doc_ids) + "documents_found": len(doc_ids), } - + # Dry run - just return what would be deleted if validated.dry_run: return { "success": True, "dry_run": True, "documents_to_delete": len(doc_ids), - "document_ids": [str(doc_id) for doc_id in doc_ids[:100]], # Limit preview - "message": f"Dry run - {len(doc_ids)} documents would be deleted" + "document_ids": [ + str(doc_id) for doc_id in doc_ids[:100] + ], # Limit preview + "message": f"Dry run - {len(doc_ids)} documents would be deleted", } - + # Execute deletion result = await self.handler.execute_batch( operation="batch_delete", items=doc_ids, - processor=lambda doc_id: self.dataset.delete(doc_id) + processor=lambda doc_id: self.dataset.delete(doc_id), ) - + return { "success": result.total_errors == 0, "documents_deleted": result.total_processed, "documents_failed": result.total_errors, - "errors": result.errors + "errors": result.errors, } - - async def batch_enhance(self, params: Dict[str, Any]) -> Dict[str, Any]: + + async def batch_enhance(self, params: dict[str, Any]) -> dict[str, Any]: """Enhance multiple documents with LLM. - + Uses the enhance module to add context, tags, metadata etc. """ validated = BatchEnhanceParams(**params) - + # Get documents to enhance if validated.document_ids: doc_ids = [UUID(doc_id) for doc_id in validated.document_ids] @@ -365,23 +360,23 @@ async def batch_enhance(self, params: Dict[str, Any]) -> Dict[str, Any]: else: return { "success": False, - "error": "Either document_ids or filter must be provided" + "error": "Either document_ids or filter must be provided", } - + # Check if enhancement tools are available if not hasattr(self.tools, 'enhancement_tools'): # Try to initialize enhancement tools + import os from contextframe.enhance import ContextEnhancer from contextframe.mcp.enhancement_tools import EnhancementTools - import os - + api_key = os.environ.get("OPENAI_API_KEY") if not api_key: return { "success": False, - "error": "No OpenAI API key found. Set OPENAI_API_KEY environment variable." + "error": "No OpenAI API key found. Set OPENAI_API_KEY environment variable.", } - + try: model = os.environ.get("CONTEXTFRAME_ENHANCE_MODEL", "gpt-4") enhancer = ContextEnhancer(model=model, api_key=api_key) @@ -389,24 +384,20 @@ async def batch_enhance(self, params: Dict[str, Any]) -> Dict[str, Any]: except Exception as e: return { "success": False, - "error": f"Failed to initialize enhancement tools: {str(e)}" + "error": f"Failed to initialize enhancement tools: {str(e)}", } - + # Prepare enhancement processor enhancement_tools = self.tools.enhancement_tools - - async def enhance_document(doc_id: UUID) -> Dict[str, Any]: + + async def enhance_document(doc_id: UUID) -> dict[str, Any]: # Get document record = self.dataset.get(doc_id) if not record: raise ValueError(f"Document {doc_id} not found") - - result = { - "document_id": str(doc_id), - "enhancements": {}, - "errors": [] - } - + + result = {"document_id": str(doc_id), "enhancements": {}, "errors": []} + # Apply each enhancement for enhancement in validated.enhancements: try: @@ -414,40 +405,40 @@ async def enhance_document(doc_id: UUID) -> Dict[str, Any]: new_context = enhancement_tools.enhance_context( content=record.text_content, purpose=validated.purpose or "general understanding", - current_context=record.context + current_context=record.context, ) result["enhancements"]["context"] = new_context - + elif enhancement == "tags": new_tags = enhancement_tools.generate_tags( content=record.text_content, tag_types="topics, technologies, concepts", - max_tags=10 + max_tags=10, ) result["enhancements"]["tags"] = new_tags - + elif enhancement == "title": new_title = enhancement_tools.improve_title( content=record.text_content, current_title=record.title, - style="descriptive" + style="descriptive", ) result["enhancements"]["title"] = new_title - + elif enhancement == "metadata": new_metadata = enhancement_tools.extract_metadata( content=record.text_content, - schema=validated.purpose or "Extract key facts and insights", - format="json" + schema=validated.purpose + or "Extract key facts and insights", + format="json", ) result["enhancements"]["custom_metadata"] = new_metadata - + except Exception as e: - result["errors"].append({ - "enhancement": enhancement, - "error": str(e) - }) - + result["errors"].append( + {"enhancement": enhancement, "error": str(e)} + ) + # Update document if we have enhancements if result["enhancements"] and not result["errors"]: updates = {} @@ -460,75 +451,77 @@ async def enhance_document(doc_id: UUID) -> Dict[str, Any]: if "custom_metadata" in result["enhancements"]: # Merge with existing metadata existing_metadata = record.custom_metadata or {} - updates["custom_metadata"] = {**existing_metadata, **result["enhancements"]["custom_metadata"]} - + updates["custom_metadata"] = { + **existing_metadata, + **result["enhancements"]["custom_metadata"], + } + # Update the record self.dataset.update(doc_id, **updates) - + return result - + # Process in batches if batch_size is specified batch_size = validated.batch_size if batch_size and batch_size > 1: # Process documents in groups for efficiency results = [] for i in range(0, len(doc_ids), batch_size): - batch_ids = doc_ids[i:i + batch_size] + batch_ids = doc_ids[i : i + batch_size] batch_result = await self.handler.execute_batch( - operation=f"batch_enhance_{i//batch_size + 1}", + operation=f"batch_enhance_{i // batch_size + 1}", items=batch_ids, - processor=enhance_document + processor=enhance_document, ) results.extend(batch_result.results) - + # Combine results total_processed = sum(1 for r in results if r.get("enhancements")) total_errors = sum(1 for r in results if r.get("errors")) - + return { "success": total_errors == 0, "documents_enhanced": total_processed, "documents_failed": total_errors, "total_documents": len(doc_ids), - "results": results + "results": results, } else: # Process all at once result = await self.handler.execute_batch( - operation="batch_enhance", - items=doc_ids, - processor=enhance_document + operation="batch_enhance", items=doc_ids, processor=enhance_document ) - + return { "success": result.total_errors == 0, "documents_enhanced": result.total_processed, "documents_failed": result.total_errors, "total_documents": len(doc_ids), - "results": result.results + "results": result.results, } - - async def batch_extract(self, params: Dict[str, Any]) -> Dict[str, Any]: + + async def batch_extract(self, params: dict[str, Any]) -> dict[str, Any]: """Extract from multiple sources. - + Uses the extract module to process files and URLs. """ validated = BatchExtractParams(**params) - + # Import extractors - from contextframe.extract import registry as extractor_registry - from contextframe.extract import ExtractionResult - from pathlib import Path - + from contextframe.extract import ( + ExtractionResult, + registry as extractor_registry, + ) + # Prepare extraction processor - async def extract_source(source: Dict[str, Any]) -> Dict[str, Any]: + async def extract_source(source: dict[str, Any]) -> dict[str, Any]: result = { "source": source, "success": False, "document_id": None, - "error": None + "error": None, } - + try: # Determine source path if source.get("type") == "file" or source.get("path"): @@ -541,72 +534,75 @@ async def extract_source(source: Dict[str, Any]) -> Dict[str, Any]: raise NotImplementedError("URL extraction not yet implemented") else: raise ValueError("Source must have either 'path' or 'url'") - + # Find appropriate extractor extractor = extractor_registry.find_extractor(source_path) if not extractor: raise ValueError(f"No extractor found for: {source_path}") - + # Extract content extraction_result: ExtractionResult = extractor.extract(source_path) - + if extraction_result.error: raise ValueError(extraction_result.error) - + # Convert to FrameRecord if adding to dataset if validated.add_to_dataset: record_kwargs = extraction_result.to_frame_record_kwargs() - + # Add shared metadata if validated.shared_metadata: existing_metadata = record_kwargs.get("custom_metadata", {}) # Add x_ prefix to custom metadata fields prefixed_metadata = { - f"x_{k}" if not k.startswith("x_") else k: v + f"x_{k}" if not k.startswith("x_") else k: v for k, v in validated.shared_metadata.items() } - record_kwargs["custom_metadata"] = {**existing_metadata, **prefixed_metadata} - + record_kwargs["custom_metadata"] = { + **existing_metadata, + **prefixed_metadata, + } + # Set collection if specified if validated.collection: record_kwargs["metadata"] = record_kwargs.get("metadata", {}) record_kwargs["metadata"]["collection"] = validated.collection - + # Create record record = FrameRecord(**record_kwargs) self.dataset.add(record) - + result["document_id"] = str(record.id) - + result["success"] = True result["content_length"] = len(extraction_result.content) result["metadata"] = extraction_result.metadata result["format"] = extraction_result.format - + if extraction_result.warnings: result["warnings"] = extraction_result.warnings - + except Exception as e: result["error"] = str(e) - + # Check if we should continue on error if not validated.continue_on_error: raise - + return result - + # Execute batch extraction result = await self.handler.execute_batch( operation="batch_extract", items=validated.sources, processor=extract_source, - max_errors=None if validated.continue_on_error else 1 + max_errors=None if validated.continue_on_error else 1, ) - + # Count successes successful_extractions = sum(1 for r in result.results if r.get("success")) documents_added = sum(1 for r in result.results if r.get("document_id")) - + return { "success": result.total_errors == 0, "sources_processed": len(validated.sources), @@ -614,22 +610,20 @@ async def extract_source(source: Dict[str, Any]) -> Dict[str, Any]: "sources_failed": result.total_errors, "documents_added": documents_added if validated.add_to_dataset else 0, "results": result.results, - "errors": result.errors + "errors": result.errors, } - - async def batch_export(self, params: Dict[str, Any]) -> Dict[str, Any]: + + async def batch_export(self, params: dict[str, Any]) -> dict[str, Any]: """Export documents in bulk. - + Uses the io.exporter module to export documents in various formats. """ validated = BatchExportParams(**params) - + # Import export utilities - from contextframe.io.formats import ExportFormat - from pathlib import Path - import json import csv - + from contextframe.io.formats import ExportFormat + # Get documents to export if validated.document_ids: doc_ids = [UUID(doc_id) for doc_id in validated.document_ids] @@ -641,34 +635,30 @@ async def batch_export(self, params: Dict[str, Any]) -> Dict[str, Any]: scanner = self.dataset.scanner(filter=validated.filter) tbl = scanner.to_table() docs = [ - FrameRecord.from_arrow(tbl.slice(i, 1)) - for i in range(tbl.num_rows) + FrameRecord.from_arrow(tbl.slice(i, 1)) for i in range(tbl.num_rows) ] else: return { "success": False, - "error": "Either document_ids or filter must be provided" + "error": "Either document_ids or filter must be provided", } - + if not docs: - return { - "success": False, - "error": "No documents found to export" - } - + return {"success": False, "error": "No documents found to export"} + # Prepare output path output_path = Path(validated.output_path) output_path.parent.mkdir(parents=True, exist_ok=True) - + # Determine format try: format_enum = ExportFormat(validated.format.lower()) except ValueError: return { "success": False, - "error": f"Unsupported format: {validated.format}" + "error": f"Unsupported format: {validated.format}", } - + # Process documents based on format try: if format_enum == ExportFormat.JSON: @@ -683,40 +673,47 @@ async def batch_export(self, params: Dict[str, Any]) -> Dict[str, Any]: "context": doc.context, "tags": doc.tags, "custom_metadata": doc.custom_metadata, - "created_at": doc.created_at.isoformat() if doc.created_at else None, - "updated_at": doc.updated_at.isoformat() if doc.updated_at else None, + "created_at": doc.created_at.isoformat() + if doc.created_at + else None, + "updated_at": doc.updated_at.isoformat() + if doc.updated_at + else None, } - + if validated.include_embeddings and doc.vector is not None: doc_dict["embeddings"] = doc.vector.tolist() - + export_data.append(doc_dict) - + # Handle chunking for large exports if validated.chunk_size and len(export_data) > validated.chunk_size: # Export in chunks exported_files = [] for i in range(0, len(export_data), validated.chunk_size): - chunk = export_data[i:i + validated.chunk_size] - chunk_path = output_path.parent / f"{output_path.stem}_chunk_{i//validated.chunk_size}{output_path.suffix}" - + chunk = export_data[i : i + validated.chunk_size] + chunk_path = ( + output_path.parent + / f"{output_path.stem}_chunk_{i // validated.chunk_size}{output_path.suffix}" + ) + with open(chunk_path, "w") as f: json.dump(chunk, f, indent=2) - + exported_files.append(str(chunk_path)) - + return { "success": True, "format": validated.format, "documents_exported": len(docs), "files_created": len(exported_files), - "output_files": exported_files + "output_files": exported_files, } else: # Export as single file with open(output_path, "w") as f: json.dump(export_data, f, indent=2) - + elif format_enum == ExportFormat.JSONL: # Export as JSONL (newline-delimited JSON) with open(output_path, "w") as f: @@ -730,28 +727,36 @@ async def batch_export(self, params: Dict[str, Any]) -> Dict[str, Any]: "tags": doc.tags, "custom_metadata": doc.custom_metadata, } - + if validated.include_embeddings and doc.vector is not None: doc_dict["embeddings"] = doc.vector.tolist() - + f.write(json.dumps(doc_dict) + "\n") - + elif format_enum == ExportFormat.CSV: # Export as CSV - fieldnames = ["id", "title", "content", "context", "tags", "created_at", "updated_at"] - + fieldnames = [ + "id", + "title", + "content", + "context", + "tags", + "created_at", + "updated_at", + ] + # Add custom metadata fields all_custom_fields = set() for doc in docs: if doc.custom_metadata: all_custom_fields.update(doc.custom_metadata.keys()) - + fieldnames.extend(sorted(all_custom_fields)) - + with open(output_path, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() - + for doc in docs: row = { "id": str(doc.id), @@ -759,23 +764,27 @@ async def batch_export(self, params: Dict[str, Any]) -> Dict[str, Any]: "content": doc.text_content, "context": doc.context or "", "tags": ", ".join(doc.tags) if doc.tags else "", - "created_at": doc.created_at.isoformat() if doc.created_at else "", - "updated_at": doc.updated_at.isoformat() if doc.updated_at else "", + "created_at": doc.created_at.isoformat() + if doc.created_at + else "", + "updated_at": doc.updated_at.isoformat() + if doc.updated_at + else "", } - + # Add custom metadata if doc.custom_metadata: for key, value in doc.custom_metadata.items(): row[key] = str(value) - + writer.writerow(row) - + elif format_enum == ExportFormat.PARQUET: # Export as Parquet (requires pyarrow) try: import pyarrow as pa import pyarrow.parquet as pq - + # Convert documents to arrow table table_data = { "id": [str(doc.id) for doc in docs], @@ -786,83 +795,81 @@ async def batch_export(self, params: Dict[str, Any]) -> Dict[str, Any]: "created_at": [doc.created_at for doc in docs], "updated_at": [doc.updated_at for doc in docs], } - + if validated.include_embeddings: table_data["embeddings"] = [doc.vector for doc in docs] - + table = pa.table(table_data) pq.write_table(table, output_path) - + except ImportError: return { "success": False, - "error": "Parquet export requires pyarrow. Install with: pip install pyarrow" + "error": "Parquet export requires pyarrow. Install with: pip install pyarrow", } else: return { "success": False, - "error": f"Format {format_enum} not yet implemented for batch export" + "error": f"Format {format_enum} not yet implemented for batch export", } - + return { "success": True, "format": validated.format, "documents_exported": len(docs), "output_path": str(output_path), - "file_size_bytes": output_path.stat().st_size + "file_size_bytes": output_path.stat().st_size, } - + except Exception as e: - return { - "success": False, - "error": f"Export failed: {str(e)}" - } - - async def batch_import(self, params: Dict[str, Any]) -> Dict[str, Any]: + return {"success": False, "error": f"Export failed: {str(e)}"} + + async def batch_import(self, params: dict[str, Any]) -> dict[str, Any]: """Import documents from files. - + Uses the io module to import documents from various formats. """ validated = BatchImportParams(**params) - + # Import utilities - from contextframe.io.formats import ExportFormat - from pathlib import Path - import json import csv - + from contextframe.io.formats import ExportFormat + source_path = Path(validated.source_path) if not source_path.exists(): - return { - "success": False, - "error": f"Source path not found: {source_path}" - } - + return {"success": False, "error": f"Source path not found: {source_path}"} + # Determine format try: format_enum = ExportFormat(validated.format.lower()) except ValueError: return { "success": False, - "error": f"Unsupported format: {validated.format}" + "error": f"Unsupported format: {validated.format}", } - + # Prepare validation settings - max_errors = validated.validation.get("max_errors", 10) if validated.validation else 10 - require_schema_match = validated.validation.get("require_schema_match", False) if validated.validation else False - + max_errors = ( + validated.validation.get("max_errors", 10) if validated.validation else 10 + ) + require_schema_match = ( + validated.validation.get("require_schema_match", False) + if validated.validation + else False + ) + # Track import progress import_results = [] error_count = 0 - - async def import_document(doc_data: Dict[str, Any]) -> Dict[str, Any]: + + async def import_document(doc_data: dict[str, Any]) -> dict[str, Any]: result = { "success": False, "document_id": None, "error": None, - "source_id": doc_data.get("id", "unknown") + "source_id": doc_data.get("id", "unknown"), } - + try: # Apply field mapping if provided if validated.mapping: @@ -871,13 +878,15 @@ async def import_document(doc_data: Dict[str, Any]) -> Dict[str, Any]: if source_field in doc_data: mapped_data[target_field] = doc_data[source_field] doc_data.update(mapped_data) - + # Extract fields according to schema record_kwargs = { - "text_content": doc_data.get("content", doc_data.get("text_content", "")), - "metadata": doc_data.get("metadata", {}) + "text_content": doc_data.get( + "content", doc_data.get("text_content", "") + ), + "metadata": doc_data.get("metadata", {}), } - + # Optional fields if "title" in doc_data: record_kwargs["title"] = doc_data["title"] @@ -887,7 +896,9 @@ async def import_document(doc_data: Dict[str, Any]) -> Dict[str, Any]: tags = doc_data["tags"] if isinstance(tags, str): # Handle comma-separated tags - record_kwargs["tags"] = [t.strip() for t in tags.split(",") if t.strip()] + record_kwargs["tags"] = [ + t.strip() for t in tags.split(",") if t.strip() + ] else: record_kwargs["tags"] = tags if "custom_metadata" in doc_data: @@ -897,52 +908,52 @@ async def import_document(doc_data: Dict[str, Any]) -> Dict[str, Any]: key = f"x_{k}" if not k.startswith("x_") else k custom_metadata[key] = v record_kwargs["custom_metadata"] = custom_metadata - + # Handle embeddings if present if "embeddings" in doc_data and not validated.generate_embeddings: record_kwargs["vector"] = doc_data["embeddings"] - + # Create and add record record = FrameRecord(**record_kwargs) self.dataset.add(record) - + # Generate embeddings if requested if validated.generate_embeddings and not record.vector: # Would need to integrate with embed module here pass - + result["success"] = True result["document_id"] = str(record.id) - + except Exception as e: result["error"] = str(e) if require_schema_match: raise - + return result - + try: documents_to_import = [] - + if format_enum == ExportFormat.JSON: # Import from JSON - with open(source_path, "r") as f: + with open(source_path) as f: data = json.load(f) if isinstance(data, list): documents_to_import = data else: documents_to_import = [data] - + elif format_enum == ExportFormat.JSONL: # Import from JSONL - with open(source_path, "r") as f: + with open(source_path) as f: for line in f: if line.strip(): documents_to_import.append(json.loads(line)) - + elif format_enum == ExportFormat.CSV: # Import from CSV - with open(source_path, "r", newline="") as f: + with open(source_path, newline="") as f: reader = csv.DictReader(f) for row in reader: # Convert CSV row to document format @@ -952,26 +963,34 @@ async def import_document(doc_data: Dict[str, Any]) -> Dict[str, Any]: "context": row.get("context", ""), "tags": row.get("tags", ""), } - + # Extract custom metadata from remaining fields - standard_fields = {"id", "content", "title", "context", "tags", "created_at", "updated_at"} + standard_fields = { + "id", + "content", + "title", + "context", + "tags", + "created_at", + "updated_at", + } custom_metadata = {} for k, v in row.items(): if k not in standard_fields and v: custom_metadata[k] = v - + if custom_metadata: doc["custom_metadata"] = custom_metadata - + documents_to_import.append(doc) - + elif format_enum == ExportFormat.PARQUET: # Import from Parquet try: import pyarrow.parquet as pq - + table = pq.read_table(source_path) - + # Convert to list of dicts for i in range(table.num_rows): doc = {} @@ -980,26 +999,26 @@ async def import_document(doc_data: Dict[str, Any]) -> Dict[str, Any]: if value is not None: doc[field.name] = value documents_to_import.append(doc) - + except ImportError: return { "success": False, - "error": "Parquet import requires pyarrow. Install with: pip install pyarrow" + "error": "Parquet import requires pyarrow. Install with: pip install pyarrow", } else: return { "success": False, - "error": f"Format {format_enum} not yet implemented for batch import" + "error": f"Format {format_enum} not yet implemented for batch import", } - + # Execute batch import result = await self.handler.execute_batch( operation="batch_import", items=documents_to_import, processor=import_document, - max_errors=max_errors + max_errors=max_errors, ) - + return { "success": result.total_errors == 0, "source_path": str(source_path), @@ -1007,11 +1026,10 @@ async def import_document(doc_data: Dict[str, Any]) -> Dict[str, Any]: "documents_found": len(documents_to_import), "documents_imported": result.total_processed, "documents_failed": result.total_errors, - "errors": result.errors[:10] if result.errors else [] # Limit error details + "errors": result.errors[:10] + if result.errors + else [], # Limit error details } - + except Exception as e: - return { - "success": False, - "error": f"Import failed: {str(e)}" - } \ No newline at end of file + return {"success": False, "error": f"Import failed: {str(e)}"} diff --git a/contextframe/mcp/batch/transaction.py b/contextframe/mcp/batch/transaction.py index 6b9a41c..d14f09f 100644 --- a/contextframe/mcp/batch/transaction.py +++ b/contextframe/mcp/batch/transaction.py @@ -1,53 +1,52 @@ """Transaction support for atomic batch operations.""" import logging -from typing import Any, Callable, Dict, List, Tuple +from collections.abc import Callable +from contextframe.frame import FrameDataset, FrameRecord from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple from uuid import UUID -from contextframe.frame import FrameDataset, FrameRecord - - logger = logging.getLogger(__name__) @dataclass class Operation: """Represents a single operation in a transaction.""" - + op_type: str # 'add', 'update', 'delete' data: Any rollback_data: Any = None -@dataclass +@dataclass class BatchTransaction: """Manages atomic batch operations with rollback support. - + Provides transaction semantics for batch operations on FrameDataset. If any operation fails, all completed operations are rolled back. """ - + dataset: FrameDataset - operations: List[Operation] = field(default_factory=list) - completed_ops: List[Tuple[int, Operation]] = field(default_factory=list) - + operations: list[Operation] = field(default_factory=list) + completed_ops: list[tuple[int, Operation]] = field(default_factory=list) + def add_operation(self, op_type: str, data: Any, rollback_data: Any = None): """Add an operation to the transaction. - + Args: op_type: Type of operation ('add', 'update', 'delete') data: Data for the operation rollback_data: Data needed to undo the operation """ self.operations.append(Operation(op_type, data, rollback_data)) - - async def commit(self) -> Dict[str, Any]: + + async def commit(self) -> dict[str, Any]: """Execute all operations atomically. - + Returns: Summary of transaction results - + Raises: Exception: If any operation fails (after rollback) """ @@ -55,36 +54,36 @@ async def commit(self) -> Dict[str, Any]: for i, op in enumerate(self.operations): await self._execute_operation(op) self.completed_ops.append((i, op)) - + return { "success": True, "operations_completed": len(self.completed_ops), - "total_operations": len(self.operations) + "total_operations": len(self.operations), } - + except Exception as e: - logger.error(f"Transaction failed at operation {len(self.completed_ops)}: {e}") + logger.error( + f"Transaction failed at operation {len(self.completed_ops)}: {e}" + ) await self.rollback() raise TransactionError( f"Transaction rolled back due to: {e}", completed=len(self.completed_ops), - total=len(self.operations) + total=len(self.operations), ) - + async def rollback(self): """Undo all completed operations.""" logger.info(f"Rolling back {len(self.completed_ops)} operations") - + # Rollback in reverse order for i, op in reversed(self.completed_ops): try: await self._rollback_operation(op) except Exception as e: - logger.error( - f"Failed to rollback operation {i} ({op.op_type}): {e}" - ) + logger.error(f"Failed to rollback operation {i} ({op.op_type}): {e}") # Continue rollback despite errors - + async def _execute_operation(self, op: Operation): """Execute a single operation.""" if op.op_type == "add": @@ -92,34 +91,34 @@ async def _execute_operation(self, op: Operation): self.dataset.add_many(op.data) else: self.dataset.add(op.data) - + elif op.op_type == "update": # For update, data should be (record_id, updated_record) record_id, updated_record = op.data - + # Store original for rollback if op.rollback_data is None: original = self.dataset.get(record_id) op.rollback_data = original - + # Delete and re-add (Lance pattern) self.dataset.delete(record_id) self.dataset.add(updated_record) - + elif op.op_type == "delete": # For delete, data is the record ID record_id = op.data - + # Store record for rollback if op.rollback_data is None: original = self.dataset.get(record_id) op.rollback_data = original - + self.dataset.delete(record_id) - + else: raise ValueError(f"Unknown operation type: {op.op_type}") - + async def _rollback_operation(self, op: Operation): """Rollback a single operation.""" if op.op_type == "add": @@ -135,7 +134,7 @@ async def _rollback_operation(self, op: Operation): self.dataset.delete(op.data.id) except: pass - + elif op.op_type == "update": # Restore original record if op.rollback_data: @@ -145,7 +144,7 @@ async def _rollback_operation(self, op: Operation): except: pass self.dataset.add(op.rollback_data) - + elif op.op_type == "delete": # Restore deleted record if op.rollback_data: @@ -154,8 +153,8 @@ async def _rollback_operation(self, op: Operation): class TransactionError(Exception): """Error during transaction execution.""" - + def __init__(self, message: str, completed: int, total: int): super().__init__(message) self.completed = completed - self.total = total \ No newline at end of file + self.total = total diff --git a/contextframe/mcp/collections/__init__.py b/contextframe/mcp/collections/__init__.py index 5f84b74..ce1bdb2 100644 --- a/contextframe/mcp/collections/__init__.py +++ b/contextframe/mcp/collections/__init__.py @@ -2,4 +2,4 @@ from .tools import CollectionTools -__all__ = ["CollectionTools"] \ No newline at end of file +__all__ = ["CollectionTools"] diff --git a/contextframe/mcp/collections/templates.py b/contextframe/mcp/collections/templates.py index e225de6..32bb5ae 100644 --- a/contextframe/mcp/collections/templates.py +++ b/contextframe/mcp/collections/templates.py @@ -1,396 +1,393 @@ """Collection template system for pre-configured collection structures.""" -from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field +from typing import Any, Dict, List, Optional class CollectionTemplate(BaseModel): """Defines a collection template structure.""" - + name: str = Field(..., description="Template identifier") display_name: str = Field(..., description="Human-readable template name") description: str = Field(..., description="Template description") - structure: Dict[str, Any] = Field(..., description="Hierarchical structure definition") - default_metadata: Dict[str, Any] = Field(default_factory=dict, description="Default metadata for collections") - naming_pattern: Optional[str] = Field(None, description="Naming pattern for collections") - auto_organize_rules: List[Dict[str, Any]] = Field(default_factory=list, description="Auto-organization rules") - icon: Optional[str] = Field(None, description="Icon identifier for UI") + structure: dict[str, Any] = Field( + ..., description="Hierarchical structure definition" + ) + default_metadata: dict[str, Any] = Field( + default_factory=dict, description="Default metadata for collections" + ) + naming_pattern: str | None = Field( + None, description="Naming pattern for collections" + ) + auto_organize_rules: list[dict[str, Any]] = Field( + default_factory=list, description="Auto-organization rules" + ) + icon: str | None = Field(None, description="Icon identifier for UI") class TemplateRegistry: """Registry for collection templates.""" - + def __init__(self): """Initialize with built-in templates.""" - self.templates: Dict[str, CollectionTemplate] = {} + self.templates: dict[str, CollectionTemplate] = {} self._register_builtin_templates() - + def register_template(self, template: CollectionTemplate) -> None: """Register a new template.""" self.templates[template.name] = template - - def get_template(self, name: str) -> Optional[CollectionTemplate]: + + def get_template(self, name: str) -> CollectionTemplate | None: """Get template by name.""" return self.templates.get(name) - - def list_templates(self) -> List[Dict[str, str]]: + + def list_templates(self) -> list[dict[str, str]]: """List all available templates.""" return [ { "name": template.name, "display_name": template.display_name, "description": template.description, - "icon": template.icon + "icon": template.icon, } for template in self.templates.values() ] - + def _register_builtin_templates(self) -> None: """Register built-in templates.""" - + # Project template - self.register_template(CollectionTemplate( - name="project", - display_name="Software Project", - description="Organize software project documentation, code, and resources", - structure={ - "root": { - "name": "{project_name}", - "description": "Project root collection", - "subcollections": { - "docs": { - "name": "Documentation", - "description": "Project documentation", - "metadata": {"x_category": "documentation"} - }, - "src": { - "name": "Source Code", - "description": "Implementation files", - "metadata": {"x_category": "implementation"} - }, - "tests": { - "name": "Tests", - "description": "Test files and fixtures", - "metadata": {"x_category": "testing"} + self.register_template( + CollectionTemplate( + name="project", + display_name="Software Project", + description="Organize software project documentation, code, and resources", + structure={ + "root": { + "name": "{project_name}", + "description": "Project root collection", + "subcollections": { + "docs": { + "name": "Documentation", + "description": "Project documentation", + "metadata": {"x_category": "documentation"}, + }, + "src": { + "name": "Source Code", + "description": "Implementation files", + "metadata": {"x_category": "implementation"}, + }, + "tests": { + "name": "Tests", + "description": "Test files and fixtures", + "metadata": {"x_category": "testing"}, + }, + "examples": { + "name": "Examples", + "description": "Usage examples and tutorials", + "metadata": {"x_category": "examples"}, + }, }, - "examples": { - "name": "Examples", - "description": "Usage examples and tutorials", - "metadata": {"x_category": "examples"} - } } - } - }, - default_metadata={ - "x_template": "project", - "x_domain": "software" - }, - naming_pattern="{name}-{category}", - auto_organize_rules=[ - { - "pattern": "*.md", - "target": "docs", - "exclude": ["README.md", "CHANGELOG.md"] }, - { - "pattern": "src/**/*", - "target": "src" - }, - { - "pattern": "test*/**/*", - "target": "tests" - }, - { - "pattern": "example*/**/*", - "target": "examples" - } - ], - icon="folder-code" - )) - + default_metadata={"x_template": "project", "x_domain": "software"}, + naming_pattern="{name}-{category}", + auto_organize_rules=[ + { + "pattern": "*.md", + "target": "docs", + "exclude": ["README.md", "CHANGELOG.md"], + }, + {"pattern": "src/**/*", "target": "src"}, + {"pattern": "test*/**/*", "target": "tests"}, + {"pattern": "example*/**/*", "target": "examples"}, + ], + icon="folder-code", + ) + ) + # Research template - self.register_template(CollectionTemplate( - name="research", - display_name="Research Papers", - description="Organize academic papers, citations, and research materials", - structure={ - "root": { - "name": "{research_topic}", - "description": "Research collection", - "subcollections": { - "by_year": { - "name": "Papers by Year", - "description": "Organized by publication year", - "dynamic": True, - "pattern": "{year}" - }, - "by_topic": { - "name": "Papers by Topic", - "description": "Organized by research topic", - "dynamic": True, - "pattern": "{topic}" - }, - "by_author": { - "name": "Papers by Author", - "description": "Organized by primary author", - "dynamic": True, - "pattern": "{author_lastname}" - }, - "citations": { - "name": "Citation Network", - "description": "Citation relationships", - "metadata": {"x_type": "citations"} + self.register_template( + CollectionTemplate( + name="research", + display_name="Research Papers", + description="Organize academic papers, citations, and research materials", + structure={ + "root": { + "name": "{research_topic}", + "description": "Research collection", + "subcollections": { + "by_year": { + "name": "Papers by Year", + "description": "Organized by publication year", + "dynamic": True, + "pattern": "{year}", + }, + "by_topic": { + "name": "Papers by Topic", + "description": "Organized by research topic", + "dynamic": True, + "pattern": "{topic}", + }, + "by_author": { + "name": "Papers by Author", + "description": "Organized by primary author", + "dynamic": True, + "pattern": "{author_lastname}", + }, + "citations": { + "name": "Citation Network", + "description": "Citation relationships", + "metadata": {"x_type": "citations"}, + }, + "notes": { + "name": "Research Notes", + "description": "Personal notes and summaries", + "metadata": {"x_type": "notes"}, + }, }, - "notes": { - "name": "Research Notes", - "description": "Personal notes and summaries", - "metadata": {"x_type": "notes"} - } } - } - }, - default_metadata={ - "x_template": "research", - "x_domain": "academic" - }, - naming_pattern="{year}-{authors}-{title}", - auto_organize_rules=[ - { - "metadata_field": "year", - "target": "by_year/{value}" }, - { - "metadata_field": "primary_topic", - "target": "by_topic/{value}" - }, - { - "metadata_field": "first_author_lastname", - "target": "by_author/{value}" - } - ], - icon="academic-cap" - )) - + default_metadata={"x_template": "research", "x_domain": "academic"}, + naming_pattern="{year}-{authors}-{title}", + auto_organize_rules=[ + {"metadata_field": "year", "target": "by_year/{value}"}, + {"metadata_field": "primary_topic", "target": "by_topic/{value}"}, + { + "metadata_field": "first_author_lastname", + "target": "by_author/{value}", + }, + ], + icon="academic-cap", + ) + ) + # Knowledge base template - self.register_template(CollectionTemplate( - name="knowledge_base", - display_name="Knowledge Base", - description="Hierarchical organization for documentation and guides", - structure={ - "root": { - "name": "{kb_name} Knowledge Base", - "description": "Knowledge base root", - "subcollections": { - "getting_started": { - "name": "Getting Started", - "description": "Introduction and quick start guides", - "metadata": {"x_priority": "high"} - }, - "tutorials": { - "name": "Tutorials", - "description": "Step-by-step tutorials", - "metadata": {"x_difficulty": "intermediate"} - }, - "reference": { - "name": "Reference", - "description": "API and reference documentation", - "metadata": {"x_type": "reference"} - }, - "troubleshooting": { - "name": "Troubleshooting", - "description": "Common issues and solutions", - "metadata": {"x_type": "troubleshooting"} + self.register_template( + CollectionTemplate( + name="knowledge_base", + display_name="Knowledge Base", + description="Hierarchical organization for documentation and guides", + structure={ + "root": { + "name": "{kb_name} Knowledge Base", + "description": "Knowledge base root", + "subcollections": { + "getting_started": { + "name": "Getting Started", + "description": "Introduction and quick start guides", + "metadata": {"x_priority": "high"}, + }, + "tutorials": { + "name": "Tutorials", + "description": "Step-by-step tutorials", + "metadata": {"x_difficulty": "intermediate"}, + }, + "reference": { + "name": "Reference", + "description": "API and reference documentation", + "metadata": {"x_type": "reference"}, + }, + "troubleshooting": { + "name": "Troubleshooting", + "description": "Common issues and solutions", + "metadata": {"x_type": "troubleshooting"}, + }, + "faq": { + "name": "FAQ", + "description": "Frequently asked questions", + "metadata": {"x_type": "faq"}, + }, }, - "faq": { - "name": "FAQ", - "description": "Frequently asked questions", - "metadata": {"x_type": "faq"} - } } - } - }, - default_metadata={ - "x_template": "knowledge_base", - "x_domain": "documentation", - "x_searchable": True - }, - naming_pattern="{category}-{title}", - auto_organize_rules=[ - { - "content_pattern": "getting started|quick start|introduction", - "target": "getting_started" }, - { - "content_pattern": "tutorial|how to|step by step", - "target": "tutorials" + default_metadata={ + "x_template": "knowledge_base", + "x_domain": "documentation", + "x_searchable": True, }, - { - "content_pattern": "api|reference|specification", - "target": "reference" - }, - { - "content_pattern": "error|issue|problem|fix", - "target": "troubleshooting" - }, - { - "content_pattern": "frequently asked|faq|common question", - "target": "faq" - } - ], - icon="book-open" - )) - + naming_pattern="{category}-{title}", + auto_organize_rules=[ + { + "content_pattern": "getting started|quick start|introduction", + "target": "getting_started", + }, + { + "content_pattern": "tutorial|how to|step by step", + "target": "tutorials", + }, + { + "content_pattern": "api|reference|specification", + "target": "reference", + }, + { + "content_pattern": "error|issue|problem|fix", + "target": "troubleshooting", + }, + { + "content_pattern": "frequently asked|faq|common question", + "target": "faq", + }, + ], + icon="book-open", + ) + ) + # Dataset template - self.register_template(CollectionTemplate( - name="dataset", - display_name="Training Dataset", - description="Organize datasets for machine learning and AI training", - structure={ - "root": { - "name": "{dataset_name}", - "description": "Dataset collection", - "subcollections": { - "train": { - "name": "Training Set", - "description": "Training data", - "metadata": {"x_split": "train", "x_ratio": 0.8} - }, - "validation": { - "name": "Validation Set", - "description": "Validation data", - "metadata": {"x_split": "validation", "x_ratio": 0.1} - }, - "test": { - "name": "Test Set", - "description": "Test data", - "metadata": {"x_split": "test", "x_ratio": 0.1} - }, - "raw": { - "name": "Raw Data", - "description": "Unprocessed source data", - "metadata": {"x_processed": False} + self.register_template( + CollectionTemplate( + name="dataset", + display_name="Training Dataset", + description="Organize datasets for machine learning and AI training", + structure={ + "root": { + "name": "{dataset_name}", + "description": "Dataset collection", + "subcollections": { + "train": { + "name": "Training Set", + "description": "Training data", + "metadata": {"x_split": "train", "x_ratio": 0.8}, + }, + "validation": { + "name": "Validation Set", + "description": "Validation data", + "metadata": {"x_split": "validation", "x_ratio": 0.1}, + }, + "test": { + "name": "Test Set", + "description": "Test data", + "metadata": {"x_split": "test", "x_ratio": 0.1}, + }, + "raw": { + "name": "Raw Data", + "description": "Unprocessed source data", + "metadata": {"x_processed": False}, + }, + "metadata": { + "name": "Dataset Metadata", + "description": "Labels, annotations, and dataset info", + "metadata": {"x_type": "metadata"}, + }, }, - "metadata": { - "name": "Dataset Metadata", - "description": "Labels, annotations, and dataset info", - "metadata": {"x_type": "metadata"} - } } - } - }, - default_metadata={ - "x_template": "dataset", - "x_domain": "ml", - "x_version": "1.0" - }, - naming_pattern="{split}-{index:06d}", - auto_organize_rules=[ - { - "random_split": True, - "ratios": { - "train": 0.8, - "validation": 0.1, - "test": 0.1 + }, + default_metadata={ + "x_template": "dataset", + "x_domain": "ml", + "x_version": "1.0", + }, + naming_pattern="{split}-{index:06d}", + auto_organize_rules=[ + { + "random_split": True, + "ratios": {"train": 0.8, "validation": 0.1, "test": 0.1}, } - } - ], - icon="database" - )) - + ], + icon="database", + ) + ) + # Legal template - self.register_template(CollectionTemplate( - name="legal", - display_name="Legal Documents", - description="Organize contracts, agreements, and legal documents", - structure={ - "root": { - "name": "{case_or_matter_name}", - "description": "Legal matter collection", - "subcollections": { - "contracts": { - "name": "Contracts & Agreements", - "description": "Executed contracts and agreements", - "metadata": {"x_type": "contract", "x_confidential": True} - }, - "correspondence": { - "name": "Correspondence", - "description": "Letters, emails, and communications", - "metadata": {"x_type": "correspondence"} - }, - "filings": { - "name": "Court Filings", - "description": "Court documents and filings", - "metadata": {"x_type": "filing"} + self.register_template( + CollectionTemplate( + name="legal", + display_name="Legal Documents", + description="Organize contracts, agreements, and legal documents", + structure={ + "root": { + "name": "{case_or_matter_name}", + "description": "Legal matter collection", + "subcollections": { + "contracts": { + "name": "Contracts & Agreements", + "description": "Executed contracts and agreements", + "metadata": { + "x_type": "contract", + "x_confidential": True, + }, + }, + "correspondence": { + "name": "Correspondence", + "description": "Letters, emails, and communications", + "metadata": {"x_type": "correspondence"}, + }, + "filings": { + "name": "Court Filings", + "description": "Court documents and filings", + "metadata": {"x_type": "filing"}, + }, + "research": { + "name": "Legal Research", + "description": "Case law, statutes, and research", + "metadata": {"x_type": "research"}, + }, + "internal": { + "name": "Internal Documents", + "description": "Internal memos and work product", + "metadata": { + "x_type": "internal", + "x_privileged": True, + }, + }, }, - "research": { - "name": "Legal Research", - "description": "Case law, statutes, and research", - "metadata": {"x_type": "research"} - }, - "internal": { - "name": "Internal Documents", - "description": "Internal memos and work product", - "metadata": {"x_type": "internal", "x_privileged": True} - } } - } - }, - default_metadata={ - "x_template": "legal", - "x_domain": "legal", - "x_access_control": "restricted" - }, - naming_pattern="{date}-{type}-{party}", - auto_organize_rules=[ - { - "metadata_field": "document_type", - "mapping": { - "contract": "contracts", - "agreement": "contracts", - "letter": "correspondence", - "email": "correspondence", - "motion": "filings", - "brief": "filings", - "memo": "internal" + }, + default_metadata={ + "x_template": "legal", + "x_domain": "legal", + "x_access_control": "restricted", + }, + naming_pattern="{date}-{type}-{party}", + auto_organize_rules=[ + { + "metadata_field": "document_type", + "mapping": { + "contract": "contracts", + "agreement": "contracts", + "letter": "correspondence", + "email": "correspondence", + "motion": "filings", + "brief": "filings", + "memo": "internal", + }, } - } - ], - icon="scale" - )) + ], + icon="scale", + ) + ) def apply_template( - template: CollectionTemplate, - params: Dict[str, Any], - parent_id: Optional[str] = None -) -> Dict[str, Any]: + template: CollectionTemplate, params: dict[str, Any], parent_id: str | None = None +) -> dict[str, Any]: """Apply a template to create collection structure. - + Args: template: The template to apply params: Parameters for template variables parent_id: Parent collection ID if creating under existing collection - + Returns: Dictionary describing the collection structure to create """ # Replace template variables in structure structure = _replace_template_vars(template.structure, params) - + # Add template metadata if parent_id: structure["root"]["parent_id"] = parent_id - + structure["root"]["metadata"] = { **template.default_metadata, **structure["root"].get("metadata", {}), - "x_created_from_template": template.name + "x_created_from_template": template.name, } - + return structure -def _replace_template_vars(obj: Any, params: Dict[str, str]) -> Any: +def _replace_template_vars(obj: Any, params: dict[str, str]) -> Any: """Recursively replace template variables in structure.""" if isinstance(obj, str): for key, value in params.items(): @@ -401,4 +398,4 @@ def _replace_template_vars(obj: Any, params: Dict[str, str]) -> Any: elif isinstance(obj, list): return [_replace_template_vars(item, params) for item in obj] else: - return obj \ No newline at end of file + return obj diff --git a/contextframe/mcp/collections/tools.py b/contextframe/mcp/collections/tools.py index 912f6db..bf97023 100644 --- a/contextframe/mcp/collections/tools.py +++ b/contextframe/mcp/collections/tools.py @@ -28,7 +28,7 @@ class CollectionTools: """Collection management tools for MCP server. - + Provides comprehensive collection management including: - Collection CRUD operations - Document membership management @@ -36,15 +36,15 @@ class CollectionTools: - Collection templates - Statistics and analytics """ - + def __init__( self, dataset: FrameDataset, transport: TransportAdapter, - template_registry: Any | None = None + template_registry: Any | None = None, ): """Initialize collection tools. - + Args: dataset: The dataset to operate on transport: Transport adapter for progress @@ -53,7 +53,7 @@ def __init__( self.dataset = dataset self.transport = transport self.template_registry = template_registry - + def register_tools(self, tool_registry): """Register collection tools with the tool registry.""" tools = [ @@ -62,51 +62,53 @@ def register_tools(self, tool_registry): ("delete_collection", self.delete_collection, DeleteCollectionParams), ("list_collections", self.list_collections, ListCollectionsParams), ("move_documents", self.move_documents, MoveDocumentsParams), - ("get_collection_stats", self.get_collection_stats, GetCollectionStatsParams), + ( + "get_collection_stats", + self.get_collection_stats, + GetCollectionStatsParams, + ), ] - + for name, handler, schema in tools: tool_registry.register_tool( name=name, handler=handler, schema=schema, - description=schema.__doc__ or f"Collection {name.split('_')[1]} operation" + description=schema.__doc__ + or f"Collection {name.split('_')[1]} operation", ) - + async def create_collection(self, params: dict[str, Any]) -> dict[str, Any]: """Create a new collection with header and initial configuration.""" validated = CreateCollectionParams(**params) - + # Create collection header document - header_metadata = { - "record_type": "collection_header", - "title": validated.name - } - + header_metadata = {"record_type": "collection_header", "title": validated.name} + # Store parent in collection_id for Lance-native filtering if validated.parent_collection: header_metadata["collection_id"] = validated.parent_collection header_metadata["collection_id_type"] = "uuid" - + if validated.description: header_metadata["context"] = validated.description - + # Create header record header_record = FrameRecord( text_content=f"Collection: {validated.name}\n\n{validated.description or 'No description provided.'}", - metadata=header_metadata + metadata=header_metadata, ) - + # Set collection metadata using helper coll_meta = { "created_at": datetime.date.today().isoformat(), "template": validated.template, "member_count": 0, "total_size": 0, - "shared_metadata": validated.metadata + "shared_metadata": validated.metadata, } self._set_collection_metadata(header_record, coll_meta) - + # Apply template if specified if validated.template and self.template_registry: template = self.template_registry.get_template(validated.template) @@ -117,7 +119,7 @@ async def create_collection(self, params: dict[str, Any]) -> dict[str, Any]: coll_meta["shared_metadata"][key] = value # Update the record self._set_collection_metadata(header_record, coll_meta) - + # Add relationships if parent collection exists parent_header = None if validated.parent_collection: @@ -130,17 +132,17 @@ async def create_collection(self, params: dict[str, Any]) -> dict[str, Any]: create_relationship( validated.parent_collection, rel_type="parent", - title=f"Parent: {parent_header.metadata.get('title', 'Unknown')}" - ) + title=f"Parent: {parent_header.metadata.get('title', 'Unknown')}", + ), ) except Exception as e: logger.warning(f"Parent collection not found: {e}") - + # Save header to dataset self.dataset.add(header_record) # Use the UUID from metadata collection_id = header_record.metadata.get("uuid") - + # Update parent to add child relationship if validated.parent_collection and parent_header: try: @@ -149,58 +151,62 @@ async def create_collection(self, params: dict[str, Any]) -> dict[str, Any]: create_relationship( collection_id, rel_type="child", - title=f"Subcollection: {validated.name}" - ) + title=f"Subcollection: {validated.name}", + ), ) self.dataset.update_record(parent_header) except Exception as e: logger.warning(f"Could not update parent: {e}") - + # Add initial members if specified added_members = 0 if validated.initial_members: for doc_id in validated.initial_members: try: - self._add_document_to_collection(doc_id, collection_id, header_record.metadata.get("uuid")) + self._add_document_to_collection( + doc_id, collection_id, header_record.metadata.get("uuid") + ) added_members += 1 except Exception as e: - logger.warning(f"Failed to add document {doc_id} to collection: {e}") - + logger.warning( + f"Failed to add document {doc_id} to collection: {e}" + ) + # Update member count if we added any if added_members > 0: coll_meta["member_count"] = added_members self._set_collection_metadata(header_record, coll_meta) self.dataset.update_record(header_record) - + return { "collection_id": collection_id, "header_id": collection_id, "name": validated.name, "created_at": coll_meta["created_at"], "member_count": added_members, - "metadata": validated.metadata + "metadata": validated.metadata, } - + async def update_collection(self, params: dict[str, Any]) -> dict[str, Any]: """Update collection properties and membership.""" validated = UpdateCollectionParams(**params) - + # Get collection header header = self._get_collection_header(validated.collection_id) if not header: raise ValueError(f"Collection not found: {validated.collection_id}") - + # Update metadata updated = False - + if validated.name: header.metadata["title"] = validated.name updated = True - + if validated.description is not None: header.metadata["context"] = validated.description updated = True - + if validated.metadata_updates: # Get current collection metadata coll_meta = self._get_collection_metadata(header) @@ -209,27 +215,31 @@ async def update_collection(self, params: dict[str, Any]) -> dict[str, Any]: # Save back self._set_collection_metadata(header, coll_meta) updated = True - + # Remove members removed_count = 0 if validated.remove_members: for doc_id in validated.remove_members: try: - self._remove_document_from_collection(doc_id, validated.collection_id) + self._remove_document_from_collection( + doc_id, validated.collection_id + ) removed_count += 1 except Exception as e: logger.warning(f"Failed to remove document {doc_id}: {e}") - + # Add members added_count = 0 if validated.add_members: for doc_id in validated.add_members: try: - self._add_document_to_collection(doc_id, validated.collection_id, header.metadata.get("uuid")) + self._add_document_to_collection( + doc_id, validated.collection_id, header.metadata.get("uuid") + ) added_count += 1 except Exception as e: logger.warning(f"Failed to add document {doc_id}: {e}") - + # Update member count coll_meta = self._get_collection_metadata(header) current_count = coll_meta["member_count"] @@ -237,48 +247,50 @@ async def update_collection(self, params: dict[str, Any]) -> dict[str, Any]: coll_meta["member_count"] = new_count coll_meta["updated_at"] = datetime.date.today().isoformat() self._set_collection_metadata(header, coll_meta) - + # Save updates if updated or removed_count > 0 or added_count > 0: self.dataset.update_record(header) - + return { "collection_id": validated.collection_id, "updated": updated, "members_added": added_count, "members_removed": removed_count, - "total_members": new_count + "total_members": new_count, } - + async def delete_collection(self, params: dict[str, Any]) -> dict[str, Any]: """Delete a collection and optionally its members.""" validated = DeleteCollectionParams(**params) - + # Get collection header header = self._get_collection_header(validated.collection_id) if not header: raise ValueError(f"Collection not found: {validated.collection_id}") - + deleted_collections = [] deleted_members = [] - + # Handle recursive deletion if validated.recursive: # Find all subcollections subcollections = self._find_subcollections(validated.collection_id) for subcoll in subcollections: # Recursively delete each subcollection - sub_result = await self.delete_collection({ - "collection_id": subcoll["collection_id"], - "delete_members": validated.delete_members, - "recursive": True - }) + sub_result = await self.delete_collection( + { + "collection_id": subcoll["collection_id"], + "delete_members": validated.delete_members, + "recursive": True, + } + ) deleted_collections.extend(sub_result["deleted_collections"]) deleted_members.extend(sub_result["deleted_members"]) - + # Get all member documents members = self._get_collection_members(validated.collection_id) - + # Delete members if requested if validated.delete_members: for member in members: @@ -291,123 +303,158 @@ async def delete_collection(self, params: dict[str, Any]) -> dict[str, Any]: # Just remove collection relationships for member in members: try: - self._remove_document_from_collection(member["uuid"], validated.collection_id) + self._remove_document_from_collection( + member["uuid"], validated.collection_id + ) except Exception as e: logger.warning(f"Failed to remove collection relationship: {e}") - + # Delete the collection header self.dataset.delete_record(header.metadata.get("uuid")) deleted_collections.append(validated.collection_id) - + return { "deleted_collections": deleted_collections, "deleted_members": deleted_members, "total_collections_deleted": len(deleted_collections), - "total_members_deleted": len(deleted_members) + "total_members_deleted": len(deleted_members), } - + async def list_collections(self, params: dict[str, Any]) -> dict[str, Any]: """List collections with filtering and statistics.""" validated = ListCollectionsParams(**params) - + # Build filter for collection headers filters = ["record_type = 'collection_header'"] - + if validated.parent_id: # Use collection_id field for Lance-native parent filtering filters.append(f"collection_id = '{validated.parent_id}'") - + filter_str = " AND ".join(filters) - + # Query collections # Exclude raw_data columns to avoid issues - columns = [col for col in self.dataset._dataset.schema.names if col not in ["raw_data", "raw_data_type"]] + columns = [ + col + for col in self.dataset._dataset.schema.names + if col not in ["raw_data", "raw_data_type"] + ] scanner = self.dataset.scanner(filter=filter_str, columns=columns) collections = [] - + for batch in scanner.to_batches(): for i in range(len(batch)): row_table = batch.slice(i, 1) record = self._safe_load_record(row_table) - + # Build collection info using helper coll_meta = self._get_collection_metadata(record) member_count = coll_meta["member_count"] - + # Skip empty collections if requested if not validated.include_empty and member_count == 0: continue - + coll_info = CollectionInfo( collection_id=str(record.metadata.get("uuid")), header_id=str(record.metadata.get("uuid")), name=record.metadata.get("title", "Unnamed"), description=record.metadata.get("context"), - parent_id=record.metadata.get("collection_id") if record.metadata.get("collection_id_type") == "uuid" else None, - created_at=coll_meta["created_at"] or record.metadata.get("created_at", ""), - updated_at=coll_meta["updated_at"] or record.metadata.get("updated_at", ""), + parent_id=record.metadata.get("collection_id") + if record.metadata.get("collection_id_type") == "uuid" + else None, + created_at=coll_meta["created_at"] + or record.metadata.get("created_at", ""), + updated_at=coll_meta["updated_at"] + or record.metadata.get("updated_at", ""), metadata=coll_meta["shared_metadata"], member_count=member_count, - total_size_bytes=coll_meta["total_size"] if coll_meta["total_size"] > 0 else None + total_size_bytes=coll_meta["total_size"] + if coll_meta["total_size"] > 0 + else None, ) - + # Add statistics if requested if validated.include_stats: - stats = await self._calculate_collection_stats(str(record.metadata.get("uuid")), include_subcollections=False) - collections.append({ - "collection": coll_info.model_dump(), - "statistics": stats - }) + stats = await self._calculate_collection_stats( + str(record.metadata.get("uuid")), include_subcollections=False + ) + collections.append( + {"collection": coll_info.model_dump(), "statistics": stats} + ) else: collections.append(coll_info.model_dump()) - + # Sort collections if validated.sort_by == "name": - collections.sort(key=lambda x: x.get("name", x.get("collection", {}).get("name", "")) if isinstance(x, dict) else x.name) + collections.sort( + key=lambda x: x.get("name", x.get("collection", {}).get("name", "")) + if isinstance(x, dict) + else x.name + ) elif validated.sort_by == "created_at": - collections.sort(key=lambda x: x.get("created_at", x.get("collection", {}).get("created_at", "")) if isinstance(x, dict) else x.created_at, reverse=True) + collections.sort( + key=lambda x: x.get( + "created_at", x.get("collection", {}).get("created_at", "") + ) + if isinstance(x, dict) + else x.created_at, + reverse=True, + ) elif validated.sort_by == "member_count": - collections.sort(key=lambda x: x.get("member_count", x.get("collection", {}).get("member_count", 0)) if isinstance(x, dict) else x.member_count, reverse=True) - + collections.sort( + key=lambda x: x.get( + "member_count", x.get("collection", {}).get("member_count", 0) + ) + if isinstance(x, dict) + else x.member_count, + reverse=True, + ) + # Apply pagination total_count = len(collections) - collections = collections[validated.offset:validated.offset + validated.limit] - + collections = collections[validated.offset : validated.offset + validated.limit] + return { "collections": collections, "total_count": total_count, "offset": validated.offset, - "limit": validated.limit + "limit": validated.limit, } - + async def move_documents(self, params: dict[str, Any]) -> dict[str, Any]: """Move documents between collections.""" validated = MoveDocumentsParams(**params) - + moved_count = 0 failed_moves = [] - + # Validate target collection exists if specified target_header = None if validated.target_collection: target_header = self._get_collection_header(validated.target_collection) if not target_header: - raise ValueError(f"Target collection not found: {validated.target_collection}") - + raise ValueError( + f"Target collection not found: {validated.target_collection}" + ) + for doc_id in validated.document_ids: try: # Remove from source collection if specified if validated.source_collection: - self._remove_document_from_collection(doc_id, validated.source_collection) - + self._remove_document_from_collection( + doc_id, validated.source_collection + ) + # Add to target collection if specified if validated.target_collection: self._add_document_to_collection( - doc_id, + doc_id, validated.target_collection, - target_header.metadata.get("uuid") + target_header.metadata.get("uuid"), ) - + # Apply shared metadata if requested if validated.update_metadata and target_header: doc = self.dataset.get_by_uuid(doc_id) @@ -416,65 +463,64 @@ async def move_documents(self, params: dict[str, Any]) -> dict[str, Any]: coll_meta = self._get_collection_metadata(target_header) doc.metadata.update(coll_meta["shared_metadata"]) self.dataset.update_record(doc) - + moved_count += 1 - + except Exception as e: logger.error(f"Failed to move document {doc_id}: {e}") - failed_moves.append({ - "document_id": doc_id, - "error": str(e) - }) - + failed_moves.append({"document_id": doc_id, "error": str(e)}) + return { "moved_count": moved_count, "failed_count": len(failed_moves), "failed_moves": failed_moves, "source_collection": validated.source_collection, - "target_collection": validated.target_collection + "target_collection": validated.target_collection, } - + async def get_collection_stats(self, params: dict[str, Any]) -> dict[str, Any]: """Get detailed statistics for a collection.""" validated = GetCollectionStatsParams(**params) - + # Get collection header header = self._get_collection_header(validated.collection_id) if not header: raise ValueError(f"Collection not found: {validated.collection_id}") - + # Calculate statistics stats = await self._calculate_collection_stats( validated.collection_id, - include_subcollections=validated.include_subcollections + include_subcollections=validated.include_subcollections, ) - + # Build response result = { "collection_id": validated.collection_id, "name": header.metadata.get("title", "Unnamed"), - "statistics": stats + "statistics": stats, } - + # Add subcollection info if requested if validated.include_subcollections: subcollections = self._find_subcollections(validated.collection_id) result["subcollections"] = subcollections - + # Add member details if requested if validated.include_member_details: - members = self._get_collection_members(validated.collection_id, include_content=False) + members = self._get_collection_members( + validated.collection_id, include_content=False + ) result["members"] = members[:100] # Limit to first 100 - + return result - + # Helper methods - + def _safe_load_record(self, row_table) -> FrameRecord: """Safely load a FrameRecord from Arrow.""" # Since we're excluding raw_data columns from scans, we can just load directly return FrameRecord.from_arrow(row_table) - + def _get_collection_metadata(self, record: FrameRecord) -> dict[str, Any]: """Extract collection metadata from custom_metadata.""" custom = record.metadata.get("custom_metadata", {}) @@ -485,27 +531,30 @@ def _get_collection_metadata(self, record: FrameRecord) -> dict[str, Any]: "total_size": int(custom.get("collection_total_size", "0")), "template": custom.get("collection_template", ""), "shared_metadata": { - k[7:]: v for k, v in custom.items() - if k.startswith("shared_") - } + k[7:]: v for k, v in custom.items() if k.startswith("shared_") + }, } - - def _set_collection_metadata(self, record: FrameRecord, coll_meta: dict[str, Any]) -> None: + + def _set_collection_metadata( + self, record: FrameRecord, coll_meta: dict[str, Any] + ) -> None: """Store collection metadata in custom_metadata.""" if "custom_metadata" not in record.metadata: record.metadata["custom_metadata"] = {} - + custom = record.metadata["custom_metadata"] custom["collection_created_at"] = coll_meta.get("created_at", "") - custom["collection_updated_at"] = coll_meta.get("updated_at", datetime.date.today().isoformat()) + custom["collection_updated_at"] = coll_meta.get( + "updated_at", datetime.date.today().isoformat() + ) custom["collection_member_count"] = str(coll_meta.get("member_count", 0)) custom["collection_total_size"] = str(coll_meta.get("total_size", 0)) custom["collection_template"] = str(coll_meta.get("template") or "") - + # Store shared metadata for key, value in coll_meta.get("shared_metadata", {}).items(): custom[f"shared_{key}"] = str(value) - + def _get_collection_header(self, collection_id: str) -> FrameRecord | None: """Get collection header by ID.""" try: @@ -513,175 +562,191 @@ def _get_collection_header(self, collection_id: str) -> FrameRecord | None: record = self.dataset.get_by_uuid(collection_id) if record and record.metadata.get("record_type") == "collection_header": return record - + # Search by collection_id using uuid field - filter_str = f"record_type = 'collection_header' AND uuid = '{collection_id}'" - columns = [col for col in self.dataset._dataset.schema.names if col not in ["raw_data", "raw_data_type"]] + filter_str = ( + f"record_type = 'collection_header' AND uuid = '{collection_id}'" + ) + columns = [ + col + for col in self.dataset._dataset.schema.names + if col not in ["raw_data", "raw_data_type"] + ] scanner = self.dataset.scanner(filter=filter_str, columns=columns) - + for batch in scanner.to_batches(): if len(batch) > 0: return self._safe_load_record(batch.slice(0, 1)) - + return None - + except Exception as e: logger.error(f"Error getting collection header: {e}") return None - + def _add_document_to_collection( - self, - doc_id: str, - collection_id: str, - header_uuid: str + self, doc_id: str, collection_id: str, header_uuid: str ) -> None: """Add document to collection by updating relationships.""" doc = self.dataset.get_by_uuid(doc_id) if not doc: raise ValueError(f"Document not found: {doc_id}") - + # Add reference relationship from document to collection header add_relationship_to_metadata( doc.metadata, create_relationship( header_uuid, rel_type="reference", - title=f"Member of collection {collection_id}" - ) + title=f"Member of collection {collection_id}", + ), ) - + # Update document self.dataset.update_record(doc) - + def _remove_document_from_collection(self, doc_id: str, collection_id: str) -> None: """Remove document from collection by removing relationships.""" doc = self.dataset.get_by_uuid(doc_id) if not doc: raise ValueError(f"Document not found: {doc_id}") - + # Remove reference relationship relationships = doc.metadata.get("relationships", []) doc.metadata["relationships"] = [ - rel for rel in relationships - if not (rel.get("type") == "reference" and collection_id in str(rel.get("id", ""))) + rel + for rel in relationships + if not ( + rel.get("type") == "reference" + and collection_id in str(rel.get("id", "")) + ) ] - + # Update document self.dataset.update_record(doc) - + def _get_collection_members( - self, - collection_id: str, - include_content: bool = True + self, collection_id: str, include_content: bool = True ) -> list[dict[str, Any]]: """Get all members of a collection.""" members = [] - + # Find documents with member_of relationship to this collection # Exclude raw_data to avoid loading large binary data - columns = [col for col in self.dataset._dataset.schema.names if col not in ["raw_data", "raw_data_type"]] + columns = [ + col + for col in self.dataset._dataset.schema.names + if col not in ["raw_data", "raw_data_type"] + ] scanner = self.dataset.scanner(columns=columns) - + for batch in scanner.to_batches(): for i in range(len(batch)): row_table = batch.slice(i, 1) record = self._safe_load_record(row_table) - + # Check relationships relationships = record.metadata.get("relationships", []) for rel in relationships: - if (rel.get("type") == "reference" and - collection_id in str(rel.get("id", ""))): - + if rel.get("type") == "reference" and collection_id in str( + rel.get("id", "") + ): member_info = { "uuid": str(record.metadata.get("uuid")), "title": record.metadata.get("title", ""), - "metadata": record.metadata + "metadata": record.metadata, } - + if include_content: member_info["content"] = record.text_content - + members.append(member_info) break - + return members - + def _find_subcollections(self, parent_id: str) -> list[dict[str, Any]]: """Find all subcollections of a parent collection.""" # Use Lance-native filtering on collection_id field - filter_str = f"record_type = 'collection_header' AND collection_id = '{parent_id}'" - columns = [col for col in self.dataset._dataset.schema.names if col not in ["raw_data", "raw_data_type"]] + filter_str = ( + f"record_type = 'collection_header' AND collection_id = '{parent_id}'" + ) + columns = [ + col + for col in self.dataset._dataset.schema.names + if col not in ["raw_data", "raw_data_type"] + ] scanner = self.dataset.scanner(filter=filter_str, columns=columns) - + subcollections = [] for batch in scanner.to_batches(): for i in range(len(batch)): row_table = batch.slice(i, 1) record = self._safe_load_record(row_table) - - subcollections.append({ - "collection_id": str(record.metadata.get("uuid")), - "name": record.metadata.get("title", "Unnamed"), - "member_count": self._get_collection_metadata(record)["member_count"] - }) - + + subcollections.append( + { + "collection_id": str(record.metadata.get("uuid")), + "name": record.metadata.get("title", "Unnamed"), + "member_count": self._get_collection_metadata(record)[ + "member_count" + ], + } + ) + return subcollections - + async def _calculate_collection_stats( - self, - collection_id: str, - include_subcollections: bool = True + self, collection_id: str, include_subcollections: bool = True ) -> dict[str, Any]: """Calculate detailed statistics for a collection.""" members = self._get_collection_members(collection_id) - + # Basic counts direct_members = len(members) subcollection_members = 0 - + # Calculate subcollection members if requested if include_subcollections: subcollections = self._find_subcollections(collection_id) for subcoll in subcollections: sub_stats = await self._calculate_collection_stats( - subcoll["collection_id"], - include_subcollections=True + subcoll["collection_id"], include_subcollections=True ) subcollection_members += sub_stats["total_members"] - + # Calculate sizes and metadata total_size = 0 unique_tags = set() dates = [] member_types = {} - + for member in members: # Size (approximate) content_size = len(member.get("content", "").encode('utf-8')) total_size += content_size - + # Tags tags = member["metadata"].get("tags", []) unique_tags.update(tags) - + # Dates created_at = member["metadata"].get("created_at") if created_at: dates.append(created_at) - + # Types record_type = member["metadata"].get("record_type", "document") member_types[record_type] = member_types.get(record_type, 0) + 1 - + # Calculate averages and ranges avg_size = total_size / direct_members if direct_members > 0 else 0 - + date_range = { "earliest": min(dates) if dates else None, - "latest": max(dates) if dates else None + "latest": max(dates) if dates else None, } - + return { "total_members": direct_members + subcollection_members, "direct_members": direct_members, @@ -690,5 +755,5 @@ async def _calculate_collection_stats( "avg_document_size": avg_size, "unique_tags": sorted(list(unique_tags)), "date_range": date_range, - "member_types": member_types - } \ No newline at end of file + "member_types": member_types, + } diff --git a/contextframe/mcp/core/__init__.py b/contextframe/mcp/core/__init__.py index c1ef43f..2151416 100644 --- a/contextframe/mcp/core/__init__.py +++ b/contextframe/mcp/core/__init__.py @@ -1,10 +1,6 @@ """Core abstractions for transport-agnostic MCP implementation.""" -from contextframe.mcp.core.transport import TransportAdapter, Progress from contextframe.mcp.core.streaming import StreamingAdapter +from contextframe.mcp.core.transport import Progress, TransportAdapter -__all__ = [ - "TransportAdapter", - "Progress", - "StreamingAdapter" -] \ No newline at end of file +__all__ = ["TransportAdapter", "Progress", "StreamingAdapter"] diff --git a/contextframe/mcp/core/streaming.py b/contextframe/mcp/core/streaming.py index a13af22..764fd02 100644 --- a/contextframe/mcp/core/streaming.py +++ b/contextframe/mcp/core/streaming.py @@ -1,78 +1,82 @@ """Streaming abstraction for transport-agnostic responses.""" from abc import ABC, abstractmethod -from typing import Any, AsyncIterator, Dict, List, Optional +from collections.abc import AsyncIterator from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional @dataclass class StreamingResponse: """Container for streaming response data.""" - + operation: str - total_items: Optional[int] = None - items: List[Dict[str, Any]] = field(default_factory=list) - metadata: Dict[str, Any] = field(default_factory=dict) - error: Optional[str] = None + total_items: int | None = None + items: list[dict[str, Any]] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + error: str | None = None completed: bool = False class StreamingAdapter(ABC): """Adapter for handling streaming responses across transports.""" - + @abstractmethod - async def start_stream(self, operation: str, total_items: Optional[int] = None) -> None: + async def start_stream( + self, operation: str, total_items: int | None = None + ) -> None: """Start a streaming operation.""" pass - + @abstractmethod - async def send_item(self, item: Dict[str, Any]) -> None: + async def send_item(self, item: dict[str, Any]) -> None: """Send a single item in the stream.""" pass - - @abstractmethod + + @abstractmethod async def send_error(self, error: str) -> None: """Send an error in the stream.""" pass - + @abstractmethod - async def complete_stream(self, metadata: Optional[Dict[str, Any]] = None) -> Any: + async def complete_stream(self, metadata: dict[str, Any] | None = None) -> Any: """Complete the streaming operation and return final result.""" pass class BufferedStreamingAdapter(StreamingAdapter): """Streaming adapter that buffers all items for non-streaming transports.""" - + def __init__(self): - self._response: Optional[StreamingResponse] = None - - async def start_stream(self, operation: str, total_items: Optional[int] = None) -> None: + self._response: StreamingResponse | None = None + + async def start_stream( + self, operation: str, total_items: int | None = None + ) -> None: """Start buffering items.""" - self._response = StreamingResponse( - operation=operation, - total_items=total_items - ) - - async def send_item(self, item: Dict[str, Any]) -> None: + self._response = StreamingResponse(operation=operation, total_items=total_items) + + async def send_item(self, item: dict[str, Any]) -> None: """Add item to buffer.""" if self._response: self._response.items.append(item) - + async def send_error(self, error: str) -> None: """Record error.""" if self._response: self._response.error = error - - async def complete_stream(self, metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + + async def complete_stream( + self, metadata: dict[str, Any] | None = None + ) -> dict[str, Any]: """Return buffered response.""" if not self._response: raise RuntimeError("No streaming operation in progress") - + self._response.completed = True if metadata: self._response.metadata.update(metadata) - + # Convert to dict for JSON serialization return { "operation": self._response.operation, @@ -80,51 +84,52 @@ async def complete_stream(self, metadata: Optional[Dict[str, Any]] = None) -> Di "items": self._response.items, "metadata": self._response.metadata, "error": self._response.error, - "completed": self._response.completed + "completed": self._response.completed, } class SSEStreamingAdapter(StreamingAdapter): """Streaming adapter for Server-Sent Events (HTTP transport).""" - + def __init__(self, send_sse_func): self.send_sse = send_sse_func - self._operation: Optional[str] = None + self._operation: str | None = None self._item_count = 0 - - async def start_stream(self, operation: str, total_items: Optional[int] = None) -> None: + + async def start_stream( + self, operation: str, total_items: int | None = None + ) -> None: """Send stream start event.""" self._operation = operation self._item_count = 0 - await self.send_sse({ - "event": "stream_start", - "operation": operation, - "total_items": total_items - }) - - async def send_item(self, item: Dict[str, Any]) -> None: + await self.send_sse( + { + "event": "stream_start", + "operation": operation, + "total_items": total_items, + } + ) + + async def send_item(self, item: dict[str, Any]) -> None: """Send item via SSE.""" self._item_count += 1 - await self.send_sse({ - "event": "stream_item", - "item": item, - "index": self._item_count - }) - + await self.send_sse( + {"event": "stream_item", "item": item, "index": self._item_count} + ) + async def send_error(self, error: str) -> None: """Send error via SSE.""" - await self.send_sse({ - "event": "stream_error", - "error": error - }) - - async def complete_stream(self, metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + await self.send_sse({"event": "stream_error", "error": error}) + + async def complete_stream( + self, metadata: dict[str, Any] | None = None + ) -> dict[str, Any]: """Send completion event and return summary.""" result = { "event": "stream_complete", "operation": self._operation, "total_items": self._item_count, - "metadata": metadata or {} + "metadata": metadata or {}, } await self.send_sse(result) - return result \ No newline at end of file + return result diff --git a/contextframe/mcp/core/transport.py b/contextframe/mcp/core/transport.py index 3d8d59f..5de391e 100644 --- a/contextframe/mcp/core/transport.py +++ b/contextframe/mcp/core/transport.py @@ -5,105 +5,108 @@ transport types. """ +import asyncio from abc import ABC, abstractmethod +from collections.abc import AsyncIterator from dataclasses import dataclass -from typing import Any, Dict, Optional, AsyncIterator -import asyncio +from typing import Any, Dict, Optional @dataclass class Progress: """Progress update for long-running operations.""" - + operation: str current: int total: int status: str - details: Optional[Dict[str, Any]] = None + details: dict[str, Any] | None = None -@dataclass +@dataclass class Subscription: """Subscription for change notifications.""" - + id: str resource_type: str - filter: Optional[str] = None - last_poll: Optional[str] = None + filter: str | None = None + last_poll: str | None = None class TransportAdapter(ABC): """Base class for transport adapters. - - This abstraction ensures that all MCP features (tools, resources, + + This abstraction ensures that all MCP features (tools, resources, subscriptions, etc.) work identically across different transports. """ - + def __init__(self): - self._subscriptions: Dict[str, Subscription] = {} + self._subscriptions: dict[str, Subscription] = {} self._progress_handlers = [] - + @abstractmethod async def initialize(self) -> None: """Initialize the transport.""" pass - + @abstractmethod async def shutdown(self) -> None: """Shutdown the transport cleanly.""" pass - + @abstractmethod - async def send_message(self, message: Dict[str, Any]) -> None: + async def send_message(self, message: dict[str, Any]) -> None: """Send a message through the transport.""" pass - + @abstractmethod - async def receive_message(self) -> Optional[Dict[str, Any]]: + async def receive_message(self) -> dict[str, Any] | None: """Receive a message from the transport.""" pass - + async def send_progress(self, progress: Progress) -> None: """Send progress update in transport-appropriate way. - + - Stdio: Include in structured response - HTTP: Send via SSE """ # Default implementation stores progress for inclusion in response for handler in self._progress_handlers: await handler(progress) - + def add_progress_handler(self, handler): """Add a progress handler callback.""" self._progress_handlers.append(handler) - - async def handle_subscription(self, subscription: Subscription) -> AsyncIterator[Dict[str, Any]]: + + async def handle_subscription( + self, subscription: Subscription + ) -> AsyncIterator[dict[str, Any]]: """Handle subscription in transport-appropriate way. - + - Stdio: Polling-based with change tokens - HTTP: SSE streaming """ self._subscriptions[subscription.id] = subscription - + # Base implementation - subclasses override while subscription.id in self._subscriptions: # This would be overridden by transport-specific logic await asyncio.sleep(1) yield {"subscription_id": subscription.id, "changes": []} - + def cancel_subscription(self, subscription_id: str) -> bool: """Cancel an active subscription.""" if subscription_id in self._subscriptions: del self._subscriptions[subscription_id] return True return False - + @property def supports_streaming(self) -> bool: """Whether this transport supports streaming responses.""" return False - - @property + + @property def transport_type(self) -> str: """Identifier for the transport type.""" - return "base" \ No newline at end of file + return "base" diff --git a/contextframe/mcp/enhancement_tools.py b/contextframe/mcp/enhancement_tools.py index 2ff13de..a9a3907 100644 --- a/contextframe/mcp/enhancement_tools.py +++ b/contextframe/mcp/enhancement_tools.py @@ -1,44 +1,42 @@ """Enhancement and extraction tools for MCP server.""" -import os import logging -from typing import Any, Dict, List, Optional -from pathlib import Path - +import os from contextframe.enhance import ContextEnhancer, EnhancementTools from contextframe.extract import ( BatchExtractor, - MarkdownExtractor, + CSVExtractor, JSONExtractor, + MarkdownExtractor, + TextFileExtractor, YAMLExtractor, - CSVExtractor, - TextFileExtractor ) -from contextframe.mcp.errors import InvalidParams, InternalError +from contextframe.mcp.errors import InternalError, InvalidParams from contextframe.mcp.schemas import Tool - +from pathlib import Path +from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) def register_enhancement_tools(tool_registry, dataset): """Register enhancement tools with the MCP tool registry.""" - + # Initialize enhancer model = os.environ.get("CONTEXTFRAME_ENHANCE_MODEL", "gpt-4") api_key = os.environ.get("OPENAI_API_KEY") - + if not api_key: logger.warning("No OpenAI API key found. Enhancement tools will be disabled.") return - + try: enhancer = ContextEnhancer(model=model, api_key=api_key) enhancement_tools = EnhancementTools(enhancer) except Exception as e: logger.warning(f"Failed to initialize enhancer: {e}") return - + # Register enhance_context tool tool_registry.register( "enhance_context", @@ -50,23 +48,23 @@ def register_enhancement_tools(tool_registry, dataset): "properties": { "document_id": { "type": "string", - "description": "Document UUID to enhance" + "description": "Document UUID to enhance", }, "purpose": { "type": "string", - "description": "What the context should focus on" + "description": "What the context should focus on", }, "current_context": { "type": "string", - "description": "Existing context if any" - } + "description": "Existing context if any", + }, }, - "required": ["document_id", "purpose"] - } + "required": ["document_id", "purpose"], + }, ), - lambda args: _enhance_context(dataset, enhancement_tools, args) + lambda args: _enhance_context(dataset, enhancement_tools, args), ) - + # Register extract_metadata tool tool_registry.register( "extract_metadata", @@ -76,27 +74,24 @@ def register_enhancement_tools(tool_registry, dataset): inputSchema={ "type": "object", "properties": { - "document_id": { - "type": "string", - "description": "Document UUID" - }, + "document_id": {"type": "string", "description": "Document UUID"}, "schema": { "type": "string", - "description": "What metadata to extract (as prompt)" + "description": "What metadata to extract (as prompt)", }, "format": { "type": "string", "enum": ["json", "text"], "default": "json", - "description": "Output format" - } + "description": "Output format", + }, }, - "required": ["document_id", "schema"] - } + "required": ["document_id", "schema"], + }, ), - lambda args: _extract_metadata(dataset, enhancement_tools, args) + lambda args: _extract_metadata(dataset, enhancement_tools, args), ) - + # Register generate_tags tool tool_registry.register( "generate_tags", @@ -106,29 +101,26 @@ def register_enhancement_tools(tool_registry, dataset): inputSchema={ "type": "object", "properties": { - "document_id": { - "type": "string", - "description": "Document UUID" - }, + "document_id": {"type": "string", "description": "Document UUID"}, "tag_types": { "type": "string", "default": "topics, technologies, concepts", - "description": "Types of tags to generate" + "description": "Types of tags to generate", }, "max_tags": { "type": "integer", "minimum": 1, "maximum": 20, "default": 5, - "description": "Maximum number of tags" - } + "description": "Maximum number of tags", + }, }, - "required": ["document_id"] - } + "required": ["document_id"], + }, ), - lambda args: _generate_tags(dataset, enhancement_tools, args) + lambda args: _generate_tags(dataset, enhancement_tools, args), ) - + # Register improve_title tool tool_registry.register( "improve_title", @@ -138,23 +130,20 @@ def register_enhancement_tools(tool_registry, dataset): inputSchema={ "type": "object", "properties": { - "document_id": { - "type": "string", - "description": "Document UUID" - }, + "document_id": {"type": "string", "description": "Document UUID"}, "style": { "type": "string", "enum": ["descriptive", "technical", "concise"], "default": "descriptive", - "description": "Title style" - } + "description": "Title style", + }, }, - "required": ["document_id"] - } + "required": ["document_id"], + }, ), - lambda args: _improve_title(dataset, enhancement_tools, args) + lambda args: _improve_title(dataset, enhancement_tools, args), ) - + # Register enhance_for_purpose tool tool_registry.register( "enhance_for_purpose", @@ -164,34 +153,31 @@ def register_enhancement_tools(tool_registry, dataset): inputSchema={ "type": "object", "properties": { - "document_id": { - "type": "string", - "description": "Document UUID" - }, + "document_id": {"type": "string", "description": "Document UUID"}, "purpose": { "type": "string", - "description": "Purpose or use case for enhancement" + "description": "Purpose or use case for enhancement", }, "fields": { "type": "array", "items": { "type": "string", - "enum": ["context", "tags", "custom_metadata"] + "enum": ["context", "tags", "custom_metadata"], }, "default": ["context", "tags", "custom_metadata"], - "description": "Which fields to enhance" - } + "description": "Which fields to enhance", + }, }, - "required": ["document_id", "purpose"] - } + "required": ["document_id", "purpose"], + }, ), - lambda args: _enhance_for_purpose(dataset, enhancement_tools, args) + lambda args: _enhance_for_purpose(dataset, enhancement_tools, args), ) def register_extraction_tools(tool_registry, dataset): """Register extraction tools with the MCP tool registry.""" - + # Register extract_from_file tool tool_registry.register( "extract_from_file", @@ -203,29 +189,29 @@ def register_extraction_tools(tool_registry, dataset): "properties": { "file_path": { "type": "string", - "description": "Path to file to extract" + "description": "Path to file to extract", }, "add_to_dataset": { "type": "boolean", "default": True, - "description": "Whether to add extracted content to dataset" + "description": "Whether to add extracted content to dataset", }, "generate_embedding": { "type": "boolean", "default": True, - "description": "Whether to generate embeddings" + "description": "Whether to generate embeddings", }, "collection": { "type": "string", - "description": "Collection to add document to" - } + "description": "Collection to add document to", + }, }, - "required": ["file_path"] - } + "required": ["file_path"], + }, ), - lambda args: _extract_from_file(dataset, args) + lambda args: _extract_from_file(dataset, args), ) - + # Register batch_extract tool tool_registry.register( "batch_extract", @@ -237,192 +223,180 @@ def register_extraction_tools(tool_registry, dataset): "properties": { "directory": { "type": "string", - "description": "Directory path to process" + "description": "Directory path to process", }, "patterns": { "type": "array", "items": {"type": "string"}, "default": ["*.md", "*.txt", "*.json", "*.yaml", "*.yml"], - "description": "File patterns to match" + "description": "File patterns to match", }, "recursive": { "type": "boolean", "default": True, - "description": "Process subdirectories" + "description": "Process subdirectories", }, "add_to_dataset": { "type": "boolean", "default": True, - "description": "Add to dataset" + "description": "Add to dataset", }, - "collection": { - "type": "string", - "description": "Collection name" - } + "collection": {"type": "string", "description": "Collection name"}, }, - "required": ["directory"] - } + "required": ["directory"], + }, ), - lambda args: _batch_extract(dataset, args) + lambda args: _batch_extract(dataset, args), ) # Implementation functions -async def _enhance_context(dataset, enhancement_tools, args: Dict[str, Any]) -> Dict[str, Any]: +async def _enhance_context( + dataset, enhancement_tools, args: dict[str, Any] +) -> dict[str, Any]: """Implement enhance_context tool.""" # Get document doc_id = args["document_id"] results = dataset.query(f"uuid = '{doc_id}'", limit=1) if not results: raise InvalidParams(f"Document not found: {doc_id}") - + record = results[0] - + # Enhance context new_context = enhancement_tools.enhance_context( content=record.content, purpose=args["purpose"], - current_context=args.get("current_context", record.metadata.get("context")) + current_context=args.get("current_context", record.metadata.get("context")), ) - + # Update document record.metadata["context"] = new_context dataset.delete(f"uuid = '{doc_id}'") dataset.add([record]) - - return { - "document_id": doc_id, - "context": new_context - } + return {"document_id": doc_id, "context": new_context} -async def _extract_metadata(dataset, enhancement_tools, args: Dict[str, Any]) -> Dict[str, Any]: + +async def _extract_metadata( + dataset, enhancement_tools, args: dict[str, Any] +) -> dict[str, Any]: """Implement extract_metadata tool.""" doc_id = args["document_id"] results = dataset.query(f"uuid = '{doc_id}'", limit=1) if not results: raise InvalidParams(f"Document not found: {doc_id}") - + record = results[0] - + # Extract metadata metadata = enhancement_tools.extract_metadata( - content=record.content, - schema=args["schema"], - format=args.get("format", "json") + content=record.content, schema=args["schema"], format=args.get("format", "json") ) - + # Update document if isinstance(metadata, dict): record.metadata.get("custom_metadata", {}).update(metadata) else: record.metadata["custom_metadata"] = metadata - + dataset.delete(f"uuid = '{doc_id}'") dataset.add([record]) - - return { - "document_id": doc_id, - "metadata": metadata - } + + return {"document_id": doc_id, "metadata": metadata} -async def _generate_tags(dataset, enhancement_tools, args: Dict[str, Any]) -> Dict[str, Any]: +async def _generate_tags( + dataset, enhancement_tools, args: dict[str, Any] +) -> dict[str, Any]: """Implement generate_tags tool.""" doc_id = args["document_id"] results = dataset.query(f"uuid = '{doc_id}'", limit=1) if not results: raise InvalidParams(f"Document not found: {doc_id}") - + record = results[0] - + # Generate tags tags = enhancement_tools.generate_tags( content=record.content, tag_types=args.get("tag_types", "topics, technologies, concepts"), - max_tags=args.get("max_tags", 5) + max_tags=args.get("max_tags", 5), ) - + # Update document record.metadata["tags"] = tags dataset.delete(f"uuid = '{doc_id}'") dataset.add([record]) - - return { - "document_id": doc_id, - "tags": tags - } + return {"document_id": doc_id, "tags": tags} -async def _improve_title(dataset, enhancement_tools, args: Dict[str, Any]) -> Dict[str, Any]: + +async def _improve_title( + dataset, enhancement_tools, args: dict[str, Any] +) -> dict[str, Any]: """Implement improve_title tool.""" doc_id = args["document_id"] results = dataset.query(f"uuid = '{doc_id}'", limit=1) if not results: raise InvalidParams(f"Document not found: {doc_id}") - + record = results[0] - + # Improve title new_title = enhancement_tools.improve_title( content=record.content, current_title=record.metadata.get("title"), - style=args.get("style", "descriptive") + style=args.get("style", "descriptive"), ) - + # Update document record.metadata["title"] = new_title dataset.delete(f"uuid = '{doc_id}'") dataset.add([record]) - - return { - "document_id": doc_id, - "title": new_title - } + return {"document_id": doc_id, "title": new_title} -async def _enhance_for_purpose(dataset, enhancement_tools, args: Dict[str, Any]) -> Dict[str, Any]: + +async def _enhance_for_purpose( + dataset, enhancement_tools, args: dict[str, Any] +) -> dict[str, Any]: """Implement enhance_for_purpose tool.""" doc_id = args["document_id"] results = dataset.query(f"uuid = '{doc_id}'", limit=1) if not results: raise InvalidParams(f"Document not found: {doc_id}") - + record = results[0] - + # Enhance for purpose enhancements = enhancement_tools.enhance_for_purpose( - content=record.content, - purpose=args["purpose"], - fields=args.get("fields") + content=record.content, purpose=args["purpose"], fields=args.get("fields") ) - + # Update document with enhancements for field, value in enhancements.items(): if field == "custom_metadata" and isinstance(value, dict): record.metadata.get("custom_metadata", {}).update(value) else: record.metadata[field] = value - + dataset.delete(f"uuid = '{doc_id}'") dataset.add([record]) - - return { - "document_id": doc_id, - "enhancements": enhancements - } + return {"document_id": doc_id, "enhancements": enhancements} -async def _extract_from_file(dataset, args: Dict[str, Any]) -> Dict[str, Any]: + +async def _extract_from_file(dataset, args: dict[str, Any]) -> dict[str, Any]: """Implement extract_from_file tool.""" file_path = Path(args["file_path"]) - + if not file_path.exists(): raise InvalidParams(f"File not found: {file_path}") - + # Determine extractor based on file extension ext = file_path.suffix.lower() - + if ext == ".md": extractor = MarkdownExtractor() elif ext == ".json": @@ -433,109 +407,107 @@ async def _extract_from_file(dataset, args: Dict[str, Any]) -> Dict[str, Any]: extractor = CSVExtractor() else: extractor = TextFileExtractor() - + try: # Extract content result = extractor.extract(str(file_path)) - + if args.get("add_to_dataset", True): # Create record from extraction from contextframe.frame import FrameRecord - - record = FrameRecord( - content=result.content, - metadata=result.metadata - ) - + + record = FrameRecord(content=result.content, metadata=result.metadata) + # Add collection if specified if args.get("collection"): record.metadata["collection"] = args["collection"] - + # Generate embedding if requested if args.get("generate_embedding", True): - model = os.environ.get("CONTEXTFRAME_EMBED_MODEL", "text-embedding-ada-002") + model = os.environ.get( + "CONTEXTFRAME_EMBED_MODEL", "text-embedding-ada-002" + ) api_key = os.environ.get("OPENAI_API_KEY") - + if api_key: from contextframe.embed import LiteLLMProvider + provider = LiteLLMProvider(model, api_key=api_key) embed_result = provider.embed(record.content) record.embeddings = embed_result.embeddings[0] - + # Add to dataset dataset.add([record]) - + return { "file_path": str(file_path), "document_id": record.uuid, "content_length": len(result.content), - "metadata": result.metadata + "metadata": result.metadata, } else: return { "file_path": str(file_path), "content": result.content, - "metadata": result.metadata + "metadata": result.metadata, } - + except Exception as e: raise InternalError(f"Extraction failed: {str(e)}") -async def _batch_extract(dataset, args: Dict[str, Any]) -> Dict[str, Any]: +async def _batch_extract(dataset, args: dict[str, Any]) -> dict[str, Any]: """Implement batch_extract tool.""" directory = Path(args["directory"]) - + if not directory.exists() or not directory.is_dir(): raise InvalidParams(f"Directory not found: {directory}") - + batch_extractor = BatchExtractor() patterns = args.get("patterns", ["*.md", "*.txt", "*.json", "*.yaml", "*.yml"]) - + try: # Extract from directory results = batch_extractor.extract_directory( - str(directory), - patterns=patterns, - recursive=args.get("recursive", True) + str(directory), patterns=patterns, recursive=args.get("recursive", True) ) - + added_documents = [] - + if args.get("add_to_dataset", True): from contextframe.frame import FrameRecord - + for result in results: - record = FrameRecord( - content=result.content, - metadata=result.metadata - ) - + record = FrameRecord(content=result.content, metadata=result.metadata) + # Add collection if specified if args.get("collection"): record.metadata["collection"] = args["collection"] - + # Generate embeddings in batch if API key available if args.get("generate_embedding", True): - model = os.environ.get("CONTEXTFRAME_EMBED_MODEL", "text-embedding-ada-002") + model = os.environ.get( + "CONTEXTFRAME_EMBED_MODEL", "text-embedding-ada-002" + ) api_key = os.environ.get("OPENAI_API_KEY") - + if api_key: from contextframe.embed import LiteLLMProvider + provider = LiteLLMProvider(model, api_key=api_key) embed_result = provider.embed(record.content) record.embeddings = embed_result.embeddings[0] - + added_documents.append(record) - + # Add all documents dataset.add(added_documents) - + return { "directory": str(directory), "files_processed": len(results), "documents_added": len(added_documents), - "patterns": patterns + "patterns": patterns, } else: return { @@ -545,11 +517,11 @@ async def _batch_extract(dataset, args: Dict[str, Any]) -> Dict[str, Any]: { "file_path": r.metadata.get("source", "unknown"), "content_length": len(r.content), - "metadata": r.metadata + "metadata": r.metadata, } for r in results - ] + ], } - + except Exception as e: - raise InternalError(f"Batch extraction failed: {str(e)}") \ No newline at end of file + raise InternalError(f"Batch extraction failed: {str(e)}") diff --git a/contextframe/mcp/example_client.py b/contextframe/mcp/example_client.py index 1c50c24..5798f79 100644 --- a/contextframe/mcp/example_client.py +++ b/contextframe/mcp/example_client.py @@ -9,9 +9,9 @@ python -m contextframe.mcp /path/to/dataset.lance | python example_client.py """ +import asyncio import json import sys -import asyncio from typing import Any, Dict, Optional @@ -26,130 +26,136 @@ def _next_id(self) -> int: self._message_id += 1 return self._message_id - async def send_message(self, method: str, params: Optional[Dict[str, Any]] = None) -> None: + async def send_message( + self, method: str, params: dict[str, Any] | None = None + ) -> None: """Send a JSON-RPC message to stdout.""" - message = { - "jsonrpc": "2.0", - "method": method, - "id": self._next_id() - } + message = {"jsonrpc": "2.0", "method": method, "id": self._next_id()} if params: message["params"] = params - + print(json.dumps(message)) sys.stdout.flush() - async def read_response(self) -> Dict[str, Any]: + async def read_response(self) -> dict[str, Any]: """Read a JSON-RPC response from stdin.""" line = sys.stdin.readline() if not line: raise EOFError("Connection closed") - + return json.loads(line.strip()) - async def call(self, method: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + async def call( + self, method: str, params: dict[str, Any] | None = None + ) -> dict[str, Any]: """Make an RPC call and wait for response.""" await self.send_message(method, params) response = await self.read_response() - + if "error" in response: raise Exception(f"RPC Error: {response['error']}") - + return response.get("result", {}) async def main(): """Example client interaction.""" client = MCPClient() - + print("=== MCP Client Example ===") - + try: # 1. Initialize print("\n1. Initializing...") - result = await client.call("initialize", { - "protocolVersion": "0.1.0", - "capabilities": {} - }) - print(f"Server: {result['serverInfo']['name']} v{result['serverInfo']['version']}") + result = await client.call( + "initialize", {"protocolVersion": "0.1.0", "capabilities": {}} + ) + print( + f"Server: {result['serverInfo']['name']} v{result['serverInfo']['version']}" + ) print(f"Capabilities: {result['capabilities']}") - + # 2. List tools print("\n2. Listing tools...") result = await client.call("tools/list") print(f"Available tools: {len(result['tools'])}") for tool in result['tools']: print(f" - {tool['name']}: {tool['description']}") - + # 3. List resources print("\n3. Listing resources...") result = await client.call("resources/list") print(f"Available resources: {len(result['resources'])}") for resource in result['resources']: print(f" - {resource['name']}: {resource['uri']}") - + # 4. Read dataset info print("\n4. Reading dataset info...") - result = await client.call("resources/read", { - "uri": "contextframe://dataset/info" - }) + result = await client.call( + "resources/read", {"uri": "contextframe://dataset/info"} + ) info = json.loads(result['contents'][0]['text']) print(f"Dataset path: {info['dataset_path']}") print(f"Total documents: {info.get('total_documents', 'Unknown')}") - + # 5. Search documents print("\n5. Searching documents...") - result = await client.call("tools/call", { - "name": "search_documents", - "arguments": { - "query": "test", - "search_type": "text", - "limit": 3 - } - }) + result = await client.call( + "tools/call", + { + "name": "search_documents", + "arguments": {"query": "test", "search_type": "text", "limit": 3}, + }, + ) print(f"Found {len(result['documents'])} documents") for doc in result['documents']: print(f" - {doc['uuid']}: {doc['content'][:50]}...") - + # 6. Add a document print("\n6. Adding a document...") - result = await client.call("tools/call", { - "name": "add_document", - "arguments": { - "content": "This is a test document added via MCP", - "metadata": { - "title": "MCP Test Document", - "source": "example_client.py" + result = await client.call( + "tools/call", + { + "name": "add_document", + "arguments": { + "content": "This is a test document added via MCP", + "metadata": { + "title": "MCP Test Document", + "source": "example_client.py", + }, + "generate_embedding": False, }, - "generate_embedding": False - } - }) + }, + ) doc_id = result['document']['uuid'] print(f"Added document: {doc_id}") - + # 7. Get the document back print("\n7. Retrieving document...") - result = await client.call("tools/call", { - "name": "get_document", - "arguments": { - "document_id": doc_id, - "include_content": True, - "include_metadata": True - } - }) + result = await client.call( + "tools/call", + { + "name": "get_document", + "arguments": { + "document_id": doc_id, + "include_content": True, + "include_metadata": True, + }, + }, + ) doc = result['document'] print(f"Retrieved: {doc['content']}") print(f"Metadata: {doc['metadata']}") - + # 8. Shutdown print("\n8. Shutting down...") await client.send_message("shutdown") print("Client complete!") - + except Exception as e: print(f"Error: {e}") sys.exit(1) if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/contextframe/mcp/handlers.py b/contextframe/mcp/handlers.py index b6ce835..8674bab 100644 --- a/contextframe/mcp/handlers.py +++ b/contextframe/mcp/handlers.py @@ -1,26 +1,24 @@ """Message handlers for MCP server.""" import logging -from typing import Any, Dict, Optional -from pydantic import ValidationError - from contextframe.mcp.errors import ( + InvalidParams, InvalidRequest, - MethodNotFound, MCPError, - InvalidParams + MethodNotFound, ) from contextframe.mcp.schemas import ( InitializeParams, InitializeResult, + JSONRPCError, JSONRPCRequest, JSONRPCResponse, - JSONRPCError, MCPCapabilities, + ResourceReadParams, ToolCallParams, - ResourceReadParams ) - +from pydantic import ValidationError +from typing import Any, Dict, Optional logger = logging.getLogger(__name__) @@ -40,13 +38,13 @@ def __init__(self, server: "ContextFrameMCPServer"): "shutdown": self.handle_shutdown, } - async def handle(self, message: Dict[str, Any]) -> Dict[str, Any]: + async def handle(self, message: dict[str, Any]) -> dict[str, Any]: """Handle incoming JSON-RPC message and return response.""" try: # Check for jsonrpc field first if "jsonrpc" not in message: raise InvalidRequest("Missing jsonrpc field") - + # Parse request try: request = JSONRPCRequest(**message) @@ -64,45 +62,38 @@ async def handle(self, message: Dict[str, Any]) -> Dict[str, Any]: # Build response (notifications don't get responses) if request.id is None: return None - - response = JSONRPCResponse( - jsonrpc="2.0", - result=result, - id=request.id - ) + + response = JSONRPCResponse(jsonrpc="2.0", result=result, id=request.id) except MCPError as e: # MCP-specific errors response = JSONRPCResponse( jsonrpc="2.0", error=JSONRPCError(**e.to_json_rpc()), - id=message.get("id") + id=message.get("id"), ) except Exception as e: # Unexpected errors logger.exception("Unexpected error handling message") - error = MCPError( - code=-32603, - message=f"Internal error: {str(e)}" - ) + error = MCPError(code=-32603, message=f"Internal error: {str(e)}") response = JSONRPCResponse( jsonrpc="2.0", error=JSONRPCError(**error.to_json_rpc()), - id=message.get("id") + id=message.get("id"), ) return response.model_dump(exclude_none=True) - async def handle_initialize(self, params: Dict[str, Any]) -> Dict[str, Any]: + async def handle_initialize(self, params: dict[str, Any]) -> dict[str, Any]: """Handle initialization handshake.""" try: init_params = InitializeParams(**params) except ValidationError as e: raise InvalidParams(f"Invalid initialize parameters: {str(e)}") - + # Initialize server state self.server._initialized = True - + # Build response result = InitializeResult( protocolVersion="0.1.0", # MCP protocol version @@ -110,50 +101,49 @@ async def handle_initialize(self, params: Dict[str, Any]) -> Dict[str, Any]: tools=True, resources=True, prompts=False, # Not implemented yet - logging=False # Not implemented yet + logging=False, # Not implemented yet ), serverInfo={ "name": "contextframe", "version": "0.1.0", - "description": "MCP server for ContextFrame datasets" - } + "description": "MCP server for ContextFrame datasets", + }, ) - + return result.model_dump() - async def handle_initialized(self, params: Dict[str, Any]) -> None: + async def handle_initialized(self, params: dict[str, Any]) -> None: """Handle initialized notification.""" # Client has confirmed initialization logger.info("MCP client initialized") return None # Notifications don't return results - async def handle_tools_list(self, params: Dict[str, Any]) -> Dict[str, Any]: + async def handle_tools_list(self, params: dict[str, Any]) -> dict[str, Any]: """List available tools.""" tools = self.server.tools.list_tools() return {"tools": [tool.model_dump() for tool in tools]} - async def handle_tool_call(self, params: Dict[str, Any]) -> Dict[str, Any]: + async def handle_tool_call(self, params: dict[str, Any]) -> dict[str, Any]: """Execute a tool.""" tool_params = ToolCallParams(**params) result = await self.server.tools.call_tool( - tool_params.name, - tool_params.arguments + tool_params.name, tool_params.arguments ) return result - async def handle_resources_list(self, params: Dict[str, Any]) -> Dict[str, Any]: + async def handle_resources_list(self, params: dict[str, Any]) -> dict[str, Any]: """List available resources.""" resources = self.server.resources.list_resources() return {"resources": [resource.model_dump() for resource in resources]} - async def handle_resource_read(self, params: Dict[str, Any]) -> Dict[str, Any]: + async def handle_resource_read(self, params: dict[str, Any]) -> dict[str, Any]: """Read a resource.""" resource_params = ResourceReadParams(**params) content = await self.server.resources.read_resource(resource_params.uri) return {"contents": [content]} - async def handle_shutdown(self, params: Dict[str, Any]) -> None: + async def handle_shutdown(self, params: dict[str, Any]) -> None: """Handle shutdown request.""" logger.info("Shutdown requested") self.server._shutdown_requested = True - return None # Shutdown is a notification \ No newline at end of file + return None # Shutdown is a notification diff --git a/contextframe/mcp/resources.py b/contextframe/mcp/resources.py index 3e48854..203da58 100644 --- a/contextframe/mcp/resources.py +++ b/contextframe/mcp/resources.py @@ -1,11 +1,10 @@ """Resource system for MCP server.""" import json -from typing import Any, Dict, List - from contextframe.frame import FrameDataset from contextframe.mcp.errors import InvalidParams from contextframe.mcp.schemas import Resource +from typing import Any, Dict, List class ResourceRegistry: @@ -15,50 +14,50 @@ def __init__(self, dataset: FrameDataset): self.dataset = dataset self._base_uri = "contextframe://" - def list_resources(self) -> List[Resource]: + def list_resources(self) -> list[Resource]: """List all available resources.""" resources = [ Resource( uri=f"{self._base_uri}dataset/info", name="Dataset Information", description="Dataset metadata, statistics, and configuration", - mimeType="application/json" + mimeType="application/json", ), Resource( uri=f"{self._base_uri}dataset/schema", name="Dataset Schema", description="Arrow schema information for the dataset", - mimeType="application/json" + mimeType="application/json", ), Resource( uri=f"{self._base_uri}dataset/stats", name="Dataset Statistics", description="Statistical information about the dataset", - mimeType="application/json" + mimeType="application/json", ), Resource( uri=f"{self._base_uri}collections", name="Document Collections", description="List of document collections in the dataset", - mimeType="application/json" + mimeType="application/json", ), Resource( uri=f"{self._base_uri}relationships", name="Document Relationships", description="Overview of document relationships in the dataset", - mimeType="application/json" - ) + mimeType="application/json", + ), ] - + return resources - async def read_resource(self, uri: str) -> Dict[str, Any]: + async def read_resource(self, uri: str) -> dict[str, Any]: """Read resource content by URI.""" if not uri.startswith(self._base_uri): raise InvalidParams(f"Invalid resource URI: {uri}") - - resource_path = uri[len(self._base_uri):] - + + resource_path = uri[len(self._base_uri) :] + if resource_path == "dataset/info": return await self._get_dataset_info() elif resource_path == "dataset/schema": @@ -72,77 +71,80 @@ async def read_resource(self, uri: str) -> Dict[str, Any]: else: raise InvalidParams(f"Unknown resource: {uri}") - async def _get_dataset_info(self) -> Dict[str, Any]: + async def _get_dataset_info(self) -> dict[str, Any]: """Get general dataset information.""" # Get dataset metadata try: # Get basic info from the dataset total_docs = self.dataset._dataset.count_rows() # Get total document count - + info = { "uri": f"{self._base_uri}dataset/info", "name": "Dataset Information", "mimeType": "application/json", - "text": json.dumps({ - "dataset_path": str(self.dataset._dataset.uri), # Lance dataset URI - "total_documents": total_docs, - "version": getattr(self.dataset._dataset, "version", "unknown"), - "storage_format": "lance", - "features": { - "vector_search": True, - "full_text_search": True, - "sql_filtering": True, - "relationships": True, - "collections": True - } - }, indent=2) + "text": json.dumps( + { + "dataset_path": str( + self.dataset._dataset.uri + ), # Lance dataset URI + "total_documents": total_docs, + "version": getattr(self.dataset._dataset, "version", "unknown"), + "storage_format": "lance", + "features": { + "vector_search": True, + "full_text_search": True, + "sql_filtering": True, + "relationships": True, + "collections": True, + }, + }, + indent=2, + ), } - + return info - + except Exception as e: return { "uri": f"{self._base_uri}dataset/info", "name": "Dataset Information", "mimeType": "application/json", - "text": json.dumps({"error": str(e)}, indent=2) + "text": json.dumps({"error": str(e)}, indent=2), } - async def _get_dataset_schema(self) -> Dict[str, Any]: + async def _get_dataset_schema(self) -> dict[str, Any]: """Get dataset schema information.""" try: # Get Arrow schema from the dataset schema = self.dataset._dataset.schema - + # Convert schema to dict representation - schema_dict = { - "fields": [] - } - + schema_dict = {"fields": []} + for field in schema: field_info = { "name": field.name, "type": str(field.type), - "nullable": field.nullable + "nullable": field.nullable, } schema_dict["fields"].append(field_info) - + return { "uri": f"{self._base_uri}dataset/schema", "name": "Dataset Schema", "mimeType": "application/json", - "text": json.dumps(schema_dict, indent=2) + "text": json.dumps(schema_dict, indent=2), } - + except Exception as e: return { "uri": f"{self._base_uri}dataset/schema", "name": "Dataset Schema", "mimeType": "application/json", - "text": json.dumps({"error": str(e)}, indent=2) + "text": json.dumps({"error": str(e)}, indent=2), } - async def _get_dataset_stats(self) -> Dict[str, Any]: + async def _get_dataset_stats(self) -> dict[str, Any]: """Get dataset statistics.""" try: # Gather statistics @@ -151,59 +153,65 @@ async def _get_dataset_stats(self) -> Dict[str, Any]: "collections": {}, "record_types": {}, "has_embeddings": 0, - "avg_content_length": 0 + "avg_content_length": 0, } - + # Sample documents for statistics sample = self.dataset.query("1=1", limit=1000) stats["document_count"] = len(sample) - + total_length = 0 for record in sample: # Count by collection collection = record.metadata.get("collection", "uncategorized") - stats["collections"][collection] = stats["collections"].get(collection, 0) + 1 - + stats["collections"][collection] = ( + stats["collections"].get(collection, 0) + 1 + ) + # Count by record type record_type = record.metadata.get("record_type", "document") - stats["record_types"][record_type] = stats["record_types"].get(record_type, 0) + 1 - + stats["record_types"][record_type] = ( + stats["record_types"].get(record_type, 0) + 1 + ) + # Check embeddings if record.embeddings is not None: stats["has_embeddings"] += 1 - + # Content length if record.content: total_length += len(record.content) - + if stats["document_count"] > 0: stats["avg_content_length"] = total_length / stats["document_count"] - stats["embedding_coverage"] = f"{(stats['has_embeddings'] / stats['document_count']) * 100:.1f}%" - + stats["embedding_coverage"] = ( + f"{(stats['has_embeddings'] / stats['document_count']) * 100:.1f}%" + ) + return { "uri": f"{self._base_uri}dataset/stats", "name": "Dataset Statistics", "mimeType": "application/json", - "text": json.dumps(stats, indent=2) + "text": json.dumps(stats, indent=2), } - + except Exception as e: return { "uri": f"{self._base_uri}dataset/stats", "name": "Dataset Statistics", "mimeType": "application/json", - "text": json.dumps({"error": str(e)}, indent=2) + "text": json.dumps({"error": str(e)}, indent=2), } - async def _get_collections(self) -> Dict[str, Any]: + async def _get_collections(self) -> dict[str, Any]: """Get information about document collections.""" try: # Find all unique collections collections = {} - + # Sample documents to find collections sample = self.dataset.query("1=1", limit=10000) - + for record in sample: collection = record.metadata.get("collection") if collection: @@ -211,34 +219,41 @@ async def _get_collections(self) -> Dict[str, Any]: collections[collection] = { "name": collection, "document_count": 0, - "has_header": False + "has_header": False, } collections[collection]["document_count"] += 1 - + # Check if it's a collection header if record.metadata.get("record_type") == "collection_header": collections[collection]["has_header"] = True - collections[collection]["description"] = record.content[:200] + "..." if len(record.content) > 200 else record.content - + collections[collection]["description"] = ( + record.content[:200] + "..." + if len(record.content) > 200 + else record.content + ) + return { "uri": f"{self._base_uri}collections", "name": "Document Collections", "mimeType": "application/json", - "text": json.dumps({ - "total_collections": len(collections), - "collections": list(collections.values()) - }, indent=2) + "text": json.dumps( + { + "total_collections": len(collections), + "collections": list(collections.values()), + }, + indent=2, + ), } - + except Exception as e: return { "uri": f"{self._base_uri}collections", "name": "Document Collections", "mimeType": "application/json", - "text": json.dumps({"error": str(e)}, indent=2) + "text": json.dumps({"error": str(e)}, indent=2), } - async def _get_relationships(self) -> Dict[str, Any]: + async def _get_relationships(self) -> dict[str, Any]: """Get information about document relationships.""" try: # Find relationships in metadata @@ -247,12 +262,12 @@ async def _get_relationships(self) -> Dict[str, Any]: "related": 0, "references": 0, "member_of": 0, - "total": 0 + "total": 0, } - + # Sample documents to find relationships sample = self.dataset.query("1=1", limit=10000) - + for record in sample: if "relationships" in record.metadata: for rel in record.metadata["relationships"]: @@ -260,21 +275,24 @@ async def _get_relationships(self) -> Dict[str, Any]: if rel_type in relationships: relationships[rel_type] += 1 relationships["total"] += 1 - + return { "uri": f"{self._base_uri}relationships", "name": "Document Relationships", "mimeType": "application/json", - "text": json.dumps({ - "relationship_counts": relationships, - "has_relationships": relationships["total"] > 0 - }, indent=2) + "text": json.dumps( + { + "relationship_counts": relationships, + "has_relationships": relationships["total"] > 0, + }, + indent=2, + ), } - + except Exception as e: return { "uri": f"{self._base_uri}relationships", "name": "Document Relationships", "mimeType": "application/json", - "text": json.dumps({"error": str(e)}, indent=2) - } \ No newline at end of file + "text": json.dumps({"error": str(e)}, indent=2), + } diff --git a/contextframe/mcp/schemas.py b/contextframe/mcp/schemas.py index 7be8626..399b821 100644 --- a/contextframe/mcp/schemas.py +++ b/contextframe/mcp/schemas.py @@ -1,116 +1,116 @@ """Pydantic schemas for MCP protocol messages and data structures.""" +from pydantic import BaseModel, ConfigDict, Field from typing import Any, Dict, List, Literal, Optional, Union -from pydantic import BaseModel, Field, ConfigDict # JSON-RPC 2.0 schemas class JSONRPCRequest(BaseModel): """JSON-RPC 2.0 request.""" - + jsonrpc: Literal["2.0"] = "2.0" method: str - params: Optional[Dict[str, Any]] = None - id: Optional[Union[str, int]] = None + params: dict[str, Any] | None = None + id: str | int | None = None class JSONRPCError(BaseModel): """JSON-RPC 2.0 error object.""" - + code: int message: str - data: Optional[Any] = None + data: Any | None = None class JSONRPCResponse(BaseModel): """JSON-RPC 2.0 response.""" - + jsonrpc: Literal["2.0"] = "2.0" - result: Optional[Any] = None - error: Optional[JSONRPCError] = None - id: Optional[Union[str, int]] = None + result: Any | None = None + error: JSONRPCError | None = None + id: str | int | None = None # MCP protocol schemas class MCPCapabilities(BaseModel): """Server capabilities.""" - - tools: Optional[bool] = None - resources: Optional[bool] = None - prompts: Optional[bool] = None - logging: Optional[bool] = None + + tools: bool | None = None + resources: bool | None = None + prompts: bool | None = None + logging: bool | None = None class InitializeParams(BaseModel): """Parameters for initialize method.""" - + protocolVersion: str capabilities: MCPCapabilities - clientInfo: Optional[Dict[str, Any]] = None + clientInfo: dict[str, Any] | None = None class InitializeResult(BaseModel): """Result of initialize method.""" - + protocolVersion: str capabilities: MCPCapabilities - serverInfo: Dict[str, Any] + serverInfo: dict[str, Any] class Tool(BaseModel): """Tool definition.""" - + name: str description: str - inputSchema: Dict[str, Any] + inputSchema: dict[str, Any] class ToolCallParams(BaseModel): """Parameters for tools/call method.""" - + name: str - arguments: Dict[str, Any] = Field(default_factory=dict) + arguments: dict[str, Any] = Field(default_factory=dict) class Resource(BaseModel): """Resource definition.""" - + uri: str name: str - description: Optional[str] = None - mimeType: Optional[str] = None + description: str | None = None + mimeType: str | None = None class ResourceReadParams(BaseModel): """Parameters for resources/read method.""" - + uri: str # ContextFrame-specific schemas class SearchDocumentsParams(BaseModel): """Parameters for search_documents tool.""" - + query: str search_type: Literal["vector", "text", "hybrid"] = "hybrid" limit: int = Field(default=10, ge=1, le=1000) - filter: Optional[str] = None + filter: str | None = None class AddDocumentParams(BaseModel): """Parameters for add_document tool.""" - + content: str - metadata: Dict[str, Any] = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) generate_embedding: bool = True - collection: Optional[str] = None - chunk_size: Optional[int] = Field(default=None, ge=100, le=10000) - chunk_overlap: Optional[int] = Field(default=None, ge=0, le=1000) + collection: str | None = None + chunk_size: int | None = Field(default=None, ge=100, le=10000) + chunk_overlap: int | None = Field(default=None, ge=0, le=1000) class GetDocumentParams(BaseModel): """Parameters for get_document tool.""" - + document_id: str include_content: bool = True include_metadata: bool = True @@ -119,54 +119,54 @@ class GetDocumentParams(BaseModel): class ListDocumentsParams(BaseModel): """Parameters for list_documents tool.""" - + limit: int = Field(default=100, ge=1, le=1000) offset: int = Field(default=0, ge=0) - filter: Optional[str] = None - order_by: Optional[str] = None + filter: str | None = None + order_by: str | None = None include_content: bool = False class UpdateDocumentParams(BaseModel): """Parameters for update_document tool.""" - + document_id: str - content: Optional[str] = None - metadata: Optional[Dict[str, Any]] = None + content: str | None = None + metadata: dict[str, Any] | None = None regenerate_embedding: bool = False class DeleteDocumentParams(BaseModel): """Parameters for delete_document tool.""" - + document_id: str # Response schemas class DocumentResult(BaseModel): """Result of a document operation.""" - + model_config = ConfigDict(extra='allow') - + uuid: str - content: Optional[str] = None - metadata: Dict[str, Any] = Field(default_factory=dict) - embedding: Optional[List[float]] = None - score: Optional[float] = None # For search results + content: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + embedding: list[float] | None = None + score: float | None = None # For search results class SearchResult(BaseModel): """Result of a search operation.""" - - documents: List[DocumentResult] + + documents: list[DocumentResult] total_count: int search_type_used: str class ListResult(BaseModel): """Result of a list operation.""" - - documents: List[DocumentResult] + + documents: list[DocumentResult] total_count: int offset: int limit: int @@ -175,104 +175,104 @@ class ListResult(BaseModel): # Batch operation schemas class BatchSearchQuery(BaseModel): """Individual search query for batch search.""" - + query: str search_type: Literal["vector", "text", "hybrid"] = "hybrid" limit: int = Field(default=10, ge=1, le=100) - filter: Optional[str] = None + filter: str | None = None class BatchSearchParams(BaseModel): """Execute multiple document searches in parallel.""" - - queries: List[BatchSearchQuery] + + queries: list[BatchSearchQuery] max_parallel: int = Field(default=5, ge=1, le=20) class BatchDocument(BaseModel): """Document for batch operations.""" - + content: str - metadata: Dict[str, Any] = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) class SharedSettings(BaseModel): """Shared settings for batch operations.""" - + generate_embeddings: bool = True - collection: Optional[str] = None - chunk_size: Optional[int] = None - chunk_overlap: Optional[int] = None - metadata: Dict[str, Any] = Field(default_factory=dict) + collection: str | None = None + chunk_size: int | None = None + chunk_overlap: int | None = None + metadata: dict[str, Any] = Field(default_factory=dict) class BatchAddParams(BaseModel): """Add multiple documents efficiently.""" - - documents: List[BatchDocument] + + documents: list[BatchDocument] shared_settings: SharedSettings = Field(default_factory=SharedSettings) atomic: bool = Field(default=True, description="Rollback all on any failure") class UpdateSpec(BaseModel): """Specification for batch updates.""" - - metadata_updates: Optional[Dict[str, Any]] = None - content_template: Optional[str] = None + + metadata_updates: dict[str, Any] | None = None + content_template: str | None = None regenerate_embeddings: bool = False class BatchUpdateParams(BaseModel): """Update multiple documents matching criteria.""" - - filter: Optional[str] = None - document_ids: Optional[List[str]] = None + + filter: str | None = None + document_ids: list[str] | None = None updates: UpdateSpec max_documents: int = Field(default=1000, ge=1, le=10000) class BatchDeleteParams(BaseModel): """Delete multiple documents with confirmation.""" - - filter: Optional[str] = None - document_ids: Optional[List[str]] = None + + filter: str | None = None + document_ids: list[str] | None = None dry_run: bool = Field(default=True, description="Preview what would be deleted") - confirm_count: Optional[int] = Field(None, description="Expected number of deletions") + confirm_count: int | None = Field(None, description="Expected number of deletions") class BatchEnhanceParams(BaseModel): """Enhance multiple documents with LLM.""" - - document_ids: Optional[List[str]] = None - filter: Optional[str] = None - enhancements: List[Literal["context", "tags", "title", "metadata"]] - purpose: Optional[str] = None + + document_ids: list[str] | None = None + filter: str | None = None + enhancements: list[Literal["context", "tags", "title", "metadata"]] + purpose: str | None = None batch_size: int = Field(default=10, ge=1, le=50) class SourceSpec(BaseModel): """Source specification for batch extract.""" - - path: Optional[str] = None - url: Optional[str] = None + + path: str | None = None + url: str | None = None type: Literal["file", "url"] class BatchExtractParams(BaseModel): """Extract from multiple files/URLs.""" - - sources: List[SourceSpec] + + sources: list[SourceSpec] add_to_dataset: bool = True - shared_metadata: Dict[str, Any] = Field(default_factory=dict) - collection: Optional[str] = None + shared_metadata: dict[str, Any] = Field(default_factory=dict) + collection: str | None = None continue_on_error: bool = True class BatchExportParams(BaseModel): """Export documents in bulk.""" - - filter: Optional[str] = None - document_ids: Optional[List[str]] = None + + filter: str | None = None + document_ids: list[str] | None = None format: Literal["json", "jsonl", "csv", "parquet"] include_embeddings: bool = False output_path: str @@ -281,40 +281,52 @@ class BatchExportParams(BaseModel): class BatchImportParams(BaseModel): """Import documents from files.""" - + source_path: str format: Literal["json", "jsonl", "csv", "parquet"] - mapping: Optional[Dict[str, str]] = None - validation: Dict[str, Any] = Field(default_factory=dict) + mapping: dict[str, str] | None = None + validation: dict[str, Any] = Field(default_factory=dict) generate_embeddings: bool = True # Collection management schemas class CreateCollectionParams(BaseModel): """Create a new collection with metadata and optional template.""" - + name: str = Field(..., description="Collection name") - description: Optional[str] = Field(None, description="Collection description") - metadata: Dict[str, Any] = Field(default_factory=dict, description="Collection metadata") - parent_collection: Optional[str] = Field(None, description="Parent collection ID for hierarchies") - template: Optional[str] = Field(None, description="Template name to apply") - initial_members: List[str] = Field(default_factory=list, description="Document IDs to add") + description: str | None = Field(None, description="Collection description") + metadata: dict[str, Any] = Field( + default_factory=dict, description="Collection metadata" + ) + parent_collection: str | None = Field( + None, description="Parent collection ID for hierarchies" + ) + template: str | None = Field(None, description="Template name to apply") + initial_members: list[str] = Field( + default_factory=list, description="Document IDs to add" + ) class UpdateCollectionParams(BaseModel): """Update collection properties and membership.""" - + collection_id: str = Field(..., description="Collection ID to update") - name: Optional[str] = Field(None, description="New name") - description: Optional[str] = Field(None, description="New description") - metadata_updates: Optional[Dict[str, Any]] = Field(None, description="Metadata to update") - add_members: List[str] = Field(default_factory=list, description="Document IDs to add") - remove_members: List[str] = Field(default_factory=list, description="Document IDs to remove") + name: str | None = Field(None, description="New name") + description: str | None = Field(None, description="New description") + metadata_updates: dict[str, Any] | None = Field( + None, description="Metadata to update" + ) + add_members: list[str] = Field( + default_factory=list, description="Document IDs to add" + ) + remove_members: list[str] = Field( + default_factory=list, description="Document IDs to remove" + ) class DeleteCollectionParams(BaseModel): """Delete a collection and optionally its members.""" - + collection_id: str = Field(..., description="Collection ID to delete") delete_members: bool = Field(False, description="Also delete member documents") recursive: bool = Field(False, description="Delete sub-collections recursively") @@ -322,8 +334,8 @@ class DeleteCollectionParams(BaseModel): class ListCollectionsParams(BaseModel): """List collections with filtering and statistics.""" - - parent_id: Optional[str] = Field(None, description="Filter by parent collection") + + parent_id: str | None = Field(None, description="Filter by parent collection") include_stats: bool = Field(True, description="Include member statistics") include_empty: bool = Field(True, description="Include collections with no members") sort_by: Literal["name", "created_at", "member_count"] = Field("name") @@ -333,145 +345,147 @@ class ListCollectionsParams(BaseModel): class MoveDocumentsParams(BaseModel): """Move documents between collections.""" - - document_ids: List[str] = Field(..., description="Documents to move") - source_collection: Optional[str] = Field(None, description="Source collection (None for uncollected)") - target_collection: Optional[str] = Field(None, description="Target collection (None to remove)") + + document_ids: list[str] = Field(..., description="Documents to move") + source_collection: str | None = Field( + None, description="Source collection (None for uncollected)" + ) + target_collection: str | None = Field( + None, description="Target collection (None to remove)" + ) update_metadata: bool = Field(True, description="Apply target collection metadata") class GetCollectionStatsParams(BaseModel): """Get detailed statistics for a collection.""" - + collection_id: str = Field(..., description="Collection ID") - include_member_details: bool = Field(False, description="Include per-member statistics") - include_subcollections: bool = Field(True, description="Include sub-collection stats") + include_member_details: bool = Field( + False, description="Include per-member statistics" + ) + include_subcollections: bool = Field( + True, description="Include sub-collection stats" + ) # Collection response schemas class CollectionInfo(BaseModel): """Information about a collection.""" - + collection_id: str header_id: str name: str - description: Optional[str] = None - parent_id: Optional[str] = None + description: str | None = None + parent_id: str | None = None created_at: str updated_at: str - metadata: Dict[str, Any] = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) member_count: int = 0 - total_size_bytes: Optional[int] = None + total_size_bytes: int | None = None class CollectionStats(BaseModel): """Detailed statistics for a collection.""" - + total_members: int direct_members: int subcollection_members: int total_size_bytes: int avg_document_size: float - unique_tags: List[str] - date_range: Dict[str, str] - member_types: Dict[str, int] + unique_tags: list[str] + date_range: dict[str, str] + member_types: dict[str, int] class CollectionResult(BaseModel): """Result of a collection operation.""" - + collection: CollectionInfo - statistics: Optional[CollectionStats] = None - subcollections: List[CollectionInfo] = Field(default_factory=list) - members: List[DocumentResult] = Field(default_factory=list) + statistics: CollectionStats | None = None + subcollections: list[CollectionInfo] = Field(default_factory=list) + members: list[DocumentResult] = Field(default_factory=list) # Subscription schemas class SubscribeChangesParams(BaseModel): """Create a subscription to monitor dataset changes.""" - + resource_type: Literal["documents", "collections", "all"] = Field( - "all", - description="Type of resources to monitor" + "all", description="Type of resources to monitor" ) - filters: Optional[Dict[str, Any]] = Field( - None, - description="Optional filters (e.g., {'collection_id': '...'})" + filters: dict[str, Any] | None = Field( + None, description="Optional filters (e.g., {'collection_id': '...'})" ) - options: Dict[str, Any] = Field( + options: dict[str, Any] = Field( default_factory=lambda: { "polling_interval": 5, "include_data": False, - "batch_size": 100 + "batch_size": 100, }, - description="Subscription options" + description="Subscription options", ) class PollChangesParams(BaseModel): """Poll for changes since the last poll.""" - + subscription_id: str = Field(..., description="Active subscription ID") - poll_token: Optional[str] = Field(None, description="Token from last poll") + poll_token: str | None = Field(None, description="Token from last poll") timeout: int = Field( - 30, - ge=0, - le=300, - description="Max seconds to wait for changes (long polling)" + 30, ge=0, le=300, description="Max seconds to wait for changes (long polling)" ) class UnsubscribeParams(BaseModel): """Cancel an active subscription.""" - + subscription_id: str = Field(..., description="Subscription to cancel") class GetSubscriptionsParams(BaseModel): """Get list of active subscriptions.""" - - resource_type: Optional[Literal["documents", "collections", "all"]] = Field( - None, - description="Filter by resource type" + + resource_type: Literal["documents", "collections", "all"] | None = Field( + None, description="Filter by resource type" ) # Subscription response schemas class ChangeEvent(BaseModel): """Change event in the dataset.""" - + type: Literal["created", "updated", "deleted"] resource_type: Literal["document", "collection"] resource_id: str version: int timestamp: str - old_data: Optional[Dict[str, Any]] = None - new_data: Optional[Dict[str, Any]] = None + old_data: dict[str, Any] | None = None + new_data: dict[str, Any] | None = None class SubscriptionInfo(BaseModel): """Information about an active subscription.""" - + subscription_id: str resource_type: str - filters: Optional[Dict[str, Any]] + filters: dict[str, Any] | None created_at: str - last_poll: Optional[str] - options: Dict[str, Any] - - + last_poll: str | None + options: dict[str, Any] + + class SubscribeResult(BaseModel): """Result of creating a subscription.""" - + subscription_id: str poll_token: str polling_interval: int - - + + class PollResult(BaseModel): """Result of polling for changes.""" - - changes: List[ChangeEvent] + + changes: list[ChangeEvent] poll_token: str has_more: bool subscription_active: bool @@ -479,13 +493,13 @@ class PollResult(BaseModel): class UnsubscribeResult(BaseModel): """Result of cancelling a subscription.""" - + cancelled: bool - final_poll_token: Optional[str] - - + final_poll_token: str | None + + class GetSubscriptionsResult(BaseModel): """Result of listing subscriptions.""" - - subscriptions: List[SubscriptionInfo] - total_count: int \ No newline at end of file + + subscriptions: list[SubscriptionInfo] + total_count: int diff --git a/contextframe/mcp/server.py b/contextframe/mcp/server.py index 20ae8b7..d492f18 100644 --- a/contextframe/mcp/server.py +++ b/contextframe/mcp/server.py @@ -3,17 +3,15 @@ import asyncio import logging import signal -from typing import Optional -from dataclasses import dataclass - from contextframe.frame import FrameDataset +from contextframe.mcp.core.transport import TransportAdapter +from contextframe.mcp.errors import DatasetNotFound from contextframe.mcp.handlers import MessageHandler -from contextframe.mcp.tools import ToolRegistry from contextframe.mcp.resources import ResourceRegistry -from contextframe.mcp.errors import DatasetNotFound -from contextframe.mcp.core.transport import TransportAdapter +from contextframe.mcp.tools import ToolRegistry from contextframe.mcp.transports.stdio import StdioAdapter - +from dataclasses import dataclass +from typing import Any, Dict, Literal, Optional logger = logging.getLogger(__name__) @@ -21,46 +19,54 @@ @dataclass class MCPConfig: """Configuration for MCP server.""" - + server_name: str = "contextframe" server_version: str = "0.1.0" protocol_version: str = "0.1.0" max_message_size: int = 10 * 1024 * 1024 # 10MB shutdown_timeout: float = 5.0 + # Transport configuration + transport: Literal["stdio", "http", "both"] = "stdio" + + # HTTP-specific configuration + http_host: str = "0.0.0.0" + http_port: int = 8080 + http_cors_origins: list[str] = None + http_auth_enabled: bool = False + http_rate_limit: dict[str, int] = None + http_ssl_cert: str | None = None + http_ssl_key: str | None = None + class ContextFrameMCPServer: """MCP server for ContextFrame datasets. - + Provides standardized access to ContextFrame datasets through the Model Context Protocol, enabling LLMs and AI agents to interact with document collections. """ - def __init__( - self, - dataset_path: str, - config: Optional[MCPConfig] = None - ): + def __init__(self, dataset_path: str, config: MCPConfig | None = None): """Initialize MCP server. - + Args: dataset_path: Path to Lance dataset config: Server configuration """ self.dataset_path = dataset_path self.config = config or MCPConfig() - + # Server state self._initialized = False self._shutdown_requested = False - + # Components (initialized in setup) - self.dataset: Optional[FrameDataset] = None - self.transport: Optional[TransportAdapter] = None - self.handler: Optional[MessageHandler] = None - self.tools: Optional[ToolRegistry] = None - self.resources: Optional[ResourceRegistry] = None + self.dataset: FrameDataset | None = None + self.transport: TransportAdapter | None = None + self.handler: MessageHandler | None = None + self.tools: ToolRegistry | None = None + self.resources: ResourceRegistry | None = None async def setup(self): """Set up server components.""" @@ -69,80 +75,164 @@ async def setup(self): self.dataset = FrameDataset.open(self.dataset_path) except Exception as e: raise DatasetNotFound(self.dataset_path) from e - - # Initialize components - self.transport = StdioAdapter() - self.handler = MessageHandler(self) + + # Initialize transport based on configuration + if self.config.transport == "stdio": + self.transport = StdioAdapter() + self.handler = MessageHandler(self) + self.tools = ToolRegistry(self.dataset, self.transport) + self.resources = ResourceRegistry(self.dataset) + await self.transport.initialize() + elif self.config.transport == "http": + await self._setup_http_transport() + elif self.config.transport == "both": + # For "both", we'll run HTTP server with stdio fallback + await self._setup_http_transport() + + logger.info( + f"MCP server initialized for dataset: {self.dataset_path} with {self.config.transport} transport" + ) + + async def _setup_http_transport(self): + """Set up HTTP transport and server.""" + from contextframe.mcp.transports.http import create_http_server + from contextframe.mcp.transports.http.config import HTTPTransportConfig + + # Create HTTP config from MCP config + http_config = HTTPTransportConfig( + host=self.config.http_host, + port=self.config.http_port, + cors_origins=self.config.http_cors_origins or ["*"], + auth_enabled=self.config.http_auth_enabled, + rate_limit_requests_per_minute=self.config.http_rate_limit.get( + "requests_per_minute", 60 + ) + if self.config.http_rate_limit + else 60, + rate_limit_burst=self.config.http_rate_limit.get("burst", 10) + if self.config.http_rate_limit + else 10, + ssl_cert=self.config.http_ssl_cert, + ssl_key=self.config.http_ssl_key, + ssl_enabled=bool(self.config.http_ssl_cert and self.config.http_ssl_key), + ) + + self.http_server = await create_http_server( + self.dataset_path, + config=http_config, + ) + + # For compatibility, set transport to the HTTP adapter + self.transport = self.http_server.adapter + self.handler = self.http_server.handler self.tools = ToolRegistry(self.dataset, self.transport) self.resources = ResourceRegistry(self.dataset) - - # Initialize transport - await self.transport.initialize() - - logger.info(f"MCP server initialized for dataset: {self.dataset_path}") async def run(self): """Main server loop.""" - if not self.transport: + if not self.transport and not hasattr(self, 'http_server'): await self.setup() - + # Set up signal handlers loop = asyncio.get_event_loop() for sig in (signal.SIGTERM, signal.SIGINT): - loop.add_signal_handler( - sig, - lambda: asyncio.create_task(self.shutdown()) + loop.add_signal_handler(sig, lambda: asyncio.create_task(self.shutdown())) + + if self.config.transport == "stdio": + logger.info( + "MCP server running with stdio transport, waiting for messages..." + ) + await self._run_stdio() + elif self.config.transport == "http" or self.config.transport == "both": + logger.info( + f"MCP server running with HTTP transport on {self.config.http_host}:{self.config.http_port}" ) - - logger.info("MCP server running, waiting for messages...") - + await self._run_http() + + async def _run_stdio(self): + """Run stdio transport loop.""" try: # Process messages while not self._shutdown_requested: message = await self.transport.receive_message() if message is None: break - + try: response = await self.handler.handle(message) if response: # Don't send response for notifications await self.transport.send_message(response) - except Exception as e: + except Exception: logger.exception("Error handling message") # Error response already sent by handler - + except KeyboardInterrupt: logger.info("Keyboard interrupt received") - except Exception as e: + except Exception: logger.exception("Server error") raise finally: await self.cleanup() + async def _run_http(self): + """Run HTTP server.""" + import uvicorn + from contextframe.mcp.transports.http.security import SecurityConfig + + try: + # Get SSL config if enabled + ssl_config = {} + if self.config.http_ssl_cert and self.config.http_ssl_key: + ssl_config = { + "ssl_keyfile": self.config.http_ssl_key, + "ssl_certfile": self.config.http_ssl_cert, + } + + # Run uvicorn server + config = uvicorn.Config( + app=self.http_server.app, + host=self.config.http_host, + port=self.config.http_port, + log_level="info", + **ssl_config, + ) + server = uvicorn.Server(config) + + # Run server with proper shutdown handling + await server.serve() + + except KeyboardInterrupt: + logger.info("Keyboard interrupt received") + except Exception: + logger.exception("HTTP server error") + raise + finally: + await self.cleanup() + async def shutdown(self): """Graceful shutdown.""" logger.info("Shutdown requested") self._shutdown_requested = True - + # Give ongoing operations time to complete await asyncio.sleep(0.1) async def cleanup(self): """Clean up resources.""" logger.info("Cleaning up server resources") - + if self.transport: await self.transport.shutdown() - + # Dataset cleanup if needed if self.dataset: # FrameDataset doesn't require explicit cleanup pass - + logger.info("Server cleanup complete") @classmethod - async def start(cls, dataset_path: str, config: Optional[MCPConfig] = None): + async def start(cls, dataset_path: str, config: MCPConfig | None = None): """Convenience method to start server.""" server = cls(dataset_path, config) await server.run() @@ -151,41 +241,67 @@ async def start(cls, dataset_path: str, config: Optional[MCPConfig] = None): # Entry point for running as module async def main(): """Main entry point when running as module.""" - import sys import argparse - - parser = argparse.ArgumentParser( - description="ContextFrame MCP Server" - ) - parser.add_argument( - "dataset", - help="Path to Lance dataset" - ) + import sys + + parser = argparse.ArgumentParser(description="ContextFrame MCP Server") + parser.add_argument("dataset", help="Path to Lance dataset") parser.add_argument( "--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], - help="Logging level" + help="Logging level", + ) + parser.add_argument( + "--transport", + default="stdio", + choices=["stdio", "http", "both"], + help="Transport type (default: stdio)", + ) + parser.add_argument( + "--host", default="0.0.0.0", help="HTTP server host (default: 0.0.0.0)" + ) + parser.add_argument( + "--port", type=int, default=8080, help="HTTP server port (default: 8080)" + ) + parser.add_argument( + "--cors-origins", nargs="*", help="CORS allowed origins (default: *)" ) - + parser.add_argument( + "--auth", action="store_true", help="Enable OAuth 2.1 authentication" + ) + parser.add_argument("--ssl-cert", help="Path to SSL certificate file") + parser.add_argument("--ssl-key", help="Path to SSL key file") + args = parser.parse_args() - + # Configure logging logging.basicConfig( level=getattr(logging, args.log_level), format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + handlers=[logging.StreamHandler()], ) - + # Reduce noise from other loggers logging.getLogger("contextframe.frame").setLevel(logging.WARNING) - + + # Create configuration + config = MCPConfig( + transport=args.transport, + http_host=args.host, + http_port=args.port, + http_cors_origins=args.cors_origins, + http_auth_enabled=args.auth, + http_ssl_cert=args.ssl_cert, + http_ssl_key=args.ssl_key, + ) + try: - await ContextFrameMCPServer.start(args.dataset) + await ContextFrameMCPServer.start(args.dataset, config) except Exception as e: logger.error(f"Server failed: {e}") sys.exit(1) if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/contextframe/mcp/subscriptions/__init__.py b/contextframe/mcp/subscriptions/__init__.py index e25dfed..8597c8c 100644 --- a/contextframe/mcp/subscriptions/__init__.py +++ b/contextframe/mcp/subscriptions/__init__.py @@ -1,17 +1,12 @@ """Subscription system for monitoring dataset changes.""" from .manager import SubscriptionManager -from .tools import ( - subscribe_changes, - poll_changes, - unsubscribe, - get_subscriptions -) +from .tools import get_subscriptions, poll_changes, subscribe_changes, unsubscribe __all__ = [ "SubscriptionManager", "subscribe_changes", "poll_changes", "unsubscribe", - "get_subscriptions" -] \ No newline at end of file + "get_subscriptions", +] diff --git a/contextframe/mcp/subscriptions/manager.py b/contextframe/mcp/subscriptions/manager.py index f26ab2c..e18d212 100644 --- a/contextframe/mcp/subscriptions/manager.py +++ b/contextframe/mcp/subscriptions/manager.py @@ -1,68 +1,67 @@ """Subscription manager for tracking dataset changes.""" import asyncio +from contextframe import FrameDataset from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import UTC, datetime, timezone from typing import Any, Dict, List, Optional, Set from uuid import uuid4 -from contextframe import FrameDataset - @dataclass class SubscriptionState: """State tracking for a subscription.""" - + id: str resource_type: str # "documents", "collections", "all" - filters: Dict[str, Any] + filters: dict[str, Any] created_at: datetime last_version: int last_poll_token: str - last_poll_time: Optional[datetime] = None - change_buffer: List["Change"] = field(default_factory=list) - options: Dict[str, Any] = field(default_factory=dict) + last_poll_time: datetime | None = None + change_buffer: list["Change"] = field(default_factory=list) + options: dict[str, Any] = field(default_factory=dict) is_active: bool = True @dataclass class Change: """Represents a change in the dataset.""" - + type: str # "created", "updated", "deleted" resource_type: str # "document", "collection" resource_id: str version: int timestamp: datetime - old_data: Optional[Dict[str, Any]] = None - new_data: Optional[Dict[str, Any]] = None + old_data: dict[str, Any] | None = None + new_data: dict[str, Any] | None = None class SubscriptionManager: """Manages subscriptions for dataset change monitoring.""" - + def __init__(self, dataset: FrameDataset): """Initialize subscription manager. - + Args: dataset: The FrameDataset to monitor """ self.dataset = dataset - self.subscriptions: Dict[str, SubscriptionState] = {} - self._polling_task: Optional[asyncio.Task] = None + self.subscriptions: dict[str, SubscriptionState] = {} + self._polling_task: asyncio.Task | None = None self._change_queue: asyncio.Queue = asyncio.Queue() - self._last_check_version: Optional[int] = None + self._last_check_version: int | None = None self._running = False - + async def start(self): """Start the subscription manager polling.""" if self._running: return - + self._running = True self._last_check_version = self.dataset.version self._polling_task = asyncio.create_task(self._poll_changes()) - + async def stop(self): """Stop the subscription manager.""" self._running = False @@ -72,61 +71,55 @@ async def stop(self): await self._polling_task except asyncio.CancelledError: pass - + async def create_subscription( self, resource_type: str, - filters: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None + filters: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, ) -> str: """Create a new subscription. - + Args: resource_type: Type of resources to monitor filters: Optional filters for the subscription options: Subscription options (polling_interval, include_data, etc.) - + Returns: Subscription ID """ subscription_id = str(uuid4()) poll_token = f"{subscription_id}:0" - + subscription = SubscriptionState( id=subscription_id, resource_type=resource_type, filters=filters or {}, - created_at=datetime.now(timezone.utc), + created_at=datetime.now(UTC), last_version=self.dataset.version, last_poll_token=poll_token, - options=options or { - "polling_interval": 5, - "include_data": False, - "batch_size": 100 - } + options=options + or {"polling_interval": 5, "include_data": False, "batch_size": 100}, ) - + self.subscriptions[subscription_id] = subscription - + # Ensure polling is running if not self._running: await self.start() - + return subscription_id - + async def poll_subscription( - self, - subscription_id: str, - poll_token: Optional[str] = None, - timeout: int = 30 - ) -> Dict[str, Any]: + self, subscription_id: str, poll_token: str | None = None, timeout: int = 30 + ) -> dict[str, Any]: """Poll for changes in a subscription. - + Args: subscription_id: The subscription to poll poll_token: Token from last poll (for ordering) timeout: Max seconds to wait for changes - + Returns: Dict with changes, new poll token, and status """ @@ -135,50 +128,49 @@ async def poll_subscription( "changes": [], "poll_token": None, "has_more": False, - "subscription_active": False + "subscription_active": False, } - + subscription = self.subscriptions[subscription_id] - + if not subscription.is_active: return { "changes": [], "poll_token": subscription.last_poll_token, "has_more": False, - "subscription_active": False + "subscription_active": False, } - + # Update last poll time - subscription.last_poll_time = datetime.now(timezone.utc) - + subscription.last_poll_time = datetime.now(UTC) + # Check for buffered changes changes = [] if subscription.change_buffer: batch_size = subscription.options.get("batch_size", 100) changes = subscription.change_buffer[:batch_size] subscription.change_buffer = subscription.change_buffer[batch_size:] - + # If no buffered changes, wait for new ones (with timeout) if not changes and timeout > 0: try: # Wait for changes with timeout await asyncio.wait_for( - self._wait_for_changes(subscription_id), - timeout=timeout + self._wait_for_changes(subscription_id), timeout=timeout ) # Check buffer again if subscription.change_buffer: batch_size = subscription.options.get("batch_size", 100) changes = subscription.change_buffer[:batch_size] subscription.change_buffer = subscription.change_buffer[batch_size:] - except asyncio.TimeoutError: + except TimeoutError: pass # No changes within timeout - + # Update poll token new_version = changes[-1].version if changes else subscription.last_version new_poll_token = f"{subscription_id}:{new_version}" subscription.last_poll_token = new_poll_token - + # Convert changes to dict format change_dicts = [] for change in changes: @@ -187,31 +179,31 @@ async def poll_subscription( "resource_type": change.resource_type, "resource_id": change.resource_id, "version": change.version, - "timestamp": change.timestamp.isoformat() + "timestamp": change.timestamp.isoformat(), } - + # Include data if requested if subscription.options.get("include_data", False): if change.old_data: change_dict["old_data"] = change.old_data if change.new_data: change_dict["new_data"] = change.new_data - + change_dicts.append(change_dict) - + return { "changes": change_dicts, "poll_token": new_poll_token, "has_more": len(subscription.change_buffer) > 0, - "subscription_active": subscription.is_active + "subscription_active": subscription.is_active, } - + async def cancel_subscription(self, subscription_id: str) -> bool: """Cancel a subscription. - + Args: subscription_id: The subscription to cancel - + Returns: Whether the subscription was cancelled """ @@ -220,242 +212,250 @@ async def cancel_subscription(self, subscription_id: str) -> bool: # Keep subscription for final poll return True return False - + def get_subscriptions( - self, - resource_type: Optional[str] = None - ) -> List[Dict[str, Any]]: + self, resource_type: str | None = None + ) -> list[dict[str, Any]]: """Get list of active subscriptions. - + Args: resource_type: Optional filter by resource type - + Returns: List of subscription info """ subscriptions = [] - + for sub in self.subscriptions.values(): if not sub.is_active: continue - + if resource_type and sub.resource_type != resource_type: continue - - subscriptions.append({ - "subscription_id": sub.id, - "resource_type": sub.resource_type, - "filters": sub.filters, - "created_at": sub.created_at.isoformat(), - "last_poll": sub.last_poll_time.isoformat() if sub.last_poll_time else None, - "options": sub.options - }) - + + subscriptions.append( + { + "subscription_id": sub.id, + "resource_type": sub.resource_type, + "filters": sub.filters, + "created_at": sub.created_at.isoformat(), + "last_poll": sub.last_poll_time.isoformat() + if sub.last_poll_time + else None, + "options": sub.options, + } + ) + return subscriptions - + async def _poll_changes(self): """Background task to poll for dataset changes.""" while self._running: try: # Check current version current_version = self.dataset.version - - if self._last_check_version and current_version > self._last_check_version: + + if ( + self._last_check_version + and current_version > self._last_check_version + ): # Detect changes between versions changes = await self._detect_changes( - self._last_check_version, - current_version + self._last_check_version, current_version ) - + # Distribute changes to subscriptions for change in changes: await self._distribute_change(change) - + self._last_check_version = current_version - + # Sleep based on minimum polling interval min_interval = min( - (sub.options.get("polling_interval", 5) - for sub in self.subscriptions.values() - if sub.is_active), - default=5 + ( + sub.options.get("polling_interval", 5) + for sub in self.subscriptions.values() + if sub.is_active + ), + default=5, ) await asyncio.sleep(min_interval) - + except Exception as e: # Log error but keep polling print(f"Error in subscription polling: {e}") await asyncio.sleep(5) - - async def _detect_changes( - self, - old_version: int, - new_version: int - ) -> List[Change]: + + async def _detect_changes(self, old_version: int, new_version: int) -> list[Change]: """Detect changes between dataset versions. - + Args: old_version: Previous version number new_version: Current version number - + Returns: List of detected changes """ changes = [] - timestamp = datetime.now(timezone.utc) - + timestamp = datetime.now(UTC) + # Get all UUIDs from both versions old_uuids = await self._get_version_uuids(old_version) new_uuids = await self._get_version_uuids(new_version) - + # Detect created documents created = new_uuids - old_uuids for uuid in created: - changes.append(Change( - type="created", - resource_type="document", - resource_id=uuid, - version=new_version, - timestamp=timestamp - )) - + changes.append( + Change( + type="created", + resource_type="document", + resource_id=uuid, + version=new_version, + timestamp=timestamp, + ) + ) + # Detect deleted documents deleted = old_uuids - new_uuids for uuid in deleted: - changes.append(Change( - type="deleted", - resource_type="document", - resource_id=uuid, - version=new_version, - timestamp=timestamp - )) - + changes.append( + Change( + type="deleted", + resource_type="document", + resource_id=uuid, + version=new_version, + timestamp=timestamp, + ) + ) + # Detect updated documents (same UUID, different content/metadata) common = old_uuids & new_uuids for uuid in common: if await self._has_changed(uuid, old_version, new_version): - changes.append(Change( - type="updated", - resource_type="document", - resource_id=uuid, - version=new_version, - timestamp=timestamp - )) - + changes.append( + Change( + type="updated", + resource_type="document", + resource_id=uuid, + version=new_version, + timestamp=timestamp, + ) + ) + return changes - - async def _get_version_uuids(self, version: int) -> Set[str]: + + async def _get_version_uuids(self, version: int) -> set[str]: """Get all document UUIDs from a specific version. - + Args: version: Version number - + Returns: Set of UUIDs """ # Use Lance's checkout_version capability versioned_dataset = self.dataset.checkout_version(version) - + # Get all UUIDs scanner = versioned_dataset.scanner(columns=["uuid"]) uuids = set() - + for batch in scanner.to_batches(): for uuid in batch["uuid"]: if uuid: uuids.add(str(uuid)) - + return uuids - - async def _has_changed( - self, - uuid: str, - old_version: int, - new_version: int - ) -> bool: + + async def _has_changed(self, uuid: str, old_version: int, new_version: int) -> bool: """Check if a document has changed between versions. - + Args: uuid: Document UUID old_version: Previous version new_version: Current version - + Returns: Whether the document changed """ # Get document from both versions old_dataset = self.dataset.checkout_version(old_version) new_dataset = self.dataset.checkout_version(new_version) - + # Compare timestamps old_record = old_dataset.search(filter=f"uuid = '{uuid}'", limit=1) new_record = new_dataset.search(filter=f"uuid = '{uuid}'", limit=1) - + if not old_record or not new_record: return True # Something changed if we can't find it - + old_record = old_record[0] new_record = new_record[0] - + # Compare updated_at timestamps old_updated = old_record.metadata.get("updated_at", "") new_updated = new_record.metadata.get("updated_at", "") - + return old_updated != new_updated - + async def _distribute_change(self, change: Change): """Distribute a change to relevant subscriptions. - + Args: change: The change to distribute """ for subscription in self.subscriptions.values(): if not subscription.is_active: continue - + # Check if change matches subscription if not self._matches_subscription(change, subscription): continue - + # Add to buffer subscription.change_buffer.append(change) - + # Notify waiting pollers self._change_queue.put_nowait(subscription.id) - + def _matches_subscription( - self, - change: Change, - subscription: SubscriptionState + self, change: Change, subscription: SubscriptionState ) -> bool: """Check if a change matches a subscription's filters. - + Args: change: The change to check subscription: The subscription to match against - + Returns: Whether the change matches """ # Check resource type if subscription.resource_type != "all": - if subscription.resource_type == "documents" and change.resource_type != "document": + if ( + subscription.resource_type == "documents" + and change.resource_type != "document" + ): return False - if subscription.resource_type == "collections" and change.resource_type != "collection": + if ( + subscription.resource_type == "collections" + and change.resource_type != "collection" + ): return False - + # TODO: Apply additional filters from subscription.filters # For now, match all changes of the correct type - + return True - + async def _wait_for_changes(self, subscription_id: str): """Wait for changes to arrive for a subscription. - + Args: subscription_id: The subscription to wait for """ while True: sub_id = await self._change_queue.get() if sub_id == subscription_id: - return \ No newline at end of file + return diff --git a/contextframe/mcp/subscriptions/tools.py b/contextframe/mcp/subscriptions/tools.py index 6852db2..a20f07d 100644 --- a/contextframe/mcp/subscriptions/tools.py +++ b/contextframe/mcp/subscriptions/tools.py @@ -1,193 +1,179 @@ """MCP tools for subscription management.""" -from typing import Any, Dict, Optional - +from .manager import SubscriptionManager from contextframe import FrameDataset from contextframe.mcp.errors import InvalidParams from contextframe.mcp.schemas import ( - SubscribeChangesParams, - PollChangesParams, - UnsubscribeParams, GetSubscriptionsParams, - SubscribeResult, + GetSubscriptionsResult, + PollChangesParams, PollResult, + SubscribeChangesParams, + SubscribeResult, + UnsubscribeParams, UnsubscribeResult, - GetSubscriptionsResult ) - -from .manager import SubscriptionManager - +from typing import Any, Dict, Optional # Global subscription managers per dataset -_managers: Dict[str, SubscriptionManager] = {} +_managers: dict[str, SubscriptionManager] = {} def _get_or_create_manager(dataset: FrameDataset) -> SubscriptionManager: """Get or create a subscription manager for a dataset. - + Args: dataset: The dataset to manage - + Returns: The subscription manager """ dataset_id = id(dataset) # Use object ID as key - + if dataset_id not in _managers: _managers[dataset_id] = SubscriptionManager(dataset) - + return _managers[dataset_id] async def subscribe_changes( - params: SubscribeChangesParams, - dataset: FrameDataset, - **kwargs -) -> Dict[str, Any]: + params: SubscribeChangesParams, dataset: FrameDataset, **kwargs +) -> dict[str, Any]: """Create a subscription to monitor dataset changes. - + Creates a subscription that allows clients to watch for changes in the dataset. - Since Lance doesn't have built-in change notifications, this implements a + Since Lance doesn't have built-in change notifications, this implements a polling-based system that efficiently detects changes between versions. - + Args: params: Subscription parameters dataset: The dataset to monitor - + Returns: Subscription information including ID and polling details """ try: # Get or create manager manager = _get_or_create_manager(dataset) - + # Create subscription subscription_id = await manager.create_subscription( resource_type=params.resource_type, filters=params.filters, - options=params.options + options=params.options, ) - + # Generate initial poll token poll_token = f"{subscription_id}:0" - + result = SubscribeResult( subscription_id=subscription_id, poll_token=poll_token, - polling_interval=params.options.get("polling_interval", 5) + polling_interval=params.options.get("polling_interval", 5), ) - + return result.model_dump() - + except Exception as e: raise InvalidParams(f"Failed to create subscription: {str(e)}") async def poll_changes( - params: PollChangesParams, - dataset: FrameDataset, - **kwargs -) -> Dict[str, Any]: + params: PollChangesParams, dataset: FrameDataset, **kwargs +) -> dict[str, Any]: """Poll for changes since the last poll. - + This tool implements long polling for change detection. It will wait up to the specified timeout for changes to occur, returning immediately if changes are available. - + Args: params: Poll parameters dataset: The dataset being monitored - + Returns: Changes since last poll, new poll token, and subscription status """ try: # Get manager manager = _get_or_create_manager(dataset) - + # Poll for changes poll_result = await manager.poll_subscription( subscription_id=params.subscription_id, poll_token=params.poll_token, - timeout=params.timeout + timeout=params.timeout, ) - + result = PollResult(**poll_result) - + return result.model_dump() - + except Exception as e: raise InvalidParams(f"Failed to poll changes: {str(e)}") async def unsubscribe( - params: UnsubscribeParams, - dataset: FrameDataset, - **kwargs -) -> Dict[str, Any]: + params: UnsubscribeParams, dataset: FrameDataset, **kwargs +) -> dict[str, Any]: """Cancel an active subscription. - + Cancels a subscription and stops monitoring for changes. The subscription can still be polled one final time to retrieve any remaining buffered changes. - + Args: params: Unsubscribe parameters dataset: The dataset being monitored - + Returns: Cancellation status and final poll token """ try: # Get manager manager = _get_or_create_manager(dataset) - + # Cancel subscription cancelled = await manager.cancel_subscription(params.subscription_id) - + result = UnsubscribeResult( cancelled=cancelled, - final_poll_token=f"{params.subscription_id}:final" if cancelled else None + final_poll_token=f"{params.subscription_id}:final" if cancelled else None, ) - + return result.model_dump() - + except Exception as e: raise InvalidParams(f"Failed to unsubscribe: {str(e)}") async def get_subscriptions( - params: GetSubscriptionsParams, - dataset: FrameDataset, - **kwargs -) -> Dict[str, Any]: + params: GetSubscriptionsParams, dataset: FrameDataset, **kwargs +) -> dict[str, Any]: """Get list of active subscriptions. - + Returns information about all active subscriptions, optionally filtered by resource type. - + Args: params: Query parameters dataset: The dataset being monitored - + Returns: List of active subscriptions with details """ try: # Get manager manager = _get_or_create_manager(dataset) - + # Get subscriptions - subscriptions = manager.get_subscriptions( - resource_type=params.resource_type - ) - + subscriptions = manager.get_subscriptions(resource_type=params.resource_type) + result = GetSubscriptionsResult( - subscriptions=subscriptions, - total_count=len(subscriptions) + subscriptions=subscriptions, total_count=len(subscriptions) ) - + return result.model_dump() - + except Exception as e: raise InvalidParams(f"Failed to get subscriptions: {str(e)}") @@ -204,11 +190,11 @@ async def get_subscriptions( "type": "string", "enum": ["documents", "collections", "all"], "default": "all", - "description": "Type of resources to monitor" + "description": "Type of resources to monitor", }, "filters": { "type": "object", - "description": "Optional filters (e.g., {'collection_id': '...'})" + "description": "Optional filters (e.g., {'collection_id': '...'})", }, "options": { "type": "object", @@ -216,22 +202,22 @@ async def get_subscriptions( "polling_interval": { "type": "integer", "default": 5, - "description": "Seconds between polls" + "description": "Seconds between polls", }, "include_data": { "type": "boolean", "default": False, - "description": "Include full document data in changes" + "description": "Include full document data in changes", }, "batch_size": { "type": "integer", "default": 100, - "description": "Max changes per poll response" - } - } - } - } - } + "description": "Max changes per poll response", + }, + }, + }, + }, + }, }, { "name": "poll_changes", @@ -242,21 +228,21 @@ async def get_subscriptions( "properties": { "subscription_id": { "type": "string", - "description": "Active subscription ID" + "description": "Active subscription ID", }, "poll_token": { "type": "string", - "description": "Token from last poll (optional for first poll)" + "description": "Token from last poll (optional for first poll)", }, "timeout": { "type": "integer", "default": 30, "minimum": 0, "maximum": 300, - "description": "Max seconds to wait for changes (long polling)" - } - } - } + "description": "Max seconds to wait for changes (long polling)", + }, + }, + }, }, { "name": "unsubscribe", @@ -267,10 +253,10 @@ async def get_subscriptions( "properties": { "subscription_id": { "type": "string", - "description": "Subscription to cancel" + "description": "Subscription to cancel", } - } - } + }, + }, }, { "name": "get_subscriptions", @@ -281,9 +267,9 @@ async def get_subscriptions( "resource_type": { "type": "string", "enum": ["documents", "collections", "all"], - "description": "Filter by resource type (optional)" + "description": "Filter by resource type (optional)", } - } - } - } -] \ No newline at end of file + }, + }, + }, +] diff --git a/contextframe/mcp/tools.py b/contextframe/mcp/tools.py index c80cd40..fac872e 100644 --- a/contextframe/mcp/tools.py +++ b/contextframe/mcp/tools.py @@ -1,35 +1,34 @@ """Tool registry and implementations for MCP server.""" -import os import logging -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional import numpy as np -from pydantic import ValidationError - -from contextframe.frame import FrameDataset, FrameRecord +import os +from collections.abc import Callable from contextframe.embed import LiteLLMProvider +from contextframe.frame import FrameDataset, FrameRecord from contextframe.mcp.errors import ( - MCPError, - InvalidParams, - InvalidSearchType, DocumentNotFound, EmbeddingError, - FilterError + FilterError, + InvalidParams, + InvalidSearchType, + MCPError, ) from contextframe.mcp.schemas import ( - Tool, - SearchDocumentsParams, AddDocumentParams, - GetDocumentParams, - ListDocumentsParams, - UpdateDocumentParams, DeleteDocumentParams, DocumentResult, + GetDocumentParams, + ListDocumentsParams, + ListResult, + SearchDocumentsParams, SearchResult, - ListResult + Tool, + UpdateDocumentParams, ) - +from pathlib import Path +from pydantic import ValidationError +from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) @@ -37,91 +36,136 @@ class ToolRegistry: """Registry for MCP tools.""" - def __init__(self, dataset: FrameDataset, transport: Optional[Any] = None): + def __init__(self, dataset: FrameDataset, transport: Any | None = None): self.dataset = dataset self.transport = transport - self._tools: Dict[str, Tool] = {} - self._handlers: Dict[str, Callable] = {} - + self._tools: dict[str, Tool] = {} + self._handlers: dict[str, Callable] = {} + # Create document tools instance self._doc_tools = self # For now, self contains the document tools - + self._register_default_tools() - + # Register enhancement and extraction tools if available try: from contextframe.mcp.enhancement_tools import ( register_enhancement_tools, - register_extraction_tools + register_extraction_tools, ) + register_enhancement_tools(self, dataset) register_extraction_tools(self, dataset) except ImportError: logger.warning("Enhancement tools not available") - + # Register batch tools if transport is available if transport: try: from contextframe.mcp.batch.tools import BatchTools + batch_tools = BatchTools(dataset, transport, self._doc_tools) batch_tools.register_tools(self) except ImportError: logger.warning("Batch tools not available") - + # Register collection tools try: - from contextframe.mcp.collections.tools import CollectionTools from contextframe.mcp.collections.templates import TemplateRegistry + from contextframe.mcp.collections.tools import CollectionTools + template_registry = TemplateRegistry() - collection_tools = CollectionTools(dataset, transport, template_registry) + collection_tools = CollectionTools( + dataset, transport, template_registry + ) collection_tools.register_tools(self) except ImportError: logger.warning("Collection tools not available") - + # Register subscription tools try: - from contextframe.mcp.subscriptions.tools import ( - subscribe_changes, - poll_changes, - unsubscribe, - get_subscriptions, - SUBSCRIPTION_TOOLS - ) from contextframe.mcp.schemas import ( - SubscribeChangesParams, + GetSubscriptionsParams, PollChangesParams, + SubscribeChangesParams, UnsubscribeParams, - GetSubscriptionsParams ) - + from contextframe.mcp.subscriptions.tools import ( + SUBSCRIPTION_TOOLS, + get_subscriptions, + poll_changes, + subscribe_changes, + unsubscribe, + ) + # Register each subscription tool self.register_tool( "subscribe_changes", subscribe_changes, SubscribeChangesParams, - "Create a subscription to monitor dataset changes" + "Create a subscription to monitor dataset changes", ) self.register_tool( "poll_changes", poll_changes, PollChangesParams, - "Poll for changes since the last poll" + "Poll for changes since the last poll", ) self.register_tool( "unsubscribe", unsubscribe, UnsubscribeParams, - "Cancel an active subscription" + "Cancel an active subscription", ) self.register_tool( "get_subscriptions", get_subscriptions, GetSubscriptionsParams, - "Get list of active subscriptions" + "Get list of active subscriptions", ) except ImportError: logger.warning("Subscription tools not available") + # Register analytics tools + try: + from contextframe.mcp.analytics.tools import ( + AnalyzeUsageHandler, + BenchmarkOperationsHandler, + ExportMetricsHandler, + GetDatasetStatsHandler, + IndexRecommendationsHandler, + OptimizeStorageHandler, + QueryPerformanceHandler, + RelationshipAnalysisHandler, + ) + + # Analytics tools + analytics_handlers = [ + GetDatasetStatsHandler(dataset), + AnalyzeUsageHandler(dataset), + QueryPerformanceHandler(dataset), + RelationshipAnalysisHandler(dataset), + OptimizeStorageHandler(dataset), + IndexRecommendationsHandler(dataset), + BenchmarkOperationsHandler(dataset), + ExportMetricsHandler(dataset), + ] + + for handler in analytics_handlers: + self.register( + handler.name, + Tool( + name=handler.name, + description=handler.description, + inputSchema=handler.get_input_schema(), + ), + lambda params, h=handler: h.execute(**params), + ) + + logger.info(f"Registered {len(analytics_handlers)} analytics tools") + except ImportError: + logger.warning("Analytics tools not available") + def _register_default_tools(self): """Register the default set of tools.""" # Search documents tool @@ -133,32 +177,29 @@ def _register_default_tools(self): inputSchema={ "type": "object", "properties": { - "query": { - "type": "string", - "description": "Search query" - }, + "query": {"type": "string", "description": "Search query"}, "search_type": { "type": "string", "enum": ["vector", "text", "hybrid"], "default": "hybrid", - "description": "Type of search to perform" + "description": "Type of search to perform", }, "limit": { "type": "integer", "minimum": 1, "maximum": 1000, "default": 10, - "description": "Maximum number of results" + "description": "Maximum number of results", }, "filter": { "type": "string", - "description": "SQL filter expression" - } + "description": "SQL filter expression", + }, }, - "required": ["query"] - } + "required": ["query"], + }, ), - self._search_documents + self._search_documents, ) # Add document tool @@ -172,38 +213,38 @@ def _register_default_tools(self): "properties": { "content": { "type": "string", - "description": "Document content" + "description": "Document content", }, "metadata": { "type": "object", - "description": "Document metadata" + "description": "Document metadata", }, "generate_embedding": { "type": "boolean", "default": True, - "description": "Whether to generate embeddings" + "description": "Whether to generate embeddings", }, "collection": { "type": "string", - "description": "Collection to add document to" + "description": "Collection to add document to", }, "chunk_size": { "type": "integer", "minimum": 100, "maximum": 10000, - "description": "Size of chunks for large documents" + "description": "Size of chunks for large documents", }, "chunk_overlap": { "type": "integer", "minimum": 0, "maximum": 1000, - "description": "Overlap between chunks" - } + "description": "Overlap between chunks", + }, }, - "required": ["content"] - } + "required": ["content"], + }, ), - self._add_document + self._add_document, ) # Get document tool @@ -217,28 +258,28 @@ def _register_default_tools(self): "properties": { "document_id": { "type": "string", - "description": "Document UUID" + "description": "Document UUID", }, "include_content": { "type": "boolean", "default": True, - "description": "Include document content" + "description": "Include document content", }, "include_metadata": { "type": "boolean", "default": True, - "description": "Include document metadata" + "description": "Include document metadata", }, "include_embeddings": { "type": "boolean", "default": False, - "description": "Include embeddings" - } + "description": "Include embeddings", + }, }, - "required": ["document_id"] - } + "required": ["document_id"], + }, ), - self._get_document + self._get_document, ) # List documents tool @@ -255,31 +296,31 @@ def _register_default_tools(self): "minimum": 1, "maximum": 1000, "default": 100, - "description": "Maximum number of results" + "description": "Maximum number of results", }, "offset": { "type": "integer", "minimum": 0, "default": 0, - "description": "Number of results to skip" + "description": "Number of results to skip", }, "filter": { "type": "string", - "description": "SQL filter expression" + "description": "SQL filter expression", }, "order_by": { "type": "string", - "description": "Order by expression" + "description": "Order by expression", }, "include_content": { "type": "boolean", "default": False, - "description": "Include document content" - } - } - } + "description": "Include document content", + }, + }, + }, ), - self._list_documents + self._list_documents, ) # Update document tool @@ -293,26 +334,26 @@ def _register_default_tools(self): "properties": { "document_id": { "type": "string", - "description": "Document UUID" + "description": "Document UUID", }, "content": { "type": "string", - "description": "New document content" + "description": "New document content", }, "metadata": { "type": "object", - "description": "New or updated metadata" + "description": "New or updated metadata", }, "regenerate_embedding": { "type": "boolean", "default": False, - "description": "Regenerate embeddings if content changed" - } + "description": "Regenerate embeddings if content changed", + }, }, - "required": ["document_id"] - } + "required": ["document_id"], + }, ), - self._update_document + self._update_document, ) # Delete document tool @@ -326,29 +367,29 @@ def _register_default_tools(self): "properties": { "document_id": { "type": "string", - "description": "Document UUID" + "description": "Document UUID", } }, - "required": ["document_id"] - } + "required": ["document_id"], + }, ), - self._delete_document + self._delete_document, ) def register(self, name: str, tool: Tool, handler: Callable): """Register a new tool.""" self._tools[name] = tool self._handlers[name] = handler - + def register_tool( self, name: str, handler: Callable, - schema: Optional[Any] = None, - description: Optional[str] = None + schema: Any | None = None, + description: str | None = None, ): """Register a tool with flexible parameters. - + Args: name: Tool name handler: Async callable handler @@ -364,21 +405,21 @@ def register_tool( input_schema = schema else: input_schema = {"type": "object", "properties": {}} - + # Create tool tool = Tool( name=name, description=description or f"{name} tool", - inputSchema=input_schema + inputSchema=input_schema, ) - + self.register(name, tool, handler) - def list_tools(self) -> List[Tool]: + def list_tools(self) -> list[Tool]: """List all registered tools.""" return list(self._tools.values()) - async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + async def call_tool(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]: """Call a tool by name with arguments.""" if name not in self._handlers: raise InvalidParams(f"Unknown tool: {name}") @@ -395,18 +436,18 @@ async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Dict[str, Any except MCPError: # Re-raise other MCP errors as-is raise - except Exception as e: + except Exception: logger.exception(f"Error calling tool {name}") raise # Tool implementations - async def _search_documents(self, arguments: Dict[str, Any]) -> Dict[str, Any]: + async def _search_documents(self, arguments: dict[str, Any]) -> dict[str, Any]: """Implement document search.""" params = SearchDocumentsParams(**arguments) - + results = [] search_type_used = params.search_type - + try: if params.search_type == "vector": results = await self._vector_search( @@ -429,7 +470,7 @@ async def _search_documents(self, arguments: Dict[str, Any]) -> Dict[str, Any]: params.query, params.limit, params.filter ) search_type_used = "text" - + except Exception as e: if "filter" in str(e).lower(): raise FilterError(str(e), params.filter or "") @@ -442,67 +483,60 @@ async def _search_documents(self, arguments: Dict[str, Any]) -> Dict[str, Any]: uuid=record.uuid, content=record.text_content, metadata=record.metadata, - score=getattr(record, '_score', None) + score=getattr(record, '_score', None), ) documents.append(doc) return SearchResult( documents=documents, total_count=len(documents), - search_type_used=search_type_used + search_type_used=search_type_used, ).model_dump() async def _vector_search( - self, query: str, limit: int, filter_expr: Optional[str] - ) -> List[FrameRecord]: + self, query: str, limit: int, filter_expr: str | None + ) -> list[FrameRecord]: """Perform vector search with embedding generation.""" # Get embedding model configuration model = os.environ.get("CONTEXTFRAME_EMBED_MODEL", "text-embedding-ada-002") api_key = os.environ.get("OPENAI_API_KEY") - + if not api_key: raise EmbeddingError( "No API key found. Set OPENAI_API_KEY environment variable.", - {"model": model} + {"model": model}, ) - + try: # Generate query embedding provider = LiteLLMProvider(model, api_key=api_key) result = provider.embed(query) query_vector = np.array(result.embeddings[0], dtype=np.float32) - + # Perform KNN search return self.dataset.knn_search( - query_vector=query_vector, - k=limit, - filter=filter_expr + query_vector=query_vector, k=limit, filter=filter_expr ) except Exception as e: raise EmbeddingError(str(e), {"model": model}) async def _text_search( - self, query: str, limit: int, filter_expr: Optional[str] - ) -> List[FrameRecord]: + self, query: str, limit: int, filter_expr: str | None + ) -> list[FrameRecord]: """Perform text search with optional filtering.""" # If no filter, use the simpler full_text_search if not filter_expr: return self.dataset.full_text_search(query, k=limit) - + # With filter, use scanner with both full_text_query and filter ftq = {"query": query, "columns": ["text_content"]} - scanner_kwargs = { - "full_text_query": ftq, - "filter": filter_expr, - "limit": limit - } - + scanner_kwargs = {"full_text_query": ftq, "filter": filter_expr, "limit": limit} + try: tbl = self.dataset.scanner(**scanner_kwargs).to_table() return [ FrameRecord.from_arrow( - tbl.slice(i, 1), - dataset_path=Path(self.dataset._dataset.uri) + tbl.slice(i, 1), dataset_path=Path(self.dataset._dataset.uri) ) for i in range(tbl.num_rows) ] @@ -511,72 +545,63 @@ async def _text_search( raise FilterError(str(e), filter_expr) raise - async def _add_document(self, arguments: Dict[str, Any]) -> Dict[str, Any]: + async def _add_document(self, arguments: dict[str, Any]) -> dict[str, Any]: """Add a new document.""" params = AddDocumentParams(**arguments) - + # Check if we need to chunk the document if params.chunk_size and len(params.content) > params.chunk_size: chunks = self._chunk_text( - params.content, - params.chunk_size, - params.chunk_overlap or 100 + params.content, params.chunk_size, params.chunk_overlap or 100 ) - + # Add each chunk as a separate document added_docs = [] for i, chunk in enumerate(chunks): chunk_metadata = params.metadata.copy() - chunk_metadata.update({ - "chunk_index": i, - "total_chunks": len(chunks), - "original_length": len(params.content) - }) - + chunk_metadata.update( + { + "chunk_index": i, + "total_chunks": len(chunks), + "original_length": len(params.content), + } + ) + doc = await self._add_single_document( - chunk, - chunk_metadata, - params.generate_embedding, - params.collection + chunk, chunk_metadata, params.generate_embedding, params.collection ) added_docs.append(doc) - - return { - "documents": added_docs, - "total_chunks": len(chunks) - } + + return {"documents": added_docs, "total_chunks": len(chunks)} else: # Add single document doc = await self._add_single_document( params.content, params.metadata, params.generate_embedding, - params.collection + params.collection, ) return {"document": doc} async def _add_single_document( self, content: str, - metadata: Dict[str, Any], + metadata: dict[str, Any], generate_embedding: bool, - collection: Optional[str] - ) -> Dict[str, Any]: + collection: str | None, + ) -> dict[str, Any]: """Add a single document to the dataset.""" # Create record - record = FrameRecord( - text_content=content, - metadata=metadata - ) - + record = FrameRecord(text_content=content, metadata=metadata) + if collection: record.metadata["collection"] = collection - + # Generate embedding if requested if generate_embedding: model = os.environ.get("CONTEXTFRAME_EMBED_MODEL", "text-embedding-ada-002") api_key = os.environ.get("OPENAI_API_KEY") - + if api_key: try: provider = LiteLLMProvider(model, api_key=api_key) @@ -584,134 +609,129 @@ async def _add_single_document( record.vector = np.array(result.embeddings[0], dtype=np.float32) except Exception as e: logger.warning(f"Failed to generate embedding: {e}") - + # Add to dataset self.dataset.add(record) - + return DocumentResult( - uuid=record.uuid, - content=record.text_content, - metadata=record.metadata + uuid=record.uuid, content=record.text_content, metadata=record.metadata ).model_dump() - def _chunk_text(self, text: str, chunk_size: int, chunk_overlap: int) -> List[str]: + def _chunk_text(self, text: str, chunk_size: int, chunk_overlap: int) -> list[str]: """Split text into overlapping chunks.""" chunks = [] start = 0 - + while start < len(text): end = start + chunk_size chunk = text[start:end] - + # Try to break at sentence or paragraph boundary if end < len(text): last_period = chunk.rfind('. ') last_newline = chunk.rfind('\n') boundary = max(last_period, last_newline) if boundary > chunk_size * 0.5: - chunk = text[start:start + boundary + 1] + chunk = text[start : start + boundary + 1] end = start + boundary + 1 - + chunks.append(chunk.strip()) start = end - chunk_overlap - + return [c for c in chunks if c] - async def _get_document(self, arguments: Dict[str, Any]) -> Dict[str, Any]: + async def _get_document(self, arguments: dict[str, Any]) -> dict[str, Any]: """Get a document by ID.""" params = GetDocumentParams(**arguments) - + # Query for the document results = self.dataset.query(f"uuid = '{params.document_id}'", limit=1) - + if not results: raise DocumentNotFound(params.document_id) - + record = results[0] - + # Build response based on requested fields doc = DocumentResult( uuid=record.uuid, - metadata=record.metadata if params.include_metadata else {} + metadata=record.metadata if params.include_metadata else {}, ) - + if params.include_content: doc.content = record.text_content - + if params.include_embeddings and record.vector is not None: doc.embedding = record.vector.tolist() - + return {"document": doc.model_dump()} - async def _list_documents(self, arguments: Dict[str, Any]) -> Dict[str, Any]: + async def _list_documents(self, arguments: dict[str, Any]) -> dict[str, Any]: """List documents with pagination.""" params = ListDocumentsParams(**arguments) - + # Build query if params.filter: try: results = self.dataset.query( - params.filter, - limit=params.limit, - offset=params.offset + params.filter, limit=params.limit, offset=params.offset ) except Exception as e: raise FilterError(str(e), params.filter) else: # No filter, get all documents # Note: This is a simplified approach, ideally we'd have a list method - results = self.dataset.query("1=1", limit=params.limit, offset=params.offset) - + results = self.dataset.query( + "1=1", limit=params.limit, offset=params.offset + ) + # Get total count (simplified - in production, use separate count query) total_count = len(results) - + # Convert to response format documents = [] for record in results: - doc = DocumentResult( - uuid=record.uuid, - metadata=record.metadata - ) + doc = DocumentResult(uuid=record.uuid, metadata=record.metadata) if params.include_content: doc.content = record.text_content documents.append(doc) - + return ListResult( documents=documents, total_count=total_count, offset=params.offset, - limit=params.limit + limit=params.limit, ).model_dump() - async def _update_document(self, arguments: Dict[str, Any]) -> Dict[str, Any]: + async def _update_document(self, arguments: dict[str, Any]) -> dict[str, Any]: """Update an existing document.""" params = UpdateDocumentParams(**arguments) - + # Get existing document results = self.dataset.query(f"uuid = '{params.document_id}'", limit=1) if not results: raise DocumentNotFound(params.document_id) - + record = results[0] - + # Update fields updated = False if params.content is not None: record.text_content = params.content updated = True - + if params.metadata is not None: record.metadata.update(params.metadata) updated = True - + if not updated: raise InvalidParams("No updates provided") - + # Regenerate embedding if requested and content changed if params.regenerate_embedding and params.content: model = os.environ.get("CONTEXTFRAME_EMBED_MODEL", "text-embedding-ada-002") api_key = os.environ.get("OPENAI_API_KEY") - + if api_key: try: provider = LiteLLMProvider(model, api_key=api_key) @@ -719,29 +739,27 @@ async def _update_document(self, arguments: Dict[str, Any]) -> Dict[str, Any]: record.vector = np.array(result.embeddings[0], dtype=np.float32) except Exception as e: logger.warning(f"Failed to regenerate embedding: {e}") - + # Update in dataset (atomic delete + add) self.dataset.delete(f"uuid = '{params.document_id}'") self.dataset.add([record]) - + return { "document": DocumentResult( - uuid=record.uuid, - content=record.text_content, - metadata=record.metadata + uuid=record.uuid, content=record.text_content, metadata=record.metadata ).model_dump() } - async def _delete_document(self, arguments: Dict[str, Any]) -> Dict[str, Any]: + async def _delete_document(self, arguments: dict[str, Any]) -> dict[str, Any]: """Delete a document.""" params = DeleteDocumentParams(**arguments) - + # Check document exists results = self.dataset.query(f"uuid = '{params.document_id}'", limit=1) if not results: raise DocumentNotFound(params.document_id) - + # Delete self.dataset.delete(f"uuid = '{params.document_id}'") - - return {"deleted": True, "document_id": params.document_id} \ No newline at end of file + + return {"deleted": True, "document_id": params.document_id} diff --git a/contextframe/mcp/transport.py b/contextframe/mcp/transport.py index a227a38..2e2ec6a 100644 --- a/contextframe/mcp/transport.py +++ b/contextframe/mcp/transport.py @@ -3,40 +3,40 @@ import asyncio import json import sys -from typing import Any, AsyncIterator, Dict, Optional - +from collections.abc import AsyncIterator from contextframe.mcp.errors import ParseError +from typing import Any, Dict, Optional class StdioTransport: """Handles stdio communication for MCP using JSON-RPC 2.0 protocol.""" def __init__(self): - self._reader: Optional[asyncio.StreamReader] = None - self._writer: Optional[asyncio.StreamWriter] = None + self._reader: asyncio.StreamReader | None = None + self._writer: asyncio.StreamWriter | None = None self._running = False async def connect(self) -> None: """Initialize stdio streams for async communication.""" loop = asyncio.get_event_loop() - + # Create async streams from stdin/stdout self._reader = asyncio.StreamReader() reader_protocol = asyncio.StreamReaderProtocol(self._reader) - + await loop.connect_read_pipe(lambda: reader_protocol, sys.stdin) - + # For stdout, we'll use a transport/protocol pair w_transport, w_protocol = await loop.connect_write_pipe( lambda: asyncio.Protocol(), sys.stdout ) self._writer = asyncio.StreamWriter(w_transport, w_protocol, self._reader, loop) - + self._running = True - async def read_message(self) -> Dict[str, Any]: + async def read_message(self) -> dict[str, Any]: """Read and parse a JSON-RPC message from stdin. - + Messages are expected to be newline-delimited JSON. """ if not self._reader: @@ -47,26 +47,26 @@ async def read_message(self) -> Dict[str, Any]: line = await self._reader.readline() if not line: raise EOFError("Connection closed") - + # Decode and parse JSON message_str = line.decode('utf-8').strip() if not message_str: # Empty line, try again return await self.read_message() - + try: message = json.loads(message_str) except json.JSONDecodeError as e: raise ParseError({"error": str(e), "input": message_str}) - + return message - + except Exception as e: if isinstance(e, (ParseError, EOFError)): raise raise ParseError({"error": str(e)}) - async def send_message(self, message: Dict[str, Any]) -> None: + async def send_message(self, message: dict[str, Any]) -> None: """Send a JSON-RPC message to stdout.""" if not self._writer: raise RuntimeError("Transport not connected") @@ -74,26 +74,26 @@ async def send_message(self, message: Dict[str, Any]) -> None: try: # Serialize to JSON and add newline message_str = json.dumps(message, separators=(',', ':')) + '\n' - + # Write to stdout self._writer.write(message_str.encode('utf-8')) await self._writer.drain() - + except Exception as e: raise RuntimeError(f"Failed to send message: {e}") async def close(self) -> None: """Clean shutdown of transport.""" self._running = False - + if self._writer: self._writer.close() await self._writer.wait_closed() - + self._reader = None self._writer = None - async def __aiter__(self) -> AsyncIterator[Dict[str, Any]]: + async def __aiter__(self) -> AsyncIterator[dict[str, Any]]: """Async iterator for reading messages.""" while self._running: try: @@ -109,4 +109,4 @@ async def __aiter__(self) -> AsyncIterator[Dict[str, Any]]: @property def is_connected(self) -> bool: """Check if transport is connected.""" - return self._reader is not None and self._writer is not None and self._running \ No newline at end of file + return self._reader is not None and self._writer is not None and self._running diff --git a/contextframe/mcp/transports/__init__.py b/contextframe/mcp/transports/__init__.py index 902ba37..e4688da 100644 --- a/contextframe/mcp/transports/__init__.py +++ b/contextframe/mcp/transports/__init__.py @@ -2,4 +2,4 @@ from contextframe.mcp.transports.stdio import StdioAdapter -__all__ = ["StdioAdapter"] \ No newline at end of file +__all__ = ["StdioAdapter"] diff --git a/contextframe/mcp/transports/stdio.py b/contextframe/mcp/transports/stdio.py index 72c6e2e..02aa2b4 100644 --- a/contextframe/mcp/transports/stdio.py +++ b/contextframe/mcp/transports/stdio.py @@ -3,47 +3,46 @@ import asyncio import json import logging -from typing import Any, Dict, Optional, AsyncIterator, List - -from contextframe.mcp.core.transport import TransportAdapter, Progress, Subscription +from collections.abc import AsyncIterator from contextframe.mcp.core.streaming import BufferedStreamingAdapter +from contextframe.mcp.core.transport import Progress, Subscription, TransportAdapter from contextframe.mcp.transport import StdioTransport - +from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) class StdioAdapter(TransportAdapter): """Stdio transport adapter using existing StdioTransport. - + This adapter wraps the existing stdio implementation to work with the new transport abstraction while maintaining backward compatibility. """ - + def __init__(self): super().__init__() self._transport = StdioTransport() self._streaming = BufferedStreamingAdapter() - self._current_progress: List[Progress] = [] - + self._current_progress: list[Progress] = [] + # Set up progress handler to collect progress self.add_progress_handler(self._collect_progress) - + async def _collect_progress(self, progress: Progress): """Collect progress updates for inclusion in response.""" self._current_progress.append(progress) - + async def initialize(self) -> None: """Initialize stdio streams.""" await self._transport.connect() logger.info("Stdio transport initialized") - + async def shutdown(self) -> None: """Close stdio streams.""" await self._transport.close() logger.info("Stdio transport shutdown") - - async def send_message(self, message: Dict[str, Any]) -> None: + + async def send_message(self, message: dict[str, Any]) -> None: """Send message via stdout.""" # Include any collected progress in the response if self._current_progress and "result" in message: @@ -55,47 +54,49 @@ async def send_message(self, message: Dict[str, Any]) -> None: "current": p.current, "total": p.total, "status": p.status, - "details": p.details + "details": p.details, } for p in self._current_progress ] self._current_progress.clear() - + await self._transport.send_message(message) - - async def receive_message(self) -> Optional[Dict[str, Any]]: + + async def receive_message(self) -> dict[str, Any] | None: """Receive message from stdin.""" return await self._transport.read_message() - + async def send_progress(self, progress: Progress) -> None: """For stdio, progress is collected and included in final response.""" await super().send_progress(progress) - - async def handle_subscription(self, subscription: Subscription) -> AsyncIterator[Dict[str, Any]]: + + async def handle_subscription( + self, subscription: Subscription + ) -> AsyncIterator[dict[str, Any]]: """Stdio uses polling-based subscriptions. - + Returns changes since last poll using change tokens. """ self._subscriptions[subscription.id] = subscription - + # For stdio, we don't actually stream - the client will poll # This is a placeholder that would be called by poll_changes tool yield { "subscription_id": subscription.id, "message": "Use poll_changes tool to check for updates", - "next_poll_token": subscription.last_poll or "initial" + "next_poll_token": subscription.last_poll or "initial", } - + @property def supports_streaming(self) -> bool: """Stdio doesn't support true streaming.""" return False - + @property def transport_type(self) -> str: """Transport type identifier.""" return "stdio" - + def get_streaming_adapter(self) -> BufferedStreamingAdapter: """Get the streaming adapter for this transport.""" - return self._streaming \ No newline at end of file + return self._streaming diff --git a/contextframe/scripts/__init__.py b/contextframe/scripts/__init__.py index 80f6919..2b1d0ea 100644 --- a/contextframe/scripts/__init__.py +++ b/contextframe/scripts/__init__.py @@ -3,4 +3,4 @@ These scripts provide bash-friendly wrappers around ContextFrame operations, designed for use with AI agents like Claude Code and manual CLI usage. -""" \ No newline at end of file +""" diff --git a/contextframe/scripts/add_impl.py b/contextframe/scripts/add_impl.py index fb675c8..5cca5aa 100644 --- a/contextframe/scripts/add_impl.py +++ b/contextframe/scripts/add_impl.py @@ -2,242 +2,275 @@ """Implementation of contextframe-add command.""" import argparse +import mimetypes import sys import uuid -from pathlib import Path -from typing import Optional, List -import mimetypes - -from contextframe.frame import FrameDataset, FrameRecord from contextframe.embed import LiteLLMProvider, create_frame_records_with_embeddings +from contextframe.frame import FrameDataset, FrameRecord +from pathlib import Path +from typing import List, Optional -def read_file_content(file_path: Path) -> tuple[str, Optional[bytes], Optional[str]]: +def read_file_content(file_path: Path) -> tuple[str, bytes | None, str | None]: """Read file content and determine if it's text or binary.""" # Guess MIME type mime_type, _ = mimetypes.guess_type(str(file_path)) - + # Common text file extensions - text_extensions = {'.txt', '.md', '.py', '.js', '.html', '.css', '.json', '.xml', - '.yaml', '.yml', '.toml', '.ini', '.cfg', '.conf', '.sh', '.bash', - '.rst', '.tex', '.csv', '.tsv', '.sql'} - + text_extensions = { + '.txt', + '.md', + '.py', + '.js', + '.html', + '.css', + '.json', + '.xml', + '.yaml', + '.yml', + '.toml', + '.ini', + '.cfg', + '.conf', + '.sh', + '.bash', + '.rst', + '.tex', + '.csv', + '.tsv', + '.sql', + } + # Check if it's a text file - is_text = (file_path.suffix.lower() in text_extensions or - (mime_type and mime_type.startswith('text/'))) - + is_text = file_path.suffix.lower() in text_extensions or ( + mime_type and mime_type.startswith('text/') + ) + if is_text: try: - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, encoding='utf-8') as f: content = f.read() return content, None, None except UnicodeDecodeError: # Fall back to binary pass - + # Read as binary with open(file_path, 'rb') as f: raw_data = f.read() - + # For binary files, we'll store a placeholder text text_content = f"Binary file: {file_path.name} ({mime_type or 'unknown type'})" - + return text_content, raw_data, mime_type -def chunk_text(text: str, chunk_size: int, chunk_overlap: int) -> List[str]: +def chunk_text(text: str, chunk_size: int, chunk_overlap: int) -> list[str]: """Split text into overlapping chunks.""" if chunk_size <= 0: return [text] - + chunks = [] start = 0 - + while start < len(text): end = start + chunk_size chunk = text[start:end] - + # Try to break at a sentence or paragraph boundary if end < len(text): # Look for sentence end last_period = chunk.rfind('. ') last_newline = chunk.rfind('\n') - + # Use the latest boundary found boundary = max(last_period, last_newline) if boundary > chunk_size * 0.5: # Only use if it's not too early - chunk = text[start:start + boundary + 1] + chunk = text[start : start + boundary + 1] end = start + boundary + 1 - + chunks.append(chunk.strip()) start = end - chunk_overlap - + return [c for c in chunks if c] # Filter out empty chunks -def create_record_metadata(file_path: Path, record_type: str, identifier: Optional[str], - collection: Optional[str], chunk_info: Optional[dict] = None) -> dict: +def create_record_metadata( + file_path: Path, + record_type: str, + identifier: str | None, + collection: str | None, + chunk_info: dict | None = None, +) -> dict: """Create metadata for a record.""" metadata = { "identifier": identifier or str(uuid.uuid4()), "record_type": record_type, "title": file_path.name, - "source": { - "type": "file", - "path": str(file_path.absolute()) - }, - "relationships": [] + "source": {"type": "file", "path": str(file_path.absolute())}, + "relationships": [], } - + # Add collection relationship if specified if collection: - metadata["relationships"].append({ - "relationship_type": "member_of", - "target_type": "collection", - "target_identifier": collection - }) - + metadata["relationships"].append( + { + "relationship_type": "member_of", + "target_type": "collection", + "target_identifier": collection, + } + ) + # Add chunk information if this is a chunked document if chunk_info: metadata["custom_metadata"] = { "chunk_index": chunk_info["index"], "chunk_total": chunk_info["total"], - "parent_document": chunk_info["parent_id"] + "parent_document": chunk_info["parent_id"], } - + # Add relationship to parent document - metadata["relationships"].append({ - "relationship_type": "child", - "target_type": "document", - "target_identifier": chunk_info["parent_id"] - }) - + metadata["relationships"].append( + { + "relationship_type": "child", + "target_type": "document", + "target_identifier": chunk_info["parent_id"], + } + ) + return metadata -def add_file(dataset: FrameDataset, file_path: Path, args) -> List[str]: +def add_file(dataset: FrameDataset, file_path: Path, args) -> list[str]: """Add a single file to the dataset. Returns list of added record IDs.""" print(f"Adding file: {file_path}") - + # Read file content text_content, raw_data, raw_data_type = read_file_content(file_path) - + # Handle chunking if requested if args.chunk_size and args.chunk_size > 0: chunks = chunk_text(text_content, args.chunk_size, args.chunk_overlap or 0) - + if len(chunks) > 1: print(f"Splitting into {len(chunks)} chunks") - + # Create parent document ID parent_id = args.identifier or str(uuid.uuid4()) added_ids = [] - + for i, chunk in enumerate(chunks): chunk_id = f"{parent_id}_chunk_{i}" - chunk_info = { - "index": i, - "total": len(chunks), - "parent_id": parent_id - } - + chunk_info = {"index": i, "total": len(chunks), "parent_id": parent_id} + metadata = create_record_metadata( file_path, "document", chunk_id, args.collection, chunk_info ) - + record = FrameRecord( text_content=chunk, metadata=metadata, - raw_data=raw_data if i == 0 else None, # Only store raw data with first chunk - raw_data_type=raw_data_type if i == 0 else None + raw_data=raw_data + if i == 0 + else None, # Only store raw data with first chunk + raw_data_type=raw_data_type if i == 0 else None, ) - + # Add to dataset dataset.add([record]) added_ids.append(chunk_id) - + return added_ids - + # Single document (no chunking) - metadata = create_record_metadata(file_path, args.type, args.identifier, args.collection) - + metadata = create_record_metadata( + file_path, args.type, args.identifier, args.collection + ) + record = FrameRecord( text_content=text_content, metadata=metadata, raw_data=raw_data, - raw_data_type=raw_data_type + raw_data_type=raw_data_type, ) - + # Generate embeddings if requested if args.embeddings: print("Generating embeddings...") import os + model = os.environ.get("CONTEXTFRAME_EMBED_MODEL", "text-embedding-ada-002") - + try: # Use the contextframe embedding integration records_with_embeddings = create_frame_records_with_embeddings( - documents=[{"content": text_content, "metadata": metadata}], - model=model + documents=[{"content": text_content, "metadata": metadata}], model=model ) record = records_with_embeddings[0] except Exception as e: print(f"Warning: Failed to generate embeddings: {e}", file=sys.stderr) print("Adding document without embeddings.", file=sys.stderr) - + # Add to dataset dataset.add([record]) return [metadata["identifier"]] -def add_directory(dataset: FrameDataset, dir_path: Path, args) -> List[str]: +def add_directory(dataset: FrameDataset, dir_path: Path, args) -> list[str]: """Add all files in a directory to the dataset.""" added_ids = [] - + # Get all files recursively files = list(dir_path.rglob('*')) files = [f for f in files if f.is_file()] - + print(f"Found {len(files)} files in {dir_path}") - + for file_path in files: try: ids = add_file(dataset, file_path, args) added_ids.extend(ids) except Exception as e: print(f"Error adding {file_path}: {e}", file=sys.stderr) - + return added_ids def main(): """Main entry point for add command.""" - parser = argparse.ArgumentParser(description='Add documents to a ContextFrame dataset') + parser = argparse.ArgumentParser( + description='Add documents to a ContextFrame dataset' + ) parser.add_argument('dataset', help='Path to the dataset') parser.add_argument('input_path', help='File or directory to add') - parser.add_argument('--type', default='document', - choices=['document', 'collection_header', 'dataset_header'], - help='Record type') + parser.add_argument( + '--type', + default='document', + choices=['document', 'collection_header', 'dataset_header'], + help='Record type', + ) parser.add_argument('--collection', help='Add to collection with this name') parser.add_argument('--identifier', help='Custom identifier') - parser.add_argument('--embeddings', action='store_true', - help='Generate embeddings for documents') + parser.add_argument( + '--embeddings', action='store_true', help='Generate embeddings for documents' + ) parser.add_argument('--chunk-size', type=int, help='Split documents into chunks') - parser.add_argument('--chunk-overlap', type=int, default=0, - help='Overlap between chunks') - + parser.add_argument( + '--chunk-overlap', type=int, default=0, help='Overlap between chunks' + ) + args = parser.parse_args() - + # Open the dataset try: dataset = FrameDataset.open(args.dataset) except Exception as e: print(f"Error opening dataset: {e}", file=sys.stderr) sys.exit(1) - + # Convert input path to Path object input_path = Path(args.input_path) - + # Add file(s) try: if input_path.is_file(): @@ -245,17 +278,20 @@ def main(): elif input_path.is_dir(): added_ids = add_directory(dataset, input_path, args) else: - print(f"Error: {input_path} is neither a file nor a directory", file=sys.stderr) + print( + f"Error: {input_path} is neither a file nor a directory", + file=sys.stderr, + ) sys.exit(1) except Exception as e: print(f"Error adding documents: {e}", file=sys.stderr) sys.exit(1) - + # Report results print(f"\nSuccessfully added {len(added_ids)} record(s) to {args.dataset}") - + return 0 if __name__ == '__main__': - sys.exit(main()) \ No newline at end of file + sys.exit(main()) diff --git a/contextframe/scripts/create_dataset.py b/contextframe/scripts/create_dataset.py index ea14288..71b0087 100644 --- a/contextframe/scripts/create_dataset.py +++ b/contextframe/scripts/create_dataset.py @@ -2,17 +2,17 @@ """Helper script to create a new ContextFrame dataset.""" import sys -from pathlib import Path from contextframe.frame import FrameDataset +from pathlib import Path def main(): if len(sys.argv) != 2: print("Usage: python -m contextframe.scripts.create_dataset ") sys.exit(1) - + dataset_path = Path(sys.argv[1]) - + try: # Create the dataset dataset = FrameDataset.create(dataset_path) @@ -23,4 +23,4 @@ def main(): if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/contextframe/scripts/get_impl.py b/contextframe/scripts/get_impl.py index d097829..b37395a 100644 --- a/contextframe/scripts/get_impl.py +++ b/contextframe/scripts/get_impl.py @@ -4,9 +4,8 @@ import argparse import json import sys -from pathlib import Path - from contextframe.frame import FrameDataset, FrameRecord +from pathlib import Path def format_as_json(record: FrameRecord) -> str: @@ -19,7 +18,7 @@ def format_as_json(record: FrameRecord) -> str: "metadata": record.metadata, "has_raw_data": record.raw_data is not None, "raw_data_type": record.raw_data_type, - "vector_dimension": len(record.vector) if record.vector is not None else 0 + "vector_dimension": len(record.vector) if record.vector is not None else 0, } return json.dumps(data, indent=2) @@ -27,66 +26,74 @@ def format_as_json(record: FrameRecord) -> str: def format_as_text(record: FrameRecord) -> str: """Format record as plain text.""" lines = [] - + # Basic info lines.append(f"Identifier: {record.metadata.get('identifier')}") lines.append(f"Type: {record.metadata.get('record_type', 'document')}") - + if 'title' in record.metadata: lines.append(f"Title: {record.metadata['title']}") - + # Source info if 'source' in record.metadata: source = record.metadata['source'] - lines.append(f"Source: {source.get('type', 'unknown')} - {source.get('path', 'N/A')}") - + lines.append( + f"Source: {source.get('type', 'unknown')} - {source.get('path', 'N/A')}" + ) + # Relationships relationships = record.metadata.get('relationships', []) if relationships: lines.append("\nRelationships:") for rel in relationships: - lines.append(f" - {rel.get('relationship_type')} -> {rel.get('target_type')}: {rel.get('target_identifier')}") - + lines.append( + f" - {rel.get('relationship_type')} -> {rel.get('target_type')}: {rel.get('target_identifier')}" + ) + # Custom metadata if 'custom_metadata' in record.metadata and record.metadata['custom_metadata']: lines.append("\nCustom Metadata:") for key, value in record.metadata['custom_metadata'].items(): lines.append(f" {key}: {value}") - + # Content lines.append("\nContent:") lines.append("-" * 60) lines.append(record.text_content) - + # Additional info if record.raw_data: - lines.append(f"\nRaw Data: {len(record.raw_data)} bytes ({record.raw_data_type or 'unknown type'})") - + lines.append( + f"\nRaw Data: {len(record.raw_data)} bytes ({record.raw_data_type or 'unknown type'})" + ) + if record.vector is not None and len(record.vector) > 0: lines.append(f"\nVector: {len(record.vector)} dimensions") - + return "\n".join(lines) def format_as_markdown(record: FrameRecord) -> str: """Format record as Markdown.""" lines = [] - + # Title title = record.metadata.get('title', record.metadata.get('identifier', 'Document')) lines.append(f"# {title}") lines.append("") - + # Metadata section lines.append("## Metadata") lines.append("") lines.append(f"- **Identifier**: `{record.metadata.get('identifier')}`") lines.append(f"- **Type**: {record.metadata.get('record_type', 'document')}") - + if 'source' in record.metadata: source = record.metadata['source'] - lines.append(f"- **Source**: {source.get('type', 'unknown')} - `{source.get('path', 'N/A')}`") - + lines.append( + f"- **Source**: {source.get('type', 'unknown')} - `{source.get('path', 'N/A')}`" + ) + # Relationships relationships = record.metadata.get('relationships', []) if relationships: @@ -98,7 +105,7 @@ def format_as_markdown(record: FrameRecord) -> str: target_type = rel.get('target_type', 'unknown') target_id = rel.get('target_identifier', 'unknown') lines.append(f"- **{rel_type}** → {target_type}: `{target_id}`") - + # Custom metadata if 'custom_metadata' in record.metadata and record.metadata['custom_metadata']: lines.append("") @@ -106,12 +113,12 @@ def format_as_markdown(record: FrameRecord) -> str: lines.append("") for key, value in record.metadata['custom_metadata'].items(): lines.append(f"- **{key}**: {value}") - + # Content section lines.append("") lines.append("## Content") lines.append("") - + # If it looks like markdown, include it directly if record.text_content.strip().startswith('#') or '\n#' in record.text_content: lines.append(record.text_content) @@ -120,54 +127,62 @@ def format_as_markdown(record: FrameRecord) -> str: lines.append("```") lines.append(record.text_content) lines.append("```") - + # Additional info if record.raw_data or record.vector is not None: lines.append("") lines.append("## Additional Information") lines.append("") - + if record.raw_data: - lines.append(f"- **Raw Data**: {len(record.raw_data)} bytes ({record.raw_data_type or 'unknown type'})") - + lines.append( + f"- **Raw Data**: {len(record.raw_data)} bytes ({record.raw_data_type or 'unknown type'})" + ) + if record.vector is not None and len(record.vector) > 0: lines.append(f"- **Vector**: {len(record.vector)} dimensions") - + return "\n".join(lines) def main(): """Main entry point for get command.""" - parser = argparse.ArgumentParser(description='Get a specific document from a ContextFrame dataset') + parser = argparse.ArgumentParser( + description='Get a specific document from a ContextFrame dataset' + ) parser.add_argument('dataset', help='Path to the dataset') parser.add_argument('identifier', help='Document identifier') - parser.add_argument('--format', choices=['json', 'text', 'markdown'], default='text', - help='Output format') - + parser.add_argument( + '--format', + choices=['json', 'text', 'markdown'], + default='text', + help='Output format', + ) + args = parser.parse_args() - + # Open the dataset try: dataset = FrameDataset.open(args.dataset) except Exception as e: print(f"Error opening dataset: {e}", file=sys.stderr) sys.exit(1) - + # Find the document try: # Use filter to find by identifier filter_expr = f"identifier = '{args.identifier}'" results = dataset.search(filter=filter_expr, limit=1) - + if not results: print(f"Error: Document not found: {args.identifier}", file=sys.stderr) sys.exit(1) - + record = results[0] except Exception as e: print(f"Error retrieving document: {e}", file=sys.stderr) sys.exit(1) - + # Format and output if args.format == 'json': output = format_as_json(record) @@ -175,11 +190,11 @@ def main(): output = format_as_markdown(record) else: # text output = format_as_text(record) - + print(output) - + return 0 if __name__ == '__main__': - sys.exit(main()) \ No newline at end of file + sys.exit(main()) diff --git a/contextframe/scripts/list_impl.py b/contextframe/scripts/list_impl.py index 6b8ce3a..a529c9e 100644 --- a/contextframe/scripts/list_impl.py +++ b/contextframe/scripts/list_impl.py @@ -4,49 +4,48 @@ import argparse import json import sys -from pathlib import Path -from datetime import datetime - from contextframe.frame import FrameDataset +from datetime import datetime +from pathlib import Path def format_as_table(records: list) -> str: """Format records as a table.""" if not records: return "No records found." - + # Define columns and their widths columns = [ ("ID", 36), # UUID width ("Type", 15), ("Title", 40), ("Size", 10), - ("Updated", 20) + ("Updated", 20), ] - + # Print header header = "" separator = "" for col_name, width in columns: header += f"{col_name:<{width}} " separator += "-" * width + " " - + lines = [header, separator] - + # Print rows for record in records: row = "" - + # ID identifier = record.metadata.get("identifier", "N/A") if len(identifier) > 36: identifier = identifier[:33] + "..." row += f"{identifier:<36} " - + # Type record_type = record.metadata.get("record_type", "document") row += f"{record_type:<15} " - + # Title title = record.metadata.get("title", "") if not title and "source" in record.metadata: @@ -61,7 +60,7 @@ def format_as_table(records: list) -> str: else: title = title[:37] + "..." if len(title) > 40 else title row += f"{title:<40} " - + # Size size = len(record.text_content) if size < 1024: @@ -71,9 +70,11 @@ def format_as_table(records: list) -> str: else: size_str = f"{size / (1024 * 1024):.1f} MB" row += f"{size_str:<10} " - + # Updated (if available in metadata) - updated = record.metadata.get("updated_at", record.metadata.get("created_at", "")) + updated = record.metadata.get( + "updated_at", record.metadata.get("created_at", "") + ) if updated: try: # Try to parse and format the date @@ -87,13 +88,13 @@ def format_as_table(records: list) -> str: else: updated = "N/A" row += f"{updated:<20}" - + lines.append(row) - + # Add summary lines.append(separator) lines.append(f"Total: {len(records)} records") - + return "\n".join(lines) @@ -101,21 +102,22 @@ def format_as_json(records: list) -> str: """Format records as JSON.""" data = [] for record in records: - data.append({ - "identifier": record.metadata.get("identifier"), - "record_type": record.metadata.get("record_type", "document"), - "title": record.metadata.get("title"), - "content_preview": record.text_content[:200] + "..." if len(record.text_content) > 200 else record.text_content, - "metadata": record.metadata, - "content_size": len(record.text_content), - "has_vector": record.vector is not None and len(record.vector) > 0, - "has_raw_data": record.raw_data is not None - }) - - return json.dumps({ - "records": data, - "count": len(data) - }, indent=2) + data.append( + { + "identifier": record.metadata.get("identifier"), + "record_type": record.metadata.get("record_type", "document"), + "title": record.metadata.get("title"), + "content_preview": record.text_content[:200] + "..." + if len(record.text_content) > 200 + else record.text_content, + "metadata": record.metadata, + "content_size": len(record.text_content), + "has_vector": record.vector is not None and len(record.vector) > 0, + "has_raw_data": record.raw_data is not None, + } + ) + + return json.dumps({"records": data, "count": len(data)}, indent=2) def format_as_ids(records: list) -> str: @@ -126,29 +128,39 @@ def format_as_ids(records: list) -> str: def main(): """Main entry point for list command.""" - parser = argparse.ArgumentParser(description='List documents in a ContextFrame dataset') + parser = argparse.ArgumentParser( + description='List documents in a ContextFrame dataset' + ) parser.add_argument('dataset', help='Path to the dataset') - parser.add_argument('--limit', type=int, default=50, help='Number of records to return') - parser.add_argument('--filter', dest='filter_expr', help='Lance SQL filter expression') - parser.add_argument('--format', choices=['table', 'json', 'ids'], default='table', - help='Output format') - + parser.add_argument( + '--limit', type=int, default=50, help='Number of records to return' + ) + parser.add_argument( + '--filter', dest='filter_expr', help='Lance SQL filter expression' + ) + parser.add_argument( + '--format', + choices=['table', 'json', 'ids'], + default='table', + help='Output format', + ) + args = parser.parse_args() - + # Open the dataset try: dataset = FrameDataset.open(args.dataset) except Exception as e: print(f"Error opening dataset: {e}", file=sys.stderr) sys.exit(1) - + # Get records try: records = dataset.search(filter=args.filter_expr, limit=args.limit) except Exception as e: print(f"Error listing records: {e}", file=sys.stderr) sys.exit(1) - + # Format and output if args.format == 'json': output = format_as_json(records) @@ -156,11 +168,11 @@ def main(): output = format_as_ids(records) else: # table output = format_as_table(records) - + print(output) - + return 0 if __name__ == '__main__': - sys.exit(main()) \ No newline at end of file + sys.exit(main()) diff --git a/contextframe/scripts/search_impl.py b/contextframe/scripts/search_impl.py index 6652032..80af60c 100644 --- a/contextframe/scripts/search_impl.py +++ b/contextframe/scripts/search_impl.py @@ -2,141 +2,152 @@ """Implementation of contextframe-search command.""" import argparse +import numpy as np import sys +from contextframe.embed import LiteLLMProvider +from contextframe.frame import FrameDataset from pathlib import Path from typing import Optional -import numpy as np - -from contextframe.frame import FrameDataset -from contextframe.embed import LiteLLMProvider - -def search_hybrid(dataset: FrameDataset, query: str, limit: int, filter_expr: Optional[str] = None) -> list: +def search_hybrid( + dataset: FrameDataset, query: str, limit: int, filter_expr: str | None = None +) -> list: """Perform hybrid search using both vector and text search.""" # Try vector search first if embeddings are available try: # Get embedding configuration from environment or use defaults import os + model = os.environ.get("CONTEXTFRAME_EMBED_MODEL", "text-embedding-ada-002") - + # Create embedding provider provider = LiteLLMProvider(model) - + # Generate query embedding result = provider.embed(query) query_vector = np.array(result.embeddings[0], dtype=np.float32) - + # Perform vector search vector_results = dataset.knn_search( - query_vector=query_vector, - k=limit, - filter=filter_expr + query_vector=query_vector, k=limit, filter=filter_expr ) - + # If we got results, return them if vector_results: return vector_results - + except Exception as e: # Fall back to text search if vector search fails print(f"Vector search unavailable: {e}", file=sys.stderr) - + # Fall back to text search - return dataset.full_text_search( - query=query, - limit=limit, - filter=filter_expr - ) + return dataset.full_text_search(query=query, limit=limit, filter=filter_expr) -def search_vector(dataset: FrameDataset, query: str, limit: int, filter_expr: Optional[str] = None) -> list: +def search_vector( + dataset: FrameDataset, query: str, limit: int, filter_expr: str | None = None +) -> list: """Perform vector search using embeddings.""" # Get embedding configuration from environment or use defaults import os + model = os.environ.get("CONTEXTFRAME_EMBED_MODEL", "text-embedding-ada-002") api_key = os.environ.get("OPENAI_API_KEY") # or other provider keys - + # Create embedding provider provider = LiteLLMProvider(model, api_key=api_key) - + # Generate query embedding try: result = provider.embed(query) query_vector = np.array(result.embeddings[0], dtype=np.float32) except Exception as e: print(f"Error generating embedding: {e}", file=sys.stderr) - print("Make sure you have set up API credentials for your embedding provider.", file=sys.stderr) + print( + "Make sure you have set up API credentials for your embedding provider.", + file=sys.stderr, + ) print("For OpenAI: export OPENAI_API_KEY='your-key'", file=sys.stderr) - print("For other providers, see: https://docs.litellm.ai/docs/providers", file=sys.stderr) + print( + "For other providers, see: https://docs.litellm.ai/docs/providers", + file=sys.stderr, + ) sys.exit(1) - + # Perform vector search - return dataset.knn_search( - query_vector=query_vector, - k=limit, - filter=filter_expr - ) + return dataset.knn_search(query_vector=query_vector, k=limit, filter=filter_expr) -def search_text(dataset: FrameDataset, query: str, limit: int, filter_expr: Optional[str] = None) -> list: +def search_text( + dataset: FrameDataset, query: str, limit: int, filter_expr: str | None = None +) -> list: """Perform text search.""" - return dataset.full_text_search( - query=query, - limit=limit, - filter=filter_expr - ) + return dataset.full_text_search(query=query, limit=limit, filter=filter_expr) def format_result(record, index: int): """Format a single search result for display.""" - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"Result {index + 1}:") print(f"ID: {record.metadata.get('identifier', 'N/A')}") print(f"Type: {record.metadata.get('record_type', 'document')}") - + # Show title if available if 'title' in record.metadata: print(f"Title: {record.metadata['title']}") - + # Show collection if part of one relationships = record.metadata.get('relationships', []) for rel in relationships: - if rel.get('relationship_type') == 'member_of' and rel.get('target_type') == 'collection': + if ( + rel.get('relationship_type') == 'member_of' + and rel.get('target_type') == 'collection' + ): print(f"Collection: {rel.get('target_identifier', 'Unknown')}") - + # Show snippet of content content = record.text_content if len(content) > 200: content = content[:200] + "..." print(f"\nContent:\n{content}") - + # Show custom metadata if present if 'custom_metadata' in record.metadata and record.metadata['custom_metadata']: - print(f"\nCustom Metadata:") + print("\nCustom Metadata:") for key, value in record.metadata['custom_metadata'].items(): print(f" {key}: {value}") def main(): """Main entry point for search command.""" - parser = argparse.ArgumentParser(description='Search documents in a ContextFrame dataset') + parser = argparse.ArgumentParser( + description='Search documents in a ContextFrame dataset' + ) parser.add_argument('dataset', help='Path to the dataset') parser.add_argument('query', help='Search query') - parser.add_argument('--limit', type=int, default=10, help='Number of results to return') - parser.add_argument('--type', choices=['vector', 'text', 'hybrid'], default='hybrid', - help='Search type') - parser.add_argument('--filter', dest='filter_expr', help='Lance SQL filter expression') - + parser.add_argument( + '--limit', type=int, default=10, help='Number of results to return' + ) + parser.add_argument( + '--type', + choices=['vector', 'text', 'hybrid'], + default='hybrid', + help='Search type', + ) + parser.add_argument( + '--filter', dest='filter_expr', help='Lance SQL filter expression' + ) + args = parser.parse_args() - + # Open the dataset try: dataset = FrameDataset.open(args.dataset) except Exception as e: print(f"Error opening dataset: {e}", file=sys.stderr) sys.exit(1) - + # Perform search based on type try: if args.type == 'vector': @@ -148,7 +159,7 @@ def main(): except Exception as e: print(f"Error performing search: {e}", file=sys.stderr) sys.exit(1) - + # Display results if not results: print("No results found.") @@ -156,9 +167,9 @@ def main(): print(f"Found {len(results)} results:") for i, record in enumerate(results): format_result(record, i) - + return 0 if __name__ == '__main__': - sys.exit(main()) \ No newline at end of file + sys.exit(main()) diff --git a/contextframe/templates/__init__.py b/contextframe/templates/__init__.py index 1d8f758..e2b4e15 100644 --- a/contextframe/templates/__init__.py +++ b/contextframe/templates/__init__.py @@ -12,10 +12,10 @@ Example: >>> from contextframe import FrameDataset >>> from contextframe.templates import SoftwareProjectTemplate - >>> + >>> >>> template = SoftwareProjectTemplate() >>> dataset = FrameDataset.create("my-project.lance") - >>> + >>> >>> # Apply template to import project >>> results = template.apply( ... source_path="~/my-project", @@ -32,11 +32,11 @@ __all__ = [ "ContextTemplate", - "TemplateResult", + "TemplateResult", "TemplateRegistry", "get_template", "list_templates", "SoftwareProjectTemplate", "ResearchTemplate", "BusinessTemplate", -] \ No newline at end of file +] diff --git a/contextframe/templates/base.py b/contextframe/templates/base.py index 30bf175..18846ad 100644 --- a/contextframe/templates/base.py +++ b/contextframe/templates/base.py @@ -13,7 +13,7 @@ @dataclass class TemplateResult: """Result of applying a template to a dataset.""" - + frames_created: int = 0 collections_created: int = 0 relationships_created: int = 0 @@ -23,10 +23,10 @@ class TemplateResult: collection_ids: list[str] = field(default_factory=list) -@dataclass +@dataclass class FileMapping: """Maps a file to frame metadata and configuration.""" - + path: Path title: str record_type: str = RecordType.DOCUMENT @@ -40,7 +40,7 @@ class FileMapping: @dataclass class CollectionDefinition: """Defines a collection structure.""" - + name: str title: str description: str @@ -52,7 +52,7 @@ class CollectionDefinition: @dataclass class EnrichmentSuggestion: """Suggested enrichment for a document type.""" - + file_pattern: str # glob pattern enhancement_config: dict[str, Any] = field(default_factory=dict) priority: int = 0 # higher = more important @@ -60,87 +60,89 @@ class EnrichmentSuggestion: class ContextTemplate(abc.ABC): """Abstract base class for Context Templates. - + Templates provide pre-configured patterns for importing and structuring documents into ContextFrame datasets. Subclasses implement domain-specific logic for categorizing, organizing, and enriching documents. """ - + def __init__(self, name: str, description: str): """Initialize the template. - + Args: name: Template name (e.g., "software_project") description: Human-readable description """ self.name = name self.description = description - + @abc.abstractmethod def scan(self, source_path: str | Path) -> list[FileMapping]: """Scan source directory and map files to frames. - + This method analyzes the directory structure and files to determine how they should be imported into ContextFrame. - + Args: source_path: Path to scan - + Returns: List of file mappings """ ... - + @abc.abstractmethod - def define_collections(self, file_mappings: list[FileMapping]) -> list[CollectionDefinition]: + def define_collections( + self, file_mappings: list[FileMapping] + ) -> list[CollectionDefinition]: """Define collection structure based on discovered files. - + Args: file_mappings: Files discovered during scan - + Returns: List of collections to create """ ... - + @abc.abstractmethod def discover_relationships( - self, - file_mappings: list[FileMapping], - dataset: FrameDataset + self, file_mappings: list[FileMapping], dataset: FrameDataset ) -> list[dict[str, Any]]: """Discover relationships between documents. - + Args: file_mappings: Files being imported dataset: Dataset being populated - + Returns: List of relationship dictionaries """ ... - + @abc.abstractmethod - def suggest_enrichments(self, file_mappings: list[FileMapping]) -> list[EnrichmentSuggestion]: + def suggest_enrichments( + self, file_mappings: list[FileMapping] + ) -> list[EnrichmentSuggestion]: """Suggest enrichments for imported documents. - + Args: file_mappings: Files being imported - + Returns: List of enrichment suggestions """ ... - + def validate_source(self, source_path: str | Path) -> Path: """Validate and normalize source path. - + Args: source_path: Path to validate - + Returns: Normalized Path object - + Raises: ValueError: If path is invalid """ @@ -150,7 +152,7 @@ def validate_source(self, source_path: str | Path) -> Path: if not path.is_dir(): raise ValueError(f"Source path must be a directory: {path}") return path - + def apply( self, source_path: str | Path, @@ -158,48 +160,50 @@ def apply( *, auto_enhance: bool = False, dry_run: bool = False, - progress_callback: callable | None = None + progress_callback: callable | None = None, ) -> TemplateResult: """Apply template to import documents into dataset. - + This is the main entry point that orchestrates the entire import process: 1. Scans source directory 2. Creates collections 3. Imports documents as frames 4. Establishes relationships 5. Optionally runs enrichments - + Args: source_path: Directory to import from dataset: Target FrameDataset auto_enhance: Whether to run suggested enrichments dry_run: If True, only simulate the import progress_callback: Optional callback for progress updates - + Returns: TemplateResult with import statistics """ result = TemplateResult() - + try: # Validate source source_path = self.validate_source(source_path) - + # Phase 1: Scan files if progress_callback: progress_callback("Scanning files...") file_mappings = self.scan(source_path) - + if dry_run: - result.warnings.append(f"DRY RUN: Would import {len(file_mappings)} files") + result.warnings.append( + f"DRY RUN: Would import {len(file_mappings)} files" + ) return result - + # Phase 2: Create collections if progress_callback: progress_callback("Creating collections...") collections = self.define_collections(file_mappings) collection_map = {} - + for coll_def in collections: try: coll_record = self._create_collection(coll_def, dataset) @@ -208,29 +212,32 @@ def apply( result.collections_created += 1 result.collection_ids.append(coll_record.uuid) except Exception as e: - result.errors.append(f"Failed to create collection {coll_def.name}: {e}") - + result.errors.append( + f"Failed to create collection {coll_def.name}: {e}" + ) + # Phase 3: Import documents if progress_callback: progress_callback("Importing documents...") - + from ..extract import extract_from_file - + frame_map = {} for mapping in file_mappings: if mapping.skip: continue - + try: # Extract content extraction = extract_from_file( - str(mapping.path), - **mapping.extract_config + str(mapping.path), **mapping.extract_config ) - + if extraction.error: - result.warnings.append(f"Extraction warning for {mapping.path}: {extraction.error}") - + result.warnings.append( + f"Extraction warning for {mapping.path}: {extraction.error}" + ) + # Create frame metadata = { "title": mapping.title, @@ -239,36 +246,37 @@ def apply( "tags": mapping.tags, "custom_metadata": mapping.custom_metadata, } - + if mapping.collection and mapping.collection in collection_map: metadata["collection"] = mapping.collection metadata["collection_id"] = collection_map[mapping.collection] metadata["collection_id_type"] = "uuid" - + # Add extraction metadata if extraction.metadata: - metadata["custom_metadata"].update({ - f"extract_{k}": str(v) - for k, v in extraction.metadata.items() - }) - + metadata["custom_metadata"].update( + { + f"extract_{k}": str(v) + for k, v in extraction.metadata.items() + } + ) + frame = FrameRecord( - text_content=extraction.content, - metadata=metadata + text_content=extraction.content, metadata=metadata ) - + dataset.add(frame) frame_map[str(mapping.path)] = frame.uuid result.frames_created += 1 result.frame_ids.append(frame.uuid) - + except Exception as e: result.errors.append(f"Failed to import {mapping.path}: {e}") - + # Phase 4: Discover relationships if progress_callback: progress_callback("Discovering relationships...") - + relationships = self.discover_relationships(file_mappings, dataset) for rel in relationships: try: @@ -277,22 +285,24 @@ def apply( result.relationships_created += 1 except Exception as e: result.warnings.append(f"Failed to create relationship: {e}") - + # Phase 5: Run enrichments if requested if auto_enhance: if progress_callback: progress_callback("Running enrichments...") - + suggestions = self.suggest_enrichments(file_mappings) # In practice, would integrate with enhancement module here result.warnings.append("Auto-enhancement not yet implemented") - + except Exception as e: result.errors.append(f"Template application failed: {e}") - + return result - - def _create_collection(self, definition: CollectionDefinition, dataset: FrameDataset) -> FrameRecord: + + def _create_collection( + self, definition: CollectionDefinition, dataset: FrameDataset + ) -> FrameRecord: """Create a collection header frame.""" metadata = { "title": definition.title, @@ -301,13 +311,10 @@ def _create_collection(self, definition: CollectionDefinition, dataset: FrameDat "position": definition.position, "custom_metadata": { "collection_name": definition.name, - } + }, } - + if definition.parent: metadata["custom_metadata"]["parent_collection"] = definition.parent - - return FrameRecord( - text_content=definition.description, - metadata=metadata - ) \ No newline at end of file + + return FrameRecord(text_content=definition.description, metadata=metadata) diff --git a/contextframe/templates/business.py b/contextframe/templates/business.py index e1485cf..86c7bc8 100644 --- a/contextframe/templates/business.py +++ b/contextframe/templates/business.py @@ -15,57 +15,63 @@ class BusinessTemplate(ContextTemplate): """Template for business documents and organizational content. - + Handles: - Meeting notes and minutes - Decision documents and proposals - Reports and analyses - Project documentation - Stakeholder communications - + Automatically: - Groups by project/initiative - Links decisions to related documents - Tracks document ownership - Suggests business-focused enrichments """ - + # Document patterns - MEETING_PATTERNS = ["*meeting*", "*minutes*", "*standup*", "*retro*", "*retrospective*"] + MEETING_PATTERNS = [ + "*meeting*", + "*minutes*", + "*standup*", + "*retro*", + "*retrospective*", + ] DECISION_PATTERNS = ["*decision*", "*proposal*", "*rfc*", "*adr*", "*design*"] REPORT_PATTERNS = ["*report*", "*analysis*", "*summary*", "*review*"] PROJECT_PATTERNS = ["*project*", "*plan*", "*roadmap*", "*strategy*"] - + # Common business directories MEETINGS_DIRS = {"meetings", "notes", "minutes"} DECISIONS_DIRS = {"decisions", "proposals", "rfcs", "adrs"} REPORTS_DIRS = {"reports", "analyses", "reviews"} PROJECTS_DIRS = {"projects", "initiatives", "programs"} - + # File extensions DOC_EXTENSIONS = {".md", ".docx", ".doc", ".pdf", ".txt", ".rtf"} SPREADSHEET_EXTENSIONS = {".xlsx", ".xls", ".csv", ".ods"} PRESENTATION_EXTENSIONS = {".pptx", ".ppt", ".key", ".odp"} - + def __init__(self): """Initialize the business template.""" super().__init__( name="business", - description="Template for business documents, meeting notes, and organizational content" + description="Template for business documents, meeting notes, and organizational content", ) - + def scan(self, source_path: str | Path) -> list[FileMapping]: """Scan business directory and map documents.""" source_path = self.validate_source(source_path) mappings = [] seen_paths = set() - + # Scan structured directories self._scan_meetings_directory(source_path, mappings, seen_paths) self._scan_decisions_directory(source_path, mappings, seen_paths) self._scan_reports_directory(source_path, mappings, seen_paths) self._scan_projects_directory(source_path, mappings, seen_paths) - + # Scan root for additional business documents for file_path in source_path.iterdir(): if file_path.is_file() and file_path not in seen_paths: @@ -77,10 +83,12 @@ def scan(self, source_path: str | Path) -> list[FileMapping]: elif file_path.suffix in self.SPREADSHEET_EXTENSIONS: mappings.append(self._create_spreadsheet_mapping(file_path)) seen_paths.add(file_path) - + return mappings - - def _scan_meetings_directory(self, base_path: Path, mappings: list[FileMapping], seen_paths: set): + + def _scan_meetings_directory( + self, base_path: Path, mappings: list[FileMapping], seen_paths: set + ): """Scan for meeting documents.""" for dir_name in self.MEETINGS_DIRS: meetings_dir = base_path / dir_name @@ -89,14 +97,18 @@ def _scan_meetings_directory(self, base_path: Path, mappings: list[FileMapping], if file_path.is_file() and file_path not in seen_paths: if file_path.suffix in self.DOC_EXTENSIONS: mapping = self._create_meeting_mapping(file_path) - + # Try to extract date from path or filename - date_match = re.search(r'(\d{4})[-_/](\d{1,2})[-_/](\d{1,2})', str(file_path)) + date_match = re.search( + r'(\d{4})[-_/](\d{1,2})[-_/](\d{1,2})', str(file_path) + ) if date_match: year, month, day = date_match.groups() - mapping.custom_metadata["meeting_date"] = f"{year}-{month:0>2}-{day:0>2}" + mapping.custom_metadata["meeting_date"] = ( + f"{year}-{month:0>2}-{day:0>2}" + ) mapping.tags.append(f"year:{year}") - + # Group by subdirectory (e.g., team meetings) rel_path = file_path.relative_to(meetings_dir) if len(rel_path.parts) > 1: @@ -105,11 +117,13 @@ def _scan_meetings_directory(self, base_path: Path, mappings: list[FileMapping], mapping.custom_metadata["team"] = team else: mapping.collection = "meetings" - + mappings.append(mapping) seen_paths.add(file_path) - - def _scan_decisions_directory(self, base_path: Path, mappings: list[FileMapping], seen_paths: set): + + def _scan_decisions_directory( + self, base_path: Path, mappings: list[FileMapping], seen_paths: set + ): """Scan for decision documents.""" for dir_name in self.DECISIONS_DIRS: decisions_dir = base_path / dir_name @@ -118,17 +132,21 @@ def _scan_decisions_directory(self, base_path: Path, mappings: list[FileMapping] if file_path.is_file() and file_path not in seen_paths: if file_path.suffix in self.DOC_EXTENSIONS: mapping = self._create_decision_mapping(file_path) - + # Extract decision number if present (e.g., ADR-001) num_match = re.search(r'(\d{3,4})', file_path.stem) if num_match: - mapping.custom_metadata["decision_number"] = num_match.group(1) - + mapping.custom_metadata["decision_number"] = ( + num_match.group(1) + ) + mapping.collection = "decisions" mappings.append(mapping) seen_paths.add(file_path) - - def _scan_reports_directory(self, base_path: Path, mappings: list[FileMapping], seen_paths: set): + + def _scan_reports_directory( + self, base_path: Path, mappings: list[FileMapping], seen_paths: set + ): """Scan for reports and analyses.""" for dir_name in self.REPORTS_DIRS: reports_dir = base_path / dir_name @@ -137,16 +155,16 @@ def _scan_reports_directory(self, base_path: Path, mappings: list[FileMapping], if file_path.is_file() and file_path not in seen_paths: if file_path.suffix in self.DOC_EXTENSIONS: mapping = self._create_report_mapping(file_path) - + # Categorize by report type from subdirectory - rel_path = file_path.relative_to(reports_dir) + rel_path = file_path.relative_to(reports_dir) if len(rel_path.parts) > 1: report_type = rel_path.parts[0] mapping.collection = f"reports/{report_type}" mapping.custom_metadata["report_type"] = report_type else: mapping.collection = "reports" - + mappings.append(mapping) seen_paths.add(file_path) elif file_path.suffix in self.SPREADSHEET_EXTENSIONS: @@ -155,8 +173,10 @@ def _scan_reports_directory(self, base_path: Path, mappings: list[FileMapping], mapping.tags.append("data-analysis") mappings.append(mapping) seen_paths.add(file_path) - - def _scan_projects_directory(self, base_path: Path, mappings: list[FileMapping], seen_paths: set): + + def _scan_projects_directory( + self, base_path: Path, mappings: list[FileMapping], seen_paths: set + ): """Scan for project documentation.""" for dir_name in self.PROJECTS_DIRS: projects_dir = base_path / dir_name @@ -165,48 +185,58 @@ def _scan_projects_directory(self, base_path: Path, mappings: list[FileMapping], for project_dir in projects_dir.iterdir(): if project_dir.is_dir(): project_name = project_dir.name - + for file_path in project_dir.rglob("*"): if file_path.is_file() and file_path not in seen_paths: if file_path.suffix in self.DOC_EXTENSIONS: - mapping = self._categorize_project_document(file_path, project_name) + mapping = self._categorize_project_document( + file_path, project_name + ) mapping.collection = f"projects/{project_name}" mappings.append(mapping) seen_paths.add(file_path) - + def _categorize_document(self, file_path: Path) -> FileMapping: """Categorize a business document by its name.""" name_lower = file_path.stem.lower() - + # Check meeting patterns - if any(pattern.replace("*", "") in name_lower for pattern in self.MEETING_PATTERNS): + if any( + pattern.replace("*", "") in name_lower for pattern in self.MEETING_PATTERNS + ): return self._create_meeting_mapping(file_path) - + # Check decision patterns - if any(pattern.replace("*", "") in name_lower for pattern in self.DECISION_PATTERNS): + if any( + pattern.replace("*", "") in name_lower for pattern in self.DECISION_PATTERNS + ): return self._create_decision_mapping(file_path) - + # Check report patterns - if any(pattern.replace("*", "") in name_lower for pattern in self.REPORT_PATTERNS): + if any( + pattern.replace("*", "") in name_lower for pattern in self.REPORT_PATTERNS + ): return self._create_report_mapping(file_path) - + # Check project patterns - if any(pattern.replace("*", "") in name_lower for pattern in self.PROJECT_PATTERNS): + if any( + pattern.replace("*", "") in name_lower for pattern in self.PROJECT_PATTERNS + ): return self._create_project_mapping(file_path, "general") - + # Default business document return FileMapping( path=file_path, title=f"Document - {file_path.stem}", tags=["business", "document", file_path.suffix[1:]], - custom_metadata={"document_type": "general"} + custom_metadata={"document_type": "general"}, ) - + def _create_meeting_mapping(self, file_path: Path) -> FileMapping: """Create mapping for a meeting document.""" meeting_type = "general" name_lower = file_path.stem.lower() - + if "standup" in name_lower or "daily" in name_lower: meeting_type = "standup" elif "retro" in name_lower or "retrospective" in name_lower: @@ -217,22 +247,19 @@ def _create_meeting_mapping(self, file_path: Path) -> FileMapping: meeting_type = "review" elif "1on1" in name_lower or "1-on-1" in name_lower: meeting_type = "one-on-one" - + return FileMapping( path=file_path, title=f"Meeting - {file_path.stem}", tags=["meeting", meeting_type, "discussion"], - custom_metadata={ - "meeting_type": meeting_type, - "document_type": "meeting" - } + custom_metadata={"meeting_type": meeting_type, "document_type": "meeting"}, ) - + def _create_decision_mapping(self, file_path: Path) -> FileMapping: """Create mapping for a decision document.""" decision_type = "general" name_lower = file_path.stem.lower() - + if "adr" in name_lower: decision_type = "adr" elif "rfc" in name_lower: @@ -241,7 +268,7 @@ def _create_decision_mapping(self, file_path: Path) -> FileMapping: decision_type = "proposal" elif "design" in name_lower: decision_type = "design" - + return FileMapping( path=file_path, title=f"Decision - {file_path.stem}", @@ -249,15 +276,15 @@ def _create_decision_mapping(self, file_path: Path) -> FileMapping: custom_metadata={ "decision_type": decision_type, "document_type": "decision", - "status": "draft" # Would be updated during enrichment - } + "status": "draft", # Would be updated during enrichment + }, ) - + def _create_report_mapping(self, file_path: Path) -> FileMapping: """Create mapping for a report.""" report_type = "general" name_lower = file_path.stem.lower() - + if "quarterly" in name_lower or "q1" in name_lower or "q2" in name_lower: report_type = "quarterly" elif "annual" in name_lower or "yearly" in name_lower: @@ -268,41 +295,39 @@ def _create_report_mapping(self, file_path: Path) -> FileMapping: report_type = "analysis" elif "summary" in name_lower: report_type = "summary" - + return FileMapping( path=file_path, title=f"Report - {file_path.stem}", tags=["report", report_type, "analysis"], - custom_metadata={ - "report_type": report_type, - "document_type": "report" - } + custom_metadata={"report_type": report_type, "document_type": "report"}, ) - - def _create_project_mapping(self, file_path: Path, project_name: str) -> FileMapping: + + def _create_project_mapping( + self, file_path: Path, project_name: str + ) -> FileMapping: """Create mapping for a project document.""" return FileMapping( path=file_path, title=f"Project - {file_path.stem}", tags=["project", project_name.lower(), "planning"], - custom_metadata={ - "project_name": project_name, - "document_type": "project" - } + custom_metadata={"project_name": project_name, "document_type": "project"}, ) - - def _categorize_project_document(self, file_path: Path, project_name: str) -> FileMapping: + + def _categorize_project_document( + self, file_path: Path, project_name: str + ) -> FileMapping: """Categorize a document within a project.""" mapping = self._categorize_document(file_path) mapping.tags.append(project_name.lower()) mapping.custom_metadata["project_name"] = project_name return mapping - + def _create_spreadsheet_mapping(self, file_path: Path) -> FileMapping: """Create mapping for spreadsheet files.""" name_lower = file_path.stem.lower() sheet_type = "data" - + if "budget" in name_lower: sheet_type = "budget" elif "forecast" in name_lower: @@ -311,43 +336,69 @@ def _create_spreadsheet_mapping(self, file_path: Path) -> FileMapping: sheet_type = "metrics" elif "tracker" in name_lower or "tracking" in name_lower: sheet_type = "tracker" - + return FileMapping( path=file_path, title=f"Spreadsheet - {file_path.stem}", tags=["spreadsheet", sheet_type, "data"], custom_metadata={ "spreadsheet_type": sheet_type, - "format": file_path.suffix[1:] - } + "format": file_path.suffix[1:], + }, ) - - def define_collections(self, file_mappings: list[FileMapping]) -> list[CollectionDefinition]: + + def define_collections( + self, file_mappings: list[FileMapping] + ) -> list[CollectionDefinition]: """Define collections for business documents.""" collections = [] seen_collections = set() - + # Core business collections collection_defs = [ - ("meetings", "Meeting Notes", "Team meetings and discussions", ["meetings", "collaboration"]), - ("decisions", "Decisions & Proposals", "Strategic decisions and proposals", ["decisions", "strategy"]), - ("reports", "Reports & Analyses", "Business reports and data analysis", ["reports", "insights"]), - ("projects", "Project Documentation", "Project plans and documentation", ["projects", "execution"]) + ( + "meetings", + "Meeting Notes", + "Team meetings and discussions", + ["meetings", "collaboration"], + ), + ( + "decisions", + "Decisions & Proposals", + "Strategic decisions and proposals", + ["decisions", "strategy"], + ), + ( + "reports", + "Reports & Analyses", + "Business reports and data analysis", + ["reports", "insights"], + ), + ( + "projects", + "Project Documentation", + "Project plans and documentation", + ["projects", "execution"], + ), ] - + position = 0 for name, title, desc, tags in collection_defs: - if any(m.collection and m.collection.startswith(name) for m in file_mappings): - collections.append(CollectionDefinition( - name=name, - title=title, - description=desc, - tags=tags, - position=position - )) + if any( + m.collection and m.collection.startswith(name) for m in file_mappings + ): + collections.append( + CollectionDefinition( + name=name, + title=title, + description=desc, + tags=tags, + position=position, + ) + ) seen_collections.add(name) position += 10 - + # Add sub-collections sub_collections = set() for mapping in file_mappings: @@ -357,29 +408,29 @@ def define_collections(self, file_mappings: list[FileMapping]) -> list[Collectio parent, child = parts if parent in seen_collections: sub_collections.add((parent, child)) - + for parent, child in sorted(sub_collections): coll_name = f"{parent}/{child}" - collections.append(CollectionDefinition( - name=coll_name, - title=f"{child.replace('-', ' ').title()}", - description=f"Documents for {child}", - tags=[parent, child.lower()], - parent=parent, - position=position - )) + collections.append( + CollectionDefinition( + name=coll_name, + title=f"{child.replace('-', ' ').title()}", + description=f"Documents for {child}", + tags=[parent, child.lower()], + parent=parent, + position=position, + ) + ) position += 1 - + return collections - + def discover_relationships( - self, - file_mappings: list[FileMapping], - dataset: FrameDataset + self, file_mappings: list[FileMapping], dataset: FrameDataset ) -> list[dict[str, Any]]: """Discover relationships between business documents.""" relationships = [] - + # Group documents by project project_docs = {} for mapping in file_mappings: @@ -388,7 +439,7 @@ def discover_relationships( if project not in project_docs: project_docs[project] = [] project_docs[project].append(mapping) - + # Link project documents for project, docs in project_docs.items(): # Find the main project doc @@ -397,103 +448,123 @@ def discover_relationships( if "plan" in doc.title.lower() or "overview" in doc.title.lower(): main_doc = doc break - + if main_doc: for doc in docs: if doc != main_doc: - relationships.append({ - "source": str(main_doc.path), - "target": str(doc.path), - "type": "contains", - "description": f"Project document for {project}" - }) - + relationships.append( + { + "source": str(main_doc.path), + "target": str(doc.path), + "type": "contains", + "description": f"Project document for {project}", + } + ) + # Link decisions to related documents decisions = [m for m in file_mappings if "decision" in m.tags] meetings = [m for m in file_mappings if "meeting" in m.tags] - + # Simple date-based matching for decisions and meetings for decision in decisions: decision_date = decision.custom_metadata.get("decision_date") if decision_date: for meeting in meetings: meeting_date = meeting.custom_metadata.get("meeting_date") - if meeting_date and abs(( - datetime.fromisoformat(decision_date) - - datetime.fromisoformat(meeting_date) - ).days) <= 7: - relationships.append({ - "source": str(decision.path), - "target": str(meeting.path), - "type": "discussed_in", - "description": "Decision discussed in meeting" - }) - + if ( + meeting_date + and abs( + ( + datetime.fromisoformat(decision_date) + - datetime.fromisoformat(meeting_date) + ).days + ) + <= 7 + ): + relationships.append( + { + "source": str(decision.path), + "target": str(meeting.path), + "type": "discussed_in", + "description": "Decision discussed in meeting", + } + ) + return relationships - - def suggest_enrichments(self, file_mappings: list[FileMapping]) -> list[EnrichmentSuggestion]: + + def suggest_enrichments( + self, file_mappings: list[FileMapping] + ) -> list[EnrichmentSuggestion]: """Suggest business-specific enrichments.""" suggestions = [] - + # Meeting note enrichments - suggestions.append(EnrichmentSuggestion( - file_pattern="**/meeting*.md", - enhancement_config={ - "enhancements": { - "context": "business_context", - "custom_metadata": "meeting_metadata" - } - }, - priority=10 - )) - + suggestions.append( + EnrichmentSuggestion( + file_pattern="**/meeting*.md", + enhancement_config={ + "enhancements": { + "context": "business_context", + "custom_metadata": "meeting_metadata", + } + }, + priority=10, + ) + ) + # Decision document enrichments - suggestions.append(EnrichmentSuggestion( - file_pattern="**/decision*.md", - enhancement_config={ - "enhancements": { - "context": "Summarize the decision, its rationale, and impact", - "tags": "Extract: stakeholders, decision-type, impact-area", - "custom_metadata": { - "decision_status": "Status: proposed, approved, or rejected", - "stakeholders": "List key stakeholders", - "impact_assessment": "Business impact assessment" + suggestions.append( + EnrichmentSuggestion( + file_pattern="**/decision*.md", + enhancement_config={ + "enhancements": { + "context": "Summarize the decision, its rationale, and impact", + "tags": "Extract: stakeholders, decision-type, impact-area", + "custom_metadata": { + "decision_status": "Status: proposed, approved, or rejected", + "stakeholders": "List key stakeholders", + "impact_assessment": "Business impact assessment", + }, } - } - }, - priority=9 - )) - + }, + priority=9, + ) + ) + # Report enrichments - suggestions.append(EnrichmentSuggestion( - file_pattern="**/report*.md", - enhancement_config={ - "enhancements": { - "context": "Summarize key findings and recommendations", - "custom_metadata": { - "report_period": "Time period covered", - "key_metrics": "Main metrics or KPIs discussed", - "recommendations": "Key recommendations" + suggestions.append( + EnrichmentSuggestion( + file_pattern="**/report*.md", + enhancement_config={ + "enhancements": { + "context": "Summarize key findings and recommendations", + "custom_metadata": { + "report_period": "Time period covered", + "key_metrics": "Main metrics or KPIs discussed", + "recommendations": "Key recommendations", + }, } - } - }, - priority=7 - )) - + }, + priority=7, + ) + ) + # Spreadsheet enrichments - suggestions.append(EnrichmentSuggestion( - file_pattern="**/*.xlsx", - enhancement_config={ - "enhancements": { - "context": "Describe the data and its business purpose", - "custom_metadata": { - "data_description": "What data is tracked", - "update_frequency": "How often is this updated", - "business_use": "How is this data used" + suggestions.append( + EnrichmentSuggestion( + file_pattern="**/*.xlsx", + enhancement_config={ + "enhancements": { + "context": "Describe the data and its business purpose", + "custom_metadata": { + "data_description": "What data is tracked", + "update_frequency": "How often is this updated", + "business_use": "How is this data used", + }, } - } - }, - priority=5 - )) - - return suggestions \ No newline at end of file + }, + priority=5, + ) + ) + + return suggestions diff --git a/contextframe/templates/examples/__init__.py b/contextframe/templates/examples/__init__.py index 1a852a6..8edbae5 100644 --- a/contextframe/templates/examples/__init__.py +++ b/contextframe/templates/examples/__init__.py @@ -1 +1 @@ -"""Example Context Templates for custom use cases.""" \ No newline at end of file +"""Example Context Templates for custom use cases.""" diff --git a/contextframe/templates/registry.py b/contextframe/templates/registry.py index 4c0c586..e2cc6a3 100644 --- a/contextframe/templates/registry.py +++ b/contextframe/templates/registry.py @@ -9,19 +9,19 @@ class TemplateRegistry: """Registry for Context Templates. - + Manages built-in and custom templates, providing discovery and instantiation capabilities. """ - + def __init__(self): """Initialize the template registry.""" self._templates: dict[str, type[ContextTemplate]] = {} self._instances: dict[str, ContextTemplate] = {} - + # Register built-in templates self._register_builtin_templates() - + def _register_builtin_templates(self): """Register all built-in templates.""" builtin_templates = [ @@ -29,58 +29,58 @@ def _register_builtin_templates(self): ResearchTemplate, BusinessTemplate, ] - + for template_class in builtin_templates: # Create an instance to get name instance = template_class() self.register(instance.name, template_class) - + def register(self, name: str, template_class: type[ContextTemplate]): """Register a template class. - + Args: name: Template name template_class: Template class (must inherit from ContextTemplate) - + Raises: ValueError: If name already registered or invalid class """ if name in self._templates: raise ValueError(f"Template '{name}' is already registered") - + if not issubclass(template_class, ContextTemplate): raise ValueError( f"Template class must inherit from ContextTemplate, got {template_class}" ) - + self._templates[name] = template_class # Clear cached instance if exists self._instances.pop(name, None) - + def unregister(self, name: str): """Unregister a template. - + Args: name: Template name to unregister - + Raises: KeyError: If template not found """ if name not in self._templates: raise KeyError(f"Template '{name}' not found") - + del self._templates[name] self._instances.pop(name, None) - + def get(self, name: str) -> ContextTemplate: """Get a template instance by name. - + Args: name: Template name - + Returns: Template instance - + Raises: KeyError: If template not found """ @@ -89,89 +89,108 @@ def get(self, name: str) -> ContextTemplate: f"Template '{name}' not found. " f"Available templates: {', '.join(self.list_names())}" ) - + # Cache instances for reuse if name not in self._instances: self._instances[name] = self._templates[name]() - + return self._instances[name] - + def list_names(self) -> list[str]: """List all registered template names. - + Returns: List of template names """ return sorted(self._templates.keys()) - + def list_templates(self) -> list[dict[str, str]]: """List all registered templates with metadata. - + Returns: List of template info dictionaries """ templates = [] for name in self.list_names(): instance = self.get(name) - templates.append({ - "name": name, - "description": instance.description, - "class": instance.__class__.__name__ - }) + templates.append( + { + "name": name, + "description": instance.description, + "class": instance.__class__.__name__, + } + ) return templates - + def find_by_path(self, path: str) -> str | None: """Find the best template for a given path. - + This method attempts to identify the most appropriate template based on directory structure and file patterns. - + Args: path: Directory path to analyze - + Returns: Template name if found, None otherwise """ from pathlib import Path - + path_obj = Path(path) if not path_obj.exists() or not path_obj.is_dir(): return None - + # Check for software project indicators software_indicators = [ - "src", "lib", "tests", "test", "package.json", - "requirements.txt", "setup.py", "Cargo.toml", "go.mod" + "src", + "lib", + "tests", + "test", + "package.json", + "requirements.txt", + "setup.py", + "Cargo.toml", + "go.mod", ] if any((path_obj / indicator).exists() for indicator in software_indicators): return "software_project" - + # Check for research indicators research_indicators = [ - "papers", "data", "notebooks", "analysis", - "results", "experiments" + "papers", + "data", + "notebooks", + "analysis", + "results", + "experiments", ] if any((path_obj / indicator).exists() for indicator in research_indicators): # Additional check for .bib files if list(path_obj.glob("**/*.bib")) or list(path_obj.glob("**/*.ipynb")): return "research" - + # Check for business indicators business_indicators = [ - "meetings", "decisions", "reports", "projects", - "proposals", "minutes" + "meetings", + "decisions", + "reports", + "projects", + "proposals", + "minutes", ] if any((path_obj / indicator).exists() for indicator in business_indicators): return "business" - + # Check file patterns if no directory indicators files = list(path_obj.iterdir()) - + # Count file types - code_files = sum(1 for f in files if f.suffix in {".py", ".js", ".java", ".cpp"}) + code_files = sum( + 1 for f in files if f.suffix in {".py", ".js", ".java", ".cpp"} + ) doc_files = sum(1 for f in files if f.suffix in {".md", ".pdf", ".docx"}) data_files = sum(1 for f in files if f.suffix in {".csv", ".xlsx", ".json"}) - + # Heuristic based on file types if code_files > doc_files and code_files > data_files: return "software_project" @@ -179,7 +198,7 @@ def find_by_path(self, path: str) -> str | None: return "research" elif doc_files > code_files: return "business" - + return None @@ -189,13 +208,13 @@ def find_by_path(self, path: str) -> str | None: def get_template(name: str) -> ContextTemplate: """Get a template by name from the global registry. - + Args: name: Template name - + Returns: Template instance - + Raises: KeyError: If template not found """ @@ -204,7 +223,7 @@ def get_template(name: str) -> ContextTemplate: def list_templates() -> list[dict[str, str]]: """List all available templates from the global registry. - + Returns: List of template info dictionaries """ @@ -213,7 +232,7 @@ def list_templates() -> list[dict[str, str]]: def register_template(name: str, template_class: type[ContextTemplate]): """Register a custom template in the global registry. - + Args: name: Template name template_class: Template class @@ -223,11 +242,11 @@ def register_template(name: str, template_class: type[ContextTemplate]): def find_template_for_path(path: str) -> str | None: """Find the best template for a given directory path. - + Args: path: Directory path - + Returns: Template name if found, None otherwise """ - return _registry.find_by_path(path) \ No newline at end of file + return _registry.find_by_path(path) diff --git a/contextframe/templates/research.py b/contextframe/templates/research.py index 0387fc1..d83eae8 100644 --- a/contextframe/templates/research.py +++ b/contextframe/templates/research.py @@ -14,52 +14,52 @@ class ResearchTemplate(ContextTemplate): """Template for research papers and academic documents. - + Handles: - Research papers (PDF, LaTeX, Markdown) - Literature reviews and citations - Data files and notebooks - Presentation slides - Author information - + Automatically: - Extracts paper metadata (title, authors, abstract) - Creates citation relationships - Groups papers by topic/category - Suggests academic enrichments """ - + # Document patterns PAPER_EXTENSIONS = {".pdf", ".tex", ".md", ".docx", ".doc"} DATA_EXTENSIONS = {".csv", ".xlsx", ".xls", ".json", ".parquet", ".h5", ".hdf5"} NOTEBOOK_EXTENSIONS = {".ipynb", ".rmd"} PRESENTATION_EXTENSIONS = {".pptx", ".ppt", ".key", ".odp"} - + # Common research directories PAPERS_DIRS = {"papers", "articles", "publications", "manuscripts"} DATA_DIRS = {"data", "datasets", "raw_data", "processed_data"} FIGURES_DIRS = {"figures", "plots", "images", "graphics"} RESULTS_DIRS = {"results", "output", "analysis"} - + def __init__(self): """Initialize the research template.""" super().__init__( name="research", - description="Template for research papers, academic documents, and scientific data" + description="Template for research papers, academic documents, and scientific data", ) - + def scan(self, source_path: str | Path) -> list[FileMapping]: """Scan research directory and map documents.""" source_path = self.validate_source(source_path) mappings = [] seen_paths = set() - + # Look for common research structures self._scan_papers_directory(source_path, mappings, seen_paths) self._scan_data_directory(source_path, mappings, seen_paths) self._scan_notebooks(source_path, mappings, seen_paths) self._scan_bibliography(source_path, mappings, seen_paths) - + # Scan root for additional papers for file_path in source_path.iterdir(): if file_path.is_file() and file_path not in seen_paths: @@ -69,10 +69,12 @@ def scan(self, source_path: str | Path) -> list[FileMapping]: elif file_path.suffix in self.PRESENTATION_EXTENSIONS: mappings.append(self._create_presentation_mapping(file_path)) seen_paths.add(file_path) - + return mappings - - def _scan_papers_directory(self, base_path: Path, mappings: list[FileMapping], seen_paths: set): + + def _scan_papers_directory( + self, base_path: Path, mappings: list[FileMapping], seen_paths: set + ): """Scan for research papers.""" for dir_name in self.PAPERS_DIRS: papers_dir = base_path / dir_name @@ -81,7 +83,7 @@ def _scan_papers_directory(self, base_path: Path, mappings: list[FileMapping], s if file_path.is_file() and file_path not in seen_paths: if file_path.suffix in self.PAPER_EXTENSIONS: mapping = self._create_paper_mapping(file_path) - + # Try to categorize by subdirectory rel_path = file_path.relative_to(papers_dir) if len(rel_path.parts) > 1: @@ -90,11 +92,13 @@ def _scan_papers_directory(self, base_path: Path, mappings: list[FileMapping], s mapping.tags.append(category.lower()) else: mapping.collection = "papers" - + mappings.append(mapping) seen_paths.add(file_path) - - def _scan_data_directory(self, base_path: Path, mappings: list[FileMapping], seen_paths: set): + + def _scan_data_directory( + self, base_path: Path, mappings: list[FileMapping], seen_paths: set + ): """Scan for data files.""" for dir_name in self.DATA_DIRS: data_dir = base_path / dir_name @@ -102,19 +106,25 @@ def _scan_data_directory(self, base_path: Path, mappings: list[FileMapping], see for file_path in data_dir.rglob("*"): if file_path.is_file() and file_path not in seen_paths: if file_path.suffix in self.DATA_EXTENSIONS: - mappings.append(FileMapping( - path=file_path, - title=f"Dataset - {file_path.stem}", - collection="data", - tags=["data", "dataset", file_path.suffix[1:]], - custom_metadata={ - "data_type": file_path.suffix[1:], - "data_category": self._categorize_data(file_path.name) - } - )) + mappings.append( + FileMapping( + path=file_path, + title=f"Dataset - {file_path.stem}", + collection="data", + tags=["data", "dataset", file_path.suffix[1:]], + custom_metadata={ + "data_type": file_path.suffix[1:], + "data_category": self._categorize_data( + file_path.name + ), + }, + ) + ) seen_paths.add(file_path) - - def _scan_notebooks(self, base_path: Path, mappings: list[FileMapping], seen_paths: set): + + def _scan_notebooks( + self, base_path: Path, mappings: list[FileMapping], seen_paths: set + ): """Scan for computational notebooks.""" for file_path in base_path.rglob("*"): if file_path.is_file() and file_path not in seen_paths: @@ -128,57 +138,71 @@ def _scan_notebooks(self, base_path: Path, mappings: list[FileMapping], seen_pat notebook_type = "exploration" elif "model" in name_lower or "train" in name_lower: notebook_type = "modeling" - - mappings.append(FileMapping( - path=file_path, - title=f"Notebook - {file_path.stem}", - collection="notebooks", - tags=["notebook", notebook_type, "computational"], - custom_metadata={ - "notebook_type": notebook_type, - "format": file_path.suffix[1:] - } - )) + + mappings.append( + FileMapping( + path=file_path, + title=f"Notebook - {file_path.stem}", + collection="notebooks", + tags=["notebook", notebook_type, "computational"], + custom_metadata={ + "notebook_type": notebook_type, + "format": file_path.suffix[1:], + }, + ) + ) seen_paths.add(file_path) - - def _scan_bibliography(self, base_path: Path, mappings: list[FileMapping], seen_paths: set): + + def _scan_bibliography( + self, base_path: Path, mappings: list[FileMapping], seen_paths: set + ): """Scan for bibliography files.""" - bib_patterns = ["*.bib", "*.bibtex", "references.*", "bibliography.*", "citations.*"] - + bib_patterns = [ + "*.bib", + "*.bibtex", + "references.*", + "bibliography.*", + "citations.*", + ] + for pattern in bib_patterns: for file_path in base_path.rglob(pattern): if file_path.is_file() and file_path not in seen_paths: - mappings.append(FileMapping( - path=file_path, - title=f"Bibliography - {file_path.name}", - tags=["bibliography", "references", "citations"], - custom_metadata={ - "bib_type": "bibtex" if file_path.suffix in [".bib", ".bibtex"] else "other", - "priority": "high" - } - )) + mappings.append( + FileMapping( + path=file_path, + title=f"Bibliography - {file_path.name}", + tags=["bibliography", "references", "citations"], + custom_metadata={ + "bib_type": "bibtex" + if file_path.suffix in [".bib", ".bibtex"] + else "other", + "priority": "high", + }, + ) + ) seen_paths.add(file_path) - + def _create_paper_mapping(self, file_path: Path) -> FileMapping: """Create mapping for a research paper.""" # Try to extract metadata from filename tags = ["paper", "research", file_path.suffix[1:]] custom_meta = {"document_type": "paper"} - + # Common filename patterns name = file_path.stem year_match = re.search(r'(\d{4})', name) if year_match: custom_meta["year"] = year_match.group(1) tags.append(f"year:{year_match.group(1)}") - + # Check for common paper types in filename name_lower = name.lower() if "review" in name_lower: tags.append("review") custom_meta["paper_type"] = "review" elif "survey" in name_lower: - tags.append("survey") + tags.append("survey") custom_meta["paper_type"] = "survey" elif "thesis" in name_lower: tags.append("thesis") @@ -186,18 +210,18 @@ def _create_paper_mapping(self, file_path: Path) -> FileMapping: elif "dissertation" in name_lower: tags.append("dissertation") custom_meta["paper_type"] = "dissertation" - + return FileMapping( path=file_path, title=f"Paper - {name}", tags=tags, - custom_metadata=custom_meta + custom_metadata=custom_meta, ) - + def _create_presentation_mapping(self, file_path: Path) -> FileMapping: """Create mapping for a presentation.""" tags = ["presentation", "slides", file_path.suffix[1:]] - + # Check for presentation type name_lower = file_path.stem.lower() if "poster" in name_lower: @@ -211,7 +235,7 @@ def _create_presentation_mapping(self, file_path: Path) -> FileMapping: pres_type = "defense" else: pres_type = "general" - + return FileMapping( path=file_path, title=f"Presentation - {file_path.stem}", @@ -219,37 +243,63 @@ def _create_presentation_mapping(self, file_path: Path) -> FileMapping: tags=tags, custom_metadata={ "presentation_type": pres_type, - "format": file_path.suffix[1:] - } + "format": file_path.suffix[1:], + }, ) - - def define_collections(self, file_mappings: list[FileMapping]) -> list[CollectionDefinition]: + + def define_collections( + self, file_mappings: list[FileMapping] + ) -> list[CollectionDefinition]: """Define collections for research documents.""" collections = [] seen_collections = set() - + # Core research collections collection_defs = [ - ("papers", "Research Papers", "Published and draft research papers", ["papers", "research"]), + ( + "papers", + "Research Papers", + "Published and draft research papers", + ["papers", "research"], + ), ("data", "Datasets", "Research data and datasets", ["data", "datasets"]), - ("notebooks", "Computational Notebooks", "Analysis notebooks and experiments", ["notebooks", "analysis"]), - ("presentations", "Presentations", "Slides and poster presentations", ["presentations", "slides"]), - ("bibliography", "References", "Bibliography and citations", ["references", "citations"]) + ( + "notebooks", + "Computational Notebooks", + "Analysis notebooks and experiments", + ["notebooks", "analysis"], + ), + ( + "presentations", + "Presentations", + "Slides and poster presentations", + ["presentations", "slides"], + ), + ( + "bibliography", + "References", + "Bibliography and citations", + ["references", "citations"], + ), ] - + position = 0 for name, title, desc, tags in collection_defs: - if any(m.collection and m.collection.startswith(name) for m in file_mappings): - collections.append(CollectionDefinition( - name=name, - title=title, - description=desc, - tags=tags, - position=position - )) + if any( + m.collection and m.collection.startswith(name) for m in file_mappings + ): + collections.append( + CollectionDefinition( + name=name, + title=title, + description=desc, + tags=tags, + position=position, + ) + ) seen_collections.add(name) position += 10 - + # Add sub-collections for paper categories paper_subcollections = set() for mapping in file_mappings: @@ -257,52 +307,54 @@ def define_collections(self, file_mappings: list[FileMapping]) -> list[Collectio parts = mapping.collection.split("/") if len(parts) > 1: paper_subcollections.add(parts[1]) - + for subcoll in sorted(paper_subcollections): coll_name = f"papers/{subcoll}" if coll_name not in seen_collections: - collections.append(CollectionDefinition( - name=coll_name, - title=f"Papers - {subcoll.title()}", - description=f"Research papers in {subcoll} category", - tags=["papers", subcoll.lower()], - parent="papers", - position=position - )) + collections.append( + CollectionDefinition( + name=coll_name, + title=f"Papers - {subcoll.title()}", + description=f"Research papers in {subcoll} category", + tags=["papers", subcoll.lower()], + parent="papers", + position=position, + ) + ) position += 1 - + return collections - + def discover_relationships( - self, - file_mappings: list[FileMapping], - dataset: FrameDataset + self, file_mappings: list[FileMapping], dataset: FrameDataset ) -> list[dict[str, Any]]: """Discover citation and authorship relationships.""" relationships = [] - + # Find bibliography files bib_files = [m for m in file_mappings if "bibliography" in m.tags] - + # Find papers papers = [m for m in file_mappings if "paper" in m.tags] - + # Create relationships between papers and bibliography for paper in papers: for bib in bib_files: # Check if they're in the same directory area if paper.path.parent == bib.path.parent: - relationships.append({ - "source": str(paper.path), - "target": str(bib.path), - "type": "references", - "description": "Paper references bibliography" - }) - + relationships.append( + { + "source": str(paper.path), + "target": str(bib.path), + "type": "references", + "description": "Paper references bibliography", + } + ) + # Find related notebooks and data notebooks = [m for m in file_mappings if "notebook" in m.tags] data_files = [m for m in file_mappings if "data" in m.tags] - + # Match notebooks to data files by name similarity for notebook in notebooks: nb_stem = notebook.path.stem.lower() @@ -310,82 +362,94 @@ def discover_relationships( data_stem = data.path.stem.lower() # Simple matching - could be improved if data_stem in nb_stem or nb_stem in data_stem: - relationships.append({ - "source": str(notebook.path), - "target": str(data.path), - "type": "uses", - "description": "Notebook analyzes dataset" - }) - + relationships.append( + { + "source": str(notebook.path), + "target": str(data.path), + "type": "uses", + "description": "Notebook analyzes dataset", + } + ) + return relationships - - def suggest_enrichments(self, file_mappings: list[FileMapping]) -> list[EnrichmentSuggestion]: + + def suggest_enrichments( + self, file_mappings: list[FileMapping] + ) -> list[EnrichmentSuggestion]: """Suggest research-specific enrichments.""" suggestions = [] - + # Paper enrichments - suggestions.append(EnrichmentSuggestion( - file_pattern="**/*.pdf", - enhancement_config={ - "enhancements": { - "context": "research_context", - "tags": "Extract: authors, keywords, research area, methodology", - "custom_metadata": "research_metadata" - } - }, - priority=10 - )) - + suggestions.append( + EnrichmentSuggestion( + file_pattern="**/*.pdf", + enhancement_config={ + "enhancements": { + "context": "research_context", + "tags": "Extract: authors, keywords, research area, methodology", + "custom_metadata": "research_metadata", + } + }, + priority=10, + ) + ) + # Notebook enrichments - suggestions.append(EnrichmentSuggestion( - file_pattern="**/*.ipynb", - enhancement_config={ - "enhancements": { - "context": "Summarize the analysis performed and key findings", - "custom_metadata": { - "libraries_used": "List main Python/R libraries used", - "analysis_type": "Type of analysis (statistical, ML, visualization)", - "key_findings": "Main results or insights" + suggestions.append( + EnrichmentSuggestion( + file_pattern="**/*.ipynb", + enhancement_config={ + "enhancements": { + "context": "Summarize the analysis performed and key findings", + "custom_metadata": { + "libraries_used": "List main Python/R libraries used", + "analysis_type": "Type of analysis (statistical, ML, visualization)", + "key_findings": "Main results or insights", + }, } - } - }, - priority=8 - )) - + }, + priority=8, + ) + ) + # Data file enrichments - suggestions.append(EnrichmentSuggestion( - file_pattern="**/*.csv", - enhancement_config={ - "enhancements": { - "context": "Describe the dataset structure and content", - "custom_metadata": { - "columns": "List main columns/features", - "row_count": "Number of records", - "data_source": "Origin of the data" + suggestions.append( + EnrichmentSuggestion( + file_pattern="**/*.csv", + enhancement_config={ + "enhancements": { + "context": "Describe the dataset structure and content", + "custom_metadata": { + "columns": "List main columns/features", + "row_count": "Number of records", + "data_source": "Origin of the data", + }, } - } - }, - priority=5 - )) - + }, + priority=5, + ) + ) + # Bibliography enrichments - suggestions.append(EnrichmentSuggestion( - file_pattern="**/*.bib", - enhancement_config={ - "enhancements": { - "context": "Summarize the types of references and research areas covered", - "custom_metadata": { - "entry_count": "Number of bibliography entries", - "publication_years": "Range of publication years", - "key_authors": "Most frequently cited authors" + suggestions.append( + EnrichmentSuggestion( + file_pattern="**/*.bib", + enhancement_config={ + "enhancements": { + "context": "Summarize the types of references and research areas covered", + "custom_metadata": { + "entry_count": "Number of bibliography entries", + "publication_years": "Range of publication years", + "key_authors": "Most frequently cited authors", + }, } - } - }, - priority=7 - )) - + }, + priority=7, + ) + ) + return suggestions - + def _categorize_data(self, filename: str) -> str: """Categorize data files.""" name_lower = filename.lower() @@ -398,4 +462,4 @@ def _categorize_data(self, filename: str) -> str: elif "meta" in name_lower: return "metadata" else: - return "general" \ No newline at end of file + return "general" diff --git a/contextframe/templates/software.py b/contextframe/templates/software.py index abd408c..116ee57 100644 --- a/contextframe/templates/software.py +++ b/contextframe/templates/software.py @@ -14,261 +14,325 @@ class SoftwareProjectTemplate(ContextTemplate): """Template for software development projects. - + Recognizes common project structures: - Source code (src/, lib/, app/) - Tests (tests/, test/, spec/) - Documentation (docs/, doc/) - Configuration files - Build/deployment files - + Automatically: - Groups files by module/package - Links tests to source files - Identifies dependencies - Suggests code-specific enrichments """ - + # Common source directories SOURCE_DIRS = {"src", "lib", "app", "source", "sources"} TEST_DIRS = {"tests", "test", "spec", "specs", "__tests__"} DOC_DIRS = {"docs", "doc", "documentation"} - + # File patterns CODE_EXTENSIONS = { - ".py", ".js", ".ts", ".jsx", ".tsx", ".java", ".cpp", ".c", ".h", - ".hpp", ".cs", ".rb", ".go", ".rs", ".swift", ".kt", ".scala", - ".php", ".r", ".m", ".mm", ".f90", ".jl", ".lua", ".dart" + ".py", + ".js", + ".ts", + ".jsx", + ".tsx", + ".java", + ".cpp", + ".c", + ".h", + ".hpp", + ".cs", + ".rb", + ".go", + ".rs", + ".swift", + ".kt", + ".scala", + ".php", + ".r", + ".m", + ".mm", + ".f90", + ".jl", + ".lua", + ".dart", } - + CONFIG_FILES = { - "package.json", "requirements.txt", "setup.py", "pyproject.toml", - "Cargo.toml", "go.mod", "pom.xml", "build.gradle", "CMakeLists.txt", - "Makefile", "Dockerfile", ".gitignore", ".dockerignore" + "package.json", + "requirements.txt", + "setup.py", + "pyproject.toml", + "Cargo.toml", + "go.mod", + "pom.xml", + "build.gradle", + "CMakeLists.txt", + "Makefile", + "Dockerfile", + ".gitignore", + ".dockerignore", } - + DOC_EXTENSIONS = {".md", ".rst", ".txt", ".adoc"} - + def __init__(self): """Initialize the software project template.""" super().__init__( name="software_project", - description="Template for software development projects with source code, tests, and documentation" + description="Template for software development projects with source code, tests, and documentation", ) - + def scan(self, source_path: str | Path) -> list[FileMapping]: """Scan project directory and map files.""" source_path = self.validate_source(source_path) mappings = [] - + # Track seen files to avoid duplicates seen_paths = set() - + # Scan for README first (project overview) for readme in ["README.md", "README.rst", "README.txt", "readme.md"]: readme_path = source_path / readme if readme_path.exists() and readme_path not in seen_paths: - mappings.append(FileMapping( - path=readme_path, - title=f"Project Overview - {source_path.name}", - tags=["readme", "overview", "documentation"], - custom_metadata={"priority": "high"} - )) + mappings.append( + FileMapping( + path=readme_path, + title=f"Project Overview - {source_path.name}", + tags=["readme", "overview", "documentation"], + custom_metadata={"priority": "high"}, + ) + ) seen_paths.add(readme_path) - + # Scan source directories for dir_name in self.SOURCE_DIRS: src_dir = source_path / dir_name if src_dir.exists() and src_dir.is_dir(): - mappings.extend(self._scan_code_directory(src_dir, "source", seen_paths)) - + mappings.extend( + self._scan_code_directory(src_dir, "source", seen_paths) + ) + # Scan test directories for dir_name in self.TEST_DIRS: test_dir = source_path / dir_name if test_dir.exists() and test_dir.is_dir(): mappings.extend(self._scan_code_directory(test_dir, "test", seen_paths)) - + # Scan documentation for dir_name in self.DOC_DIRS: doc_dir = source_path / dir_name if doc_dir.exists() and doc_dir.is_dir(): mappings.extend(self._scan_doc_directory(doc_dir, seen_paths)) - + # Scan root-level config files for config_file in self.CONFIG_FILES: config_path = source_path / config_file if config_path.exists() and config_path not in seen_paths: - mappings.append(FileMapping( - path=config_path, - title=f"Configuration - {config_file}", - tags=["configuration", self._get_config_type(config_file)], - custom_metadata={"config_type": self._get_config_type(config_file)} - )) + mappings.append( + FileMapping( + path=config_path, + title=f"Configuration - {config_file}", + tags=["configuration", self._get_config_type(config_file)], + custom_metadata={ + "config_type": self._get_config_type(config_file) + }, + ) + ) seen_paths.add(config_path) - + # Scan for other code files in root for file_path in source_path.iterdir(): if file_path.is_file() and file_path.suffix in self.CODE_EXTENSIONS: if file_path not in seen_paths: - mappings.append(FileMapping( - path=file_path, - title=f"Code - {file_path.name}", - collection="root", - tags=["code", self._get_language_tag(file_path.suffix)], - custom_metadata={"code_type": "script"} - )) + mappings.append( + FileMapping( + path=file_path, + title=f"Code - {file_path.name}", + collection="root", + tags=["code", self._get_language_tag(file_path.suffix)], + custom_metadata={"code_type": "script"}, + ) + ) seen_paths.add(file_path) - + return mappings - - def _scan_code_directory(self, directory: Path, category: str, seen_paths: set) -> list[FileMapping]: + + def _scan_code_directory( + self, directory: Path, category: str, seen_paths: set + ) -> list[FileMapping]: """Scan a code directory recursively.""" mappings = [] - + for file_path in directory.rglob("*"): if file_path in seen_paths or not file_path.is_file(): continue - + # Skip hidden files and __pycache__ - if any(part.startswith(".") or part == "__pycache__" for part in file_path.parts): + if any( + part.startswith(".") or part == "__pycache__" + for part in file_path.parts + ): continue - + if file_path.suffix in self.CODE_EXTENSIONS: # Determine module/package structure rel_path = file_path.relative_to(directory) module_parts = list(rel_path.parts[:-1]) - + # Create collection name from module path collection = None if module_parts: collection = "/".join(module_parts) - + # Determine if it's a test file - is_test = (category == "test" or - any(part in file_path.name.lower() for part in ["test", "spec"])) - - mappings.append(FileMapping( - path=file_path, - title=f"{'Test' if is_test else 'Code'} - {file_path.stem}", - collection=collection, - tags=[ - category, - self._get_language_tag(file_path.suffix), - "test" if is_test else "implementation" - ], - custom_metadata={ - "module": "/".join(module_parts) if module_parts else "root", - "language": self._get_language_tag(file_path.suffix), - "file_type": "test" if is_test else "source" - } - )) + is_test = category == "test" or any( + part in file_path.name.lower() for part in ["test", "spec"] + ) + + mappings.append( + FileMapping( + path=file_path, + title=f"{'Test' if is_test else 'Code'} - {file_path.stem}", + collection=collection, + tags=[ + category, + self._get_language_tag(file_path.suffix), + "test" if is_test else "implementation", + ], + custom_metadata={ + "module": "/".join(module_parts) + if module_parts + else "root", + "language": self._get_language_tag(file_path.suffix), + "file_type": "test" if is_test else "source", + }, + ) + ) seen_paths.add(file_path) - + return mappings - - def _scan_doc_directory(self, directory: Path, seen_paths: set) -> list[FileMapping]: + + def _scan_doc_directory( + self, directory: Path, seen_paths: set + ) -> list[FileMapping]: """Scan documentation directory.""" mappings = [] - + for file_path in directory.rglob("*"): if file_path in seen_paths or not file_path.is_file(): continue - + if file_path.suffix in self.DOC_EXTENSIONS: rel_path = file_path.relative_to(directory) - - mappings.append(FileMapping( - path=file_path, - title=f"Documentation - {file_path.stem}", - collection="documentation", - tags=["documentation", file_path.suffix[1:]], - custom_metadata={ - "doc_path": str(rel_path), - "doc_type": self._categorize_doc(file_path.name) - } - )) + + mappings.append( + FileMapping( + path=file_path, + title=f"Documentation - {file_path.stem}", + collection="documentation", + tags=["documentation", file_path.suffix[1:]], + custom_metadata={ + "doc_path": str(rel_path), + "doc_type": self._categorize_doc(file_path.name), + }, + ) + ) seen_paths.add(file_path) - + return mappings - - def define_collections(self, file_mappings: list[FileMapping]) -> list[CollectionDefinition]: + + def define_collections( + self, file_mappings: list[FileMapping] + ) -> list[CollectionDefinition]: """Define collections based on project structure.""" collections = [] seen_collections = set() - + # Main project collection - collections.append(CollectionDefinition( - name="project", - title="Project Overview", - description="Top-level project information and configuration", - tags=["project", "overview"], - position=0 - )) - + collections.append( + CollectionDefinition( + name="project", + title="Project Overview", + description="Top-level project information and configuration", + tags=["project", "overview"], + position=0, + ) + ) + # Extract unique collections from mappings for mapping in file_mappings: if mapping.collection and mapping.collection not in seen_collections: # Determine collection type if mapping.collection == "documentation": - collections.append(CollectionDefinition( - name="documentation", - title="Documentation", - description="Project documentation and guides", - tags=["documentation"], - position=1 - )) + collections.append( + CollectionDefinition( + name="documentation", + title="Documentation", + description="Project documentation and guides", + tags=["documentation"], + position=1, + ) + ) elif "/" in mapping.collection: # Module/package collection parts = mapping.collection.split("/") parent = None - + # Create nested collections for module hierarchy for i, part in enumerate(parts): - coll_name = "/".join(parts[:i+1]) + coll_name = "/".join(parts[: i + 1]) if coll_name not in seen_collections: - collections.append(CollectionDefinition( - name=coll_name, - title=f"Module: {part}", - description=f"Code module {coll_name}", - tags=["module", "code"], - parent=parent, - position=10 + i - )) + collections.append( + CollectionDefinition( + name=coll_name, + title=f"Module: {part}", + description=f"Code module {coll_name}", + tags=["module", "code"], + parent=parent, + position=10 + i, + ) + ) seen_collections.add(coll_name) parent = coll_name - + seen_collections.add(mapping.collection) - + # Add test collection if we have tests if any("test" in m.tags for m in file_mappings): - collections.append(CollectionDefinition( - name="tests", - title="Test Suite", - description="Project test files and specifications", - tags=["tests", "quality"], - position=20 - )) - + collections.append( + CollectionDefinition( + name="tests", + title="Test Suite", + description="Project test files and specifications", + tags=["tests", "quality"], + position=20, + ) + ) + return collections - + def discover_relationships( - self, - file_mappings: list[FileMapping], - dataset: FrameDataset + self, file_mappings: list[FileMapping], dataset: FrameDataset ) -> list[dict[str, Any]]: """Discover relationships between source files and tests.""" relationships = [] - + # Create lookup maps source_files = {} test_files = {} - + for mapping in file_mappings: if "test" in mapping.tags: test_files[mapping.path.stem] = mapping elif "implementation" in mapping.custom_metadata.get("file_type", ""): source_files[mapping.path.stem] = mapping - + # Match tests to source files for test_name, test_mapping in test_files.items(): # Common patterns: test_foo.py -> foo.py, foo_test.py -> foo.py @@ -279,85 +343,97 @@ def discover_relationships( source_name = test_name[:-5] elif test_name.endswith("Test") or test_name.endswith("Spec"): source_name = test_name[:-4] - + if source_name and source_name in source_files: - relationships.append({ - "source": str(test_mapping.path), - "target": str(source_files[source_name].path), - "type": "tests", - "description": f"Tests for {source_name}" - }) - + relationships.append( + { + "source": str(test_mapping.path), + "target": str(source_files[source_name].path), + "type": "tests", + "description": f"Tests for {source_name}", + } + ) + # Discover import relationships (simplified - would need AST parsing) # This is a placeholder for more sophisticated analysis - + return relationships - - def suggest_enrichments(self, file_mappings: list[FileMapping]) -> list[EnrichmentSuggestion]: + + def suggest_enrichments( + self, file_mappings: list[FileMapping] + ) -> list[EnrichmentSuggestion]: """Suggest code-specific enrichments.""" suggestions = [] - + # Source code enrichments - suggestions.append(EnrichmentSuggestion( - file_pattern="**/*.py", - enhancement_config={ - "enhancements": { - "context": "technical_summary", - "tags": "technical_tags", - "custom_metadata": "code_metadata" - } - }, - priority=10 - )) - + suggestions.append( + EnrichmentSuggestion( + file_pattern="**/*.py", + enhancement_config={ + "enhancements": { + "context": "technical_summary", + "tags": "technical_tags", + "custom_metadata": "code_metadata", + } + }, + priority=10, + ) + ) + # Test file enrichments - suggestions.append(EnrichmentSuggestion( - file_pattern="**/test_*.py", - enhancement_config={ - "enhancements": { - "context": "Explain what this test validates and why it's important", - "custom_metadata": { - "test_type": "Identify test type: unit, integration, or e2e", - "coverage": "What code areas does this test cover?" + suggestions.append( + EnrichmentSuggestion( + file_pattern="**/test_*.py", + enhancement_config={ + "enhancements": { + "context": "Explain what this test validates and why it's important", + "custom_metadata": { + "test_type": "Identify test type: unit, integration, or e2e", + "coverage": "What code areas does this test cover?", + }, } - } - }, - priority=8 - )) - + }, + priority=8, + ) + ) + # Documentation enrichments - suggestions.append(EnrichmentSuggestion( - file_pattern="**/*.md", - enhancement_config={ - "enhancements": { - "context": "tutorial_context", - "tags": "topic_tags" - } - }, - priority=5 - )) - + suggestions.append( + EnrichmentSuggestion( + file_pattern="**/*.md", + enhancement_config={ + "enhancements": { + "context": "tutorial_context", + "tags": "topic_tags", + } + }, + priority=5, + ) + ) + # Config file enrichments - suggestions.append(EnrichmentSuggestion( - file_pattern="**/package.json", - enhancement_config={ - "enhancements": { - "custom_metadata": { - "dependencies": "List key dependencies", - "scripts": "List available npm scripts" + suggestions.append( + EnrichmentSuggestion( + file_pattern="**/package.json", + enhancement_config={ + "enhancements": { + "custom_metadata": { + "dependencies": "List key dependencies", + "scripts": "List available npm scripts", + } } - } - }, - priority=3 - )) - + }, + priority=3, + ) + ) + return suggestions - + def _get_language_tag(self, extension: str) -> str: """Map file extension to language tag.""" lang_map = { ".py": "python", - ".js": "javascript", + ".js": "javascript", ".ts": "typescript", ".java": "java", ".cpp": "cpp", @@ -371,10 +447,10 @@ def _get_language_tag(self, extension: str) -> str: ".kt": "kotlin", ".scala": "scala", ".r": "r", - ".jl": "julia" + ".jl": "julia", } return lang_map.get(extension, extension[1:]) - + def _get_config_type(self, filename: str) -> str: """Categorize configuration file.""" config_types = { @@ -388,10 +464,10 @@ def _get_config_type(self, filename: str) -> str: "build.gradle": "gradle", "CMakeLists.txt": "cmake", "Makefile": "make", - "Dockerfile": "docker" + "Dockerfile": "docker", } return config_types.get(filename, "config") - + def _categorize_doc(self, filename: str) -> str: """Categorize documentation file.""" name_lower = filename.lower() @@ -406,4 +482,4 @@ def _categorize_doc(self, filename: str) -> str: elif "contributing" in name_lower: return "contributing" else: - return "general" \ No newline at end of file + return "general" diff --git a/contextframe/tests/test_mcp/__init__.py b/contextframe/tests/test_mcp/__init__.py index 4b786a7..97de4de 100644 --- a/contextframe/tests/test_mcp/__init__.py +++ b/contextframe/tests/test_mcp/__init__.py @@ -1 +1 @@ -"""Tests for MCP server implementation.""" \ No newline at end of file +"""Tests for MCP server implementation.""" diff --git a/contextframe/tests/test_mcp/test_batch_handler.py b/contextframe/tests/test_mcp/test_batch_handler.py index 79eef59..393a45f 100644 --- a/contextframe/tests/test_mcp/test_batch_handler.py +++ b/contextframe/tests/test_mcp/test_batch_handler.py @@ -1,37 +1,36 @@ """Tests for batch operation handler.""" -import pytest import asyncio -from typing import Any, Dict, List - +import pytest from contextframe.frame import FrameDataset, FrameRecord from contextframe.mcp.batch.handler import BatchOperationHandler, execute_parallel -from contextframe.mcp.core.transport import TransportAdapter, Progress +from contextframe.mcp.core.transport import Progress, TransportAdapter +from typing import Any, Dict, List class MockTransportAdapter(TransportAdapter): """Mock transport adapter for testing.""" - + def __init__(self): super().__init__() - self.progress_updates: List[Progress] = [] - self.messages_sent: List[Dict[str, Any]] = [] - + self.progress_updates: list[Progress] = [] + self.messages_sent: list[dict[str, Any]] = [] + # Add progress handler to capture updates self.add_progress_handler(self._capture_progress) - + async def _capture_progress(self, progress: Progress): self.progress_updates.append(progress) - + async def initialize(self) -> None: pass - + async def shutdown(self) -> None: pass - - async def send_message(self, message: Dict[str, Any]) -> None: + + async def send_message(self, message: dict[str, Any]) -> None: self.messages_sent.append(message) - + async def receive_message(self) -> None: return None @@ -58,49 +57,44 @@ def batch_handler(test_dataset, mock_transport): class TestBatchOperationHandler: """Test batch operation handler functionality.""" - + @pytest.mark.asyncio async def test_execute_batch_success(self, batch_handler, mock_transport): """Test successful batch execution.""" items = [1, 2, 3, 4, 5] - + async def processor(item: int) -> int: return item * 2 - + result = await batch_handler.execute_batch( - operation="test_multiply", - items=items, - processor=processor + operation="test_multiply", items=items, processor=processor ) - + # Check results assert result.total_processed == 5 assert result.total_errors == 0 assert result.results == [2, 4, 6, 8, 10] assert result.operation == "test_multiply" - + # Check progress updates assert len(mock_transport.progress_updates) == 5 for i, progress in enumerate(mock_transport.progress_updates): assert progress.operation == "test_multiply" assert progress.current == i + 1 assert progress.total == 5 - + @pytest.mark.asyncio async def test_execute_batch_with_errors(self, batch_handler): """Test batch execution with some errors.""" items = [1, 2, 0, 4, 5] # 0 will cause division error - + def processor(item: int) -> float: return 10 / item # Division by zero for item=0 - + result = await batch_handler.execute_batch( - operation="test_divide", - items=items, - processor=processor, - max_errors=2 + operation="test_divide", items=items, processor=processor, max_errors=2 ) - + # Check results assert result.total_processed == 4 assert result.total_errors == 1 @@ -108,42 +102,36 @@ def processor(item: int) -> float: assert len(result.errors) == 1 assert result.errors[0]["item_index"] == 2 assert "division by zero" in result.errors[0]["error"] - + @pytest.mark.asyncio async def test_execute_batch_atomic_failure(self, batch_handler): """Test atomic batch execution that fails.""" items = [1, 2, 0, 4, 5] - + def processor(item: int) -> float: return 10 / item - + with pytest.raises(Exception) as excinfo: await batch_handler.execute_batch( - operation="test_atomic", - items=items, - processor=processor, - atomic=True + operation="test_atomic", items=items, processor=processor, atomic=True ) - + assert "Atomic operation failed at item 2" in str(excinfo.value) - + @pytest.mark.asyncio async def test_execute_batch_max_errors(self, batch_handler): """Test stopping after max errors.""" items = list(range(10)) # [0, 1, 2, ..., 9] - + def processor(item: int) -> float: if item < 5: raise ValueError(f"Item {item} too small") return item - + result = await batch_handler.execute_batch( - operation="test_max_errors", - items=items, - processor=processor, - max_errors=3 + operation="test_max_errors", items=items, processor=processor, max_errors=3 ) - + # Should stop after 3 errors assert result.total_errors == 3 # Only items 0, 1, 2 should have been processed (all failed) @@ -152,30 +140,32 @@ def processor(item: int) -> float: class TestExecuteParallel: """Test parallel execution utility.""" - + @pytest.mark.asyncio async def test_execute_parallel_basic(self): """Test basic parallel execution.""" + async def task(i: int) -> int: await asyncio.sleep(0.01) # Simulate work return i * 2 - + tasks = [lambda i=i: task(i) for i in range(10)] - + results = await execute_parallel(tasks, max_parallel=3) - + assert results == [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] - + @pytest.mark.asyncio async def test_execute_parallel_with_errors(self): """Test parallel execution with some errors.""" + async def task(i: int) -> int: if i == 5: raise ValueError("Test error") return i * 2 - + tasks = [lambda i=i: task(i) for i in range(10)] - + # execute_parallel doesn't handle errors - they propagate with pytest.raises(ValueError): - await execute_parallel(tasks, max_parallel=3) \ No newline at end of file + await execute_parallel(tasks, max_parallel=3) diff --git a/contextframe/tests/test_mcp/test_batch_tools.py b/contextframe/tests/test_mcp/test_batch_tools.py index d655f1e..0d7f9c2 100644 --- a/contextframe/tests/test_mcp/test_batch_tools.py +++ b/contextframe/tests/test_mcp/test_batch_tools.py @@ -1,39 +1,38 @@ """Tests for MCP batch operation tools.""" -import pytest import asyncio -from uuid import uuid4 -from typing import Any, Dict, List - +import pytest from contextframe.frame import FrameDataset, FrameRecord from contextframe.mcp.batch import BatchTools -from contextframe.mcp.core.transport import TransportAdapter, Progress +from contextframe.mcp.core.transport import Progress, TransportAdapter from contextframe.mcp.tools import ToolRegistry +from typing import Any, Dict, List +from uuid import uuid4 class MockTransportAdapter(TransportAdapter): """Mock transport adapter for testing.""" - + def __init__(self): super().__init__() - self.progress_updates: List[Progress] = [] - self.messages_sent: List[Dict[str, Any]] = [] - + self.progress_updates: list[Progress] = [] + self.messages_sent: list[dict[str, Any]] = [] + # Add progress handler to capture updates self.add_progress_handler(self._capture_progress) - + async def _capture_progress(self, progress: Progress): self.progress_updates.append(progress) - + async def initialize(self) -> None: pass - + async def shutdown(self) -> None: pass - - async def send_message(self, message: Dict[str, Any]) -> None: + + async def send_message(self, message: dict[str, Any]) -> None: self.messages_sent.append(message) - + async def receive_message(self) -> None: return None @@ -43,14 +42,14 @@ async def test_dataset(tmp_path): """Create a test dataset with sample documents.""" dataset_path = tmp_path / "test_batch.lance" dataset = FrameDataset.create(str(dataset_path)) - + # Add test documents - simple approach for i in range(10): record = FrameRecord( text_content=f"Test document {i}: This is test content about topic {i % 3}" ) dataset.add(record) - + yield dataset @@ -59,21 +58,21 @@ async def batch_tools(test_dataset): """Create batch tools with test dataset and transport.""" transport = MockTransportAdapter() await transport.initialize() - + # Create tool registry with document tools tool_registry = ToolRegistry(test_dataset, transport) - + # Create batch tools batch_tools = BatchTools(test_dataset, transport, tool_registry) - + yield batch_tools - + await transport.shutdown() class TestBatchSearch: """Test batch search functionality.""" - + @pytest.mark.asyncio async def test_batch_search_basic(self, batch_tools): """Test basic batch search with multiple queries.""" @@ -81,17 +80,17 @@ async def test_batch_search_basic(self, batch_tools): "queries": [ {"query": "topic 0", "search_type": "text", "limit": 5}, {"query": "topic 1", "search_type": "text", "limit": 5}, - {"query": "topic 2", "search_type": "text", "limit": 5} + {"query": "topic 2", "search_type": "text", "limit": 5}, ], - "max_parallel": 3 + "max_parallel": 3, } - + result = await batch_tools.batch_search(params) - + assert result["searches_completed"] == 3 assert result["searches_failed"] == 0 assert len(result["results"]) == 3 - + # Each query should find documents for search_result in result["results"]: assert search_result["success"] @@ -102,7 +101,7 @@ async def test_batch_search_basic(self, batch_tools): class TestBatchAdd: """Test batch add functionality.""" - + @pytest.mark.asyncio async def test_batch_add_atomic(self, batch_tools): """Test atomic batch add.""" @@ -110,44 +109,42 @@ async def test_batch_add_atomic(self, batch_tools): "documents": [ {"content": "New document 1", "metadata": {"x_type": "test"}}, {"content": "New document 2", "metadata": {"x_type": "test"}}, - {"content": "New document 3", "metadata": {"x_type": "test"}} + {"content": "New document 3", "metadata": {"x_type": "test"}}, ], "shared_settings": { "generate_embeddings": False, - "collection": "batch_test" + "collection": "batch_test", }, - "atomic": True + "atomic": True, } - + # Get initial count initial_count = batch_tools.dataset._dataset.count_rows() - + result = await batch_tools.batch_add(params) - + assert result["success"] assert result["documents_added"] == 3 assert result["atomic"] - + # Verify documents were added final_count = batch_tools.dataset._dataset.count_rows() assert final_count == initial_count + 3 - + @pytest.mark.asyncio async def test_batch_add_non_atomic(self, batch_tools): """Test non-atomic batch add.""" params = { "documents": [ {"content": "Doc A", "metadata": {"x_idx": 1}}, - {"content": "Doc B", "metadata": {"x_idx": 2}} + {"content": "Doc B", "metadata": {"x_idx": 2}}, ], - "shared_settings": { - "generate_embeddings": False - }, - "atomic": False + "shared_settings": {"generate_embeddings": False}, + "atomic": False, } - + result = await batch_tools.batch_add(params) - + assert result["documents_added"] == 2 assert result["documents_failed"] == 0 assert not result["atomic"] @@ -155,20 +152,18 @@ async def test_batch_add_non_atomic(self, batch_tools): class TestBatchUpdate: """Test batch update functionality.""" - - @pytest.mark.asyncio + + @pytest.mark.asyncio async def test_batch_update_by_filter(self, batch_tools): """Test updating documents by filter.""" params = { "filter": "text_content LIKE '%topic 0%'", - "updates": { - "metadata_updates": {"x_updated": True, "x_version": 2} - }, - "max_documents": 10 + "updates": {"metadata_updates": {"x_updated": True, "x_version": 2}}, + "max_documents": 10, } - + result = await batch_tools.batch_update(params) - + assert "documents_updated" in result assert result["documents_updated"] > 0 assert result["documents_failed"] == 0 @@ -176,52 +171,49 @@ async def test_batch_update_by_filter(self, batch_tools): class TestBatchDelete: """Test batch delete functionality.""" - + @pytest.mark.asyncio async def test_batch_delete_dry_run(self, batch_tools): """Test batch delete with dry run.""" - params = { - "filter": "text_content LIKE '%topic 1%'", - "dry_run": True - } - + params = {"filter": "text_content LIKE '%topic 1%'", "dry_run": True} + result = await batch_tools.batch_delete(params) - + assert result["success"] assert result["dry_run"] assert result["documents_to_delete"] > 0 assert "document_ids" in result - + @pytest.mark.asyncio async def test_batch_delete_with_confirm(self, batch_tools): """Test batch delete with confirmation count.""" # First do a dry run to get count dry_run_params = { "filter": "text_content LIKE '%document 0%' OR text_content LIKE '%document 1%' OR text_content LIKE '%document 2%'", - "dry_run": True + "dry_run": True, } - + dry_run_result = await batch_tools.batch_delete(dry_run_params) count = dry_run_result["documents_to_delete"] - + # Now delete with wrong confirm count wrong_params = { "filter": "text_content LIKE '%document 0%' OR text_content LIKE '%document 1%' OR text_content LIKE '%document 2%'", "dry_run": False, - "confirm_count": count + 1 + "confirm_count": count + 1, } - + wrong_result = await batch_tools.batch_delete(wrong_params) assert not wrong_result["success"] assert "Expected" in wrong_result["error"] - + # Delete with correct confirm count correct_params = { "filter": "text_content LIKE '%document 0%' OR text_content LIKE '%document 1%' OR text_content LIKE '%document 2%'", "dry_run": False, - "confirm_count": count + "confirm_count": count, } - + correct_result = await batch_tools.batch_delete(correct_params) assert correct_result["success"] - assert correct_result["documents_deleted"] == count \ No newline at end of file + assert correct_result["documents_deleted"] == count diff --git a/contextframe/tests/test_mcp/test_collection_tools.py b/contextframe/tests/test_mcp/test_collection_tools.py index ef72e5d..d7633d8 100644 --- a/contextframe/tests/test_mcp/test_collection_tools.py +++ b/contextframe/tests/test_mcp/test_collection_tools.py @@ -1,39 +1,38 @@ """Tests for MCP collection management tools.""" -import pytest import asyncio -from uuid import uuid4, UUID -from typing import Any, Dict, List - +import pytest from contextframe.frame import FrameDataset, FrameRecord from contextframe.mcp.collections import CollectionTools from contextframe.mcp.collections.templates import TemplateRegistry -from contextframe.mcp.core.transport import TransportAdapter, Progress +from contextframe.mcp.core.transport import Progress, TransportAdapter +from typing import Any, Dict, List +from uuid import UUID, uuid4 class MockTransportAdapter(TransportAdapter): """Mock transport adapter for testing.""" - + def __init__(self): super().__init__() - self.progress_updates: List[Progress] = [] - self.messages_sent: List[Dict[str, Any]] = [] - + self.progress_updates: list[Progress] = [] + self.messages_sent: list[dict[str, Any]] = [] + # Add progress handler to capture updates self.add_progress_handler(self._capture_progress) - + async def _capture_progress(self, progress: Progress): self.progress_updates.append(progress) - + async def initialize(self) -> None: pass - + async def shutdown(self) -> None: pass - - async def send_message(self, message: Dict[str, Any]) -> None: + + async def send_message(self, message: dict[str, Any]) -> None: self.messages_sent.append(message) - + async def receive_message(self) -> None: return None @@ -43,7 +42,7 @@ async def test_dataset(tmp_path): """Create a test dataset with sample documents.""" dataset_path = tmp_path / "test_collections.lance" dataset = FrameDataset.create(str(dataset_path)) - + # Add test documents docs = [] for i in range(15): @@ -52,12 +51,12 @@ async def test_dataset(tmp_path): metadata={ "title": f"Document {i}", "tags": [f"tag{i % 3}", f"category{i % 2}"], - "created_at": f"2024-01-{(i % 30) + 1:02d}" - } + "created_at": f"2024-01-{(i % 30) + 1:02d}", + }, ) dataset.add(record) docs.append(record) - + yield dataset, docs @@ -67,266 +66,247 @@ async def collection_tools(test_dataset): dataset, _ = test_dataset transport = MockTransportAdapter() await transport.initialize() - + template_registry = TemplateRegistry() collection_tools = CollectionTools(dataset, transport, template_registry) - + yield collection_tools - + await transport.shutdown() class TestCollectionCreation: """Test collection creation functionality.""" - + @pytest.mark.asyncio async def test_create_basic_collection(self, collection_tools): """Test creating a basic collection.""" params = { "name": "Test Collection", "description": "A test collection for unit tests", - "metadata": {"x_purpose": "testing", "x_version": "1.0"} + "metadata": {"x_purpose": "testing", "x_version": "1.0"}, } - + result = await collection_tools.create_collection(params) - + assert result["name"] == "Test Collection" assert result["member_count"] == 0 assert result["metadata"]["x_purpose"] == "testing" assert "collection_id" in result assert "header_id" in result assert "created_at" in result - + @pytest.mark.asyncio async def test_create_collection_with_members(self, collection_tools, test_dataset): """Test creating a collection with initial members.""" dataset, docs = test_dataset - + # Get some document IDs doc_ids = [str(doc.uuid) for doc in docs[:3]] - + params = { "name": "Project Docs", "description": "Project documentation collection", - "initial_members": doc_ids + "initial_members": doc_ids, } - + result = await collection_tools.create_collection(params) - + assert result["member_count"] == 3 - + @pytest.mark.asyncio async def test_create_collection_with_template(self, collection_tools): """Test creating a collection with a template.""" params = { "name": "My Project", "template": "project", - "metadata": {"x_team": "engineering"} + "metadata": {"x_team": "engineering"}, } - + result = await collection_tools.create_collection(params) - + assert result["name"] == "My Project" assert "collection_id" in result - + @pytest.mark.asyncio async def test_create_hierarchical_collection(self, collection_tools): """Test creating a collection hierarchy.""" # Create parent collection - parent_params = { - "name": "Parent Collection", - "description": "The parent" - } + parent_params = {"name": "Parent Collection", "description": "The parent"} parent_result = await collection_tools.create_collection(parent_params) - + # Create child collection child_params = { "name": "Child Collection", "description": "The child", - "parent_collection": parent_result["collection_id"] + "parent_collection": parent_result["collection_id"], } child_result = await collection_tools.create_collection(child_params) - + assert child_result["name"] == "Child Collection" # The relationships should be established class TestCollectionUpdate: """Test collection update functionality.""" - + @pytest.mark.asyncio async def test_update_collection_metadata(self, collection_tools): """Test updating collection metadata.""" # Create collection - create_params = { - "name": "Original Name", - "description": "Original description" - } + create_params = {"name": "Original Name", "description": "Original description"} create_result = await collection_tools.create_collection(create_params) - + # Update collection update_params = { "collection_id": create_result["collection_id"], "name": "Updated Name", "description": "Updated description", - "metadata_updates": {"x_status": "active"} + "metadata_updates": {"x_status": "active"}, } update_result = await collection_tools.update_collection(update_params) - + assert update_result["updated"] is True - + @pytest.mark.asyncio async def test_add_remove_members(self, collection_tools, test_dataset): """Test adding and removing collection members.""" dataset, docs = test_dataset - + # Create collection create_result = await collection_tools.create_collection({"name": "Test"}) collection_id = create_result["collection_id"] - + # Add members add_params = { "collection_id": collection_id, - "add_members": [str(docs[0].uuid), str(docs[1].uuid), str(docs[2].uuid)] + "add_members": [str(docs[0].uuid), str(docs[1].uuid), str(docs[2].uuid)], } add_result = await collection_tools.update_collection(add_params) - + assert add_result["members_added"] == 3 assert add_result["total_members"] == 3 - + # Remove members remove_params = { "collection_id": collection_id, - "remove_members": [str(docs[0].uuid)] + "remove_members": [str(docs[0].uuid)], } remove_result = await collection_tools.update_collection(remove_params) - + assert remove_result["members_removed"] == 1 assert remove_result["total_members"] == 2 class TestCollectionDeletion: """Test collection deletion functionality.""" - + @pytest.mark.asyncio async def test_delete_collection_only(self, collection_tools, test_dataset): """Test deleting collection without members.""" dataset, docs = test_dataset - + # Create collection with members create_params = { "name": "To Delete", - "initial_members": [str(docs[0].uuid), str(docs[1].uuid)] + "initial_members": [str(docs[0].uuid), str(docs[1].uuid)], } create_result = await collection_tools.create_collection(create_params) - + # Delete collection only delete_params = { "collection_id": create_result["collection_id"], - "delete_members": False + "delete_members": False, } delete_result = await collection_tools.delete_collection(delete_params) - + assert delete_result["total_collections_deleted"] == 1 assert delete_result["total_members_deleted"] == 0 - + # Members should still exist assert dataset.get_by_uuid(str(docs[0].uuid)) is not None assert dataset.get_by_uuid(str(docs[1].uuid)) is not None - + @pytest.mark.asyncio async def test_delete_collection_with_members(self, collection_tools, test_dataset): """Test deleting collection with its members.""" dataset, docs = test_dataset - + # Create collection with members create_params = { "name": "To Delete With Members", - "initial_members": [str(docs[0].uuid), str(docs[1].uuid)] + "initial_members": [str(docs[0].uuid), str(docs[1].uuid)], } create_result = await collection_tools.create_collection(create_params) - + # Delete collection and members delete_params = { "collection_id": create_result["collection_id"], - "delete_members": True + "delete_members": True, } delete_result = await collection_tools.delete_collection(delete_params) - + assert delete_result["total_collections_deleted"] == 1 assert delete_result["total_members_deleted"] == 2 - + @pytest.mark.asyncio async def test_recursive_deletion(self, collection_tools): """Test recursive deletion of collection hierarchy.""" # Create parent parent = await collection_tools.create_collection({"name": "Parent"}) - + # Create children - child1 = await collection_tools.create_collection({ - "name": "Child1", - "parent_collection": parent["collection_id"] - }) - child2 = await collection_tools.create_collection({ - "name": "Child2", - "parent_collection": parent["collection_id"] - }) - + child1 = await collection_tools.create_collection( + {"name": "Child1", "parent_collection": parent["collection_id"]} + ) + child2 = await collection_tools.create_collection( + {"name": "Child2", "parent_collection": parent["collection_id"]} + ) + # Delete recursively - delete_params = { - "collection_id": parent["collection_id"], - "recursive": True - } + delete_params = {"collection_id": parent["collection_id"], "recursive": True} delete_result = await collection_tools.delete_collection(delete_params) - + assert delete_result["total_collections_deleted"] == 3 # Parent + 2 children class TestCollectionListing: """Test collection listing functionality.""" - + @pytest.mark.asyncio async def test_list_all_collections(self, collection_tools): """Test listing all collections.""" # Create some collections for i in range(5): - await collection_tools.create_collection({ - "name": f"Collection {i}", - "metadata": {"x_index": i} - }) - + await collection_tools.create_collection( + {"name": f"Collection {i}", "metadata": {"x_index": i}} + ) + # List all - list_params = { - "limit": 10, - "include_stats": False - } + list_params = {"limit": 10, "include_stats": False} list_result = await collection_tools.list_collections(list_params) - + assert list_result["total_count"] >= 5 assert len(list_result["collections"]) >= 5 - + @pytest.mark.asyncio async def test_list_with_filters(self, collection_tools): """Test listing collections with filters.""" # Create parent parent = await collection_tools.create_collection({"name": "Parent"}) - + # Create children for i in range(3): - await collection_tools.create_collection({ - "name": f"Child {i}", - "parent_collection": parent["collection_id"] - }) - + await collection_tools.create_collection( + {"name": f"Child {i}", "parent_collection": parent["collection_id"]} + ) + # List only children - list_params = { - "parent_id": parent["collection_id"], - "include_stats": False - } + list_params = {"parent_id": parent["collection_id"], "include_stats": False} list_result = await collection_tools.list_collections(list_params) - + assert list_result["total_count"] == 3 - + @pytest.mark.asyncio async def test_list_with_sorting(self, collection_tools): """Test listing collections with different sort orders.""" @@ -334,130 +314,149 @@ async def test_list_with_sorting(self, collection_tools): names = ["Zebra", "Alpha", "Beta"] for name in names: await collection_tools.create_collection({"name": name}) - + # Sort by name - list_params = { - "sort_by": "name", - "limit": 10 - } + list_params = {"sort_by": "name", "limit": 10} list_result = await collection_tools.list_collections(list_params) - + # Extract names from results # Debug: check the structure if list_result["collections"]: first_item = list_result["collections"][0] if isinstance(first_item, dict) and "collection" in first_item: # It's wrapped with stats - collection_names = [c["collection"]["name"] for c in list_result["collections"]] + collection_names = [ + c["collection"]["name"] for c in list_result["collections"] + ] else: # Direct collection dicts collection_names = [c["name"] for c in list_result["collections"]] else: collection_names = [] - + # Check alphabetical order assert collection_names[:3] == ["Alpha", "Beta", "Zebra"] class TestDocumentMovement: """Test moving documents between collections.""" - + @pytest.mark.asyncio - async def test_move_documents_between_collections(self, collection_tools, test_dataset): + async def test_move_documents_between_collections( + self, collection_tools, test_dataset + ): """Test moving documents from one collection to another.""" dataset, docs = test_dataset - + # Create two collections - source = await collection_tools.create_collection({ - "name": "Source", - "initial_members": [str(docs[0].uuid), str(docs[1].uuid), str(docs[2].uuid)] - }) + source = await collection_tools.create_collection( + { + "name": "Source", + "initial_members": [ + str(docs[0].uuid), + str(docs[1].uuid), + str(docs[2].uuid), + ], + } + ) target = await collection_tools.create_collection({"name": "Target"}) - + # Move documents move_params = { "document_ids": [str(docs[0].uuid), str(docs[1].uuid)], "source_collection": source["collection_id"], - "target_collection": target["collection_id"] + "target_collection": target["collection_id"], } move_result = await collection_tools.move_documents(move_params) - + assert move_result["moved_count"] == 2 assert move_result["failed_count"] == 0 - + @pytest.mark.asyncio async def test_remove_from_collection(self, collection_tools, test_dataset): """Test removing documents from a collection.""" dataset, docs = test_dataset - + # Create collection with members - collection = await collection_tools.create_collection({ - "name": "Collection", - "initial_members": [str(docs[0].uuid), str(docs[1].uuid)] - }) - + collection = await collection_tools.create_collection( + { + "name": "Collection", + "initial_members": [str(docs[0].uuid), str(docs[1].uuid)], + } + ) + # Remove from collection (no target) move_params = { "document_ids": [str(docs[0].uuid)], "source_collection": collection["collection_id"], - "target_collection": None + "target_collection": None, } move_result = await collection_tools.move_documents(move_params) - + assert move_result["moved_count"] == 1 class TestCollectionStatistics: """Test collection statistics functionality.""" - + @pytest.mark.asyncio async def test_basic_statistics(self, collection_tools, test_dataset): """Test getting basic collection statistics.""" dataset, docs = test_dataset - + # Create collection with members - collection = await collection_tools.create_collection({ - "name": "Stats Test", - "initial_members": [str(doc.uuid) for doc in docs[:5]] - }) - + collection = await collection_tools.create_collection( + { + "name": "Stats Test", + "initial_members": [str(doc.uuid) for doc in docs[:5]], + } + ) + # Get stats stats_params = { "collection_id": collection["collection_id"], - "include_member_details": False + "include_member_details": False, } stats_result = await collection_tools.get_collection_stats(stats_params) - + assert stats_result["name"] == "Stats Test" assert stats_result["statistics"]["direct_members"] == 5 assert stats_result["statistics"]["total_members"] == 5 assert len(stats_result["statistics"]["unique_tags"]) > 0 - + @pytest.mark.asyncio async def test_hierarchical_statistics(self, collection_tools, test_dataset): """Test statistics for hierarchical collections.""" dataset, docs = test_dataset - + # Create parent with members - parent = await collection_tools.create_collection({ - "name": "Parent", - "initial_members": [str(docs[0].uuid), str(docs[1].uuid)] - }) - + parent = await collection_tools.create_collection( + { + "name": "Parent", + "initial_members": [str(docs[0].uuid), str(docs[1].uuid)], + } + ) + # Create child with members - child = await collection_tools.create_collection({ - "name": "Child", - "parent_collection": parent["collection_id"], - "initial_members": [str(docs[2].uuid), str(docs[3].uuid), str(docs[4].uuid)] - }) - + child = await collection_tools.create_collection( + { + "name": "Child", + "parent_collection": parent["collection_id"], + "initial_members": [ + str(docs[2].uuid), + str(docs[3].uuid), + str(docs[4].uuid), + ], + } + ) + # Get parent stats with subcollections stats_params = { "collection_id": parent["collection_id"], - "include_subcollections": True + "include_subcollections": True, } stats_result = await collection_tools.get_collection_stats(stats_params) - + assert stats_result["statistics"]["direct_members"] == 2 assert stats_result["statistics"]["subcollection_members"] == 3 assert stats_result["statistics"]["total_members"] == 5 @@ -465,29 +464,29 @@ async def test_hierarchical_statistics(self, collection_tools, test_dataset): class TestCollectionTemplates: """Test collection template functionality.""" - + @pytest.mark.asyncio async def test_create_from_project_template(self, collection_tools): """Test creating a collection from project template.""" params = { "name": "My Software Project", "template": "project", - "metadata": {"x_language": "python"} + "metadata": {"x_language": "python"}, } - + result = await collection_tools.create_collection(params) - + assert result["name"] == "My Software Project" # Template metadata should be applied - + @pytest.mark.asyncio async def test_available_templates(self, collection_tools): """Test that built-in templates are available.""" registry = collection_tools.template_registry templates = registry.list_templates() - + template_names = [t["name"] for t in templates] assert "project" in template_names assert "research" in template_names assert "knowledge_base" in template_names - assert "dataset" in template_names \ No newline at end of file + assert "dataset" in template_names diff --git a/contextframe/tests/test_mcp/test_protocol.py b/contextframe/tests/test_mcp/test_protocol.py index 3ed6648..6f2ac0e 100644 --- a/contextframe/tests/test_mcp/test_protocol.py +++ b/contextframe/tests/test_mcp/test_protocol.py @@ -1,17 +1,16 @@ """Test MCP protocol compliance.""" -import pytest import asyncio import json +import pytest import tempfile -from unittest.mock import AsyncMock, MagicMock, patch - -from contextframe.mcp.server import ContextFrameMCPServer, MCPConfig -from contextframe.mcp.transport import StdioTransport +from contextframe.frame import FrameDataset, FrameRecord from contextframe.mcp.handlers import MessageHandler -from contextframe.mcp.tools import ToolRegistry from contextframe.mcp.resources import ResourceRegistry -from contextframe.frame import FrameDataset, FrameRecord +from contextframe.mcp.server import ContextFrameMCPServer, MCPConfig +from contextframe.mcp.tools import ToolRegistry +from contextframe.mcp.transport import StdioTransport +from unittest.mock import AsyncMock, MagicMock, patch @pytest.fixture @@ -19,20 +18,20 @@ async def test_dataset(tmp_path): """Create a test dataset.""" dataset_path = tmp_path / "test.lance" dataset = FrameDataset.create(str(dataset_path)) - + # Add some test documents records = [ FrameRecord( text_content="Test document 1", - metadata={"title": "Doc 1", "collection": "test"} + metadata={"title": "Doc 1", "collection": "test"}, ), FrameRecord( text_content="Test document 2", - metadata={"title": "Doc 2", "collection": "test"} - ) + metadata={"title": "Doc 2", "collection": "test"}, + ), ] dataset.add_many(records) - + return str(dataset_path) @@ -59,16 +58,13 @@ async def test_initialization_handshake(self, mcp_server): request = { "jsonrpc": "2.0", "method": "initialize", - "params": { - "protocolVersion": "0.1.0", - "capabilities": {} - }, - "id": 1 + "params": {"protocolVersion": "0.1.0", "capabilities": {}}, + "id": 1, } - + # Handle request response = await mcp_server.handler.handle(request) - + # Verify response assert response["jsonrpc"] == "2.0" assert response["id"] == 1 @@ -81,15 +77,10 @@ async def test_initialization_handshake(self, mcp_server): @pytest.mark.asyncio async def test_method_not_found(self, mcp_server): """Test handling of unknown methods.""" - request = { - "jsonrpc": "2.0", - "method": "unknown_method", - "params": {}, - "id": 2 - } - + request = {"jsonrpc": "2.0", "method": "unknown_method", "params": {}, "id": 2} + response = await mcp_server.handler.handle(request) - + assert response["jsonrpc"] == "2.0" assert response["id"] == 2 assert "error" in response @@ -100,14 +91,10 @@ async def test_method_not_found(self, mcp_server): async def test_invalid_request(self, mcp_server): """Test handling of invalid requests.""" # Missing jsonrpc field - request = { - "method": "initialize", - "params": {}, - "id": 3 - } - + request = {"method": "initialize", "params": {}, "id": 3} + response = await mcp_server.handler.handle(request) - + assert response["jsonrpc"] == "2.0" assert response["id"] == 3 assert "error" in response @@ -117,28 +104,25 @@ async def test_invalid_request(self, mcp_server): async def test_tools_list(self, mcp_server): """Test listing available tools.""" # Initialize first - await mcp_server.handler.handle({ - "jsonrpc": "2.0", - "method": "initialize", - "params": {"protocolVersion": "0.1.0", "capabilities": {}}, - "id": 1 - }) - + await mcp_server.handler.handle( + { + "jsonrpc": "2.0", + "method": "initialize", + "params": {"protocolVersion": "0.1.0", "capabilities": {}}, + "id": 1, + } + ) + # List tools - request = { - "jsonrpc": "2.0", - "method": "tools/list", - "params": {}, - "id": 4 - } - + request = {"jsonrpc": "2.0", "method": "tools/list", "params": {}, "id": 4} + response = await mcp_server.handler.handle(request) - + assert response["jsonrpc"] == "2.0" assert response["id"] == 4 assert "result" in response assert "tools" in response["result"] - + # Verify expected tools tool_names = {tool["name"] for tool in response["result"]["tools"]} expected_tools = { @@ -147,27 +131,22 @@ async def test_tools_list(self, mcp_server): "get_document", "list_documents", "update_document", - "delete_document" + "delete_document", } assert expected_tools.issubset(tool_names) @pytest.mark.asyncio async def test_resources_list(self, mcp_server): """Test listing available resources.""" - request = { - "jsonrpc": "2.0", - "method": "resources/list", - "params": {}, - "id": 5 - } - + request = {"jsonrpc": "2.0", "method": "resources/list", "params": {}, "id": 5} + response = await mcp_server.handler.handle(request) - + assert response["jsonrpc"] == "2.0" assert response["id"] == 5 assert "result" in response assert "resources" in response["result"] - + # Verify expected resources resource_uris = {res["uri"] for res in response["result"]["resources"]} expected_resources = { @@ -175,7 +154,7 @@ async def test_resources_list(self, mcp_server): "contextframe://dataset/schema", "contextframe://dataset/stats", "contextframe://collections", - "contextframe://relationships" + "contextframe://relationships", } assert expected_resources.issubset(resource_uris) @@ -183,14 +162,10 @@ async def test_resources_list(self, mcp_server): async def test_notification_no_response(self, mcp_server): """Test that notifications don't return responses.""" # Notifications have no ID - request = { - "jsonrpc": "2.0", - "method": "initialized", - "params": {} - } - + request = {"jsonrpc": "2.0", "method": "initialized", "params": {}} + response = await mcp_server.handler.handle(request) - + # Notifications should return None (no response sent) assert response is None @@ -206,17 +181,13 @@ async def test_search_documents_tool(self, mcp_server): "method": "tools/call", "params": { "name": "search_documents", - "arguments": { - "query": "test", - "search_type": "hybrid", - "limit": 5 - } + "arguments": {"query": "test", "search_type": "hybrid", "limit": 5}, }, - "id": 10 + "id": 10, } - + response = await mcp_server.handler.handle(request) - + assert response["jsonrpc"] == "2.0" assert response["id"] == 10 assert "result" in response @@ -234,21 +205,21 @@ async def test_add_document_tool(self, mcp_server): "arguments": { "content": "New test document", "metadata": {"title": "New Doc"}, - "generate_embedding": False - } + "generate_embedding": False, + }, }, - "id": 11 + "id": 11, } - + response = await mcp_server.handler.handle(request) - + assert response["jsonrpc"] == "2.0" assert response["id"] == 11 assert "result" in response assert "document" in response["result"] assert response["result"]["document"]["content"] == "New test document" - @pytest.mark.asyncio + @pytest.mark.asyncio async def test_invalid_tool_params(self, mcp_server): """Test tool execution with invalid parameters.""" request = { @@ -259,13 +230,13 @@ async def test_invalid_tool_params(self, mcp_server): "arguments": { # Missing required 'query' parameter "search_type": "text" - } + }, }, - "id": 12 + "id": 12, } - + response = await mcp_server.handler.handle(request) - + assert response["jsonrpc"] == "2.0" assert response["id"] == 12 assert "error" in response @@ -281,28 +252,26 @@ async def test_read_dataset_info(self, mcp_server): request = { "jsonrpc": "2.0", "method": "resources/read", - "params": { - "uri": "contextframe://dataset/info" - }, - "id": 20 + "params": {"uri": "contextframe://dataset/info"}, + "id": 20, } - + response = await mcp_server.handler.handle(request) - + assert response["jsonrpc"] == "2.0" assert response["id"] == 20 assert "result" in response assert "contents" in response["result"] assert len(response["result"]["contents"]) > 0 - + # Verify content structure content = response["result"]["contents"][0] assert content["uri"] == "contextframe://dataset/info" assert content["mimeType"] == "application/json" assert "text" in content - + # Parse JSON content info = json.loads(content["text"]) assert "dataset_path" in info assert "storage_format" in info - assert info["storage_format"] == "lance" \ No newline at end of file + assert info["storage_format"] == "lance" diff --git a/contextframe/tests/test_mcp/test_subscription_tools.py b/contextframe/tests/test_mcp/test_subscription_tools.py index 622ffaf..41be165 100644 --- a/contextframe/tests/test_mcp/test_subscription_tools.py +++ b/contextframe/tests/test_mcp/test_subscription_tools.py @@ -2,28 +2,27 @@ import asyncio import pytest -from datetime import datetime, timezone -from unittest.mock import Mock, AsyncMock, patch -from uuid import uuid4 - from contextframe.frame import FrameDataset, FrameRecord +from contextframe.mcp.schemas import ( + GetSubscriptionsParams, + PollChangesParams, + SubscribeChangesParams, + UnsubscribeParams, +) from contextframe.mcp.subscriptions.manager import ( - SubscriptionManager, - SubscriptionState, - Change + Change, + SubscriptionManager, + SubscriptionState, ) from contextframe.mcp.subscriptions.tools import ( - subscribe_changes, + get_subscriptions, poll_changes, + subscribe_changes, unsubscribe, - get_subscriptions -) -from contextframe.mcp.schemas import ( - SubscribeChangesParams, - PollChangesParams, - UnsubscribeParams, - GetSubscriptionsParams ) +from datetime import UTC, datetime, timezone +from unittest.mock import AsyncMock, Mock, patch +from uuid import uuid4 @pytest.fixture @@ -44,73 +43,73 @@ def subscription_manager(mock_dataset): class TestSubscriptionManager: """Test subscription manager functionality.""" - + @pytest.mark.asyncio async def test_create_subscription(self, subscription_manager): """Test creating a subscription.""" sub_id = await subscription_manager.create_subscription( resource_type="documents", filters={"collection_id": "test"}, - options={"polling_interval": 10} + options={"polling_interval": 10}, ) - + assert sub_id in subscription_manager.subscriptions subscription = subscription_manager.subscriptions[sub_id] assert subscription.resource_type == "documents" assert subscription.filters == {"collection_id": "test"} assert subscription.options["polling_interval"] == 10 - + @pytest.mark.asyncio async def test_poll_subscription_no_changes(self, subscription_manager): """Test polling with no changes.""" # Create subscription sub_id = await subscription_manager.create_subscription("all") - + # Poll immediately (no changes) result = await subscription_manager.poll_subscription(sub_id, timeout=0) - + assert result["changes"] == [] assert result["subscription_active"] is True assert result["has_more"] is False - + @pytest.mark.asyncio async def test_poll_subscription_with_changes(self, subscription_manager): """Test polling with buffered changes.""" # Create subscription sub_id = await subscription_manager.create_subscription("documents") subscription = subscription_manager.subscriptions[sub_id] - + # Add changes to buffer change = Change( type="created", resource_type="document", resource_id="doc-123", version=2, - timestamp=datetime.now(timezone.utc) + timestamp=datetime.now(UTC), ) subscription.change_buffer.append(change) - + # Poll for changes result = await subscription_manager.poll_subscription(sub_id) - + assert len(result["changes"]) == 1 assert result["changes"][0]["type"] == "created" assert result["changes"][0]["resource_id"] == "doc-123" - + @pytest.mark.asyncio async def test_cancel_subscription(self, subscription_manager): """Test cancelling a subscription.""" # Create subscription sub_id = await subscription_manager.create_subscription("all") - + # Cancel it cancelled = await subscription_manager.cancel_subscription(sub_id) assert cancelled is True - + # Verify it's inactive subscription = subscription_manager.subscriptions[sub_id] assert subscription.is_active is False - + def test_get_subscriptions(self, subscription_manager): """Test listing subscriptions.""" # Create multiple subscriptions manually @@ -118,30 +117,30 @@ def test_get_subscriptions(self, subscription_manager): id="sub1", resource_type="documents", filters={}, - created_at=datetime.now(timezone.utc), + created_at=datetime.now(UTC), last_version=1, - last_poll_token="sub1:0" + last_poll_token="sub1:0", ) sub2 = SubscriptionState( id="sub2", resource_type="collections", filters={}, - created_at=datetime.now(timezone.utc), + created_at=datetime.now(UTC), last_version=1, - last_poll_token="sub2:0" + last_poll_token="sub2:0", ) - + subscription_manager.subscriptions = {"sub1": sub1, "sub2": sub2} - + # Get all subscriptions all_subs = subscription_manager.get_subscriptions() assert len(all_subs) == 2 - + # Filter by type doc_subs = subscription_manager.get_subscriptions("documents") assert len(doc_subs) == 1 assert doc_subs[0]["resource_type"] == "documents" - + @pytest.mark.asyncio async def test_detect_changes(self, subscription_manager, mock_dataset): """Test change detection between versions.""" @@ -149,31 +148,31 @@ async def test_detect_changes(self, subscription_manager, mock_dataset): old_dataset = Mock() new_dataset = Mock() mock_dataset.checkout_version.side_effect = [old_dataset, new_dataset] - + # Mock UUID retrieval and has_changed check with patch.object( subscription_manager, '_get_version_uuids', side_effect=[ {"doc1", "doc2"}, # Old version - {"doc2", "doc3"} # New version - ] + {"doc2", "doc3"}, # New version + ], ): # Mock _has_changed to avoid calling search with patch.object( subscription_manager, '_has_changed', - return_value=False # doc2 hasn't changed + return_value=False, # doc2 hasn't changed ): changes = await subscription_manager._detect_changes(1, 2) - + assert len(changes) == 2 - + # Check for created document created = [c for c in changes if c.type == "created"] assert len(created) == 1 assert created[0].resource_id == "doc3" - + # Check for deleted document deleted = [c for c in changes if c.type == "deleted"] assert len(deleted) == 1 @@ -182,116 +181,108 @@ async def test_detect_changes(self, subscription_manager, mock_dataset): class TestSubscriptionTools: """Test subscription tool functions.""" - + @pytest.mark.asyncio async def test_subscribe_changes(self, mock_dataset): """Test subscribe_changes tool.""" params = SubscribeChangesParams( resource_type="documents", filters={"collection_id": "test"}, - options={"polling_interval": 10} + options={"polling_interval": 10}, ) - + result = await subscribe_changes(params, mock_dataset) - + assert "subscription_id" in result assert result["polling_interval"] == 10 assert "poll_token" in result - + @pytest.mark.asyncio async def test_poll_changes(self, mock_dataset): """Test poll_changes tool.""" # First create a subscription sub_params = SubscribeChangesParams(resource_type="all") sub_result = await subscribe_changes(sub_params, mock_dataset) - + # Then poll it poll_params = PollChangesParams( subscription_id=sub_result["subscription_id"], poll_token=sub_result["poll_token"], - timeout=0 + timeout=0, ) - + result = await poll_changes(poll_params, mock_dataset) - + assert "changes" in result assert "poll_token" in result assert "has_more" in result assert result["subscription_active"] is True - + @pytest.mark.asyncio async def test_unsubscribe(self, mock_dataset): """Test unsubscribe tool.""" # Create subscription sub_params = SubscribeChangesParams(resource_type="all") sub_result = await subscribe_changes(sub_params, mock_dataset) - + # Unsubscribe - unsub_params = UnsubscribeParams( - subscription_id=sub_result["subscription_id"] - ) - + unsub_params = UnsubscribeParams(subscription_id=sub_result["subscription_id"]) + result = await unsubscribe(unsub_params, mock_dataset) - + assert result["cancelled"] is True assert result["final_poll_token"] is not None - + @pytest.mark.asyncio async def test_get_subscriptions(self, mock_dataset): """Test get_subscriptions tool.""" # Create some subscriptions await subscribe_changes( - SubscribeChangesParams(resource_type="documents"), - mock_dataset + SubscribeChangesParams(resource_type="documents"), mock_dataset ) await subscribe_changes( - SubscribeChangesParams(resource_type="collections"), - mock_dataset + SubscribeChangesParams(resource_type="collections"), mock_dataset ) - + # Get all subscriptions params = GetSubscriptionsParams() result = await get_subscriptions(params, mock_dataset) - + assert result["total_count"] == 2 assert len(result["subscriptions"]) == 2 - + # Filter by type params = GetSubscriptionsParams(resource_type="documents") result = await get_subscriptions(params, mock_dataset) - + assert result["total_count"] == 1 assert result["subscriptions"][0]["resource_type"] == "documents" - + @pytest.mark.asyncio async def test_subscription_lifecycle(self, mock_dataset): """Test complete subscription lifecycle.""" # 1. Create subscription sub_params = SubscribeChangesParams( - resource_type="all", - options={"include_data": True} + resource_type="all", options={"include_data": True} ) sub_result = await subscribe_changes(sub_params, mock_dataset) sub_id = sub_result["subscription_id"] - + # 2. Poll for changes (should be empty) - poll_params = PollChangesParams( - subscription_id=sub_id, - timeout=0 - ) + poll_params = PollChangesParams(subscription_id=sub_id, timeout=0) poll_result = await poll_changes(poll_params, mock_dataset) assert len(poll_result["changes"]) == 0 - + # 3. List subscriptions list_params = GetSubscriptionsParams() list_result = await get_subscriptions(list_params, mock_dataset) assert list_result["total_count"] == 1 - + # 4. Unsubscribe unsub_params = UnsubscribeParams(subscription_id=sub_id) unsub_result = await unsubscribe(unsub_params, mock_dataset) assert unsub_result["cancelled"] is True - + # 5. Poll should show inactive final_poll = await poll_changes(poll_params, mock_dataset) - assert final_poll["subscription_active"] is False \ No newline at end of file + assert final_poll["subscription_active"] is False diff --git a/examples/embedding_providers_demo.py b/examples/embedding_providers_demo.py index ca8fcf5..8743df5 100644 --- a/examples/embedding_providers_demo.py +++ b/examples/embedding_providers_demo.py @@ -6,27 +6,26 @@ Requirements: pip install contextframe[embed,extract] - + # For specific providers: export OPENAI_API_KEY="your-key" export COHERE_API_KEY="your-key" export VOYAGE_API_KEY="your-key" - + # For local embeddings: # Install Ollama: https://ollama.ai # Pull model: ollama pull nomic-embed-text """ +import logging import os import time -import logging -from typing import List, Dict, Any -from pathlib import Path - from contextframe import FrameDataset, FrameRecord -from contextframe.embed import embed_frames, LiteLLMProvider +from contextframe.embed import LiteLLMProvider, embed_frames from contextframe.extract import DirectoryExtractor from contextframe.extract.chunking import semantic_splitter +from pathlib import Path +from typing import Any, Dict, List logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -34,27 +33,27 @@ def compare_embedding_providers(): """Compare different embedding providers on the same content.""" - + # Sample documents test_frames = [ FrameRecord( uri="doc1.txt", content="Artificial intelligence is transforming how we work and live.", - metadata={"type": "statement"} + metadata={"type": "statement"}, ), FrameRecord( uri="doc2.txt", content="Machine learning models require large amounts of data for training.", - metadata={"type": "technical"} + metadata={"type": "technical"}, ), FrameRecord( uri="doc3.py", content="""def calculate_similarity(vec1, vec2): return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))""", - metadata={"type": "code"} + metadata={"type": "code"}, ), ] - + providers = [ { "name": "OpenAI Small", @@ -83,70 +82,68 @@ def compare_embedding_providers(): "api_base": "http://localhost:11434", }, ] - + results = [] - + for provider_info in providers: # Skip if required API key not available if "requires_api_key" in provider_info: if not os.getenv(provider_info["requires_api_key"]): logger.warning(f"Skipping {provider_info['name']}: No API key") continue - + try: logger.info(f"\nTesting {provider_info['name']}...") start_time = time.time() - + # Create provider kwargs = {} if "api_base" in provider_info: kwargs["api_base"] = provider_info["api_base"] - - provider = LiteLLMProvider( - model=provider_info["model"], - **kwargs - ) - + + provider = LiteLLMProvider(model=provider_info["model"], **kwargs) + # Embed frames embedded_frames = embed_frames( - test_frames, - model=provider_info["model"], - **kwargs + test_frames, model=provider_info["model"], **kwargs ) - + elapsed = time.time() - start_time - + # Verify embeddings for frame in embedded_frames: if frame.embedding is None: raise ValueError(f"No embedding for {frame.uri}") - + result = { "provider": provider_info["name"], "model": provider_info["model"], "success": True, "time": elapsed, "dimensions": len(embedded_frames[0].embedding), - "cost_estimate": provider_info["cost_per_million"] * 0.001, # Rough estimate + "cost_estimate": provider_info["cost_per_million"] + * 0.001, # Rough estimate } results.append(result) - + logger.info(f"✓ Success: {elapsed:.2f}s, {result['dimensions']} dimensions") - + except Exception as e: logger.error(f"✗ Failed: {e}") - results.append({ - "provider": provider_info["name"], - "model": provider_info["model"], - "success": False, - "error": str(e), - }) - + results.append( + { + "provider": provider_info["name"], + "model": provider_info["model"], + "success": False, + "error": str(e), + } + ) + # Print comparison - print("\n" + "="*60) + print("\n" + "=" * 60) print("EMBEDDING PROVIDER COMPARISON") - print("="*60) - + print("=" * 60) + for result in results: if result["success"]: print(f"\n{result['provider']}:") @@ -160,31 +157,43 @@ def compare_embedding_providers(): def demonstrate_multilingual_embedding(): """Show multilingual embedding capabilities.""" - + multilingual_frames = [ - FrameRecord(uri="en.txt", content="Hello, how are you?", metadata={"lang": "en"}), - FrameRecord(uri="es.txt", content="Hola, ¿cómo estás?", metadata={"lang": "es"}), - FrameRecord(uri="fr.txt", content="Bonjour, comment allez-vous?", metadata={"lang": "fr"}), - FrameRecord(uri="de.txt", content="Hallo, wie geht es dir?", metadata={"lang": "de"}), + FrameRecord( + uri="en.txt", content="Hello, how are you?", metadata={"lang": "en"} + ), + FrameRecord( + uri="es.txt", content="Hola, ¿cómo estás?", metadata={"lang": "es"} + ), + FrameRecord( + uri="fr.txt", + content="Bonjour, comment allez-vous?", + metadata={"lang": "fr"}, + ), + FrameRecord( + uri="de.txt", content="Hallo, wie geht es dir?", metadata={"lang": "de"} + ), FrameRecord(uri="zh.txt", content="你好,你好吗?", metadata={"lang": "zh"}), - FrameRecord(uri="ja.txt", content="こんにちは、元気ですか?", metadata={"lang": "ja"}), + FrameRecord( + uri="ja.txt", content="こんにちは、元気ですか?", metadata={"lang": "ja"} + ), ] - + # Use multilingual model if os.getenv("COHERE_API_KEY"): logger.info("Using Cohere multilingual embeddings...") embedded = embed_frames( multilingual_frames, model="cohere/embed-multilingual-v3.0", - input_type="search_document" + input_type="search_document", ) - + # Calculate cross-language similarities import numpy as np - + print("\nCross-language similarity (all asking 'Hello, how are you?'):") en_embedding = embedded[0].embedding - + for frame in embedded[1:]: similarity = np.dot(en_embedding, frame.embedding) / ( np.linalg.norm(en_embedding) * np.linalg.norm(frame.embedding) @@ -193,14 +202,13 @@ def demonstrate_multilingual_embedding(): else: logger.info("Cohere API key not found, using OpenAI...") embedded = embed_frames( - multilingual_frames, - model="openai/text-embedding-3-small" + multilingual_frames, model="openai/text-embedding-3-small" ) def demonstrate_code_embeddings(): """Show code-optimized embeddings.""" - + code_samples = [ FrameRecord( uri="python_func.py", @@ -215,7 +223,7 @@ def demonstrate_code_embeddings(): else: right = mid - 1 return -1""", - metadata={"language": "python", "algorithm": "binary_search"} + metadata={"language": "python", "algorithm": "binary_search"}, ), FrameRecord( uri="js_func.js", @@ -229,7 +237,7 @@ def demonstrate_code_embeddings(): } return -1; }""", - metadata={"language": "javascript", "algorithm": "binary_search"} + metadata={"language": "javascript", "algorithm": "binary_search"}, ), FrameRecord( uri="bubble_sort.py", @@ -240,10 +248,10 @@ def demonstrate_code_embeddings(): if arr[j] > arr[j+1]: arr[j], arr[j+1] = arr[j+1], arr[j] return arr""", - metadata={"language": "python", "algorithm": "bubble_sort"} + metadata={"language": "python", "algorithm": "bubble_sort"}, ), ] - + # Use code-optimized model if available if os.getenv("VOYAGE_API_KEY"): logger.info("Using Voyage code embeddings...") @@ -251,25 +259,19 @@ def demonstrate_code_embeddings(): else: logger.info("Using OpenAI embeddings for code...") model = "openai/text-embedding-3-small" - + embedded_code = embed_frames(code_samples, model=model) - + # Find similar code across languages dataset = FrameDataset("code_embeddings.lance") dataset.add(embedded_code) - + # Search for similar algorithms - query_frame = FrameRecord( - uri="query", - content="binary search implementation" - ) + query_frame = FrameRecord(uri="query", content="binary search implementation") query_embedded = embed_frames([query_frame], model=model)[0] - - results = dataset.search( - query_embedding=query_embedded.embedding, - limit=3 - ) - + + results = dataset.search(query_embedding=query_embedded.embedding, limit=3) + print("\nCode similarity search for 'binary search implementation':") for result in results: print(f" {result.uri} (score: {result.score:.3f})") @@ -279,71 +281,71 @@ def demonstrate_code_embeddings(): def demonstrate_batch_processing(): """Show efficient batch processing strategies.""" - + # Generate many documents documents = [] for i in range(500): - documents.append(FrameRecord( - uri=f"doc_{i}.txt", - content=f"This is document number {i}. " * 20, # ~100 tokens - metadata={"index": i, "batch": i // 100} - )) - + documents.append( + FrameRecord( + uri=f"doc_{i}.txt", + content=f"This is document number {i}. " * 20, # ~100 tokens + metadata={"index": i, "batch": i // 100}, + ) + ) + logger.info(f"Processing {len(documents)} documents in batches...") - + # Method 1: Automatic batching start = time.time() embedded_auto = embed_frames( documents, model="openai/text-embedding-3-small", batch_size=100, # Process 100 at a time - show_progress=True + show_progress=True, ) auto_time = time.time() - start - + logger.info(f"Automatic batching completed in {auto_time:.2f}s") - + # Method 2: Manual batching with different models start = time.time() all_embedded = [] - + for batch_idx in range(5): batch_docs = [d for d in documents if d.metadata["batch"] == batch_idx] - + # Use different strategies for different batches if batch_idx == 0: # High priority - use best model embedded_batch = embed_frames( - batch_docs, - model="openai/text-embedding-3-large" + batch_docs, model="openai/text-embedding-3-large" ) else: # Lower priority - use cheaper model embedded_batch = embed_frames( - batch_docs, - model="openai/text-embedding-3-small" + batch_docs, model="openai/text-embedding-3-small" ) - + all_embedded.extend(embedded_batch) - + manual_time = time.time() - start logger.info(f"Manual batching completed in {manual_time:.2f}s") def demonstrate_fallback_strategy(): """Show robust embedding with fallbacks.""" - + frames = [ FrameRecord(uri="important.txt", content="Critical business document content") ] - + # Define fallback chain models = [ ("openai/text-embedding-3-small", {"api_key": os.getenv("OPENAI_API_KEY")}), ("cohere/embed-english-v3.0", {"api_key": os.getenv("COHERE_API_KEY")}), ("ollama/nomic-embed-text", {"api_base": "http://localhost:11434"}), ] - + embedded = None for model, kwargs in models: try: @@ -354,9 +356,9 @@ def demonstrate_fallback_strategy(): except Exception as e: logger.warning(f"Failed with {model}: {e}") continue - + if embedded: - print(f"\nSuccessfully embedded with fallback strategy") + print("\nSuccessfully embedded with fallback strategy") print(f"Embedding dimensions: {len(embedded[0].embedding)}") else: print("\nAll embedding methods failed!") @@ -364,58 +366,55 @@ def demonstrate_fallback_strategy(): def demonstrate_search_vs_document_embeddings(): """Show the difference between search query and document embeddings.""" - + if not os.getenv("COHERE_API_KEY"): logger.info("Cohere API key not found, skipping search vs document demo") return - + # Document content documents = [ FrameRecord( uri="physics.txt", content="Quantum mechanics is the mathematical description of the motion and interaction of subatomic particles.", - metadata={"subject": "physics"} + metadata={"subject": "physics"}, ), FrameRecord( uri="biology.txt", content="DNA replication is the biological process of producing two identical replicas from one original DNA molecule.", - metadata={"subject": "biology"} + metadata={"subject": "biology"}, ), ] - + # Embed as documents doc_embedded = embed_frames( documents, model="cohere/embed-english-v3.0", - input_type="search_document" # Optimize for storage + input_type="search_document", # Optimize for storage ) - + # Create dataset dataset = FrameDataset("search_demo.lance") dataset.add(doc_embedded) - + # Search queries queries = [ "what is quantum physics?", "how does DNA copying work?", "subatomic particle behavior", ] - + for query in queries: # Create query embedding query_frame = FrameRecord(uri="query", content=query) query_embedded = embed_frames( [query_frame], model="cohere/embed-english-v3.0", - input_type="search_query" # Optimize for search + input_type="search_query", # Optimize for search )[0] - + # Search - results = dataset.search( - query_embedding=query_embedded.embedding, - limit=2 - ) - + results = dataset.search(query_embedding=query_embedded.embedding, limit=2) + print(f"\nQuery: '{query}'") for result in results: print(f" → {result.uri} (score: {result.score:.3f})") @@ -423,9 +422,9 @@ def demonstrate_search_vs_document_embeddings(): def main(): """Run all demonstrations.""" - + print("🚀 ContextFrame Embedding Providers Demo\n") - + demos = [ ("Provider Comparison", compare_embedding_providers), ("Multilingual Embeddings", demonstrate_multilingual_embedding), @@ -434,19 +433,19 @@ def main(): ("Fallback Strategy", demonstrate_fallback_strategy), ("Search vs Document", demonstrate_search_vs_document_embeddings), ] - + for name, demo_func in demos: - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"Demo: {name}") - print(f"{'='*60}") - + print(f"{'=' * 60}") + try: demo_func() except Exception as e: logger.error(f"Demo failed: {e}") - + time.sleep(1) # Brief pause between demos - + print("\n✅ All demos completed!") print("\nKey takeaways:") print("1. LiteLLM provides unified access to 100+ embedding providers") @@ -457,4 +456,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/enrichment_demo.py b/examples/enrichment_demo.py index 21607cc..742562d 100644 --- a/examples/enrichment_demo.py +++ b/examples/enrichment_demo.py @@ -10,20 +10,20 @@ """ import os -from pathlib import Path from contextframe import FrameDataset, FrameRecord +from contextframe.embed import embed_frames from contextframe.enrich import ( ContextEnricher, get_prompt_template, list_available_prompts, ) -from contextframe.embed import embed_frames +from pathlib import Path def basic_enrichment_example(): """Show basic enrichment of context and tags.""" print("\n=== Basic Enrichment Example ===\n") - + # Create sample documents frames = [ FrameRecord( @@ -55,20 +55,22 @@ def search(self, query_vec, k=10): """, ), ] - + # Create dataset dataset = FrameDataset.create("enrichment_demo.lance", exist_ok=True) - + # Add frames with embeddings embedded_frames = embed_frames(frames, model="openai/text-embedding-3-small") dataset.add_many(embedded_frames) - + # Enrich with context and tags - dataset.enrich({ - "context": "Explain in 2-3 sentences what this document teaches and why it matters for AI developers", - "tags": "Extract 5-7 technical tags covering languages, concepts, and tools mentioned", - }) - + dataset.enrich( + { + "context": "Explain in 2-3 sentences what this document teaches and why it matters for AI developers", + "tags": "Extract 5-7 technical tags covering languages, concepts, and tools mentioned", + } + ) + # Show results for frame in dataset.iter_records(): print(f"Document: {frame.title}") @@ -80,7 +82,7 @@ def search(self, query_vec, k=10): def custom_metadata_example(): """Show extraction of custom metadata using prompts.""" print("\n=== Custom Metadata Extraction ===\n") - + # Create a code documentation frame frame = FrameRecord( uri="api_docs.md", @@ -119,11 +121,11 @@ def custom_metadata_example(): - 500: Server error """, ) - + dataset = FrameDataset.create("metadata_demo.lance", exist_ok=True) embedded = embed_frames([frame], model="openai/text-embedding-3-small")[0] dataset.add(embedded) - + # Extract API metadata enricher = ContextEnricher() enricher.enrich_dataset( @@ -139,11 +141,11 @@ def custom_metadata_example(): - response_fields (list of field names in response) - error_codes (object mapping code to description) """, - "format": "json" + "format": "json", } - } + }, ) - + # Show extracted metadata record = list(dataset.iter_records())[0] print(f"Document: {record.title}") @@ -155,7 +157,7 @@ def custom_metadata_example(): def relationship_discovery_example(): """Show how to find relationships between documents.""" print("\n=== Relationship Discovery ===\n") - + # Create related documents frames = [ FrameRecord( @@ -174,28 +176,28 @@ def relationship_discovery_example(): text_content="Effective prompts are crucial for LLM performance...", ), ] - + dataset = FrameDataset.create("relationships_demo.lance", exist_ok=True) embedded = embed_frames(frames, model="openai/text-embedding-3-small") dataset.add_many(embedded) - + enricher = ContextEnricher() - + # Find relationships between documents for i, source in enumerate(embedded): # Get other documents as candidates candidates = [f for j, f in enumerate(embedded) if i != j] - + relationships = enricher.find_relationships( source_doc=source, candidate_docs=candidates, - prompt=get_prompt_template("relationships", "topic_relationships") + prompt=get_prompt_template("relationships", "topic_relationships"), ) - + # Update the frame with relationships source.relationships = relationships dataset.update_record(source) - + # Display relationships for frame in dataset.iter_records(): print(f"\nDocument: {frame.title}") @@ -209,53 +211,52 @@ def relationship_discovery_example(): def mcp_tool_example(): """Show how agents can use enrichment as tools.""" print("\n=== MCP Tool Interface Example ===\n") - + from contextframe.enrich import EnrichmentTools, list_available_tools - + # Create enricher and tools enricher = ContextEnricher() tools = EnrichmentTools(enricher) - + # Show available tools print("Available enrichment tools:") for tool_name in list_available_tools(): print(f" - {tool_name}") - + # Example: Agent using the enrich_context tool content = """ FastAPI is a modern web framework for building APIs with Python 3.7+ based on standard Python type hints. It's designed to be easy to use while providing high performance through Starlette and Pydantic. """ - + # Agent calls the tool context = tools.enrich_context( - content=content, - purpose="building REST APIs with Python" + content=content, purpose="building REST APIs with Python" ) - + print(f"\nGenerated context: {context}") - + # Extract metadata tool metadata = tools.extract_metadata( content=content, - schema="Extract: framework_name, python_version, key_dependencies, main_features" + schema="Extract: framework_name, python_version, key_dependencies, main_features", ) - + print(f"\nExtracted metadata: {metadata}") def template_showcase(): """Show available prompt templates.""" print("\n=== Available Prompt Templates ===\n") - + templates = list_available_prompts() - + for category, template_names in templates.items(): print(f"{category.upper()}:") for name in template_names: print(f" - {name}") - + # Example using a template print("\n\nExample template (technical_context):") template = get_prompt_template("context", "technical_context") @@ -265,7 +266,7 @@ def template_showcase(): def purpose_driven_enrichment(): """Show enrichment for specific purposes.""" print("\n=== Purpose-Driven Enrichment ===\n") - + # Document about testing frame = FrameRecord( uri="testing_guide.md", @@ -285,14 +286,14 @@ def test_addition(): assert 1 + 1 == 2 """, ) - + dataset = FrameDataset.create("purpose_demo.lance", exist_ok=True) embedded = embed_frames([frame], model="openai/text-embedding-3-small")[0] dataset.add(embedded) - + # Enrich for different purposes enricher = ContextEnricher() - + # For RAG system rag_prompt = get_prompt_template("purpose", "rag_optimization") enricher.enrich_dataset( @@ -302,11 +303,11 @@ def test_addition(): "tags": rag_prompt.split("TAGS:")[1].split("3.")[0].strip(), "custom_metadata": { "prompt": rag_prompt.split("METADATA:")[1].strip(), - "format": "json" - } - } + "format": "json", + }, + }, ) - + # Show enriched document record = list(dataset.iter_records())[0] print(f"Document: {record.title}") @@ -317,13 +318,13 @@ def test_addition(): def main(): """Run all examples.""" - + # Check for API key if not os.getenv("OPENAI_API_KEY"): print("Please set OPENAI_API_KEY environment variable") print("Or modify examples to use a different model (e.g., ollama/mistral)") return - + examples = [ ("Basic Enrichment", basic_enrichment_example), ("Custom Metadata", custom_metadata_example), @@ -332,24 +333,25 @@ def main(): ("Template Showcase", template_showcase), ("Purpose-Driven", purpose_driven_enrichment), ] - + for name, func in examples: - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f" {name}") - print(f"{'='*60}") + print(f"{'=' * 60}") try: func() except Exception as e: print(f"Error in {name}: {e}") - + # Cleanup print("\n\nCleaning up demo datasets...") for lance_dir in Path(".").glob("*_demo.lance"): import shutil + shutil.rmtree(lance_dir) - + print("\nDemo complete!") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/external_tools/chunkr_pipeline.py b/examples/external_tools/chunkr_pipeline.py index 7b52dd6..c7ca85e 100644 --- a/examples/external_tools/chunkr_pipeline.py +++ b/examples/external_tools/chunkr_pipeline.py @@ -12,18 +12,17 @@ Requirements: pip install chunkr-ai contextframe[extract,embed] - + Get API key: https://chunkr.ai """ import logging -from pathlib import Path -from typing import List, Dict, Any, Optional import os import time - from contextframe import FrameDataset, FrameRecord from contextframe.embed import embed_frames +from pathlib import Path +from typing import Any, Dict, List, Optional logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -33,19 +32,19 @@ def extract_with_chunkr( file_path: str, api_key: str, max_wait_time: int = 180, - target_chunk_length: Optional[int] = None, + target_chunk_length: int | None = None, ocr_strategy: str = "auto", -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Extract and chunk content using Chunkr API. - + Args: file_path: Path to document api_key: Chunkr API key max_wait_time: Maximum time to wait for processing (seconds) target_chunk_length: Target chunk length (Chunkr handles this intelligently) ocr_strategy: OCR strategy - "auto", "force", or "off" - + Returns: Dictionary with chunks and metadata """ @@ -53,15 +52,14 @@ def extract_with_chunkr( from chunkr import Chunkr except ImportError: raise ImportError( - "chunkr-ai package is required. " - "Install with: pip install chunkr-ai" + "chunkr-ai package is required. Install with: pip install chunkr-ai" ) - + # Initialize client client = Chunkr(api_key=api_key) - + logger.info(f"Uploading {file_path} to Chunkr...") - + # Upload and process with open(file_path, "rb") as f: task = client.upload( @@ -69,29 +67,29 @@ def extract_with_chunkr( target_chunk_length=target_chunk_length, ocr_strategy=ocr_strategy, ) - + # Poll for completion logger.info(f"Processing document (task ID: {task.task_id})...") start_time = time.time() - + while task.status in ["Pending", "Processing"]: if time.time() - start_time > max_wait_time: raise TimeoutError(f"Processing exceeded {max_wait_time} seconds") - + time.sleep(2) task = client.get_task(task.task_id) logger.info(f"Status: {task.status}") - + if task.status == "Failed": raise Exception(f"Chunkr processing failed: {task.error}") - + # Get results output = task.output - + # Extract segments and chunks segments = output.get("segments", []) chunks = output.get("chunks", []) - + # Build metadata metadata = { "source": "chunkr", @@ -102,12 +100,14 @@ def extract_with_chunkr( "segment_types": {}, "processing_time": time.time() - start_time, } - + # Count segment types for segment in segments: seg_type = segment.get("type", "unknown") - metadata["segment_types"][seg_type] = metadata["segment_types"].get(seg_type, 0) + 1 - + metadata["segment_types"][seg_type] = ( + metadata["segment_types"].get(seg_type, 0) + 1 + ) + return { "chunks": chunks, "segments": segments, @@ -119,48 +119,50 @@ def extract_with_chunkr( def create_frames_from_chunkr( file_path: str, api_key: str, - collection_uri: Optional[str] = None, + collection_uri: str | None = None, include_segment_frames: bool = False, - **chunkr_kwargs -) -> List[FrameRecord]: + **chunkr_kwargs, +) -> list[FrameRecord]: """ Create FrameRecords from Chunkr output. - + Args: file_path: Path to document api_key: Chunkr API key collection_uri: Optional collection URI include_segment_frames: Whether to create frames for segments too **chunkr_kwargs: Additional Chunkr arguments - + Returns: List of FrameRecord objects """ # Extract with Chunkr result = extract_with_chunkr(file_path, api_key, **chunkr_kwargs) - + frames = [] base_uri = file_path - + # Create frames from chunks (primary output) for i, chunk in enumerate(result["chunks"]): # Chunkr chunks include rich metadata chunk_metadata = result["metadata"].copy() - chunk_metadata.update({ - "chunk_index": i, - "total_chunks": len(result["chunks"]), - "chunk_id": chunk.get("chunk_id"), - "segment_ids": chunk.get("segment_ids", []), - "confidence": chunk.get("confidence"), - }) - + chunk_metadata.update( + { + "chunk_index": i, + "total_chunks": len(result["chunks"]), + "chunk_id": chunk.get("chunk_id"), + "segment_ids": chunk.get("segment_ids", []), + "confidence": chunk.get("confidence"), + } + ) + # Handle different content types content = chunk.get("content", "") if chunk.get("type") == "table": # Tables might have structured data if "table_data" in chunk: content = f"Table:\n{chunk['table_data']}" - + frame = FrameRecord( uri=f"{base_uri}#chunk-{i}", title=Path(file_path).stem, @@ -170,7 +172,7 @@ def create_frames_from_chunkr( collection_uri=collection_uri, ) frames.append(frame) - + # Optionally create frames for segments if include_segment_frames: for i, segment in enumerate(result["segments"]): @@ -182,7 +184,7 @@ def create_frames_from_chunkr( "bbox": segment.get("bbox"), # Bounding box if available "page": segment.get("page"), } - + frame = FrameRecord( uri=f"{base_uri}#segment-{i}", title=f"{Path(file_path).stem} - {segment.get('type', 'segment')}", @@ -192,7 +194,7 @@ def create_frames_from_chunkr( parent_uri=base_uri, ) frames.append(frame) - + return frames @@ -204,7 +206,7 @@ def process_multimodal_document( ): """ Process a document with tables, images, and mixed content. - + Args: file_path: Path to document api_key: Chunkr API key @@ -218,12 +220,12 @@ def process_multimodal_document( target_chunk_length=512, # Smaller chunks for multimodal ocr_strategy="auto", ) - + # Separate chunks by type text_chunks = [] table_chunks = [] image_chunks = [] - + for chunk in result["chunks"]: chunk_type = chunk.get("type", "text") if chunk_type == "table": @@ -232,12 +234,14 @@ def process_multimodal_document( image_chunks.append(chunk) else: text_chunks.append(chunk) - - logger.info(f"Found: {len(text_chunks)} text, {len(table_chunks)} tables, {len(image_chunks)} images") - + + logger.info( + f"Found: {len(text_chunks)} text, {len(table_chunks)} tables, {len(image_chunks)} images" + ) + # Create specialized frames frames = [] - + # Text frames for i, chunk in enumerate(text_chunks): frame = FrameRecord( @@ -251,15 +255,15 @@ def process_multimodal_document( record_type="document", ) frames.append(frame) - + # Table frames (with special handling) for i, chunk in enumerate(table_chunks): # Tables might need different embedding strategy table_content = chunk.get("content", "") if "table_data" in chunk: # Convert structured table data to text - table_content = f"Table {i+1}:\n{chunk['table_data']}" - + table_content = f"Table {i + 1}:\n{chunk['table_data']}" + frame = FrameRecord( uri=f"{file_path}#table-{i}", content=table_content, @@ -272,12 +276,12 @@ def process_multimodal_document( record_type="document", ) frames.append(frame) - + # Image frames (descriptions or OCR text) for i, chunk in enumerate(image_chunks): frame = FrameRecord( uri=f"{file_path}#image-{i}", - content=chunk.get("content", chunk.get("description", f"Image {i+1}")), + content=chunk.get("content", chunk.get("description", f"Image {i + 1}")), metadata={ "type": "image", "image_index": i, @@ -287,12 +291,12 @@ def process_multimodal_document( record_type="document", ) frames.append(frame) - + # Embed and store dataset = FrameDataset(dataset_path) embedded_frames = embed_frames(frames, model=embed_model) dataset.add(embedded_frames) - + return dataset @@ -300,12 +304,12 @@ def batch_process_with_chunkr( folder_path: str, dataset_path: str, api_key: str, - file_patterns: List[str] = None, + file_patterns: list[str] = None, embed_model: str = "openai/text-embedding-3-small", ): """ Process multiple documents with Chunkr. - + Args: folder_path: Folder containing documents dataset_path: ContextFrame dataset path @@ -315,39 +319,39 @@ def batch_process_with_chunkr( """ if file_patterns is None: file_patterns = ["*.pdf", "*.docx", "*.pptx"] - + # Find files all_files = [] folder = Path(folder_path) for pattern in file_patterns: all_files.extend(folder.glob(f"**/{pattern}")) - + logger.info(f"Found {len(all_files)} files to process") - + # Initialize dataset dataset = FrameDataset(dataset_path) - + # Process each file for file_path in all_files: try: logger.info(f"Processing: {file_path}") - + # Create frames with Chunkr frames = create_frames_from_chunkr( str(file_path), api_key=api_key, collection_uri=f"documents/{file_path.parent.name}", ) - + # Embed and store embedded_frames = embed_frames(frames, model=embed_model) dataset.add(embedded_frames) - + logger.info(f"Added {len(frames)} frames from {file_path.name}") - + except Exception as e: logger.error(f"Failed to process {file_path}: {e}") - + return dataset @@ -359,9 +363,9 @@ def compare_chunking_strategies(): if not api_key: print("Set CHUNKR_API_KEY environment variable") return - + file_path = "complex_document.pdf" - + # 1. Chunkr's intelligent chunking print("Using Chunkr's intelligent chunking...") chunkr_frames = create_frames_from_chunkr( @@ -369,17 +373,17 @@ def compare_chunking_strategies(): api_key=api_key, target_chunk_length=500, ) - + print(f"\nChunkr created {len(chunkr_frames)} chunks") for i, frame in enumerate(chunkr_frames[:3]): - print(f"\nChunk {i+1}:") + print(f"\nChunk {i + 1}:") print(frame.content[:200] + "...") print(f"Metadata: {frame.metadata.get('segment_ids')}") - + # 2. Basic character splitting (for comparison) from contextframe.extract import TextFileExtractor from contextframe.extract.chunking import ChunkingMixin - + print("\n\nUsing basic character splitting...") # This would need to extract PDF first # Just showing the concept @@ -388,7 +392,7 @@ def compare_chunking_strategies(): chunk_size=500, splitter_type="text", ) - + print(f"Basic splitting created {len(basic_chunks)} chunks") print("\nKey differences:") print("- Chunkr preserves document structure") @@ -408,7 +412,7 @@ def compare_chunking_strategies(): print(f"Created {len(frames)} frames with Chunkr") else: print("Set CHUNKR_API_KEY environment variable") - + # Example 2: Process multimodal document (commented out) # if api_key: # dataset = process_multimodal_document( @@ -416,7 +420,7 @@ def compare_chunking_strategies(): # api_key=api_key, # dataset_path="./multimodal_docs.lance", # ) - + # Example 3: Batch processing (commented out) # if api_key: # dataset = batch_process_with_chunkr( @@ -424,6 +428,6 @@ def compare_chunking_strategies(): # dataset_path="./chunkr_docs.lance", # api_key=api_key, # ) - + # Example 4: Compare strategies (commented out) - # compare_chunking_strategies() \ No newline at end of file + # compare_chunking_strategies() diff --git a/examples/external_tools/docling_pdf_pipeline.py b/examples/external_tools/docling_pdf_pipeline.py index 78c6d5e..d59bfca 100644 --- a/examples/external_tools/docling_pdf_pipeline.py +++ b/examples/external_tools/docling_pdf_pipeline.py @@ -12,20 +12,20 @@ """ import logging -from pathlib import Path -from typing import List, Optional, Dict, Any - from contextframe import FrameDataset, FrameRecord from contextframe.embed import embed_frames +from pathlib import Path +from typing import Any, Dict, List, Optional # Docling imports (only if needed) try: - from docling.document_converter import DocumentConverter from docling.datamodel.pipeline_options import ( - PipelineOptions, EasyOcrOptions, + PipelineOptions, TableFormerMode, ) + from docling.document_converter import DocumentConverter + DOCLING_AVAILABLE = True except ImportError: DOCLING_AVAILABLE = False @@ -41,44 +41,43 @@ def extract_pdf_with_docling( extract_tables: bool = True, extract_images: bool = True, table_mode: str = "accurate", -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Extract content from PDF using Docling. - + Args: pdf_path: Path to PDF file use_ocr: Enable OCR for scanned documents extract_tables: Extract table structures extract_images: Extract and classify images table_mode: "fast" or "accurate" for table extraction - + Returns: Dictionary with extracted content and metadata """ if not DOCLING_AVAILABLE: raise ImportError("Docling is required for PDF extraction") - + # Configure pipeline options pipeline_options = PipelineOptions() pipeline_options.do_ocr = use_ocr pipeline_options.do_table_structure = extract_tables - + if use_ocr: pipeline_options.ocr_options = EasyOcrOptions() - + if extract_tables: pipeline_options.table_structure_options.mode = ( - TableFormerMode.FAST if table_mode == "fast" - else TableFormerMode.ACCURATE + TableFormerMode.FAST if table_mode == "fast" else TableFormerMode.ACCURATE ) - + # Create converter converter = DocumentConverter() - + # Convert document logger.info(f"Processing PDF: {pdf_path}") result = converter.convert(pdf_path) - + # Extract metadata metadata = { "source": "docling", @@ -88,15 +87,15 @@ def extract_pdf_with_docling( "tables_extracted": extract_tables, "images_extracted": extract_images, "table_mode": table_mode if extract_tables else None, - } + }, } - + # Get the document doc = result.document - + # Export to markdown (includes tables, formulas, etc.) markdown_content = doc.export_to_markdown() - + # Count elements if hasattr(doc, 'tables'): metadata["num_tables"] = len(doc.tables) @@ -104,7 +103,7 @@ def extract_pdf_with_docling( metadata["num_images"] = len(doc.pictures) if hasattr(doc, 'pages'): metadata["num_pages"] = len(doc.pages) - + # Extract title if available title = Path(pdf_path).stem if hasattr(doc, 'texts') and doc.texts: @@ -112,7 +111,7 @@ def extract_pdf_with_docling( first_text = doc.texts[0] if hasattr(first_text, 'text') and len(first_text.text) < 200: title = first_text.text - + return { "content": markdown_content, "metadata": metadata, @@ -123,25 +122,25 @@ def extract_pdf_with_docling( def create_frame_from_pdf( pdf_path: str, - collection_uri: Optional[str] = None, - chunk_size: Optional[int] = None, - **docling_kwargs -) -> List[FrameRecord]: + collection_uri: str | None = None, + chunk_size: int | None = None, + **docling_kwargs, +) -> list[FrameRecord]: """ Create FrameRecord(s) from a PDF file. - + Args: pdf_path: Path to PDF file collection_uri: Optional collection to add document to chunk_size: If specified, chunk the content **docling_kwargs: Additional arguments for Docling - + Returns: List of FrameRecord objects """ # Extract content extracted = extract_pdf_with_docling(pdf_path, **docling_kwargs) - + # Create base frame data base_frame = { "uri": pdf_path, @@ -149,39 +148,41 @@ def create_frame_from_pdf( "metadata": extracted["metadata"], "record_type": "document", } - + # Handle chunking if requested if chunk_size: # Use ContextFrame's chunking capabilities from contextframe.extract.chunking import ChunkingMixin - + # Detect if content is markdown splitter_type = "markdown" if extracted["content"].startswith("#") else "text" - + chunks = ChunkingMixin.chunk_text( extracted["content"], chunk_size=chunk_size, chunk_overlap=100, splitter_type=splitter_type, ) - + frames = [] for i, chunk in enumerate(chunks): frame_data = base_frame.copy() - frame_data.update({ - "uri": f"{pdf_path}#chunk-{i}", - "content": chunk, - "metadata": { - **frame_data["metadata"], - "chunk_index": i, - "total_chunks": len(chunks), - }, - "parent_uri": pdf_path if i > 0 else None, - }) + frame_data.update( + { + "uri": f"{pdf_path}#chunk-{i}", + "content": chunk, + "metadata": { + **frame_data["metadata"], + "chunk_index": i, + "total_chunks": len(chunks), + }, + "parent_uri": pdf_path if i > 0 else None, + } + ) if collection_uri: frame_data["collection_uri"] = collection_uri frames.append(FrameRecord(**frame_data)) - + return frames else: # Single document @@ -195,13 +196,13 @@ def process_pdf_folder( folder_path: str, dataset_path: str, embed_model: str = "openai/text-embedding-3-small", - chunk_size: Optional[int] = 1000, + chunk_size: int | None = 1000, batch_size: int = 50, - **docling_kwargs + **docling_kwargs, ): """ Process all PDFs in a folder and store in ContextFrame. - + Args: folder_path: Path to folder containing PDFs dataset_path: Path for ContextFrame dataset @@ -212,14 +213,14 @@ def process_pdf_folder( """ # Initialize dataset dataset = FrameDataset(dataset_path) - + # Find all PDFs pdf_files = list(Path(folder_path).glob("**/*.pdf")) logger.info(f"Found {len(pdf_files)} PDF files") - + # Process PDFs in batches all_frames = [] - + for pdf_path in pdf_files: try: logger.info(f"Processing: {pdf_path}") @@ -227,29 +228,29 @@ def process_pdf_folder( str(pdf_path), collection_uri=f"pdfs/{pdf_path.parent.name}", chunk_size=chunk_size, - **docling_kwargs + **docling_kwargs, ) all_frames.extend(frames) logger.info(f"Created {len(frames)} frames from {pdf_path.name}") - + # Process batch if we've accumulated enough frames if len(all_frames) >= batch_size: logger.info(f"Embedding batch of {len(all_frames)} frames...") embedded_frames = embed_frames(all_frames, model=embed_model) dataset.add(embedded_frames) all_frames = [] - + except Exception as e: logger.error(f"Failed to process {pdf_path}: {e}") - + # Process remaining frames if all_frames: logger.info(f"Embedding final batch of {len(all_frames)} frames...") embedded_frames = embed_frames(all_frames, model=embed_model) dataset.add(embedded_frames) - + # Print summary - print(f"\nProcessing complete!") + print("\nProcessing complete!") print(f"Total PDFs processed: {len(pdf_files)}") print(f"Dataset location: {dataset_path}") @@ -261,31 +262,33 @@ def search_pdf_content( ): """ Search PDF content in ContextFrame dataset. - + Args: dataset_path: Path to ContextFrame dataset query: Search query limit: Number of results to return """ dataset = FrameDataset(dataset_path) - + # Search with embeddings results = dataset.search( query=query, limit=limit, search_type="hybrid", # Combines vector and keyword search ) - + print(f"\nSearch results for: '{query}'") print("-" * 50) - + for i, result in enumerate(results, 1): print(f"\n{i}. {result.title or result.uri}") print(f" Score: {result.score:.3f}") print(f" Source: {result.uri}") if result.metadata.get("chunk_index") is not None: - print(f" Chunk: {result.metadata['chunk_index'] + 1}/{result.metadata['total_chunks']}") - print(f"\n Content preview:") + print( + f" Chunk: {result.metadata['chunk_index'] + 1}/{result.metadata['total_chunks']}" + ) + print("\n Content preview:") print(f" {result.content[:200]}...") @@ -296,27 +299,27 @@ def advanced_docling_example(): if not DOCLING_AVAILABLE: print("This example requires Docling to be installed") return - + try: from docling.chunking import HybridChunker except ImportError: print("Advanced chunking requires newer Docling version") return - + # Convert document converter = DocumentConverter() result = converter.convert("https://arxiv.org/pdf/2408.09869") doc = result.document - + # Use Docling's hybrid chunker chunker = HybridChunker( tokenizer="sentence-transformers/all-MiniLM-L6-v2", max_tokens=512, ) - + # Chunk the document chunks = list(chunker.chunk(doc)) - + # Convert Docling chunks to FrameRecords frames = [] for i, chunk in enumerate(chunks): @@ -333,7 +336,7 @@ def advanced_docling_example(): record_type="document", ) frames.append(frame) - + return frames @@ -341,7 +344,7 @@ def advanced_docling_example(): # Example 1: Process a single PDF if DOCLING_AVAILABLE: print("Example 1: Processing a single PDF") - + # Extract a research paper frames = create_frame_from_pdf( "research_paper.pdf", @@ -351,12 +354,12 @@ def advanced_docling_example(): table_mode="accurate", ) print(f"Created {len(frames)} frames from PDF") - + # Store with embeddings embedded = embed_frames(frames, model="openai/text-embedding-3-small") dataset = FrameDataset("research_papers.lance") dataset.add(embedded) - + # Example 2: Process folder of PDFs (commented out) # process_pdf_folder( # folder_path="./documents/pdfs", @@ -367,15 +370,15 @@ def advanced_docling_example(): # extract_tables=True, # table_mode="accurate", # ) - + # Example 3: Search processed PDFs (commented out) # search_pdf_content( # dataset_path="./pdf_knowledge_base.lance", # query="machine learning optimization techniques", # limit=5, # ) - + # Example 4: Advanced Docling chunking (commented out) # frames = advanced_docling_example() # if frames: - # print(f"Created {len(frames)} frames using Docling's chunker") \ No newline at end of file + # print(f"Created {len(frames)} frames using Docling's chunker") diff --git a/examples/external_tools/ollama_local_embedding.py b/examples/external_tools/ollama_local_embedding.py index cd9de60..d1486aa 100644 --- a/examples/external_tools/ollama_local_embedding.py +++ b/examples/external_tools/ollama_local_embedding.py @@ -16,12 +16,11 @@ """ import logging -from pathlib import Path -from typing import List, Optional - from contextframe import FrameDataset, FrameRecord from contextframe.embed import embed_frames from contextframe.extract import BatchExtractor +from pathlib import Path +from typing import List, Optional logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -31,15 +30,20 @@ def test_ollama_connection(): """Test if Ollama is running and accessible.""" try: import requests + response = requests.get("http://localhost:11434/api/tags") if response.status_code == 200: models = response.json().get("models", []) embedding_models = [m for m in models if "embed" in m["name"]] if embedding_models: - logger.info(f"Found embedding models: {[m['name'] for m in embedding_models]}") + logger.info( + f"Found embedding models: {[m['name'] for m in embedding_models]}" + ) return True else: - logger.warning("No embedding models found. Run: ollama pull nomic-embed-text") + logger.warning( + "No embedding models found. Run: ollama pull nomic-embed-text" + ) return False except Exception as e: logger.error(f"Cannot connect to Ollama: {e}") @@ -56,7 +60,7 @@ def process_local_documents( ): """ Process documents using local Ollama embeddings. - + Args: folder_path: Path to folder containing documents dataset_path: Path for ContextFrame dataset @@ -67,10 +71,10 @@ def process_local_documents( # Check Ollama connection if not test_ollama_connection(): return - + # Initialize dataset dataset = FrameDataset(dataset_path) - + # Set up batch extractor extractor = BatchExtractor( patterns=["*.txt", "*.md", "*.json"], @@ -79,11 +83,11 @@ def process_local_documents( use_threads=True, max_workers=4, ) - + # Extract documents logger.info(f"Extracting documents from {folder_path}...") results = extractor.extract_folder(folder_path) - + # Convert to FrameRecords frames = [] for result in results: @@ -113,18 +117,18 @@ def process_local_documents( record_type="document", ) frames.append(frame) - + logger.info(f"Created {len(frames)} frames from {len(results)} documents") - + # Embed frames in batches total_embedded = 0 for i in range(0, len(frames), batch_size): - batch = frames[i:i + batch_size] - logger.info(f"Embedding batch {i//batch_size + 1} ({len(batch)} frames)...") - + batch = frames[i : i + batch_size] + logger.info(f"Embedding batch {i // batch_size + 1} ({len(batch)} frames)...") + try: embedded_batch = embed_frames( - batch, + batch, model=model, show_progress=True, ) @@ -132,7 +136,7 @@ def process_local_documents( total_embedded += len(embedded_batch) except Exception as e: logger.error(f"Failed to embed batch: {e}") - + logger.info(f"Successfully embedded {total_embedded} frames") return dataset @@ -142,16 +146,16 @@ def compare_embedding_models( ): """ Compare different local embedding models. - + Args: test_text: Text to embed with different models """ models = [ "ollama/nomic-embed-text", - "ollama/mxbai-embed-large", + "ollama/mxbai-embed-large", "ollama/all-minilm", ] - + for model in models: try: frame = FrameRecord( @@ -159,14 +163,14 @@ def compare_embedding_models( content=test_text, record_type="document", ) - + embedded = embed_frames([frame], model=model) embedding = embedded[0].embedding - + print(f"\n{model}:") print(f" Dimensions: {len(embedding)}") print(f" First 5 values: {embedding[:5]}") - + except Exception as e: print(f"\n{model}: Failed - {e}") @@ -174,28 +178,28 @@ def compare_embedding_models( def semantic_search_example(dataset_path: str = "local_docs.lance"): """ Example of semantic search using local embeddings. - + Args: dataset_path: Path to existing dataset """ dataset = FrameDataset(dataset_path) - + queries = [ "How to install Python packages", "Machine learning algorithms", "Database optimization techniques", ] - + for query in queries: print(f"\n\nSearching for: '{query}'") print("-" * 50) - + results = dataset.search( query=query, limit=3, search_type="vector", # Pure vector search ) - + for i, result in enumerate(results, 1): print(f"\n{i}. {result.title or result.uri}") print(f" Score: {result.score:.3f}") @@ -208,23 +212,23 @@ def process_with_custom_chunking( ): """ Example using semantic text splitting with token counting. - + Args: file_path: Path to document model: Embedding model """ from contextframe.extract import TextFileExtractor from contextframe.extract.chunking import ChunkingMixin - + # Extract with semantic chunking extractor = TextFileExtractor() - + # Mix in chunking capability class ChunkingTextExtractor(TextFileExtractor, ChunkingMixin): pass - + chunking_extractor = ChunkingTextExtractor() - + # Extract with token-based chunking result = chunking_extractor.extract_with_chunking( file_path, @@ -232,7 +236,7 @@ class ChunkingTextExtractor(TextFileExtractor, ChunkingMixin): tokenizer_model="gpt-3.5-turbo", # Use tiktoken for counting splitter_type="markdown" if file_path.endswith(".md") else "text", ) - + if result.success and result.chunks: frames = [] for i, chunk in enumerate(result.chunks): @@ -248,13 +252,13 @@ class ChunkingTextExtractor(TextFileExtractor, ChunkingMixin): record_type="document", ) frames.append(frame) - + # Embed with local model embedded = embed_frames(frames, model=model) - + print(f"Created {len(embedded)} embedded chunks") print(f"First chunk preview: {embedded[0].content[:100]}...") - + return embedded @@ -262,22 +266,22 @@ class ChunkingTextExtractor(TextFileExtractor, ChunkingMixin): # Example 1: Test Ollama connection print("Testing Ollama connection...") test_ollama_connection() - + # Example 2: Compare embedding models (commented out) # compare_embedding_models() - + # Example 3: Process a folder of documents (commented out) # dataset = process_local_documents( # folder_path="./documents", # model="ollama/nomic-embed-text", # chunk_size=1000, # ) - + # Example 4: Search embedded documents (commented out) # semantic_search_example("local_docs.lance") - + # Example 5: Custom chunking with token counting (commented out) # process_with_custom_chunking( # "README.md", # model="ollama/nomic-embed-text", - # ) \ No newline at end of file + # ) diff --git a/examples/external_tools/reducto_pipeline.py b/examples/external_tools/reducto_pipeline.py index 8bceb7f..bf612b3 100644 --- a/examples/external_tools/reducto_pipeline.py +++ b/examples/external_tools/reducto_pipeline.py @@ -13,22 +13,21 @@ Requirements: pip install requests contextframe[extract,embed] - + Get API key: https://reducto.ai API Docs: https://docs.reducto.ai """ +import json import logging -from pathlib import Path -from typing import List, Dict, Any, Optional import os -import json import requests import time - from contextframe import FrameDataset, FrameRecord from contextframe.embed import embed_frames from contextframe.extract.chunking import ChunkingMixin +from pathlib import Path +from typing import Any, Dict, List, Optional logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -37,20 +36,20 @@ def extract_with_reducto( file_path: str, api_key: str, - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, wait_for_completion: bool = True, max_wait_time: int = 300, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Extract content using Reducto API. - + Args: file_path: Path to document api_key: Reducto API key options: Processing options (see Reducto docs) wait_for_completion: Whether to wait for async processing max_wait_time: Maximum time to wait (seconds) - + Returns: Dictionary with extracted content and metadata """ @@ -63,69 +62,69 @@ def extract_with_reducto( "output_format": "markdown", "chunk_size": None, # Let Reducto handle chunking } - + # Upload file logger.info(f"Uploading {file_path} to Reducto...") - + with open(file_path, "rb") as f: files = {"file": (Path(file_path).name, f)} headers = {"Authorization": f"Bearer {api_key}"} - + response = requests.post( "https://api.reducto.ai/v1/documents", files=files, headers=headers, data={"options": json.dumps(options)}, ) - + if response.status_code != 200: raise Exception(f"Upload failed: {response.status_code} - {response.text}") - + result = response.json() document_id = result["document_id"] - + logger.info(f"Document ID: {document_id}") - + # Wait for processing if requested if wait_for_completion: logger.info("Waiting for processing to complete...") start_time = time.time() - + while True: # Check status status_response = requests.get( f"https://api.reducto.ai/v1/documents/{document_id}/status", headers=headers, ) - + if status_response.status_code != 200: raise Exception(f"Status check failed: {status_response.text}") - + status = status_response.json() - + if status["status"] == "completed": break elif status["status"] == "failed": raise Exception(f"Processing failed: {status.get('error')}") - + if time.time() - start_time > max_wait_time: raise TimeoutError(f"Processing exceeded {max_wait_time} seconds") - + time.sleep(2) - + # Get results logger.info("Retrieving extraction results...") - + results_response = requests.get( f"https://api.reducto.ai/v1/documents/{document_id}/content", headers=headers, ) - + if results_response.status_code != 200: raise Exception(f"Failed to get results: {results_response.text}") - + content_data = results_response.json() - + # Build response metadata = { "source": "reducto", @@ -134,7 +133,7 @@ def extract_with_reducto( "language": content_data.get("language"), "confidence": content_data.get("confidence"), } - + # Count extracted elements if "elements" in content_data: element_types = {} @@ -142,7 +141,7 @@ def extract_with_reducto( elem_type = element.get("type", "unknown") element_types[elem_type] = element_types.get(elem_type, 0) + 1 metadata["element_types"] = element_types - + return { "content": content_data.get("content", ""), "metadata": metadata, @@ -157,33 +156,33 @@ def extract_with_reducto( def create_frames_from_reducto( file_path: str, api_key: str, - chunk_size: Optional[int] = None, + chunk_size: int | None = None, include_specialized_frames: bool = False, - **reducto_options -) -> List[FrameRecord]: + **reducto_options, +) -> list[FrameRecord]: """ Create FrameRecords from Reducto extraction. - + Args: file_path: Path to document api_key: Reducto API key chunk_size: Optional chunk size for additional splitting include_specialized_frames: Create separate frames for tables/figures **reducto_options: Additional Reducto options - + Returns: List of FrameRecord objects """ # Extract with Reducto result = extract_with_reducto(file_path, api_key, options=reducto_options) - + frames = [] base_uri = file_path - + # Main content frame(s) main_content = result["content"] base_metadata = result["metadata"].copy() - + if chunk_size and len(main_content) > chunk_size: # Additional chunking if needed chunks = ChunkingMixin.chunk_text( @@ -192,14 +191,16 @@ def create_frames_from_reducto( chunk_overlap=50, splitter_type="markdown", # Reducto usually outputs markdown ) - + for i, chunk in enumerate(chunks): frame_metadata = base_metadata.copy() - frame_metadata.update({ - "chunk_index": i, - "total_chunks": len(chunks), - }) - + frame_metadata.update( + { + "chunk_index": i, + "total_chunks": len(chunks), + } + ) + frame = FrameRecord( uri=f"{base_uri}#chunk-{i}", title=Path(file_path).stem, @@ -218,7 +219,7 @@ def create_frames_from_reducto( record_type="document", ) frames.append(frame) - + # Specialized frames for tables, figures, equations if include_specialized_frames: # Table frames @@ -230,23 +231,23 @@ def create_frames_from_reducto( "page": table.get("page"), "confidence": table.get("confidence"), } - + # Format table content table_content = table.get("markdown", "") if not table_content and "data" in table: # Convert structured data to markdown table table_content = _format_table_as_markdown(table["data"]) - + frame = FrameRecord( uri=f"{base_uri}#table-{i}", - title=f"Table {i+1} - {Path(file_path).stem}", + title=f"Table {i + 1} - {Path(file_path).stem}", content=table_content, metadata=table_metadata, record_type="document", parent_uri=base_uri, ) frames.append(frame) - + # Figure frames for i, figure in enumerate(result.get("figures", [])): figure_metadata = { @@ -257,17 +258,17 @@ def create_frames_from_reducto( "caption": figure.get("caption"), "figure_type": figure.get("figure_type"), } - + frame = FrameRecord( uri=f"{base_uri}#figure-{i}", - title=f"Figure {i+1} - {Path(file_path).stem}", - content=figure.get("caption", f"Figure {i+1}"), + title=f"Figure {i + 1} - {Path(file_path).stem}", + content=figure.get("caption", f"Figure {i + 1}"), metadata=figure_metadata, record_type="document", parent_uri=base_uri, ) frames.append(frame) - + # Equation frames for i, equation in enumerate(result.get("equations", [])): equation_metadata = { @@ -277,33 +278,35 @@ def create_frames_from_reducto( "page": equation.get("page"), "latex": equation.get("latex"), } - + frame = FrameRecord( uri=f"{base_uri}#equation-{i}", - title=f"Equation {i+1} - {Path(file_path).stem}", - content=equation.get("latex", equation.get("text", f"Equation {i+1}")), + title=f"Equation {i + 1} - {Path(file_path).stem}", + content=equation.get( + "latex", equation.get("text", f"Equation {i + 1}") + ), metadata=equation_metadata, record_type="document", parent_uri=base_uri, ) frames.append(frame) - + return frames -def _format_table_as_markdown(table_data: List[List[str]]) -> str: +def _format_table_as_markdown(table_data: list[list[str]]) -> str: """Convert table data to markdown format.""" if not table_data or not table_data[0]: return "" - + # Header markdown = "| " + " | ".join(str(cell) for cell in table_data[0]) + " |\n" markdown += "|" + "|".join(["---"] * len(table_data[0])) + "|\n" - + # Rows for row in table_data[1:]: markdown += "| " + " | ".join(str(cell) for cell in row) + " |\n" - + return markdown @@ -315,7 +318,7 @@ def process_scientific_documents( ): """ Process scientific documents with equations and figures. - + Args: folder_path: Folder containing documents dataset_path: ContextFrame dataset path @@ -325,17 +328,17 @@ def process_scientific_documents( # Find scientific documents folder = Path(folder_path) pdf_files = list(folder.glob("**/*.pdf")) - + logger.info(f"Found {len(pdf_files)} PDF files") - + # Initialize dataset dataset = FrameDataset(dataset_path) - + # Process with specialized handling for pdf_path in pdf_files: try: logger.info(f"Processing scientific document: {pdf_path}") - + # Extract with equation and figure parsing frames = create_frames_from_reducto( str(pdf_path), @@ -346,27 +349,29 @@ def process_scientific_documents( parse_equations=True, output_format="markdown", ) - + # Separate frames by type for different embedding strategies text_frames = [f for f in frames if f.metadata.get("type") != "equation"] - equation_frames = [f for f in frames if f.metadata.get("type") == "equation"] - + equation_frames = [ + f for f in frames if f.metadata.get("type") == "equation" + ] + # Embed text content normally if text_frames: embedded_text = embed_frames(text_frames, model=embed_model) dataset.add(embedded_text) - + # Equations might need special handling if equation_frames: # Could use a different model or preprocessing embedded_equations = embed_frames(equation_frames, model=embed_model) dataset.add(embedded_equations) - + logger.info(f"Added {len(frames)} frames from {pdf_path.name}") - + except Exception as e: logger.error(f"Failed to process {pdf_path}: {e}") - + return dataset @@ -375,17 +380,17 @@ def compare_with_native_extraction(): Compare Reducto extraction with native extraction. """ from contextframe.extract import TextFileExtractor - + # Simple text file text_file = "sample.txt" - + # Native extraction print("Native extraction:") extractor = TextFileExtractor() native_result = extractor.extract(text_file) print(f"Content length: {len(native_result.content)}") print(f"Metadata: {native_result.metadata}") - + # Reducto extraction (would handle PDFs, complex layouts, etc.) api_key = os.getenv("REDUCTO_API_KEY") if api_key: @@ -397,7 +402,7 @@ def compare_with_native_extraction(): print(f"Additional elements: {list(reducto_result.keys())}") except Exception as e: print(f"Reducto extraction failed: {e}") - + print("\nKey differences:") print("- Reducto handles complex formats (PDF, DOCX, etc.)") print("- Reducto extracts tables, figures, equations separately") @@ -406,25 +411,25 @@ def compare_with_native_extraction(): def async_batch_processing( - file_paths: List[str], + file_paths: list[str], api_key: str, dataset_path: str, ): """ Process multiple documents asynchronously with Reducto. - + Args: file_paths: List of document paths api_key: Reducto API key dataset_path: ContextFrame dataset path """ headers = {"Authorization": f"Bearer {api_key}"} - + # Submit all documents document_ids = [] for file_path in file_paths: logger.info(f"Submitting {file_path}...") - + with open(file_path, "rb") as f: files = {"file": (Path(file_path).name, f)} response = requests.post( @@ -433,16 +438,16 @@ def async_batch_processing( headers=headers, data={"options": json.dumps({"output_format": "markdown"})}, ) - + if response.status_code == 200: doc_id = response.json()["document_id"] document_ids.append((file_path, doc_id)) else: logger.error(f"Failed to submit {file_path}: {response.text}") - + # Wait for all to complete logger.info(f"Waiting for {len(document_ids)} documents to process...") - + completed = [] while document_ids: remaining = [] @@ -451,7 +456,7 @@ def async_batch_processing( f"https://api.reducto.ai/v1/documents/{doc_id}/status", headers=headers, ) - + if status_response.status_code == 200: status = status_response.json() if status["status"] == "completed": @@ -460,18 +465,18 @@ def async_batch_processing( remaining.append((file_path, doc_id)) else: logger.error(f"Status check failed for {file_path}") - + document_ids = remaining if document_ids: time.sleep(2) - + # Process completed documents dataset = FrameDataset(dataset_path) for file_path, doc_id in completed: logger.info(f"Processing results for {file_path}...") # Retrieve and process results # ... (similar to create_frames_from_reducto) - + return dataset @@ -489,7 +494,7 @@ def async_batch_processing( print(f"Created {len(frames)} frames with Reducto") else: print("Set REDUCTO_API_KEY environment variable") - + # Example 2: Process scientific documents (commented out) # if api_key: # dataset = process_scientific_documents( @@ -497,10 +502,10 @@ def async_batch_processing( # dataset_path="./scientific_docs.lance", # api_key=api_key, # ) - + # Example 3: Compare with native extraction (commented out) # compare_with_native_extraction() - + # Example 4: Async batch processing (commented out) # if api_key: # files = ["doc1.pdf", "doc2.pdf", "doc3.pdf"] @@ -508,4 +513,4 @@ def async_batch_processing( # files, # api_key=api_key, # dataset_path="./batch_docs.lance", - # ) \ No newline at end of file + # ) diff --git a/examples/external_tools/unstructured_io_pipeline.py b/examples/external_tools/unstructured_io_pipeline.py index 120661c..2cd09e3 100644 --- a/examples/external_tools/unstructured_io_pipeline.py +++ b/examples/external_tools/unstructured_io_pipeline.py @@ -11,19 +11,18 @@ Requirements: # For API usage: pip install unstructured-client contextframe[extract,embed] - + # For local usage (heavy dependencies): pip install "unstructured[all-docs]" contextframe[extract,embed] """ import logging -from pathlib import Path -from typing import List, Dict, Any, Optional import os - from contextframe import FrameDataset, FrameRecord from contextframe.embed import embed_frames from contextframe.extract.chunking import ChunkingMixin +from pathlib import Path +from typing import Any, Dict, List, Optional logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -33,17 +32,17 @@ def extract_with_unstructured_api( file_path: str, api_key: str, strategy: str = "hi_res", - languages: Optional[List[str]] = None, -) -> Dict[str, Any]: + languages: list[str] | None = None, +) -> dict[str, Any]: """ Extract content using Unstructured.io API. - + Args: file_path: Path to document api_key: Unstructured API key strategy: Processing strategy - "hi_res", "fast", or "auto" languages: List of languages for OCR (e.g., ["eng", "spa"]) - + Returns: Dictionary with extracted content and metadata """ @@ -56,17 +55,17 @@ def extract_with_unstructured_api( "unstructured-client is required for API usage. " "Install with: pip install unstructured-client" ) - + # Initialize client client = UnstructuredClient( api_key_auth=api_key, server_url="https://api.unstructured.io", ) - + # Prepare file with open(file_path, "rb") as f: file_content = f.read() - + # Set up parameters req = shared.PartitionParameters( files=shared.Files( @@ -76,15 +75,15 @@ def extract_with_unstructured_api( strategy=strategy, languages=languages, ) - + try: # Process document logger.info(f"Processing {file_path} with Unstructured API...") resp = client.general.partition(req) - + # Extract elements elements = resp.elements - + # Group elements by type content_parts = [] metadata = { @@ -93,12 +92,14 @@ def extract_with_unstructured_api( "num_elements": len(elements), "element_types": {}, } - + # Process elements for element in elements: element_type = element.get("type", "unknown") - metadata["element_types"][element_type] = metadata["element_types"].get(element_type, 0) + 1 - + metadata["element_types"][element_type] = ( + metadata["element_types"].get(element_type, 0) + 1 + ) + # Format based on type text = element.get("text", "") if element_type == "Title": @@ -111,31 +112,29 @@ def extract_with_unstructured_api( content_parts.append(f"\n{table_html}\n") else: content_parts.append(text) - + return { "content": "\n\n".join(content_parts), "metadata": metadata, "elements": elements, # Keep raw elements for advanced processing } - + except SDKError as e: logger.error(f"API error: {e}") raise def extract_with_unstructured_local( - file_path: str, - strategy: str = "hi_res", - **kwargs -) -> Dict[str, Any]: + file_path: str, strategy: str = "hi_res", **kwargs +) -> dict[str, Any]: """ Extract content using local Unstructured library. - + Args: file_path: Path to document strategy: Processing strategy **kwargs: Additional arguments for partition - + Returns: Dictionary with extracted content and metadata """ @@ -146,16 +145,12 @@ def extract_with_unstructured_local( "unstructured package is required for local processing. " "Install with: pip install 'unstructured[all-docs]'" ) - + logger.info(f"Processing {file_path} with local Unstructured...") - + # Partition document - elements = partition( - filename=file_path, - strategy=strategy, - **kwargs - ) - + elements = partition(filename=file_path, strategy=strategy, **kwargs) + # Group elements content_parts = [] metadata = { @@ -164,12 +159,14 @@ def extract_with_unstructured_local( "num_elements": len(elements), "element_types": {}, } - + # Process elements for element in elements: element_type = element.category - metadata["element_types"][element_type] = metadata["element_types"].get(element_type, 0) + 1 - + metadata["element_types"][element_type] = ( + metadata["element_types"].get(element_type, 0) + 1 + ) + # Format based on type if hasattr(element, "text"): text = element.text @@ -179,13 +176,15 @@ def extract_with_unstructured_local( content_parts.append(f"## {text}") elif element_type == "Table": # Get table as HTML if available - if hasattr(element, "metadata") and hasattr(element.metadata, "text_as_html"): + if hasattr(element, "metadata") and hasattr( + element.metadata, "text_as_html" + ): content_parts.append(f"\n{element.metadata.text_as_html}\n") else: content_parts.append(text) else: content_parts.append(text) - + return { "content": "\n\n".join(content_parts), "metadata": metadata, @@ -195,15 +194,15 @@ def extract_with_unstructured_local( def create_frames_from_unstructured( file_path: str, - api_key: Optional[str] = None, + api_key: str | None = None, use_api: bool = True, - chunk_size: Optional[int] = None, + chunk_size: int | None = None, strategy: str = "hi_res", - **extract_kwargs -) -> List[FrameRecord]: + **extract_kwargs, +) -> list[FrameRecord]: """ Create FrameRecords from document using Unstructured. - + Args: file_path: Path to document api_key: API key (required if use_api=True) @@ -211,7 +210,7 @@ def create_frames_from_unstructured( chunk_size: Optional chunk size for splitting strategy: Extraction strategy **extract_kwargs: Additional extraction arguments - + Returns: List of FrameRecord objects """ @@ -221,20 +220,15 @@ def create_frames_from_unstructured( api_key = os.getenv("UNSTRUCTURED_API_KEY") if not api_key: raise ValueError("API key required for Unstructured API") - + extracted = extract_with_unstructured_api( - file_path, - api_key=api_key, - strategy=strategy, - **extract_kwargs + file_path, api_key=api_key, strategy=strategy, **extract_kwargs ) else: extracted = extract_with_unstructured_local( - file_path, - strategy=strategy, - **extract_kwargs + file_path, strategy=strategy, **extract_kwargs ) - + # Create base frame base_frame = { "uri": file_path, @@ -242,33 +236,35 @@ def create_frames_from_unstructured( "metadata": extracted["metadata"], "record_type": "document", } - + # Handle chunking if requested if chunk_size: # Detect content type for intelligent chunking splitter_type = "markdown" if extracted["content"].count("#") > 2 else "text" - + chunks = ChunkingMixin.chunk_text( extracted["content"], chunk_size=chunk_size, chunk_overlap=50, splitter_type=splitter_type, ) - + frames = [] for i, chunk in enumerate(chunks): frame_data = base_frame.copy() - frame_data.update({ - "uri": f"{file_path}#chunk-{i}", - "content": chunk, - "metadata": { - **frame_data["metadata"], - "chunk_index": i, - "total_chunks": len(chunks), - }, - }) + frame_data.update( + { + "uri": f"{file_path}#chunk-{i}", + "content": chunk, + "metadata": { + **frame_data["metadata"], + "chunk_index": i, + "total_chunks": len(chunks), + }, + } + ) frames.append(FrameRecord(**frame_data)) - + return frames else: base_frame["content"] = extracted["content"] @@ -278,15 +274,15 @@ def create_frames_from_unstructured( def process_folder_with_unstructured( folder_path: str, dataset_path: str, - api_key: Optional[str] = None, + api_key: str | None = None, use_api: bool = True, embed_model: str = "openai/text-embedding-3-small", - chunk_size: Optional[int] = 1000, - file_patterns: List[str] = None, + chunk_size: int | None = 1000, + file_patterns: list[str] = None, ): """ Process a folder of documents using Unstructured. - + Args: folder_path: Path to folder dataset_path: Path for ContextFrame dataset @@ -298,18 +294,18 @@ def process_folder_with_unstructured( """ if file_patterns is None: file_patterns = ["*.pdf", "*.docx", "*.pptx", "*.html", "*.md"] - + # Initialize dataset dataset = FrameDataset(dataset_path) - + # Find files all_files = [] folder = Path(folder_path) for pattern in file_patterns: all_files.extend(folder.glob(f"**/{pattern}")) - + logger.info(f"Found {len(all_files)} files to process") - + # Process each file all_frames = [] for file_path in all_files: @@ -324,17 +320,17 @@ def process_folder_with_unstructured( ) all_frames.extend(frames) logger.info(f"Created {len(frames)} frames from {file_path.name}") - + except Exception as e: logger.error(f"Failed to process {file_path}: {e}") - + # Embed frames if all_frames: logger.info(f"Embedding {len(all_frames)} frames...") embedded_frames = embed_frames(all_frames, model=embed_model) dataset.add(embedded_frames) logger.info(f"Stored {len(embedded_frames)} frames in {dataset_path}") - + return dataset @@ -344,13 +340,11 @@ def advanced_element_processing(): """ # This example shows how to process specific element types api_key = os.getenv("UNSTRUCTURED_API_KEY") - + extracted = extract_with_unstructured_api( - "document.pdf", - api_key=api_key, - strategy="hi_res" + "document.pdf", api_key=api_key, strategy="hi_res" ) - + # Group elements by type for specialized processing elements_by_type = {} for element in extracted["elements"]: @@ -358,22 +352,24 @@ def advanced_element_processing(): if elem_type not in elements_by_type: elements_by_type[elem_type] = [] elements_by_type[elem_type].append(element) - + # Extract specific content tables = elements_by_type.get("Table", []) images = elements_by_type.get("Image", []) formulas = elements_by_type.get("Formula", []) - + print(f"Found {len(tables)} tables, {len(images)} images, {len(formulas)} formulas") - + # Create specialized frames frames = [] - + # Table frames for i, table in enumerate(tables): frame = FrameRecord( uri=f"document.pdf#table-{i}", - content=table.get("metadata", {}).get("text_as_html", table.get("text", "")), + content=table.get("metadata", {}).get( + "text_as_html", table.get("text", "") + ), metadata={ "type": "table", "source": "unstructured.io", @@ -381,7 +377,7 @@ def advanced_element_processing(): record_type="document", ) frames.append(frame) - + return frames @@ -397,7 +393,7 @@ def advanced_element_processing(): strategy="hi_res", ) print(f"Created {len(frames)} frames") - + # Example 2: Process with local library (commented out) # frames = create_frames_from_unstructured( # "document.pdf", @@ -405,7 +401,7 @@ def advanced_element_processing(): # chunk_size=1000, # strategy="fast", # ) - + # Example 3: Process folder (commented out) # dataset = process_folder_with_unstructured( # folder_path="./documents", @@ -414,7 +410,7 @@ def advanced_element_processing(): # use_api=True, # chunk_size=1000, # ) - + # Example 4: Advanced element processing (commented out) # if api_key: - # frames = advanced_element_processing() \ No newline at end of file + # frames = advanced_element_processing() diff --git a/examples/semantic_chunking_demo.py b/examples/semantic_chunking_demo.py index c65e591..f043c22 100644 --- a/examples/semantic_chunking_demo.py +++ b/examples/semantic_chunking_demo.py @@ -8,12 +8,12 @@ pip install contextframe[extract] """ -from contextframe.extract.chunking import semantic_splitter, ChunkingMixin +from contextframe.extract.chunking import ChunkingMixin, semantic_splitter def compare_chunking_methods(): """Compare different chunking approaches.""" - + # Sample markdown text with structure markdown_text = """# Introduction to Machine Learning @@ -48,35 +48,35 @@ def compare_chunking_methods(): print("=" * 80) print("SEMANTIC CHUNKING DEMO") print("=" * 80) - + # 1. Character-based chunking print("\n1. CHARACTER-BASED CHUNKING (300 chars)") print("-" * 40) - + char_chunks = semantic_splitter( [markdown_text], chunk_size=300, splitter_type="text", # Plain text mode ) - + for i, (_, chunk) in enumerate(char_chunks): print(f"\nChunk {i + 1}:") print(chunk[:100] + "..." if len(chunk) > 100 else chunk) - + # 2. Semantic markdown chunking print("\n\n2. SEMANTIC MARKDOWN CHUNKING (300 chars)") print("-" * 40) - + md_chunks = semantic_splitter( [markdown_text], chunk_size=300, splitter_type="markdown", # Markdown-aware ) - + for i, (_, chunk) in enumerate(md_chunks): print(f"\nChunk {i + 1}:") print(chunk[:100] + "..." if len(chunk) > 100 else chunk) - + # Show the difference print("\n\n" + "=" * 80) print("KEY DIFFERENCES:") @@ -87,7 +87,7 @@ def compare_chunking_methods(): def token_based_chunking_example(): """Demonstrate token-based chunking for LLM compatibility.""" - + text = """ Natural language processing (NLP) is a field of artificial intelligence that focuses on the interaction between computers and humans through natural language. The ultimate objective of NLP is to enable computers to understand, interpret, and generate human language in a way that is both meaningful and useful. @@ -97,19 +97,19 @@ def token_based_chunking_example(): - Named entity recognition: Identifying people, places, and organizations - Text summarization: Creating concise summaries of longer documents """ - + print("\n\n" + "=" * 80) print("TOKEN-BASED CHUNKING (for LLMs)") print("=" * 80) - + # Token-based chunking with GPT tokenizer token_chunks = semantic_splitter( [text], chunk_size=50, # 50 tokens tokenizer_model="gpt-3.5-turbo", ) - - print(f"\nUsing GPT-3.5 tokenizer (50 tokens per chunk):") + + print("\nUsing GPT-3.5 tokenizer (50 tokens per chunk):") for i, (_, chunk) in enumerate(token_chunks): print(f"\nChunk {i + 1}:") print(chunk) @@ -117,7 +117,7 @@ def token_based_chunking_example(): def code_splitting_example(): """Demonstrate code-aware splitting.""" - + python_code = '''def process_data(input_file, output_file): """Process data from input file and save to output file.""" # Read the input data @@ -156,7 +156,7 @@ def transform_record(record): print("\n\n" + "=" * 80) print("CODE-AWARE SPLITTING") print("=" * 80) - + try: # This requires tree-sitter-python to be installed code_chunks = semantic_splitter( @@ -165,8 +165,8 @@ def transform_record(record): splitter_type="code", language="python", ) - - print(f"\nPython code split into semantic chunks:") + + print("\nPython code split into semantic chunks:") for i, (_, chunk) in enumerate(code_chunks): print(f"\nChunk {i + 1}:") print(chunk) @@ -174,7 +174,7 @@ def transform_record(record): except ImportError: print("\nCode splitting requires tree-sitter-python:") print("pip install tree-sitter-python") - + # Fall back to text splitting print("\nFalling back to text-based splitting:") text_chunks = semantic_splitter( @@ -182,7 +182,7 @@ def transform_record(record): chunk_size=200, splitter_type="text", ) - + for i, (_, chunk) in enumerate(text_chunks): print(f"\nChunk {i + 1}:") print(chunk[:100] + "..." if len(chunk) > 100 else chunk) @@ -190,22 +190,22 @@ def transform_record(record): def range_based_chunking(): """Demonstrate range-based chunk sizing.""" - + text = "This is a sentence. " * 50 # Repetitive text - + print("\n\n" + "=" * 80) print("RANGE-BASED CHUNKING") print("=" * 80) - + # Note: semantic-text-splitter supports range-based sizing # by creating the splitter with a tuple from semantic_text_splitter import TextSplitter - + # Chunks will be between 100-200 characters splitter = TextSplitter((100, 200)) chunks = splitter.chunks(text) - - print(f"\nChunks sized between 100-200 characters:") + + print("\nChunks sized between 100-200 characters:") for i, chunk in enumerate(chunks): print(f"\nChunk {i + 1} ({len(chunk)} chars):") print(chunk) @@ -217,7 +217,7 @@ def range_based_chunking(): token_based_chunking_example() code_splitting_example() range_based_chunking() - + print("\n\n" + "=" * 80) print("SUMMARY") print("=" * 80) @@ -233,4 +233,4 @@ def range_based_chunking(): - RAG applications - LLM context windows - Document processing pipelines -""") \ No newline at end of file +""") diff --git a/pyproject.toml b/pyproject.toml index b299c32..728b7d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,7 +105,12 @@ encode = [ serve = [ "mcp>=0.9.0", "fastapi>=0.104.0", - "uvicorn>=0.24.0", + "uvicorn[standard]>=0.24.0", + "sse-starlette>=1.6.0", + "python-jose[cryptography]>=3.3.0", + "python-multipart>=0.0.6", + "httpx>=0.25.0", + "slowapi>=0.1.9", ] all = [ "contextframe[extract,embed,enhance,encode,serve,io]", diff --git a/tests/test_embed.py b/tests/test_embed.py index aa47e72..05cbfbb 100644 --- a/tests/test_embed.py +++ b/tests/test_embed.py @@ -1,7 +1,6 @@ """Tests for the embed module.""" import pytest - from contextframe.embed import ( BatchEmbedder, EmbeddingProvider, @@ -12,34 +11,34 @@ class MockProvider(EmbeddingProvider): """Mock embedding provider for testing.""" - + def __init__(self, dimension: int = 128): super().__init__(model="mock-model") self.dimension = dimension self.call_count = 0 - + def embed(self, texts, **kwargs): """Generate mock embeddings.""" texts = self.validate_texts(texts) self.call_count += 1 - + # Generate fake embeddings embeddings = [] for text in texts: # Simple hash-based fake embedding - embedding = [float(ord(c) % 10) / 10 for c in text[:self.dimension]] + embedding = [float(ord(c) % 10) / 10 for c in text[: self.dimension]] # Pad if needed if len(embedding) < self.dimension: embedding.extend([0.0] * (self.dimension - len(embedding))) embeddings.append(embedding) - + return EmbeddingResult( embeddings=embeddings, model=self.model, dimension=self.dimension, usage={"prompt_tokens": len(texts) * 10, "total_tokens": len(texts) * 10}, ) - + def get_model_info(self): return { "model": self.model, @@ -47,11 +46,11 @@ def get_model_info(self): "provider": "mock", "supports_batch": True, } - + @property def supports_batch(self): return True - + @property def max_batch_size(self): return 10 @@ -59,193 +58,183 @@ def max_batch_size(self): class TestEmbeddingResult: """Test EmbeddingResult dataclass.""" - + def test_single_embedding_normalization(self): """Test that single embeddings are normalized to 2D.""" result = EmbeddingResult( - embeddings=[[0.1, 0.2, 0.3]], - model="test", - dimension=3 + embeddings=[[0.1, 0.2, 0.3]], model="test", dimension=3 ) assert result.embeddings[0] == [0.1, 0.2, 0.3] assert len(result.embeddings) == 1 - + def test_dimension_validation(self): """Test dimension validation.""" with pytest.raises(ValueError, match="does not match"): EmbeddingResult( embeddings=[[0.1, 0.2]], model="test", - dimension=3 # Mismatch + dimension=3, # Mismatch ) - + def test_dimension_inference(self): """Test dimension is inferred from embeddings.""" result = EmbeddingResult( - embeddings=[[0.1, 0.2, 0.3, 0.4]], - model="test", - dimension=None + embeddings=[[0.1, 0.2, 0.3, 0.4]], model="test", dimension=None ) assert result.dimension == 4 class TestLiteLLMProvider: """Test LiteLLM provider.""" - + def test_initialization(self): """Test provider initialization.""" provider = LiteLLMProvider(model="text-embedding-ada-002") assert provider.model == "text-embedding-ada-002" assert provider.supports_batch is True - + def test_model_info(self): """Test getting model information.""" provider = LiteLLMProvider(model="text-embedding-ada-002") info = provider.get_model_info() - + assert info["model"] == "text-embedding-ada-002" assert info["dimension"] == 1536 # Known dimension assert info["provider"] == "openai" assert info["supports_batch"] is True - + def test_provider_detection(self): """Test provider detection from model names.""" provider = LiteLLMProvider(model="cohere/embed-english-v3.0") assert provider._detect_provider() == "cohere" - + provider = LiteLLMProvider(model="embed-english-v3.0") assert provider._detect_provider() == "cohere" - + provider = LiteLLMProvider(model="voyage-01") assert provider._detect_provider() == "voyage" - + def test_text_validation(self): """Test text validation.""" provider = LiteLLMProvider() - + # Valid texts assert provider.validate_texts("hello") == ["hello"] assert provider.validate_texts(["hello", "world"]) == ["hello", "world"] - + # Invalid texts with pytest.raises(ValueError, match="No texts"): provider.validate_texts([]) - + with pytest.raises(ValueError, match="must be strings"): provider.validate_texts([123]) - + with pytest.raises(ValueError, match="Empty texts"): provider.validate_texts(["", "hello"]) class TestBatchEmbedder: """Test batch embedder functionality.""" - + def test_basic_batch_embedding(self): """Test basic batch embedding.""" provider = MockProvider(dimension=10) embedder = BatchEmbedder(provider, batch_size=2) - + texts = ["text1", "text2", "text3", "text4", "text5"] result = embedder.embed_batch(texts) - + assert len(result.embeddings) == 5 assert result.dimension == 10 assert provider.call_count == 3 # 5 texts with batch size 2 = 3 calls - + def test_progress_callback(self): """Test progress callback.""" provider = MockProvider() progress_calls = [] - + def progress(completed, total): progress_calls.append((completed, total)) - + embedder = BatchEmbedder(provider, batch_size=2, progress_callback=progress) texts = ["text1", "text2", "text3"] embedder.embed_batch(texts) - + # Should have progress updates after each batch assert len(progress_calls) == 2 assert progress_calls[0] == (2, 3) # First batch assert progress_calls[1] == (3, 3) # Second batch - + def test_embed_documents(self): """Test embedding documents.""" provider = MockProvider(dimension=5) embedder = BatchEmbedder(provider) - + documents = [ {"id": 1, "content": "Hello world", "metadata": "test"}, {"id": 2, "content": "Goodbye world", "other": "data"}, {"id": 3, "title": "No content"}, # Missing content field ] - + result_docs = embedder.embed_documents(documents) - + assert len(result_docs) == 3 - + # First two should have embeddings assert result_docs[0]["embedding"] is not None assert len(result_docs[0]["embedding"]) == 5 assert result_docs[0]["embedding_model"] == "mock-model" assert result_docs[0]["embedding_dimension"] == 5 - + assert result_docs[1]["embedding"] is not None - + # Third should have error assert result_docs[2]["embedding"] is None assert result_docs[2]["embedding_error"] == "No text content" - + def test_empty_batch_error(self): """Test error on empty batch.""" provider = MockProvider() embedder = BatchEmbedder(provider) - + with pytest.raises(ValueError, match="No texts provided"): embedder.embed_batch([]) class TestIntegration: """Test integration with extraction results.""" - + def test_embed_extraction_results(self): """Test embedding extraction results.""" from contextframe.embed.integration import embed_extraction_results from contextframe.extract import ExtractionResult - + # Create extraction results results = [ ExtractionResult( content="Document 1 content", metadata={"title": "Doc 1"}, - chunks=["chunk1", "chunk2"] - ), - ExtractionResult( - content="Document 2 content", - metadata={"title": "Doc 2"} + chunks=["chunk1", "chunk2"], ), + ExtractionResult(content="Document 2 content", metadata={"title": "Doc 2"}), ] - + # Embed with mock provider provider = MockProvider(dimension=8) enhanced = embed_extraction_results( - results, - provider, - embed_content=True, - embed_chunks=True + results, provider, embed_content=True, embed_chunks=True ) - + # Check content embeddings assert "content_embedding" in enhanced[0].metadata assert len(enhanced[0].metadata["content_embedding"]) == 8 assert "embedding_model" in enhanced[0].metadata - + # Check chunk embeddings assert "chunk_embeddings" in enhanced[0].metadata assert len(enhanced[0].metadata["chunk_embeddings"]) == 2 assert enhanced[0].metadata["chunk_embeddings"][0]["index"] == 0 - + # Second doc has no chunks assert "content_embedding" in enhanced[1].metadata - assert "chunk_embeddings" not in enhanced[1].metadata \ No newline at end of file + assert "chunk_embeddings" not in enhanced[1].metadata diff --git a/tests/test_enhance.py b/tests/test_enhance.py index ded1023..cc71fe6 100644 --- a/tests/test_enhance.py +++ b/tests/test_enhance.py @@ -2,25 +2,25 @@ import json import pytest -from unittest.mock import Mock, patch, MagicMock - -from contextframe import FrameRecord, FrameDataset +from contextframe import FrameDataset, FrameRecord from contextframe.enhance import ( ContextEnhancer, EnhancementResult, EnhancementTools, + build_enhancement_prompt, get_prompt_template, list_available_prompts, - build_enhancement_prompt, ) +from unittest.mock import MagicMock, Mock, patch class TestContextEnhancer: """Test the ContextEnhancer class.""" - + @patch('contextframe.enhance.base.llm.call') def test_enhance_context(self, mock_llm_call): """Test enhancing context field with structured output.""" + # Create a mock function that simulates the decorator behavior def mock_decorator(provider, model, response_model, **kwargs): def decorator(func): @@ -29,72 +29,81 @@ def wrapper(messages): mock_response = Mock() mock_response.context = "This document explains RAG architecture." return mock_response + return wrapper + return decorator - + mock_llm_call.side_effect = mock_decorator - + enhancer = ContextEnhancer(provider="openai", model="gpt-4o-mini") result = enhancer.enhance_context( content="RAG combines LLMs with retrieval...", - purpose="understanding AI systems" + purpose="understanding AI systems", ) - + assert result == "This document explains RAG architecture." - + @patch('contextframe.enhance.base.llm.call') def test_enhance_tags(self, mock_llm_call): """Test extracting tags with structured output.""" + def mock_decorator(provider, model, response_model, **kwargs): def decorator(func): def wrapper(messages): mock_response = Mock() mock_response.tags = ["RAG", "LLM", "retrieval", "embeddings"] return mock_response + return wrapper + return decorator - + mock_llm_call.side_effect = mock_decorator - + enhancer = ContextEnhancer() result = enhancer.enhance_tags( content="RAG architecture combines retrieval with LLM generation", tag_types="technologies", - max_tags=5 + max_tags=5, ) - + assert isinstance(result, list) assert len(result) == 4 assert "RAG" in result assert "LLM" in result - + @patch('contextframe.enhance.base.llm.call') def test_enhance_custom_metadata(self, mock_llm_call): """Test extracting custom metadata with structured output.""" + def mock_decorator(provider, model, response_model, **kwargs): def decorator(func): def wrapper(messages): mock_response = Mock() mock_response.metadata = {"complexity": 3, "topics": ["RAG", "LLM"]} return mock_response + return wrapper + return decorator - + mock_llm_call.side_effect = mock_decorator - + enhancer = ContextEnhancer() result = enhancer.enhance_custom_metadata( content="RAG architecture guide", - schema_prompt="Extract complexity level and main topics" + schema_prompt="Extract complexity level and main topics", ) - + assert isinstance(result, dict) assert result["complexity"] == 3 assert "RAG" in result["topics"] - - @patch('contextframe.enhance.base.llm.call') + + @patch('contextframe.enhance.base.llm.call') def test_enhance_relationships(self, mock_llm_call): """Test finding relationships with structured output.""" + def mock_decorator(provider, model, response_model, **kwargs): def decorator(func): def wrapper(messages): @@ -107,29 +116,30 @@ def wrapper(messages): rel1.target_id = "123" mock_response.relationships = [rel1] return mock_response + return wrapper + return decorator - + mock_llm_call.side_effect = mock_decorator - + enhancer = ContextEnhancer() relationships = enhancer.enhance_relationships( source_content="Basic RAG introduction", source_title="Basic RAG", - candidates=[ - {"title": "Advanced RAG", "summary": "Advanced techniques"} - ], - max_relationships=5 + candidates=[{"title": "Advanced RAG", "summary": "Advanced techniques"}], + max_relationships=5, ) - + assert len(relationships) == 1 assert relationships[0]["type"] == "related" assert relationships[0]["title"] == "Advanced RAG" assert relationships[0]["description"] == "Builds on basic concepts" - + @patch('contextframe.enhance.base.llm.call') def test_enhance_field_generic(self, mock_llm_call): """Test generic field enhancement.""" + def mock_decorator(provider, model, response_model, **kwargs): def decorator(func): def wrapper(messages): @@ -142,34 +152,32 @@ def wrapper(messages): mock_response = Mock() mock_response.tags = ["tag1", "tag2"] return mock_response + return wrapper + return decorator - + mock_llm_call.side_effect = mock_decorator - + enhancer = ContextEnhancer() - + # Test context field result = enhancer.enhance_field( - content="Test content", - field_name="context", - prompt="Add context" + content="Test content", field_name="context", prompt="Add context" ) assert result == "Enhanced context" - + # Test tags field result = enhancer.enhance_field( - content="Test content", - field_name="tags", - prompt="Extract tags" + content="Test content", field_name="tags", prompt="Extract tags" ) assert result == ["tag1", "tag2"] - + @patch('contextframe.enhance.base.llm.call') def test_enhance_document(self, mock_llm_call): """Test enhancing a full document.""" call_count = 0 - + def mock_decorator(provider, model, response_model, **kwargs): def decorator(func): def wrapper(messages): @@ -180,124 +188,135 @@ def wrapper(messages): mock_response.context = "A guide to RAG architecture" return mock_response else: # Second call for tags - mock_response = Mock() + mock_response = Mock() mock_response.tags = ["RAG", "LLM", "retrieval"] return mock_response + return wrapper + return decorator - + mock_llm_call.side_effect = mock_decorator - + frame = FrameRecord.create( - title="RAG Guide", - content="Content about RAG...", - uri="test.md" + title="RAG Guide", content="Content about RAG...", uri="test.md" ) - + enhancer = ContextEnhancer() enhanced = enhancer.enhance_document( frame, enhancements={ "context": "Summarize the document", - "tags": "Extract key topics" - } + "tags": "Extract key topics", + }, ) - + assert enhanced.metadata.get("context") == "A guide to RAG architecture" assert enhanced.metadata.get("tags") == ["RAG", "LLM", "retrieval"] - + def test_field_has_value(self): """Test checking if fields have values.""" enhancer = ContextEnhancer() - + frame = FrameRecord.create(title="Test", uri="test.md") assert not enhancer._field_has_value(frame, "context") - + frame.metadata["context"] = "Some context" assert enhancer._field_has_value(frame, "context") - + frame.metadata["tags"] = [] assert not enhancer._field_has_value(frame, "tags") - + frame.metadata["tags"] = ["tag1"] assert enhancer._field_has_value(frame, "tags") class TestEnhancementTools: """Test the MCP-compatible tools.""" - + @patch('contextframe.enhance.base.llm.call') def test_enhance_context_tool(self, mock_llm_call): """Test the enhance_context tool.""" + def mock_decorator(provider, model, response_model, **kwargs): def decorator(func): def wrapper(messages): mock_response = Mock() mock_response.context = "Document about testing frameworks" return mock_response + return wrapper + return decorator - + mock_llm_call.side_effect = mock_decorator - + enhancer = ContextEnhancer() tools = EnhancementTools(enhancer) - + result = tools.enhance_context( content="pytest is a testing framework...", - purpose="understanding Python testing" + purpose="understanding Python testing", ) - + assert result == "Document about testing frameworks" - + @patch('contextframe.enhance.base.llm.call') def test_extract_metadata_tool(self, mock_llm_call): """Test the extract_metadata tool.""" + def mock_decorator(provider, model, response_model, **kwargs): def decorator(func): def wrapper(messages): mock_response = Mock() - mock_response.metadata = {"language": "python", "framework": "pytest"} + mock_response.metadata = { + "language": "python", + "framework": "pytest", + } return mock_response + return wrapper + return decorator - + mock_llm_call.side_effect = mock_decorator - + enhancer = ContextEnhancer() tools = EnhancementTools(enhancer) - + result = tools.extract_metadata( - content="pytest tutorial...", - schema="Extract language and framework" + content="pytest tutorial...", schema="Extract language and framework" ) - + assert result["language"] == "python" assert result["framework"] == "pytest" - + @patch('contextframe.enhance.base.llm.call') def test_generate_tags_tool(self, mock_llm_call): """Test the generate_tags tool.""" + def mock_decorator(provider, model, response_model, **kwargs): def decorator(func): def wrapper(messages): mock_response = Mock() mock_response.tags = ["python", "testing", "pytest", "TDD"] return mock_response + return wrapper + return decorator - + mock_llm_call.side_effect = mock_decorator - + enhancer = ContextEnhancer() tools = EnhancementTools(enhancer) - + result = tools.generate_tags( content="Test-driven development with pytest", tag_types="technologies and methodologies", - max_tags=5 + max_tags=5, ) - + assert len(result) == 4 assert "python" in result assert "pytest" in result @@ -305,39 +324,39 @@ def wrapper(messages): class TestPromptTemplates: """Test prompt template functionality.""" - + def test_get_prompt_template(self): """Test retrieving prompt templates.""" template = get_prompt_template("context", "technical_summary") assert "technical problem" in template assert "{content}" in template - + template = get_prompt_template("tags", "technical_tags") assert "Programming languages" in template assert "{content}" in template - + def test_list_available_prompts(self): """Test listing available prompts.""" prompts = list_available_prompts() - + assert isinstance(prompts, dict) assert "context" in prompts assert "tags" in prompts assert "metadata" in prompts assert "relationships" in prompts - + assert "technical_summary" in prompts["context"] assert "technical_tags" in prompts["tags"] - + def test_build_enhancement_prompt(self): """Test building custom prompts.""" prompt = build_enhancement_prompt( task="Extract key information", fields=["summary", "technologies"], context="For a technical blog", - examples="summary: Brief overview\ntechnologies: Python, FastAPI" + examples="summary: Brief overview\ntechnologies: Python, FastAPI", ) - + assert "Extract key information" in prompt assert "- summary" in prompt assert "- technologies" in prompt @@ -348,82 +367,81 @@ def test_build_enhancement_prompt(self): class TestFrameDatasetEnhance: """Test FrameDataset.enhance() integration.""" - + @patch('contextframe.enhance.base.llm.call') def test_dataset_enhance_method(self, mock_llm_call): """Test the convenience enhance method on FrameDataset.""" import shutil from pathlib import Path - + # Clean up any existing test dataset test_path = Path("test_enhance.lance") if test_path.exists(): shutil.rmtree(test_path) - + def mock_decorator(provider, model, response_model, **kwargs): def decorator(func): def wrapper(messages): mock_response = Mock() mock_response.context = "Test context" return mock_response + return wrapper + return decorator - + mock_llm_call.side_effect = mock_decorator - + # Create dataset with test data dataset = FrameDataset.create("test_enhance.lance", overwrite=True) frame = FrameRecord.create( - title="Test Document", - content="Test content", - uri="test.md" + title="Test Document", content="Test content", uri="test.md" ) dataset.add(frame) - + # Use the enhance method - results = dataset.enhance({ - "context": "Add a test context" - }) - + results = dataset.enhance({"context": "Add a test context"}) + # Check that enhancement results show success assert len(results) == 1 assert results[0].success assert results[0].field_name == "context" assert results[0].value == "Test context" - + # Verify enhancement by reading back from dataset # Since we're using Lance's update method, we need to re-read the record enhanced_record = dataset.get_by_uuid(frame.uuid) assert enhanced_record is not None assert enhanced_record.metadata.get("context") == "Test context" - + # Cleanup import shutil + shutil.rmtree("test_enhance.lance") class TestEnhancementIntegration: """Integration tests with mocked LLM.""" - + @patch('contextframe.enhance.base.llm.call') def test_full_enhancement_workflow(self, mock_llm_call): """Test complete enhancement workflow.""" import shutil from pathlib import Path - + # Clean up any existing test dataset test_path = Path("integration_test.lance") if test_path.exists(): shutil.rmtree(test_path) - + call_count = 0 - + def mock_decorator(provider, model, response_model, **kwargs): def decorator(func): def wrapper(messages): nonlocal call_count call_count += 1 - + if response_model.__name__ == "ContextResponse": mock_response = Mock() mock_response.context = "This document teaches RAG architecture" @@ -434,24 +452,27 @@ def wrapper(messages): return mock_response elif response_model.__name__ == "CustomMetadataResponse": mock_response = Mock() - mock_response.metadata = {"complexity": 4, "audience": "developers"} + mock_response.metadata = { + "complexity": 4, + "audience": "developers", + } return mock_response + return wrapper + return decorator - + mock_llm_call.side_effect = mock_decorator - + # Create dataset dataset = FrameDataset.create("integration_test.lance", overwrite=True) frames = [ FrameRecord.create( - title="RAG Basics", - content="Introduction to RAG...", - uri="doc1.md" + title="RAG Basics", content="Introduction to RAG...", uri="doc1.md" ), ] dataset.add_many(frames) - + # Enhance enhancer = ContextEnhancer() enhancer.enhance_dataset( @@ -459,24 +480,28 @@ def wrapper(messages): enhancements={ "context": "Explain what this teaches", "tags": "Extract topics", - "custom_metadata": "Extract complexity and audience as JSON" - } + "custom_metadata": "Extract complexity and audience as JSON", + }, ) - + # Verify by reading back the enhanced record enhanced_record = dataset.get_by_uuid(frames[0].uuid) assert enhanced_record is not None - - assert enhanced_record.metadata.get("context") == "This document teaches RAG architecture" - + + assert ( + enhanced_record.metadata.get("context") + == "This document teaches RAG architecture" + ) + tags = enhanced_record.metadata.get("tags", []) assert "RAG" in tags - + # Custom metadata should be a dict with string values custom_meta = enhanced_record.metadata.get("custom_metadata", {}) assert custom_meta["complexity"] == "4" assert custom_meta["audience"] == "developers" - + # Cleanup import shutil - shutil.rmtree("integration_test.lance") \ No newline at end of file + + shutil.rmtree("integration_test.lance") diff --git a/tests/test_extract.py b/tests/test_extract.py index 0435ac4..51002b0 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -2,13 +2,9 @@ import csv import json -import tempfile -from pathlib import Path -from unittest.mock import Mock, patch - import pytest +import tempfile import yaml - from contextframe.extract import ( BatchExtractor, CSVExtractor, @@ -21,36 +17,38 @@ ) from contextframe.extract.base import ExtractorRegistry from contextframe.extract.chunking import split_extraction_results +from pathlib import Path +from unittest.mock import Mock, patch class TestExtractionResult: """Test the ExtractionResult dataclass.""" - + def test_success_property(self): """Test the success property.""" # Successful extraction result = ExtractionResult(content="test content") assert result.success is True - + # Failed extraction with error result = ExtractionResult(content="", error="Failed to extract") assert result.success is False - + # Empty content without error result = ExtractionResult(content="") assert result.success is False - + def test_to_frame_record_kwargs(self): """Test conversion to FrameRecord kwargs.""" result = ExtractionResult( content="test content", metadata={"key": "value", "title": "Test Title"}, source="/path/to/file.txt", - format="text" + format="text", ) - + kwargs = result.to_frame_record_kwargs() - + assert kwargs["content"] == "test content" assert kwargs["title"] == "Test Title" assert kwargs.get("source_file") == "/path/to/file.txt" @@ -61,85 +59,85 @@ def test_to_frame_record_kwargs(self): class TestTextFileExtractor: """Test the TextFileExtractor.""" - + def test_can_extract(self, tmp_path): """Test file type detection.""" extractor = TextFileExtractor() - + # Should handle .txt files txt_file = tmp_path / "test.txt" txt_file.touch() assert extractor.can_extract(txt_file) is True - + # Should handle .log files log_file = tmp_path / "test.log" log_file.touch() assert extractor.can_extract(log_file) is True - + # Should not handle other files other_file = tmp_path / "test.pdf" other_file.touch() assert extractor.can_extract(other_file) is False - + def test_extract(self, tmp_path): """Test content extraction.""" extractor = TextFileExtractor() - + # Create test file test_file = tmp_path / "test.txt" test_content = "This is a test file.\nWith multiple lines." test_file.write_text(test_content) - + result = extractor.extract(test_file) - + assert result.success is True assert result.content == test_content assert result.metadata["filename"] == "test.txt" assert result.metadata["size"] == len(test_content) assert result.metadata["encoding"] == "utf-8" assert result.format == "text" - + def test_extract_nonexistent_file(self): """Test extraction of non-existent file.""" extractor = TextFileExtractor() - + result = extractor.extract("/nonexistent/file.txt") - + assert result.success is False assert "does not exist" in result.error class TestMarkdownExtractor: """Test the MarkdownExtractor.""" - + def test_can_extract(self, tmp_path): """Test file type detection.""" extractor = MarkdownExtractor() - + # Should handle various markdown extensions for ext in [".md", ".markdown", ".mdown", ".mkd"]: md_file = tmp_path / f"test{ext}" md_file.touch() assert extractor.can_extract(md_file) is True - + def test_extract_without_frontmatter(self, tmp_path): """Test extraction of markdown without frontmatter.""" extractor = MarkdownExtractor() - + test_file = tmp_path / "test.md" test_content = "# Title\n\nThis is markdown content." test_file.write_text(test_content) - + result = extractor.extract(test_file) - + assert result.success is True assert result.content == test_content assert result.format == "markdown" - + def test_extract_with_frontmatter(self, tmp_path): """Test extraction of markdown with YAML frontmatter.""" extractor = MarkdownExtractor() - + test_file = tmp_path / "test.md" frontmatter_content = """--- title: Test Document @@ -153,19 +151,19 @@ def test_extract_with_frontmatter(self, tmp_path): This is the markdown content.""" test_file.write_text(frontmatter_content) - + result = extractor.extract(test_file) - + assert result.success is True assert result.content.strip() == "# Title\n\nThis is the markdown content." assert result.metadata["title"] == "Test Document" assert result.metadata["author"] == "Test Author" assert result.metadata["tags"] == ["test", "document"] - + def test_extract_with_invalid_frontmatter(self, tmp_path): """Test extraction with invalid YAML frontmatter.""" extractor = MarkdownExtractor() - + test_file = tmp_path / "test.md" content = """--- invalid: yaml: content @@ -173,25 +171,25 @@ def test_extract_with_invalid_frontmatter(self, tmp_path): # Title""" test_file.write_text(content) - + result = extractor.extract(test_file) - + assert result.success is True assert len(result.warnings) > 0 assert "Failed to parse frontmatter" in result.warnings[0] - + def test_extract_from_string(self): """Test extraction from string content.""" extractor = MarkdownExtractor() - + content = """--- title: String Test --- Content from string""" - + result = extractor.extract_from_string(content) - + assert result.success is True assert result.content.strip() == "Content from string" assert result.metadata["title"] == "String Test" @@ -199,88 +197,88 @@ def test_extract_from_string(self): class TestJSONExtractor: """Test the JSONExtractor.""" - + def test_extract_json(self, tmp_path): """Test extraction of regular JSON file.""" extractor = JSONExtractor() - + test_file = tmp_path / "test.json" test_data = {"key": "value", "nested": {"field": "content"}} test_file.write_text(json.dumps(test_data, indent=2)) - + result = extractor.extract(test_file) - + assert result.success is True assert json.loads(result.content) == test_data assert result.metadata["json_data"] == test_data assert result.format == "json" - + def test_extract_jsonl(self, tmp_path): """Test extraction of JSON Lines file.""" extractor = JSONExtractor() - + test_file = tmp_path / "test.jsonl" lines = [ {"id": 1, "text": "First line"}, {"id": 2, "text": "Second line"}, ] test_file.write_text("\n".join(json.dumps(line) for line in lines)) - + result = extractor.extract(test_file) - + assert result.success is True assert result.metadata["json_data"] == lines - + def test_extract_text_fields(self, tmp_path): """Test extraction of specific text fields.""" extractor = JSONExtractor() - + test_file = tmp_path / "test.json" test_data = { "title": "Test Title", "description": "Test Description", "metadata": {"author": "Test Author"}, - "content": "Main content here" + "content": "Main content here", } test_file.write_text(json.dumps(test_data)) - + result = extractor.extract(test_file, extract_text_fields=["title", "content"]) - + assert result.success is True assert "Test Title" in result.content assert "Main content here" in result.content assert "Test Description" not in result.content - + def test_extract_invalid_json(self, tmp_path): """Test extraction of invalid JSON.""" extractor = JSONExtractor() - + test_file = tmp_path / "test.json" test_file.write_text("{invalid json}") - + result = extractor.extract(test_file) - + assert result.success is False assert "Invalid JSON" in result.error class TestYAMLExtractor: """Test the YAMLExtractor.""" - + def test_extract_yaml(self, tmp_path): """Test extraction of YAML file.""" extractor = YAMLExtractor() - + test_file = tmp_path / "test.yaml" test_data = { "key": "value", "list": ["item1", "item2"], - "nested": {"field": "content"} + "nested": {"field": "content"}, } test_file.write_text(yaml.dump(test_data)) - + result = extractor.extract(test_file) - + assert result.success is True assert result.metadata["yaml_data"] == test_data assert result.format == "yaml" @@ -288,81 +286,81 @@ def test_extract_yaml(self, tmp_path): class TestCSVExtractor: """Test the CSVExtractor.""" - + def test_extract_csv(self, tmp_path): """Test extraction of CSV file.""" extractor = CSVExtractor() - + test_file = tmp_path / "test.csv" rows = [ ["Name", "Age", "City"], ["Alice", "30", "New York"], ["Bob", "25", "London"], ] - + with open(test_file, "w", newline="") as f: writer = csv.writer(f) writer.writerows(rows) - + result = extractor.extract(test_file) - + assert result.success is True assert "Alice, 30, New York" in result.content assert "Bob, 25, London" in result.content assert result.metadata["headers"] == ["Name", "Age", "City"] assert result.metadata["row_count"] == 3 assert result.metadata["csv_data"] == rows - + def test_extract_specific_columns(self, tmp_path): """Test extraction of specific columns.""" extractor = CSVExtractor() - + test_file = tmp_path / "test.csv" rows = [ ["Name", "Age", "City", "Email"], ["Alice", "30", "New York", "alice@example.com"], ["Bob", "25", "London", "bob@example.com"], ] - + with open(test_file, "w", newline="") as f: writer = csv.writer(f) writer.writerows(rows) - + # Extract by column names result = extractor.extract(test_file, text_columns=["Name", "City"]) - + assert result.success is True assert "Alice, New York" in result.content assert "30" not in result.content assert "alice@example.com" not in result.content - + # Extract by column indices result = extractor.extract(test_file, text_columns=[0, 2]) - + assert result.success is True assert "Alice, New York" in result.content class TestExtractorRegistry: """Test the ExtractorRegistry.""" - + def test_registry_operations(self): """Test registry registration and lookup.""" registry = ExtractorRegistry() - + # Create a mock extractor mock_extractor = Mock(spec=TextExtractor) mock_extractor.can_extract.return_value = True mock_extractor.format_name = "test" mock_extractor.supported_extensions = [".test"] - + # Register extractor registry.register(mock_extractor) - + # Find extractor found = registry.find_extractor("test.test") assert found == mock_extractor - + # Get supported formats formats = registry.get_supported_formats() assert "test" in formats @@ -371,7 +369,7 @@ def test_registry_operations(self): class TestBatchExtractor: """Test the BatchExtractor.""" - + def test_extract_files(self, tmp_path): """Test batch extraction of multiple files.""" # Create test files @@ -380,33 +378,33 @@ def test_extract_files(self, tmp_path): file = tmp_path / f"test{i}.txt" file.write_text(f"Content {i}") files.append(file) - + batch = BatchExtractor() results = batch.extract_files(files) - + assert len(results) == 3 for i, result in enumerate(results): assert result.success is True assert result.content == f"Content {i}" - + def test_extract_directory(self, tmp_path): """Test extraction of all files in directory.""" # Create test files (tmp_path / "file1.txt").write_text("Text 1") (tmp_path / "file2.md").write_text("# Markdown") (tmp_path / "data.json").write_text('{"key": "value"}') - + batch = BatchExtractor() results = batch.extract_directory(tmp_path) - + assert len(results) == 3 - + # Check that different extractors were used formats = {r.format for r in results if r.success} assert "text" in formats assert "markdown" in formats assert "json" in formats - + def test_progress_callback(self, tmp_path): """Test progress callback functionality.""" # Create test files @@ -415,20 +413,20 @@ def test_progress_callback(self, tmp_path): file = tmp_path / f"test{i}.txt" file.write_text(f"Content {i}") files.append(file) - + # Track progress progress_calls = [] - + def progress_callback(current, total, file_path): progress_calls.append((current, total, file_path)) - + batch = BatchExtractor(progress_callback=progress_callback) batch.extract_files(files) - + assert len(progress_calls) == 3 assert all(call[1] == 3 for call in progress_calls) # Total is 3 assert progress_calls[-1][0] == 3 # Last call shows 3/3 - + @pytest.mark.asyncio async def test_extract_files_async(self, tmp_path): """Test async batch extraction.""" @@ -438,10 +436,10 @@ async def test_extract_files_async(self, tmp_path): file = tmp_path / f"test{i}.txt" file.write_text(f"Content {i}") files.append(file) - + batch = BatchExtractor() results = await batch.extract_files_async(files) - + assert len(results) == 3 for result in results: assert result.success is True @@ -449,7 +447,7 @@ async def test_extract_files_async(self, tmp_path): class TestChunking: """Test chunking functionality.""" - + @patch("contextframe.extract.chunking.semantic_splitter") def test_split_extraction_results(self, mock_splitter): """Test splitting of extraction results.""" @@ -459,34 +457,30 @@ def test_split_extraction_results(self, mock_splitter): (0, "Chunk 2 from doc 1"), (1, "Chunk 1 from doc 2"), ] - + # Create test results results = [ ExtractionResult( - content="Long content 1", - metadata={"doc": 1}, - source="file1.txt" + content="Long content 1", metadata={"doc": 1}, source="file1.txt" ), ExtractionResult( - content="Long content 2", - metadata={"doc": 2}, - source="file2.txt" + content="Long content 2", metadata={"doc": 2}, source="file2.txt" ), ] - + chunked = split_extraction_results(results, chunk_size=100) - + # Should have 3 chunks total assert len(chunked) == 3 - + # Check metadata preservation and chunk info chunk1 = chunked[0] assert chunk1.metadata["doc"] == 1 assert chunk1.metadata["chunk_index"] == 0 assert chunk1.metadata["chunk_count"] == 2 assert chunk1.content == "Chunk 1 from doc 1" - + chunk3 = chunked[2] assert chunk3.metadata["doc"] == 2 assert chunk3.metadata["chunk_index"] == 0 - assert chunk3.metadata["chunk_count"] == 1 \ No newline at end of file + assert chunk3.metadata["chunk_count"] == 1 diff --git a/tests/test_extract_integration.py b/tests/test_extract_integration.py index b8a82f1..ef185eb 100644 --- a/tests/test_extract_integration.py +++ b/tests/test_extract_integration.py @@ -1,11 +1,8 @@ """Integration tests for extraction module with ContextFrame data model.""" import json -import tempfile -from pathlib import Path - import pytest - +import tempfile from contextframe import FrameDataset, FrameRecord from contextframe.extract import ( BatchExtractor, @@ -14,11 +11,12 @@ TextFileExtractor, ) from contextframe.extract.chunking import split_extraction_results +from pathlib import Path class TestExtractionToFrameRecord: """Test conversion from ExtractionResult to FrameRecord.""" - + def test_extraction_result_to_frame_record(self): """Test that ExtractionResult converts to valid FrameRecord.""" # Create an extraction result @@ -27,28 +25,28 @@ def test_extraction_result_to_frame_record(self): metadata={ "title": "Test Document", "author": "Test Author", - "custom_field": "custom_value" + "custom_field": "custom_value", }, source="/path/to/file.txt", - format="text" + format="text", ) - + # Convert to FrameRecord kwargs kwargs = result.to_frame_record_kwargs() - + # Create FrameRecord directly from kwargs - frame = FrameRecord.create( - record_type="document", - **kwargs - ) - + frame = FrameRecord.create(record_type="document", **kwargs) + assert frame.content == "Test content" assert frame.title == "Test Document" assert frame.metadata.get("author") == "Test Author" assert frame.metadata.get("source_type") == "text" assert frame.metadata.get("source_file") == "/path/to/file.txt" - assert frame.metadata.get("custom_metadata", {}).get("custom_field") == "custom_value" - + assert ( + frame.metadata.get("custom_metadata", {}).get("custom_field") + == "custom_value" + ) + def test_markdown_extraction_to_frame_record(self, tmp_path): """Test Markdown extraction creates valid FrameRecord with frontmatter.""" # Create test markdown file @@ -65,18 +63,15 @@ def test_markdown_extraction_to_frame_record(self, tmp_path): This is a test document for integration testing.""" md_file.write_text(md_content) - + # Extract extractor = MarkdownExtractor() result = extractor.extract(md_file) - + # Convert to FrameRecord kwargs = result.to_frame_record_kwargs() - frame = FrameRecord.create( - record_type="document", - **kwargs - ) - + frame = FrameRecord.create(record_type="document", **kwargs) + # Verify assert frame.content.strip().startswith("# Test Document") assert frame.title == "Integration Test" @@ -91,41 +86,41 @@ def test_markdown_extraction_to_frame_record(self, tmp_path): class TestExtractionToFrameDataset: """Test extraction workflow with FrameDataset.""" - + def test_single_file_extraction_to_dataset(self, tmp_path): """Test extracting a single file and adding to dataset.""" # Create test file test_file = tmp_path / "test.txt" test_file.write_text("This is test content for dataset integration.") - + # Create dataset dataset_path = tmp_path / "test.lance" dataset = FrameDataset.create(dataset_path) - + # Extract file extractor = TextFileExtractor() result = extractor.extract(test_file) - + # Convert and add to dataset kwargs = result.to_frame_record_kwargs() - frame = FrameRecord.create( - record_type="document", - **kwargs - ) - + frame = FrameRecord.create(record_type="document", **kwargs) + dataset.add(frame) - + # Verify the frame was added by checking dataset stats assert len(dataset._dataset) == 1 - + # Verify we can query the dataset results = dataset.scanner().to_table() assert len(results) == 1 - + # Check the content directly from the table - assert results["text_content"][0].as_py() == "This is test content for dataset integration." + assert ( + results["text_content"][0].as_py() + == "This is test content for dataset integration." + ) assert results["source_type"][0].as_py() == "text" - + def test_batch_extraction_to_dataset(self, tmp_path): """Test batch extraction workflow.""" # Create test files @@ -134,7 +129,7 @@ def test_batch_extraction_to_dataset(self, tmp_path): file = tmp_path / f"doc{i}.txt" file.write_text(f"Document {i} content") files.append(file) - + # Create markdown file md_file = tmp_path / "readme.md" md_file.write_text("""--- @@ -145,63 +140,60 @@ def test_batch_extraction_to_dataset(self, tmp_path): This is the readme file.""") files.append(md_file) - + # Create dataset dataset_path = tmp_path / "batch_test.lance" dataset = FrameDataset.create(dataset_path) - + # Batch extract batch_extractor = BatchExtractor() results = batch_extractor.extract_files(files) - + # Convert all to FrameRecords frames = [] for result in results: if result.success: kwargs = result.to_frame_record_kwargs() - frame = FrameRecord.create( - record_type="document", - **kwargs - ) + frame = FrameRecord.create(record_type="document", **kwargs) frames.append(frame) - + # Add to dataset dataset.add_many(frames) - + # Verify assert len(dataset._dataset) == 4 - + # Query and verify content types results = dataset.scanner().to_table() source_types = [st.as_py() for st in results["source_type"]] - + # Check text files text_count = sum(1 for st in source_types if st == "text") assert text_count == 3 - + # Check markdown file md_count = sum(1 for st in source_types if st == "markdown") assert md_count == 1 - + # Check title of markdown file titles = [t.as_py() for t in results["title"]] assert "Readme" in titles - + def test_extraction_with_chunking_to_dataset(self, tmp_path): """Test extraction with chunking creates valid dataset entries.""" # Create a longer document long_file = tmp_path / "long_doc.txt" long_content = " ".join([f"Sentence {i}." for i in range(100)]) long_file.write_text(long_content) - + # Create dataset dataset_path = tmp_path / "chunked_test.lance" dataset = FrameDataset.create(dataset_path) - + # Extract extractor = TextFileExtractor() result = extractor.extract(long_file) - + # Chunk the results (mocking since we don't have semantic-text-splitter installed) # In real usage, this would use semantic_splitter def mock_splitter(texts, chunk_size=100, chunk_overlap=20): @@ -210,17 +202,15 @@ def mock_splitter(texts, chunk_size=100, chunk_overlap=20): # Simple character-based chunking for testing overlap = chunk_overlap if chunk_overlap is not None else 20 for i in range(0, len(text), chunk_size - overlap): - chunk = text[i:i + chunk_size] + chunk = text[i : i + chunk_size] if chunk: chunks.append((idx, chunk)) return chunks - + chunked_results = split_extraction_results( - [result], - chunk_size=100, - splitter_fn=mock_splitter + [result], chunk_size=100, splitter_fn=mock_splitter ) - + # Convert chunks to FrameRecords frames = [] for chunk_result in chunked_results: @@ -229,23 +219,20 @@ def mock_splitter(texts, chunk_size=100, chunk_overlap=20): chunk_idx = kwargs.get('custom_metadata', {}).get('chunk_index', 0) # Override title with chunk info kwargs['title'] = f"Chunk {chunk_idx}" - frame = FrameRecord.create( - record_type="document", - **kwargs - ) + frame = FrameRecord.create(record_type="document", **kwargs) frames.append(frame) - + # Add to dataset dataset.add_many(frames) - + # Verify num_chunks = len(dataset._dataset) assert num_chunks > 1 # Should have multiple chunks - + # Check chunk metadata directly from table results = dataset.scanner().to_table() custom_metadata_col = results["custom_metadata"] - + # Verify chunk metadata for i in range(len(results)): custom_meta_list = custom_metadata_col[i].as_py() @@ -259,32 +246,29 @@ def mock_splitter(texts, chunk_size=100, chunk_overlap=20): class TestExtractionMetadataSchema: """Test that extraction metadata follows ContextFrame schema.""" - + def test_extraction_metadata_types(self, tmp_path): """Test that metadata types are compatible with Lance schema.""" # Create test file test_file = tmp_path / "test.txt" test_file.write_text("Test content") - + # Extract extractor = TextFileExtractor() result = extractor.extract(test_file) - + # Check metadata types from extraction result assert isinstance(result.metadata["filename"], str) assert isinstance(result.metadata["size"], int) assert isinstance(result.metadata["encoding"], str) - + # Convert to FrameRecord to ensure compatibility kwargs = result.to_frame_record_kwargs() - frame = FrameRecord.create( - record_type="document", - **kwargs - ) - + frame = FrameRecord.create(record_type="document", **kwargs) + # Should not raise any schema validation errors assert frame.metadata is not None - + def test_custom_metadata_preserved(self, tmp_path): """Test that custom metadata is preserved through the pipeline.""" # Create JSON file with custom fields @@ -292,84 +276,82 @@ def test_custom_metadata_preserved(self, tmp_path): json_data = { "title": "Test Data", "custom_field": "custom_value", - "nested": { - "field": "value" - }, - "tags": ["tag1", "tag2"] + "nested": {"field": "value"}, + "tags": ["tag1", "tag2"], } json_file.write_text(json.dumps(json_data)) - + # Extract from contextframe.extract import JSONExtractor + extractor = JSONExtractor() result = extractor.extract(json_file) - + # The JSON data should be in metadata assert result.metadata["json_data"] == json_data - + # Create FrameRecord with additional custom metadata kwargs = result.to_frame_record_kwargs() # Add custom fields directly to kwargs (they'll be moved to custom_metadata) kwargs["processing_version"] = "1.0" kwargs["department"] = "Engineering" - - frame = FrameRecord.create( - record_type="document", - **kwargs - ) - + + frame = FrameRecord.create(record_type="document", **kwargs) + # Verify all metadata is preserved custom_meta = frame.metadata.get("custom_metadata", {}) # json_data was converted to JSON string, so parse it back assert json.loads(custom_meta.get("json_data")) == json_data # Check if the extra fields were added at the top level or custom_metadata - assert frame.metadata.get("processing_version") == "1.0" or custom_meta.get("processing_version") == "1.0" - assert frame.metadata.get("department") == "Engineering" or custom_meta.get("department") == "Engineering" + assert ( + frame.metadata.get("processing_version") == "1.0" + or custom_meta.get("processing_version") == "1.0" + ) + assert ( + frame.metadata.get("department") == "Engineering" + or custom_meta.get("department") == "Engineering" + ) class TestExtractionErrorHandling: """Test error handling in extraction to dataset pipeline.""" - + def test_failed_extraction_handling(self, tmp_path): """Test handling of failed extractions in batch processing.""" # Create mix of valid and invalid files valid_file = tmp_path / "valid.txt" valid_file.write_text("Valid content") - + # Non-existent file invalid_file = tmp_path / "nonexistent.txt" - + # Create dataset dataset_path = tmp_path / "error_test.lance" dataset = FrameDataset.create(dataset_path) - + # Batch extract with error handling batch_extractor = BatchExtractor() results = batch_extractor.extract_files( - [valid_file, invalid_file], - skip_errors=True + [valid_file, invalid_file], skip_errors=True ) - + # Process results frames = [] errors = [] - + for result in results: if result.success: kwargs = result.to_frame_record_kwargs() - frame = FrameRecord.create( - record_type="document", - **kwargs - ) + frame = FrameRecord.create(record_type="document", **kwargs) frames.append(frame) else: errors.append(result) - + # Add successful extractions to dataset if frames: dataset.add_many(frames) - + # Verify assert len(dataset._dataset) == 1 # Only valid file assert len(errors) == 1 # One failed extraction - assert "does not exist" in errors[0].error \ No newline at end of file + assert "does not exist" in errors[0].error diff --git a/tests/test_io.py b/tests/test_io.py index 50d5f49..bb01ac1 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -373,13 +373,13 @@ def test_export_with_mermaid_diagram(temp_dataset_with_frameset): assert "```mermaid" in content assert "graph TD" in content assert 'FS["Test FrameSet
FrameSet"]' in content - + # Check that relationships are shown assert 'F1["Document 1"]' in content assert 'F2["Document 2"]' in content assert "FS -->|contains| F1" in content assert "FS -->|contains| F2" in content - + # Check legend assert "**Relationship Types:**" in content assert "- `contains`: Direct inclusion in the frameset" in content diff --git a/tests/test_litellm_provider.py b/tests/test_litellm_provider.py index c07f540..8c0550c 100644 --- a/tests/test_litellm_provider.py +++ b/tests/test_litellm_provider.py @@ -2,14 +2,13 @@ import os import pytest -from unittest.mock import Mock, patch - from contextframe.embed.litellm_provider import LiteLLMProvider +from unittest.mock import Mock, patch class TestLiteLLMProviderEnhanced: """Test enhanced LiteLLM provider features.""" - + def test_provider_detection_comprehensive(self): """Test provider detection for all supported patterns.""" test_cases = [ @@ -27,7 +26,6 @@ def test_provider_detection_comprehensive(self): ("ollama/all-minilm", "ollama"), ("mistral/mistral-embed", "mistral"), ("jina/jina-embeddings-v2-base-en", "jina"), - # Implicit detection from model names ("voyage-01", "voyage"), ("voyage-large-2", "voyage"), @@ -40,38 +38,39 @@ def test_provider_detection_comprehensive(self): ("titan-embed-text-v1", "bedrock"), ("text-embedding-ada-002", "openai"), ("text-embedding-3-large", "openai"), - # Default cases ("unknown-model", "openai"), ("custom-model", "openai"), ] - + for model, expected_provider in test_cases: provider = LiteLLMProvider(model=model) - assert provider._detect_provider() == expected_provider, f"Failed for model: {model}" - + assert provider._detect_provider() == expected_provider, ( + f"Failed for model: {model}" + ) + def test_model_dimensions_lookup(self): """Test model dimension lookup for known models.""" # Test with provider prefix provider = LiteLLMProvider("openai/text-embedding-ada-002") info = provider.get_model_info(skip_dimension_check=True) assert info["dimension"] == 1536 - + # Test without provider prefix provider = LiteLLMProvider("text-embedding-ada-002") info = provider.get_model_info(skip_dimension_check=True) assert info["dimension"] == 1536 - + # Test Cohere model provider = LiteLLMProvider("cohere/embed-english-v3.0") info = provider.get_model_info(skip_dimension_check=True) assert info["dimension"] == 1024 - + # Test unknown model provider = LiteLLMProvider("unknown/custom-model") info = provider.get_model_info(skip_dimension_check=True) assert info["dimension"] is None - + def test_api_key_environment_mapping(self): """Test API key environment variable mapping.""" test_cases = [ @@ -89,39 +88,41 @@ def test_api_key_environment_mapping(self): ("ai21", "test-key", "AI21_API_KEY"), ("nlp_cloud", "test-key", "NLP_CLOUD_API_KEY"), ] - + for provider_name, api_key, env_var in test_cases: # Clear environment if env_var in os.environ: del os.environ[env_var] - + # Create provider with explicit provider provider = LiteLLMProvider(f"{provider_name}/test-model", api_key=api_key) provider._set_api_key() - + assert os.environ.get(env_var) == api_key - + # Clean up if env_var in os.environ: del os.environ[env_var] - + def test_bedrock_credentials_parsing(self): """Test AWS Bedrock credential parsing.""" # Clear AWS environment variables for var in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"]: if var in os.environ: del os.environ[var] - - provider = LiteLLMProvider("bedrock/amazon.titan-embed", api_key="access123:secret456") + + provider = LiteLLMProvider( + "bedrock/amazon.titan-embed", api_key="access123:secret456" + ) provider._set_api_key() - + assert os.environ["AWS_ACCESS_KEY_ID"] == "access123" assert os.environ["AWS_SECRET_ACCESS_KEY"] == "secret456" - + # Clean up del os.environ["AWS_ACCESS_KEY_ID"] del os.environ["AWS_SECRET_ACCESS_KEY"] - + def test_max_batch_size_by_provider(self): """Test max batch size limits by provider.""" test_cases = [ @@ -139,65 +140,64 @@ def test_max_batch_size_by_provider(self): ("ollama/model", 1), ("unknown/model", 100), # Default ] - + for model, expected_batch_size in test_cases: provider = LiteLLMProvider(model=model) assert provider.max_batch_size == expected_batch_size - + def test_custom_model_support(self): """Test support for custom models not in MODEL_DIMENSIONS.""" # Test ModernBERT example provider = LiteLLMProvider("huggingface/answerdotai/ModernBERT-base") assert provider.model == "huggingface/answerdotai/ModernBERT-base" assert provider._detect_provider() == "huggingface" - + # Test ColBERT example provider = LiteLLMProvider( - "huggingface/colbert-ir/colbertv2.0", - api_base="http://localhost:8000/v1" + "huggingface/colbert-ir/colbertv2.0", api_base="http://localhost:8000/v1" ) assert provider.model == "huggingface/colbert-ir/colbertv2.0" assert provider.api_base == "http://localhost:8000/v1" - + # Test completely custom model provider = LiteLLMProvider( "custom/my-special-model", api_base="http://my-server/v1", - custom_llm_provider="openai" + custom_llm_provider="openai", ) assert provider.model == "custom/my-special-model" assert provider.custom_llm_provider == "openai" - + @patch('litellm.embedding') def test_embed_with_litellm_mock(self, mock_embedding): """Test embed method with mocked litellm.""" # Mock the litellm module mock_litellm = Mock() mock_litellm.embedding = mock_embedding - + # Mock response mock_response = Mock() mock_response.data = [ {"embedding": [0.1, 0.2, 0.3]}, - {"embedding": [0.4, 0.5, 0.6]} + {"embedding": [0.4, 0.5, 0.6]}, ] mock_response.model = "text-embedding-ada-002" mock_response.usage = Mock(prompt_tokens=10, total_tokens=10) mock_embedding.return_value = mock_response - + # Create provider and patch litellm provider = LiteLLMProvider("text-embedding-ada-002", api_key="test-key") provider._litellm = mock_litellm - + # Test embedding result = provider.embed(["Hello", "World"]) - + # Verify call mock_embedding.assert_called_once() call_args = mock_embedding.call_args[1] assert call_args["model"] == "text-embedding-ada-002" assert call_args["input"] == ["Hello", "World"] - + # Verify result assert len(result.embeddings) == 2 assert result.embeddings[0] == [0.1, 0.2, 0.3] @@ -205,32 +205,32 @@ def test_embed_with_litellm_mock(self, mock_embedding): assert result.model == "text-embedding-ada-002" assert result.dimension == 3 assert result.usage["prompt_tokens"] == 10 - + @patch('litellm.embedding') def test_dynamic_dimension_detection(self, mock_embedding): """Test automatic dimension detection for unknown models.""" # Mock the litellm module mock_litellm = Mock() mock_litellm.embedding = mock_embedding - + # Mock response for test embedding mock_response = Mock() mock_response.data = [{"embedding": [0.1] * 768}] # 768 dimensions mock_response.model = "custom/unknown-model" mock_embedding.return_value = mock_response - + # Create provider with unknown model provider = LiteLLMProvider("custom/unknown-model") provider._litellm = mock_litellm - + # Get model info (should trigger test embedding) info = provider.get_model_info() - + # Should have made a test call mock_embedding.assert_called_once() assert info["dimension"] == 768 assert info["model"] == "custom/unknown-model" - + def test_initialization_parameters(self): """Test all initialization parameters are stored correctly.""" provider = LiteLLMProvider( @@ -243,9 +243,9 @@ def test_initialization_parameters(self): organization="test-org", custom_llm_provider="custom", input_type="search_document", - encoding_format="base64" + encoding_format="base64", ) - + assert provider.model == "openai/text-embedding-3-large" assert provider.api_key == "test-key" assert provider.api_base == "https://api.example.com" @@ -256,50 +256,50 @@ def test_initialization_parameters(self): assert provider.custom_llm_provider == "custom" assert provider.input_type == "search_document" assert provider.encoding_format == "base64" - + def test_error_handling(self): """Test error handling in embed method.""" provider = LiteLLMProvider("test-model") - + # Mock litellm to raise an exception mock_litellm = Mock() mock_litellm.embedding.side_effect = Exception("API Error") provider._litellm = mock_litellm - + with pytest.raises(RuntimeError) as exc_info: provider.embed("Test text") - + assert "Failed to generate embeddings with test-model" in str(exc_info.value) assert "API Error" in str(exc_info.value) class TestLiteLLMProviderIntegration: """Integration tests requiring actual LiteLLM library.""" - + @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY"), - reason="Requires OPENAI_API_KEY for integration test" + reason="Requires OPENAI_API_KEY for integration test", ) def test_real_openai_embedding(self): """Test with real OpenAI API (requires API key).""" provider = LiteLLMProvider("text-embedding-ada-002") result = provider.embed("Hello, world!") - + assert len(result.embeddings) == 1 assert len(result.embeddings[0]) == 1536 assert result.model == "text-embedding-ada-002" assert result.dimension == 1536 assert result.usage is not None - + def test_import_error_handling(self): """Test handling when litellm is not installed.""" provider = LiteLLMProvider("test-model") - + # Force import error provider._litellm = None with patch('builtins.__import__', side_effect=ImportError): with pytest.raises(ImportError) as exc_info: _ = provider.litellm - + assert "LiteLLM is required" in str(exc_info.value) - assert "pip install 'contextframe[extract]'" in str(exc_info.value) \ No newline at end of file + assert "pip install 'contextframe[extract]'" in str(exc_info.value) diff --git a/tests/test_templates.py b/tests/test_templates.py index 46b7190..8a2395f 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,65 +1,68 @@ """Tests for Context Templates module.""" import pytest -from pathlib import Path -import tempfile import shutil - +import tempfile +from contextframe import FrameDataset from contextframe.templates import ( + BusinessTemplate, ContextTemplate, - TemplateResult, - SoftwareProjectTemplate, ResearchTemplate, - BusinessTemplate, + SoftwareProjectTemplate, + TemplateResult, get_template, list_templates, ) -from contextframe.templates.base import FileMapping, CollectionDefinition, EnrichmentSuggestion +from contextframe.templates.base import ( + CollectionDefinition, + EnrichmentSuggestion, + FileMapping, +) from contextframe.templates.registry import TemplateRegistry, find_template_for_path -from contextframe import FrameDataset +from pathlib import Path class TestTemplateRegistry: """Test the template registry functionality.""" - + def test_builtin_templates_registered(self): """Test that built-in templates are automatically registered.""" templates = list_templates() template_names = [t["name"] for t in templates] - + assert "software_project" in template_names - assert "research" in template_names + assert "research" in template_names assert "business" in template_names assert len(templates) >= 3 - + def test_get_template(self): """Test retrieving templates by name.""" software_template = get_template("software_project") assert isinstance(software_template, SoftwareProjectTemplate) assert software_template.name == "software_project" - + research_template = get_template("research") assert isinstance(research_template, ResearchTemplate) - + business_template = get_template("business") assert isinstance(business_template, BusinessTemplate) - + def test_get_nonexistent_template(self): """Test error handling for non-existent template.""" with pytest.raises(KeyError) as exc_info: get_template("nonexistent") assert "not found" in str(exc_info.value) - + def test_template_info(self): """Test template metadata.""" templates = list_templates() - + for template_info in templates: assert "name" in template_info assert "description" in template_info assert "class" in template_info assert isinstance(template_info["description"], str) - + def test_find_template_for_path(self, tmp_path): """Test automatic template detection.""" # Create software project structure @@ -68,30 +71,30 @@ def test_find_template_for_path(self, tmp_path): (software_dir / "src").mkdir() (software_dir / "tests").mkdir() (software_dir / "setup.py").touch() - + assert find_template_for_path(str(software_dir)) == "software_project" - + # Create research structure research_dir = tmp_path / "research_project" research_dir.mkdir() (research_dir / "papers").mkdir() (research_dir / "data").mkdir() (research_dir / "references.bib").touch() - + assert find_template_for_path(str(research_dir)) == "research" - + # Create business structure business_dir = tmp_path / "business_docs" business_dir.mkdir() (business_dir / "meetings").mkdir() (business_dir / "decisions").mkdir() - + assert find_template_for_path(str(business_dir)) == "business" class TestSoftwareProjectTemplate: """Test the software project template.""" - + @pytest.fixture def software_project(self, tmp_path): """Create a sample software project structure.""" @@ -99,31 +102,35 @@ def software_project(self, tmp_path): (tmp_path / "src" / "myapp").mkdir(parents=True) (tmp_path / "tests").mkdir() (tmp_path / "docs").mkdir() - + # Create files (tmp_path / "README.md").write_text("# My Project\n\nA sample project.") - (tmp_path / "setup.py").write_text("from setuptools import setup\nsetup(name='myapp')") + (tmp_path / "setup.py").write_text( + "from setuptools import setup\nsetup(name='myapp')" + ) (tmp_path / "requirements.txt").write_text("pytest\nnumpy") - + # Source files (tmp_path / "src" / "myapp" / "__init__.py").write_text("__version__ = '0.1.0'") (tmp_path / "src" / "myapp" / "core.py").write_text("def main():\n pass") (tmp_path / "src" / "myapp" / "utils.py").write_text("def helper():\n pass") - + # Test files (tmp_path / "tests" / "test_core.py").write_text("def test_main():\n pass") - (tmp_path / "tests" / "test_utils.py").write_text("def test_helper():\n pass") - + (tmp_path / "tests" / "test_utils.py").write_text( + "def test_helper():\n pass" + ) + # Documentation (tmp_path / "docs" / "guide.md").write_text("# User Guide") - + return tmp_path - + def test_scan_software_project(self, software_project): """Test scanning a software project.""" template = SoftwareProjectTemplate() mappings = template.scan(software_project) - + # Check we found all expected files paths = [str(m.path.name) for m in mappings] assert "README.md" in paths @@ -132,25 +139,25 @@ def test_scan_software_project(self, software_project): assert "core.py" in paths assert "test_core.py" in paths assert "guide.md" in paths - + # Check categorization readme = next(m for m in mappings if m.path.name == "README.md") assert "readme" in readme.tags assert "overview" in readme.tags - + core_file = next(m for m in mappings if m.path.name == "core.py") assert "source" in core_file.tags assert "python" in core_file.tags - + test_file = next(m for m in mappings if m.path.name == "test_core.py") assert "test" in test_file.tags - + def test_define_collections(self, software_project): """Test collection definition for software projects.""" template = SoftwareProjectTemplate() mappings = template.scan(software_project) collections = template.define_collections(mappings) - + # Check collections created coll_names = [c.name for c in collections] assert "project" in coll_names @@ -158,28 +165,28 @@ def test_define_collections(self, software_project): assert "tests" in coll_names # Check we have at least the base collections assert len(collections) >= 3 - + def test_discover_relationships(self, software_project, tmp_path): """Test relationship discovery.""" template = SoftwareProjectTemplate() mappings = template.scan(software_project) - + # Create a mock dataset dataset = FrameDataset.create(tmp_path / "test.lance") - + relationships = template.discover_relationships(mappings, dataset) - + # Should find test->source relationships test_rels = [r for r in relationships if r["type"] == "tests"] # With test_core.py and core.py, we should find a relationship assert isinstance(relationships, list) - + def test_suggest_enrichments(self, software_project): """Test enrichment suggestions.""" template = SoftwareProjectTemplate() mappings = template.scan(software_project) suggestions = template.suggest_enrichments(mappings) - + # Should have suggestions for different file types patterns = [s.file_pattern for s in suggestions] assert any("*.py" in p for p in patterns) @@ -189,7 +196,7 @@ def test_suggest_enrichments(self, software_project): class TestResearchTemplate: """Test the research template.""" - + @pytest.fixture def research_project(self, tmp_path): """Create a sample research project structure.""" @@ -197,21 +204,21 @@ def research_project(self, tmp_path): (tmp_path / "papers").mkdir() (tmp_path / "data").mkdir() (tmp_path / "notebooks").mkdir() - + # Create files (tmp_path / "papers" / "paper1.pdf").write_bytes(b"PDF content") (tmp_path / "papers" / "draft_2024.md").write_text("# Research Paper") (tmp_path / "data" / "results.csv").write_text("id,value\n1,100") (tmp_path / "notebooks" / "analysis.ipynb").write_text('{"cells": []}') (tmp_path / "references.bib").write_text("@article{...}") - + return tmp_path - + def test_scan_research_project(self, research_project): """Test scanning a research project.""" template = ResearchTemplate() mappings = template.scan(research_project) - + # Check we found expected files paths = [str(m.path.name) for m in mappings] assert "paper1.pdf" in paths @@ -219,24 +226,24 @@ def test_scan_research_project(self, research_project): assert "results.csv" in paths assert "analysis.ipynb" in paths assert "references.bib" in paths - + # Check categorization paper = next(m for m in mappings if m.path.name == "paper1.pdf") assert "paper" in paper.tags assert "research" in paper.tags - + notebook = next(m for m in mappings if m.path.name == "analysis.ipynb") assert "notebook" in notebook.tags - + bib = next(m for m in mappings if m.path.name == "references.bib") assert "bibliography" in bib.tags - + def test_research_collections(self, research_project): """Test collection definition for research projects.""" template = ResearchTemplate() mappings = template.scan(research_project) collections = template.define_collections(mappings) - + coll_names = [c.name for c in collections] assert "papers" in coll_names assert "data" in coll_names @@ -247,7 +254,7 @@ def test_research_collections(self, research_project): class TestBusinessTemplate: """Test the business template.""" - + @pytest.fixture def business_project(self, tmp_path): """Create a sample business project structure.""" @@ -255,46 +262,48 @@ def business_project(self, tmp_path): (tmp_path / "meetings" / "weekly").mkdir(parents=True) (tmp_path / "decisions").mkdir() (tmp_path / "reports").mkdir() - + # Create files - (tmp_path / "meetings" / "weekly" / "2024-01-15-standup.md").write_text("# Standup") + (tmp_path / "meetings" / "weekly" / "2024-01-15-standup.md").write_text( + "# Standup" + ) (tmp_path / "decisions" / "ADR-001-architecture.md").write_text("# Decision") (tmp_path / "reports" / "Q1-2024-summary.md").write_text("# Report") (tmp_path / "project-plan.md").write_text("# Project Plan") - + return tmp_path - + def test_scan_business_project(self, business_project): """Test scanning a business project.""" template = BusinessTemplate() mappings = template.scan(business_project) - + # Check we found expected files paths = [str(m.path.name) for m in mappings] assert "2024-01-15-standup.md" in paths assert "ADR-001-architecture.md" in paths assert "Q1-2024-summary.md" in paths - + # Check categorization meeting = next(m for m in mappings if "standup" in m.path.name) assert "meeting" in meeting.tags assert "standup" in meeting.custom_metadata.get("meeting_type", "") - + decision = next(m for m in mappings if "ADR" in m.path.name) assert "decision" in decision.tags - + def test_business_date_extraction(self, business_project): """Test date extraction from filenames.""" template = BusinessTemplate() mappings = template.scan(business_project) - + meeting = next(m for m in mappings if "standup" in m.path.name) assert meeting.custom_metadata.get("meeting_date") == "2024-01-15" class TestTemplateApplication: """Test end-to-end template application.""" - + def test_apply_template_dry_run(self, tmp_path): """Test dry run of template application.""" # Create simple structure @@ -302,35 +311,35 @@ def test_apply_template_dry_run(self, tmp_path): source_dir.mkdir() (source_dir / "README.md").write_text("# Test") (source_dir / "main.py").write_text("print('hello')") - + # Create dataset dataset_path = tmp_path / "test.lance" dataset = FrameDataset.create(dataset_path) - + # Apply template in dry run mode template = SoftwareProjectTemplate() result = template.apply(source_dir, dataset, dry_run=True) - + assert result.frames_created == 0 assert result.collections_created == 0 assert len(result.warnings) > 0 assert "DRY RUN" in result.warnings[0] - + def test_template_result_tracking(self): """Test TemplateResult tracking.""" result = TemplateResult() - + # Test initial state assert result.frames_created == 0 assert result.collections_created == 0 assert result.relationships_created == 0 assert len(result.errors) == 0 - + # Test tracking result.frames_created += 5 result.collections_created += 2 result.errors.append("Test error") - + assert result.frames_created == 5 assert result.collections_created == 2 assert len(result.errors) == 1 @@ -338,34 +347,34 @@ def test_template_result_tracking(self): class TestCustomTemplate: """Test creating custom templates.""" - + def test_custom_template_implementation(self): """Test implementing a custom template.""" - + class CustomTemplate(ContextTemplate): def __init__(self): super().__init__("custom", "A custom template") - + def scan(self, source_path): return [] - + def define_collections(self, file_mappings): return [] - + def discover_relationships(self, file_mappings, dataset): return [] - + def suggest_enrichments(self, file_mappings): return [] - + # Create and test custom template template = CustomTemplate() assert template.name == "custom" assert template.description == "A custom template" - + # Test it can be registered registry = TemplateRegistry() registry.register("my_custom", CustomTemplate) - + retrieved = registry.get("my_custom") - assert isinstance(retrieved, CustomTemplate) \ No newline at end of file + assert isinstance(retrieved, CustomTemplate)