diff --git a/.claude/implementations/phase4.1_monitoring.md b/.claude/implementations/phase4.1_monitoring.md new file mode 100644 index 0000000..72d0939 --- /dev/null +++ b/.claude/implementations/phase4.1_monitoring.md @@ -0,0 +1,361 @@ +# Phase 4.1: Production Monitoring Implementation Plan + +## Overview + +This phase implements comprehensive monitoring capabilities for the MCP server, enabling production-grade observability and cost tracking. + +## Architecture + +### Core Components + +1. **MetricsCollector** - Central metrics aggregation +2. **UsageTracker** - Document and query usage tracking +3. **PerformanceMonitor** - Response time and throughput monitoring +4. **CostCalculator** - LLM API and storage cost attribution +5. **MetricsExporter** - Export to various monitoring systems + +### Data Storage + +Metrics will be stored in: +- **In-memory**: Recent metrics (last hour) for fast access +- **Lance dataset**: Historical metrics in dedicated tables +- **Export targets**: Prometheus, CloudWatch, etc. + +## Implementation Details + +### 1. Context Usage Metrics + +Track how documents are accessed and used: + +```python +class UsageTracker: + """Track document access patterns and query statistics.""" + + async def track_document_access( + self, + document_id: str, + operation: str, # read, search_hit, update, delete + agent_id: str | None = None, + metadata: dict[str, Any] | None = None + ) -> None: + """Record document access event.""" + + async def track_query( + self, + query: str, + query_type: str, # vector, text, hybrid, sql + result_count: int, + execution_time_ms: float, + agent_id: str | None = None + ) -> None: + """Record query execution.""" + + async def get_usage_stats( + self, + start_time: datetime, + end_time: datetime, + group_by: str = "hour" # hour, day, week + ) -> UsageStats: + """Get aggregated usage statistics.""" +``` + +Metrics collected: +- Document access frequency +- Query patterns and types +- Search result relevance (click-through) +- Collection usage distribution +- Time-based access patterns + +### 2. Agent Performance Tracking + +Monitor MCP operation performance: + +```python +class PerformanceMonitor: + """Track MCP server and agent performance metrics.""" + + async def start_operation( + self, + operation_id: str, + operation_type: str, # tool_call, resource_read, subscription + agent_id: str | None = None + ) -> OperationContext: + """Start tracking an operation.""" + + async def end_operation( + self, + operation_id: str, + status: str, # success, error, timeout + result_size: int | None = None, + error: str | None = None + ) -> None: + """Complete operation tracking.""" + + async def record_metric( + self, + metric_name: str, + value: float, + tags: dict[str, str] | None = None + ) -> None: + """Record a custom metric.""" +``` + +Metrics collected: +- Response times (p50, p95, p99) +- Request throughput +- Error rates by operation +- Resource utilization +- Queue depths (for batch operations) +- Active connections + +### 3. Cost Attribution + +Track costs associated with operations: + +```python +class CostCalculator: + """Calculate and track costs for operations.""" + + def __init__(self, pricing_config: PricingConfig): + self.llm_pricing = pricing_config.llm_pricing + self.storage_pricing = pricing_config.storage_pricing + + async def track_llm_usage( + self, + provider: str, + model: str, + input_tokens: int, + output_tokens: int, + operation_id: str, + agent_id: str | None = None + ) -> float: + """Track LLM API usage and calculate cost.""" + + async def track_storage_usage( + self, + operation: str, # read, write, delete + size_bytes: int, + agent_id: str | None = None + ) -> float: + """Track storage operations and costs.""" + + async def get_cost_report( + self, + start_time: datetime, + end_time: datetime, + group_by: str = "agent" # agent, operation, model + ) -> CostReport: + """Generate cost attribution report.""" +``` + +Cost tracking includes: +- LLM API calls (by provider/model) +- Storage operations (reads/writes) +- Bandwidth usage +- Compute time for operations +- Cost allocation by agent/purpose + +### 4. Metrics Storage Schema + +```python +# Lance table schemas for metrics storage + +USAGE_METRICS_SCHEMA = pa.schema([ + ("timestamp", pa.timestamp("us", tz="UTC")), + ("metric_type", pa.string()), # document_access, query, etc. + ("resource_id", pa.string()), # document_id, collection_id + ("operation", pa.string()), + ("agent_id", pa.string()), + ("value", pa.float64()), + ("metadata", pa.string()), # JSON string +]) + +PERFORMANCE_METRICS_SCHEMA = pa.schema([ + ("timestamp", pa.timestamp("us", tz="UTC")), + ("operation_id", pa.string()), + ("operation_type", pa.string()), + ("agent_id", pa.string()), + ("duration_ms", pa.float64()), + ("status", pa.string()), + ("error", pa.string()), + ("result_size", pa.int64()), +]) + +COST_METRICS_SCHEMA = pa.schema([ + ("timestamp", pa.timestamp("us", tz="UTC")), + ("operation_id", pa.string()), + ("cost_type", pa.string()), # llm, storage, bandwidth + ("provider", pa.string()), + ("amount_usd", pa.float64()), + ("units", pa.int64()), # tokens, bytes, requests + ("agent_id", pa.string()), + ("metadata", pa.string()), +]) +``` + +### 5. MCP Monitoring Tools + +New tools for accessing monitoring data: + +```python +# Monitoring tools accessible via MCP + +@tool_registry.register("get_usage_metrics") +async def get_usage_metrics(params: GetUsageMetricsParams) -> UsageMetricsResult: + """Get usage metrics for documents and queries.""" + +@tool_registry.register("get_performance_metrics") +async def get_performance_metrics(params: GetPerformanceParams) -> PerformanceResult: + """Get performance metrics for operations.""" + +@tool_registry.register("get_cost_report") +async def get_cost_report(params: GetCostReportParams) -> CostReportResult: + """Get cost attribution report.""" + +@tool_registry.register("export_metrics") +async def export_metrics(params: ExportMetricsParams) -> ExportResult: + """Export metrics to external monitoring system.""" +``` + +### 6. Integration Points + +#### With Existing Components + +```python +# In MessageHandler +async def handle_message(self, message: dict) -> dict: + operation_id = str(uuid.uuid4()) + + # Start performance tracking + ctx = await self.performance_monitor.start_operation( + operation_id=operation_id, + operation_type=message["method"], + agent_id=self._get_agent_id(message) + ) + + try: + # Execute operation + result = await self._execute_method(message) + + # Track success + await self.performance_monitor.end_operation( + operation_id=operation_id, + status="success", + result_size=self._calculate_result_size(result) + ) + + return result + except Exception as e: + # Track error + await self.performance_monitor.end_operation( + operation_id=operation_id, + status="error", + error=str(e) + ) + raise +``` + +#### With Analytics Tools + +The monitoring system will integrate with Phase 3.6 analytics: +- Analytics tools provide dataset-level insights +- Monitoring tracks operation-level metrics +- Combined view shows full system health + +### 7. Configuration + +```python +@dataclass +class MonitoringConfig: + """Configuration for monitoring system.""" + + # Metrics collection + enabled: bool = True + metrics_retention_days: int = 30 + aggregation_intervals: list[str] = field( + default_factory=lambda: ["1m", "5m", "1h", "1d"] + ) + + # Performance thresholds + slow_query_threshold_ms: float = 1000 + error_rate_threshold: float = 0.05 + + # Cost tracking + track_costs: bool = True + pricing_config_path: str | None = None + + # Export targets + prometheus_enabled: bool = False + prometheus_port: int = 9090 + cloudwatch_enabled: bool = False + cloudwatch_namespace: str = "ContextFrame/MCP" +``` + +### 8. Monitoring Dashboard + +Create a simple web dashboard for monitoring: + +```python +# In contextframe/mcp/monitoring/dashboard.py +class MonitoringDashboard: + """Simple web dashboard for monitoring metrics.""" + + def __init__(self, metrics_collector: MetricsCollector): + self.app = FastAPI(title="ContextFrame Monitoring") + self.metrics = metrics_collector + self._setup_routes() + + def _setup_routes(self): + @self.app.get("/metrics/usage") + async def usage_metrics( + start: datetime = Query(default=datetime.now() - timedelta(hours=1)), + end: datetime = Query(default=datetime.now()) + ): + return await self.metrics.get_usage_stats(start, end) +``` + +## Implementation Order + +1. **Week 1**: Core infrastructure + - MetricsCollector base class + - In-memory metrics storage + - Basic performance tracking + +2. **Week 2**: Usage tracking + - Document access tracking + - Query pattern analysis + - Integration with existing tools + +3. **Week 3**: Cost attribution + - LLM cost tracking + - Storage cost calculation + - Cost reporting tools + +4. **Week 4**: Export and dashboards + - Prometheus exporter + - Simple web dashboard + - Alerting rules + +## Success Criteria + +1. **Performance Impact**: < 5% overhead on operations +2. **Data Completeness**: 100% of operations tracked +3. **Cost Accuracy**: Within 5% of actual costs +4. **Query Performance**: Metrics queries < 100ms +5. **Integration**: Works with all existing tools + +## Testing Strategy + +1. **Unit Tests**: Each component tested in isolation +2. **Integration Tests**: End-to-end monitoring flows +3. **Performance Tests**: Verify minimal overhead +4. **Accuracy Tests**: Validate cost calculations +5. **Load Tests**: Handle high-volume metrics + +## Future Enhancements + +- Machine learning for anomaly detection +- Predictive cost modeling +- Auto-scaling recommendations +- Performance optimization suggestions +- Custom alerting rules \ No newline at end of file diff --git a/contextframe/mcp/monitoring/__init__.py b/contextframe/mcp/monitoring/__init__.py new file mode 100644 index 0000000..3dfd0df --- /dev/null +++ b/contextframe/mcp/monitoring/__init__.py @@ -0,0 +1,23 @@ +"""Monitoring module for MCP server. + +Provides comprehensive monitoring capabilities including: +- Context usage metrics +- Agent performance tracking +- Cost attribution +- Metrics export +""" + +from .collector import MetricsCollector +from .cost import CostCalculator, PricingConfig +from .performance import OperationContext, PerformanceMonitor +from .usage import UsageStats, UsageTracker + +__all__ = [ + "MetricsCollector", + "UsageTracker", + "UsageStats", + "PerformanceMonitor", + "OperationContext", + "CostCalculator", + "PricingConfig", +] \ No newline at end of file diff --git a/contextframe/mcp/monitoring/collector.py b/contextframe/mcp/monitoring/collector.py new file mode 100644 index 0000000..0eb3f20 --- /dev/null +++ b/contextframe/mcp/monitoring/collector.py @@ -0,0 +1,296 @@ +"""Central metrics collection and aggregation.""" + +import asyncio +import json +from collections import defaultdict, deque +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List, Optional + +import pyarrow as pa +from contextframe.frame import FrameDataset + + +@dataclass +class MetricsConfig: + """Configuration for metrics collection.""" + + enabled: bool = True + retention_days: int = 30 + aggregation_intervals: list[str] = field( + default_factory=lambda: ["1m", "5m", "1h", "1d"] + ) + max_memory_metrics: int = 10000 + flush_interval_seconds: int = 60 + + +class MetricsCollector: + """Central metrics collection and storage. + + Collects metrics in memory and periodically flushes to Lance dataset. + Provides aggregation and querying capabilities. + """ + + def __init__( + self, + dataset: FrameDataset | None = None, + config: MetricsConfig | None = None + ): + self.dataset = dataset + self.config = config or MetricsConfig() + + # In-memory buffers + self._usage_buffer: deque = deque(maxlen=self.config.max_memory_metrics) + self._performance_buffer: deque = deque(maxlen=self.config.max_memory_metrics) + self._cost_buffer: deque = deque(maxlen=self.config.max_memory_metrics) + + # Aggregated metrics for fast access + self._aggregated_metrics: dict[str, dict[str, Any]] = defaultdict(dict) + + # Background tasks + self._flush_task: asyncio.Task | None = None + self._aggregation_task: asyncio.Task | None = None + + # Metrics schemas + self.usage_schema = pa.schema([ + ("timestamp", pa.timestamp("us", tz="UTC")), + ("metric_type", pa.string()), + ("resource_id", pa.string()), + ("operation", pa.string()), + ("agent_id", pa.string()), + ("value", pa.float64()), + ("metadata", pa.string()), + ]) + + self.performance_schema = pa.schema([ + ("timestamp", pa.timestamp("us", tz="UTC")), + ("operation_id", pa.string()), + ("operation_type", pa.string()), + ("agent_id", pa.string()), + ("duration_ms", pa.float64()), + ("status", pa.string()), + ("error", pa.string()), + ("result_size", pa.int64()), + ]) + + self.cost_schema = pa.schema([ + ("timestamp", pa.timestamp("us", tz="UTC")), + ("operation_id", pa.string()), + ("cost_type", pa.string()), + ("provider", pa.string()), + ("amount_usd", pa.float64()), + ("units", pa.int64()), + ("agent_id", pa.string()), + ("metadata", pa.string()), + ]) + + async def start(self) -> None: + """Start background tasks for metrics processing.""" + if not self.config.enabled: + return + + # Start flush task + self._flush_task = asyncio.create_task(self._flush_loop()) + + # Start aggregation task + self._aggregation_task = asyncio.create_task(self._aggregation_loop()) + + async def stop(self) -> None: + """Stop background tasks and flush remaining metrics.""" + # Cancel background tasks + if self._flush_task: + self._flush_task.cancel() + try: + await self._flush_task + except asyncio.CancelledError: + pass + + if self._aggregation_task: + self._aggregation_task.cancel() + try: + await self._aggregation_task + except asyncio.CancelledError: + pass + + # Final flush + await self._flush_metrics() + + async def record_usage( + self, + metric_type: str, + resource_id: str, + operation: str, + value: float = 1.0, + agent_id: str | None = None, + metadata: dict[str, Any] | None = None + ) -> None: + """Record a usage metric.""" + if not self.config.enabled: + return + + metric = { + "timestamp": datetime.now(timezone.utc), + "metric_type": metric_type, + "resource_id": resource_id, + "operation": operation, + "agent_id": agent_id or "anonymous", + "value": value, + "metadata": json.dumps(metadata) if metadata else None, + } + + self._usage_buffer.append(metric) + + async def record_performance( + self, + operation_id: str, + operation_type: str, + duration_ms: float, + status: str, + agent_id: str | None = None, + error: str | None = None, + result_size: int | None = None + ) -> None: + """Record a performance metric.""" + if not self.config.enabled: + return + + metric = { + "timestamp": datetime.now(timezone.utc), + "operation_id": operation_id, + "operation_type": operation_type, + "agent_id": agent_id or "anonymous", + "duration_ms": duration_ms, + "status": status, + "error": error, + "result_size": result_size, + } + + self._performance_buffer.append(metric) + + async def record_cost( + self, + operation_id: str, + cost_type: str, + provider: str, + amount_usd: float, + units: int, + agent_id: str | None = None, + metadata: dict[str, Any] | None = None + ) -> None: + """Record a cost metric.""" + if not self.config.enabled: + return + + metric = { + "timestamp": datetime.now(timezone.utc), + "operation_id": operation_id, + "cost_type": cost_type, + "provider": provider, + "amount_usd": amount_usd, + "units": units, + "agent_id": agent_id or "anonymous", + "metadata": json.dumps(metadata) if metadata else None, + } + + self._cost_buffer.append(metric) + + async def get_aggregated_metrics( + self, + metric_category: str, + interval: str = "1h", + lookback_hours: int = 24 + ) -> dict[str, Any]: + """Get aggregated metrics for a category.""" + key = f"{metric_category}:{interval}:{lookback_hours}" + return self._aggregated_metrics.get(key, {}) + + async def _flush_loop(self) -> None: + """Background task to periodically flush metrics.""" + while True: + try: + await asyncio.sleep(self.config.flush_interval_seconds) + await self._flush_metrics() + except asyncio.CancelledError: + raise + except Exception as e: + # Log error but continue + print(f"Error flushing metrics: {e}") + + async def _flush_metrics(self) -> None: + """Flush in-memory metrics to Lance dataset.""" + if not self.dataset: + return + + # Flush usage metrics + if self._usage_buffer: + usage_data = list(self._usage_buffer) + self._usage_buffer.clear() + + # Convert to Lance table and append + # This would append to a metrics table in the dataset + # For now, we'll just clear the buffer + + # Flush performance metrics + if self._performance_buffer: + perf_data = list(self._performance_buffer) + self._performance_buffer.clear() + + # Flush cost metrics + if self._cost_buffer: + cost_data = list(self._cost_buffer) + self._cost_buffer.clear() + + async def _aggregation_loop(self) -> None: + """Background task to aggregate metrics.""" + while True: + try: + await asyncio.sleep(60) # Aggregate every minute + await self._aggregate_metrics() + except asyncio.CancelledError: + raise + except Exception as e: + # Log error but continue + print(f"Error aggregating metrics: {e}") + + async def _aggregate_metrics(self) -> None: + """Aggregate recent metrics for fast access.""" + now = datetime.now(timezone.utc) + + # Aggregate usage metrics by hour + usage_by_hour = defaultdict(lambda: {"count": 0, "resources": set()}) + for metric in self._usage_buffer: + if (now - metric["timestamp"]).total_seconds() < 3600: + hour = metric["timestamp"].replace(minute=0, second=0, microsecond=0) + key = (hour, metric["metric_type"]) + usage_by_hour[key]["count"] += metric["value"] + usage_by_hour[key]["resources"].add(metric["resource_id"]) + + # Store aggregated results + self._aggregated_metrics["usage:1h:1"] = { + str(k): {"count": v["count"], "unique_resources": len(v["resources"])} + for k, v in usage_by_hour.items() + } + + # Aggregate performance metrics + perf_by_type = defaultdict(lambda: {"count": 0, "total_ms": 0, "errors": 0}) + for metric in self._performance_buffer: + if (now - metric["timestamp"]).total_seconds() < 3600: + op_type = metric["operation_type"] + perf_by_type[op_type]["count"] += 1 + perf_by_type[op_type]["total_ms"] += metric["duration_ms"] + if metric["status"] == "error": + perf_by_type[op_type]["errors"] += 1 + + # Calculate averages + for op_type, stats in perf_by_type.items(): + if stats["count"] > 0: + stats["avg_ms"] = stats["total_ms"] / stats["count"] + stats["error_rate"] = stats["errors"] / stats["count"] + + self._aggregated_metrics["performance:1h:1"] = dict(perf_by_type) + + # Clean up old aggregated metrics + cutoff_time = now - timedelta(days=1) + for key in list(self._aggregated_metrics.keys()): + # Parse timestamp from key if stored + # For now, we'll keep all aggregated metrics \ No newline at end of file diff --git a/contextframe/mcp/monitoring/cost.py b/contextframe/mcp/monitoring/cost.py new file mode 100644 index 0000000..b9846d1 --- /dev/null +++ b/contextframe/mcp/monitoring/cost.py @@ -0,0 +1,392 @@ +"""Cost tracking and attribution for MCP operations.""" + +import json +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List, Optional + +from .collector import MetricsCollector + + +@dataclass +class LLMPricing: + """Pricing for a specific LLM model.""" + + provider: str + model: str + input_cost_per_1k: float # Cost per 1k input tokens + output_cost_per_1k: float # Cost per 1k output tokens + + def calculate_cost(self, input_tokens: int, output_tokens: int) -> float: + """Calculate cost for given token counts.""" + input_cost = (input_tokens / 1000) * self.input_cost_per_1k + output_cost = (output_tokens / 1000) * self.output_cost_per_1k + return input_cost + output_cost + + +@dataclass +class StoragePricing: + """Pricing for storage operations.""" + + read_cost_per_gb: float = 0.01 # Cost per GB read + write_cost_per_gb: float = 0.02 # Cost per GB written + storage_cost_per_gb_month: float = 0.023 # Monthly storage cost + + def calculate_operation_cost(self, operation: str, size_bytes: int) -> float: + """Calculate cost for a storage operation.""" + size_gb = size_bytes / (1024 ** 3) + + if operation in ["read", "search"]: + return size_gb * self.read_cost_per_gb + elif operation in ["write", "update"]: + return size_gb * self.write_cost_per_gb + elif operation == "delete": + return 0.0 # No cost for deletion + else: + return 0.0 + + +@dataclass +class PricingConfig: + """Complete pricing configuration.""" + + # Default LLM pricing (as of 2024) + llm_pricing: dict[str, LLMPricing] = field(default_factory=lambda: { + "openai:gpt-4": LLMPricing("openai", "gpt-4", 0.03, 0.06), + "openai:gpt-3.5-turbo": LLMPricing("openai", "gpt-3.5-turbo", 0.0005, 0.0015), + "anthropic:claude-3-opus": LLMPricing("anthropic", "claude-3-opus", 0.015, 0.075), + "anthropic:claude-3-sonnet": LLMPricing("anthropic", "claude-3-sonnet", 0.003, 0.015), + "cohere:command": LLMPricing("cohere", "command", 0.0015, 0.002), + }) + + storage_pricing: StoragePricing = field(default_factory=StoragePricing) + + # Bandwidth pricing + bandwidth_cost_per_gb: float = 0.09 # Egress bandwidth cost + + @classmethod + def from_file(cls, path: str) -> "PricingConfig": + """Load pricing config from JSON file.""" + with open(path, 'r') as f: + data = json.load(f) + + config = cls() + + # Load LLM pricing + if "llm_pricing" in data: + config.llm_pricing = {} + for key, pricing in data["llm_pricing"].items(): + config.llm_pricing[key] = LLMPricing(**pricing) + + # Load storage pricing + if "storage_pricing" in data: + config.storage_pricing = StoragePricing(**data["storage_pricing"]) + + # Load bandwidth pricing + if "bandwidth_cost_per_gb" in data: + config.bandwidth_cost_per_gb = data["bandwidth_cost_per_gb"] + + return config + + +@dataclass +class CostSummary: + """Summary of costs for a period.""" + + period_start: datetime + period_end: datetime + total_cost: float = 0.0 + llm_cost: float = 0.0 + storage_cost: float = 0.0 + bandwidth_cost: float = 0.0 + costs_by_provider: dict[str, float] = field(default_factory=dict) + costs_by_operation: dict[str, float] = field(default_factory=dict) + costs_by_agent: dict[str, float] = field(default_factory=dict) + top_expensive_operations: list[dict[str, Any]] = field(default_factory=list) + + +@dataclass +class CostReport: + """Detailed cost report.""" + + summary: CostSummary + daily_breakdown: list[CostSummary] = field(default_factory=list) + recommendations: list[str] = field(default_factory=list) + projected_monthly_cost: float = 0.0 + + +class CostCalculator: + """Calculate and track costs for operations.""" + + def __init__( + self, + metrics_collector: MetricsCollector, + pricing_config: PricingConfig | None = None + ): + self.metrics = metrics_collector + self.pricing = pricing_config or PricingConfig() + + # Cost tracking by operation + self._operation_costs: dict[str, float] = {} + + # Aggregated costs + self._daily_costs: dict[str, CostSummary] = {} + + # Token usage tracking for projections + self._token_usage: dict[str, dict[str, int]] = {} + + async def track_llm_usage( + self, + provider: str, + model: str, + input_tokens: int, + output_tokens: int, + operation_id: str, + agent_id: str | None = None, + purpose: str | None = None + ) -> float: + """Track LLM API usage and calculate cost. + + Args: + provider: LLM provider (openai, anthropic, etc.) + model: Model name + input_tokens: Number of input tokens + output_tokens: Number of output tokens + operation_id: Associated operation ID + agent_id: Optional agent identifier + purpose: Purpose of the LLM call (enhancement, extraction, etc.) + + Returns: + Calculated cost in USD + """ + # Look up pricing + pricing_key = f"{provider}:{model}" + llm_pricing = self.pricing.llm_pricing.get(pricing_key) + + if not llm_pricing: + # Use a default pricing if model not found + llm_pricing = LLMPricing(provider, model, 0.01, 0.02) + + # Calculate cost + cost = llm_pricing.calculate_cost(input_tokens, output_tokens) + + # Track operation cost + self._operation_costs[operation_id] = self._operation_costs.get(operation_id, 0.0) + cost + + # Track token usage + if provider not in self._token_usage: + self._token_usage[provider] = {} + if model not in self._token_usage[provider]: + self._token_usage[provider][model] = {"input": 0, "output": 0} + + self._token_usage[provider][model]["input"] += input_tokens + self._token_usage[provider][model]["output"] += output_tokens + + # Record metric + await self.metrics.record_cost( + operation_id=operation_id, + cost_type="llm", + provider=provider, + amount_usd=cost, + units=input_tokens + output_tokens, + agent_id=agent_id, + metadata={ + "model": model, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "purpose": purpose + } + ) + + return cost + + async def track_storage_usage( + self, + operation: str, + size_bytes: int, + agent_id: str | None = None, + operation_id: str | None = None + ) -> float: + """Track storage operations and costs. + + Args: + operation: Type of operation (read, write, delete) + size_bytes: Size of data in bytes + agent_id: Optional agent identifier + operation_id: Optional operation ID + + Returns: + Calculated cost in USD + """ + # Calculate cost + cost = self.pricing.storage_pricing.calculate_operation_cost(operation, size_bytes) + + # Track operation cost + if operation_id: + self._operation_costs[operation_id] = self._operation_costs.get(operation_id, 0.0) + cost + + # Record metric + await self.metrics.record_cost( + operation_id=operation_id or "storage", + cost_type="storage", + provider="lance", + amount_usd=cost, + units=size_bytes, + agent_id=agent_id, + metadata={ + "operation": operation, + "size_bytes": size_bytes + } + ) + + return cost + + async def track_bandwidth_usage( + self, + size_bytes: int, + direction: str = "egress", + agent_id: str | None = None, + operation_id: str | None = None + ) -> float: + """Track bandwidth usage and costs. + + Args: + size_bytes: Size of data transferred + direction: Transfer direction (egress, ingress) + agent_id: Optional agent identifier + operation_id: Optional operation ID + + Returns: + Calculated cost in USD + """ + # Only charge for egress + if direction != "egress": + return 0.0 + + # Calculate cost + size_gb = size_bytes / (1024 ** 3) + cost = size_gb * self.pricing.bandwidth_cost_per_gb + + # Track operation cost + if operation_id: + self._operation_costs[operation_id] = self._operation_costs.get(operation_id, 0.0) + cost + + # Record metric + await self.metrics.record_cost( + operation_id=operation_id or "bandwidth", + cost_type="bandwidth", + provider="network", + amount_usd=cost, + units=size_bytes, + agent_id=agent_id, + metadata={ + "direction": direction + } + ) + + return cost + + async def get_cost_report( + self, + start_time: datetime, + end_time: datetime, + group_by: str = "agent" + ) -> CostReport: + """Generate cost attribution report. + + Args: + start_time: Start of reporting period + end_time: End of reporting period + group_by: How to group costs (agent, operation, provider) + + Returns: + Detailed cost report + """ + # Create summary + summary = CostSummary( + period_start=start_time, + period_end=end_time + ) + + # Get aggregated metrics from collector + cost_metrics = await self.metrics.get_aggregated_metrics( + "cost", + interval="1h", + lookback_hours=int((end_time - start_time).total_seconds() / 3600) + ) + + # Calculate totals from operation costs + for op_id, cost in self._operation_costs.items(): + summary.total_cost += cost + + # Generate daily breakdown + daily_breakdown = [] + current_date = start_time.date() + end_date = end_time.date() + + while current_date <= end_date: + day_start = datetime.combine(current_date, datetime.min.time(), timezone.utc) + day_end = day_start + timedelta(days=1) + + day_summary = CostSummary( + period_start=day_start, + period_end=day_end + ) + + # Add to daily breakdown + daily_breakdown.append(day_summary) + current_date += timedelta(days=1) + + # Generate recommendations + recommendations = self._generate_recommendations(summary) + + # Calculate projected monthly cost + days_in_period = (end_time - start_time).days or 1 + daily_average = summary.total_cost / days_in_period + projected_monthly = daily_average * 30 + + return CostReport( + summary=summary, + daily_breakdown=daily_breakdown, + recommendations=recommendations, + projected_monthly_cost=projected_monthly + ) + + def _generate_recommendations(self, summary: CostSummary) -> list[str]: + """Generate cost optimization recommendations.""" + recommendations = [] + + # Check if LLM costs are high + if summary.llm_cost > summary.total_cost * 0.7: + recommendations.append( + "LLM costs represent over 70% of total costs. " + "Consider using cheaper models for non-critical operations." + ) + + # Check for expensive providers + if summary.costs_by_provider: + most_expensive = max(summary.costs_by_provider.items(), key=lambda x: x[1]) + if most_expensive[1] > summary.total_cost * 0.5: + recommendations.append( + f"{most_expensive[0]} accounts for over 50% of costs. " + f"Consider diversifying providers or negotiating rates." + ) + + # Check token usage patterns + total_tokens = sum( + usage["input"] + usage["output"] + for provider_models in self._token_usage.values() + for usage in provider_models.values() + ) + + if total_tokens > 1_000_000: + recommendations.append( + "High token usage detected. Consider implementing caching " + "for frequently requested enhancements." + ) + + return recommendations + + def get_operation_cost(self, operation_id: str) -> float: + """Get total cost for a specific operation.""" + return self._operation_costs.get(operation_id, 0.0) \ No newline at end of file diff --git a/contextframe/mcp/monitoring/integration.py b/contextframe/mcp/monitoring/integration.py new file mode 100644 index 0000000..038d336 --- /dev/null +++ b/contextframe/mcp/monitoring/integration.py @@ -0,0 +1,238 @@ +"""Integration of monitoring with MCP server components.""" + +import time +import uuid +from typing import Any, Dict, Optional + +from contextframe.mcp.handlers import MessageHandler as BaseMessageHandler +from contextframe.mcp.tools import ToolRegistry as BaseToolRegistry + +from .collector import MetricsCollector, MetricsConfig +from .cost import CostCalculator, PricingConfig +from .performance import PerformanceMonitor +from .tools import init_monitoring_tools +from .usage import UsageTracker + + +class MonitoringSystem: + """Central monitoring system for MCP server.""" + + def __init__( + self, + dataset: Any, + metrics_config: MetricsConfig | None = None, + pricing_config: PricingConfig | None = None + ): + # Initialize components + self.collector = MetricsCollector(dataset, metrics_config) + self.usage_tracker = UsageTracker(self.collector) + self.performance_monitor = PerformanceMonitor(self.collector) + self.cost_calculator = CostCalculator(self.collector, pricing_config) + + # Initialize monitoring tools + init_monitoring_tools( + self.collector, + self.usage_tracker, + self.performance_monitor, + self.cost_calculator + ) + + async def start(self) -> None: + """Start monitoring system.""" + await self.collector.start() + await self.performance_monitor.start() + + async def stop(self) -> None: + """Stop monitoring system.""" + await self.performance_monitor.stop() + await self.collector.stop() + + +class MonitoredMessageHandler(BaseMessageHandler): + """Message handler with integrated monitoring.""" + + def __init__(self, server: Any, monitoring: MonitoringSystem | None = None): + super().__init__(server) + self.monitoring = monitoring + + def _get_agent_id(self, message: dict[str, Any]) -> str | None: + """Extract agent ID from message metadata.""" + # Check for agent ID in various places + if "agent_id" in message: + return message["agent_id"] + + # Check in params + params = message.get("params", {}) + if isinstance(params, dict): + if "agent_id" in params: + return params["agent_id"] + # Check metadata + metadata = params.get("metadata", {}) + if isinstance(metadata, dict) and "agent_id" in metadata: + return metadata["agent_id"] + + return None + + async def handle(self, message: dict[str, Any]) -> dict[str, Any] | None: + """Handle message with monitoring.""" + if not self.monitoring: + # No monitoring, use base implementation + return await super().handle(message) + + # Generate operation ID + operation_id = str(uuid.uuid4()) + method = message.get("method", "unknown") + agent_id = self._get_agent_id(message) + + # Start performance tracking + context = await self.monitoring.performance_monitor.start_operation( + operation_id=operation_id, + operation_type=method, + agent_id=agent_id, + metadata={ + "request_id": message.get("id"), + "params": message.get("params", {}) + } + ) + + try: + # Execute base handler + result = await super().handle(message) + + # Track success + result_size = len(str(result)) if result else 0 + await self.monitoring.performance_monitor.end_operation( + operation_id=operation_id, + status="success", + result_size=result_size + ) + + # Track specific operations + if method == "tools/call" and result and "result" in result: + await self._track_tool_call(message, result, agent_id) + elif method == "resources/read" and result and "result" in result: + await self._track_resource_read(message, result, agent_id) + + return result + + except Exception as e: + # Track error + await self.monitoring.performance_monitor.end_operation( + operation_id=operation_id, + status="error", + error=str(e) + ) + raise + + async def _track_tool_call( + self, + request: dict[str, Any], + response: dict[str, Any], + agent_id: str | None + ) -> None: + """Track tool call metrics.""" + params = request.get("params", {}) + tool_name = params.get("name", "unknown") + + # Track tool usage + await self.monitoring.usage_tracker.track_query( + query=tool_name, + query_type="tool_call", + result_count=1, + execution_time_ms=0, # Already tracked in performance + agent_id=agent_id, + success=True, + metadata={"tool": tool_name} + ) + + # Track document access for document-related tools + if tool_name in ["get_document", "search_documents", "update_document"]: + result = response.get("result", {}) + + if tool_name == "get_document" and "document" in result: + doc_id = result["document"].get("uuid") + if doc_id: + await self.monitoring.usage_tracker.track_document_access( + document_id=doc_id, + operation="read", + agent_id=agent_id + ) + + elif tool_name == "search_documents" and "documents" in result: + for doc in result["documents"]: + doc_id = doc.get("uuid") + if doc_id: + await self.monitoring.usage_tracker.track_document_access( + document_id=doc_id, + operation="search_hit", + agent_id=agent_id + ) + + async def _track_resource_read( + self, + request: dict[str, Any], + response: dict[str, Any], + agent_id: str | None + ) -> None: + """Track resource read metrics.""" + params = request.get("params", {}) + uri = params.get("uri", "unknown") + + # Track resource access + await self.monitoring.usage_tracker.track_query( + query=uri, + query_type="resource_read", + result_count=1, + execution_time_ms=0, + agent_id=agent_id, + success=True, + metadata={"uri": uri} + ) + + +class MonitoredToolRegistry(BaseToolRegistry): + """Tool registry with cost tracking.""" + + def __init__(self, dataset: Any, transport: Any, monitoring: MonitoringSystem | None = None): + super().__init__(dataset, transport) + self.monitoring = monitoring + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]: + """Call tool with cost tracking for LLM operations.""" + # Check if this is an enhancement tool that uses LLMs + llm_tools = [ + "enhance_context", + "extract_metadata", + "generate_tags", + "improve_title", + "enhance_for_purpose", + "batch_enhance" + ] + + if name in llm_tools and self.monitoring: + # Track LLM usage (simplified - in reality would need actual token counts) + operation_id = arguments.get("operation_id", str(uuid.uuid4())) + + # Estimate tokens based on content size + content_size = 0 + if "content" in arguments: + content_size = len(arguments["content"]) + elif "document_id" in arguments: + # Would need to fetch document to get size + content_size = 1000 # Estimate + + # Rough token estimation (1 token ≈ 4 characters) + estimated_tokens = content_size // 4 + + # Track cost (assuming GPT-3.5 by default) + await self.monitoring.cost_calculator.track_llm_usage( + provider="openai", + model="gpt-3.5-turbo", + input_tokens=estimated_tokens, + output_tokens=estimated_tokens // 2, # Rough estimate + operation_id=operation_id, + purpose=name + ) + + # Call base implementation + return await super().call_tool(name, arguments) \ No newline at end of file diff --git a/contextframe/mcp/monitoring/performance.py b/contextframe/mcp/monitoring/performance.py new file mode 100644 index 0000000..9214cb3 --- /dev/null +++ b/contextframe/mcp/monitoring/performance.py @@ -0,0 +1,371 @@ +"""Performance monitoring for MCP operations.""" + +import asyncio +import time +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from typing import Any, AsyncIterator, Dict, List, Optional + +from .collector import MetricsCollector + + +@dataclass +class OperationMetrics: + """Metrics for a single operation.""" + + operation_type: str + count: int = 0 + total_duration_ms: float = 0.0 + min_duration_ms: float = float('inf') + max_duration_ms: float = 0.0 + error_count: int = 0 + timeout_count: int = 0 + + @property + def avg_duration_ms(self) -> float: + """Average duration in milliseconds.""" + return self.total_duration_ms / self.count if self.count > 0 else 0.0 + + @property + def error_rate(self) -> float: + """Error rate as a percentage.""" + return (self.error_count / self.count * 100) if self.count > 0 else 0.0 + + @property + def success_rate(self) -> float: + """Success rate as a percentage.""" + return 100.0 - self.error_rate + + +@dataclass +class PerformanceSnapshot: + """Point-in-time performance snapshot.""" + + timestamp: datetime + operations_per_second: float = 0.0 + avg_response_time_ms: float = 0.0 + error_rate: float = 0.0 + active_operations: int = 0 + queue_depth: int = 0 + memory_usage_mb: float = 0.0 + cpu_usage_percent: float = 0.0 + + +@dataclass +class OperationContext: + """Context for tracking a single operation.""" + + operation_id: str + operation_type: str + start_time: float + agent_id: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def duration_ms(self) -> float: + """Get current duration in milliseconds.""" + return (time.time() - self.start_time) * 1000 + + +class PerformanceMonitor: + """Track MCP server and agent performance metrics.""" + + def __init__(self, metrics_collector: MetricsCollector): + self.metrics = metrics_collector + + # Active operations tracking + self._active_operations: dict[str, OperationContext] = {} + + # Operation metrics by type + self._operation_metrics: dict[str, OperationMetrics] = {} + + # Performance snapshots + self._snapshots: list[PerformanceSnapshot] = [] + self._max_snapshots = 1440 # 24 hours at 1 per minute + + # Response time percentiles tracking + self._response_times: dict[str, list[float]] = {} + self._max_response_samples = 1000 + + # Background monitoring + self._monitor_task: asyncio.Task | None = None + + async def start(self) -> None: + """Start performance monitoring.""" + self._monitor_task = asyncio.create_task(self._monitor_loop()) + + async def stop(self) -> None: + """Stop performance monitoring.""" + if self._monitor_task: + self._monitor_task.cancel() + try: + await self._monitor_task + except asyncio.CancelledError: + pass + + async def start_operation( + self, + operation_id: str, + operation_type: str, + agent_id: str | None = None, + metadata: dict[str, Any] | None = None + ) -> OperationContext: + """Start tracking an operation. + + Args: + operation_id: Unique operation identifier + operation_type: Type of operation (tool_call, resource_read, etc.) + agent_id: Optional agent identifier + metadata: Additional operation context + + Returns: + Operation context for tracking + """ + context = OperationContext( + operation_id=operation_id, + operation_type=operation_type, + start_time=time.time(), + agent_id=agent_id, + metadata=metadata or {} + ) + + self._active_operations[operation_id] = context + + # Initialize metrics for this operation type + if operation_type not in self._operation_metrics: + self._operation_metrics[operation_type] = OperationMetrics( + operation_type=operation_type + ) + + return context + + async def end_operation( + self, + operation_id: str, + status: str, + result_size: int | None = None, + error: str | None = None + ) -> None: + """Complete operation tracking. + + Args: + operation_id: Operation identifier + status: Final status (success, error, timeout) + result_size: Optional size of the result + error: Error message if failed + """ + context = self._active_operations.pop(operation_id, None) + if not context: + return + + duration_ms = context.duration_ms() + + # Update operation metrics + metrics = self._operation_metrics[context.operation_type] + metrics.count += 1 + metrics.total_duration_ms += duration_ms + metrics.min_duration_ms = min(metrics.min_duration_ms, duration_ms) + metrics.max_duration_ms = max(metrics.max_duration_ms, duration_ms) + + if status == "error": + metrics.error_count += 1 + elif status == "timeout": + metrics.timeout_count += 1 + + # Track response times for percentile calculation + if context.operation_type not in self._response_times: + self._response_times[context.operation_type] = [] + + response_times = self._response_times[context.operation_type] + response_times.append(duration_ms) + + # Keep only recent samples + if len(response_times) > self._max_response_samples: + response_times.pop(0) + + # Record metric + await self.metrics.record_performance( + operation_id=operation_id, + operation_type=context.operation_type, + duration_ms=duration_ms, + status=status, + agent_id=context.agent_id, + error=error, + result_size=result_size + ) + + @asynccontextmanager + async def track_operation( + self, + operation_type: str, + agent_id: str | None = None, + metadata: dict[str, Any] | None = None + ) -> AsyncIterator[OperationContext]: + """Context manager for operation tracking. + + Usage: + async with monitor.track_operation("tool_call") as ctx: + # Perform operation + result = await some_operation() + """ + import uuid + + operation_id = str(uuid.uuid4()) + context = await self.start_operation( + operation_id=operation_id, + operation_type=operation_type, + agent_id=agent_id, + metadata=metadata + ) + + try: + yield context + await self.end_operation(operation_id, "success") + except asyncio.TimeoutError: + await self.end_operation(operation_id, "timeout") + raise + except Exception as e: + await self.end_operation(operation_id, "error", error=str(e)) + raise + + async def record_metric( + self, + metric_name: str, + value: float, + tags: dict[str, str] | None = None + ) -> None: + """Record a custom metric. + + Args: + metric_name: Name of the metric + value: Metric value + tags: Optional tags for categorization + """ + await self.metrics.record_usage( + metric_type="custom", + resource_id=metric_name, + operation="record", + value=value, + metadata=tags + ) + + def get_operation_metrics( + self, + operation_type: str | None = None + ) -> dict[str, OperationMetrics]: + """Get operation metrics. + + Args: + operation_type: Optional filter by operation type + + Returns: + Dictionary of operation metrics + """ + if operation_type: + return { + operation_type: self._operation_metrics.get( + operation_type, + OperationMetrics(operation_type=operation_type) + ) + } + return self._operation_metrics.copy() + + def get_response_percentiles( + self, + operation_type: str, + percentiles: list[float] = [0.5, 0.95, 0.99] + ) -> dict[float, float]: + """Get response time percentiles. + + Args: + operation_type: Operation type to analyze + percentiles: List of percentiles to calculate (0-1) + + Returns: + Dictionary mapping percentile to response time in ms + """ + response_times = self._response_times.get(operation_type, []) + if not response_times: + return {p: 0.0 for p in percentiles} + + sorted_times = sorted(response_times) + result = {} + + for p in percentiles: + index = int(len(sorted_times) * p) + index = min(index, len(sorted_times) - 1) + result[p] = sorted_times[index] + + return result + + def get_current_snapshot(self) -> PerformanceSnapshot: + """Get current performance snapshot.""" + if self._snapshots: + return self._snapshots[-1] + + return PerformanceSnapshot(timestamp=datetime.now(timezone.utc)) + + def get_performance_history( + self, + minutes: int = 60 + ) -> list[PerformanceSnapshot]: + """Get performance history. + + Args: + minutes: How many minutes of history to return + + Returns: + List of performance snapshots + """ + if not self._snapshots: + return [] + + cutoff = datetime.now(timezone.utc) - timedelta(minutes=minutes) + return [s for s in self._snapshots if s.timestamp >= cutoff] + + async def _monitor_loop(self) -> None: + """Background monitoring loop.""" + while True: + try: + await asyncio.sleep(60) # Snapshot every minute + await self._take_snapshot() + except asyncio.CancelledError: + raise + except Exception as e: + # Log error but continue + print(f"Error in performance monitor: {e}") + + async def _take_snapshot(self) -> None: + """Take a performance snapshot.""" + snapshot = PerformanceSnapshot(timestamp=datetime.now(timezone.utc)) + + # Calculate operations per second + total_ops = sum(m.count for m in self._operation_metrics.values()) + if self._snapshots: + prev_snapshot = self._snapshots[-1] + time_diff = (snapshot.timestamp - prev_snapshot.timestamp).total_seconds() + if time_diff > 0: + prev_total = sum( + m.count for m in self._operation_metrics.values() + ) + snapshot.operations_per_second = (total_ops - prev_total) / time_diff + + # Calculate average response time + total_duration = sum(m.total_duration_ms for m in self._operation_metrics.values()) + if total_ops > 0: + snapshot.avg_response_time_ms = total_duration / total_ops + + # Calculate error rate + total_errors = sum(m.error_count for m in self._operation_metrics.values()) + if total_ops > 0: + snapshot.error_rate = (total_errors / total_ops) * 100 + + # Active operations + snapshot.active_operations = len(self._active_operations) + + # Add snapshot + self._snapshots.append(snapshot) + + # Trim old snapshots + if len(self._snapshots) > self._max_snapshots: + self._snapshots.pop(0) \ No newline at end of file diff --git a/contextframe/mcp/monitoring/tools.py b/contextframe/mcp/monitoring/tools.py new file mode 100644 index 0000000..b16fb93 --- /dev/null +++ b/contextframe/mcp/monitoring/tools.py @@ -0,0 +1,510 @@ +"""MCP tools for accessing monitoring data.""" + +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List, Optional + +from contextframe.mcp.errors import InvalidParams +from contextframe.mcp.tools import tool_registry + +from .collector import MetricsCollector +from .cost import CostCalculator +from .performance import PerformanceMonitor +from .usage import UsageTracker + + +# Global instances (initialized by server) +metrics_collector: MetricsCollector | None = None +usage_tracker: UsageTracker | None = None +performance_monitor: PerformanceMonitor | None = None +cost_calculator: CostCalculator | None = None + + +def init_monitoring_tools( + collector: MetricsCollector, + usage: UsageTracker, + performance: PerformanceMonitor, + cost: CostCalculator +) -> None: + """Initialize monitoring tools with required components.""" + global metrics_collector, usage_tracker, performance_monitor, cost_calculator + metrics_collector = collector + usage_tracker = usage + performance_monitor = performance + cost_calculator = cost + + +def _ensure_initialized() -> None: + """Ensure monitoring components are initialized.""" + if not all([metrics_collector, usage_tracker, performance_monitor, cost_calculator]): + raise RuntimeError("Monitoring tools not initialized") + + +@tool_registry.register( + name="get_usage_metrics", + description="Get usage metrics for documents and queries", + input_schema={ + "type": "object", + "properties": { + "start_time": { + "type": "string", + "format": "date-time", + "description": "Start time (ISO format). Defaults to 1 hour ago" + }, + "end_time": { + "type": "string", + "format": "date-time", + "description": "End time (ISO format). Defaults to now" + }, + "group_by": { + "type": "string", + "enum": ["hour", "day", "week"], + "description": "Aggregation interval", + "default": "hour" + }, + "include_details": { + "type": "boolean", + "description": "Include detailed breakdowns", + "default": False + } + } + } +) +async def get_usage_metrics(params: dict[str, Any]) -> dict[str, Any]: + """Get usage metrics for documents and queries. + + Returns metrics including: + - Total queries and document accesses + - Unique documents and agents + - Query distribution by type + - Top accessed documents + - Access patterns over time + """ + _ensure_initialized() + + # Parse parameters + end_time = datetime.now(timezone.utc) + start_time = end_time - timedelta(hours=1) + + if "start_time" in params: + start_time = datetime.fromisoformat(params["start_time"].replace("Z", "+00:00")) + if "end_time" in params: + end_time = datetime.fromisoformat(params["end_time"].replace("Z", "+00:00")) + + group_by = params.get("group_by", "hour") + include_details = params.get("include_details", False) + + # Get usage stats + stats = await usage_tracker.get_usage_stats(start_time, end_time, group_by) + + result = { + "period": { + "start": stats.period_start.isoformat(), + "end": stats.period_end.isoformat() + }, + "summary": { + "total_queries": stats.total_queries, + "total_document_accesses": stats.total_document_accesses, + "unique_documents": stats.unique_documents_accessed, + "unique_agents": stats.unique_agents + }, + "queries_by_type": stats.queries_by_type, + "access_patterns": stats.access_patterns + } + + if include_details: + result["top_documents"] = [ + { + "document_id": doc.document_id, + "access_count": doc.access_count, + "search_appearances": doc.search_appearances, + "last_accessed": doc.last_accessed.isoformat() if doc.last_accessed else None, + "access_by_operation": doc.access_by_operation + } + for doc in stats.top_documents[:10] + ] + + result["top_queries"] = [ + { + "query": q.query, + "type": q.query_type, + "count": q.count, + "avg_results": q.total_results / q.count if q.count > 0 else 0, + "avg_execution_time_ms": q.avg_execution_time_ms, + "success_rate": q.success_rate + } + for q in stats.top_queries[:10] + ] + + return result + + +@tool_registry.register( + name="get_performance_metrics", + description="Get performance metrics for MCP operations", + input_schema={ + "type": "object", + "properties": { + "operation_type": { + "type": "string", + "description": "Filter by operation type (e.g., tool_call, resource_read)" + }, + "minutes": { + "type": "integer", + "description": "How many minutes of history to include", + "default": 60 + }, + "include_percentiles": { + "type": "boolean", + "description": "Include response time percentiles", + "default": True + } + } + } +) +async def get_performance_metrics(params: dict[str, Any]) -> dict[str, Any]: + """Get performance metrics for operations. + + Returns metrics including: + - Operation counts and durations + - Error rates and success rates + - Response time percentiles + - Current performance snapshot + - Historical trends + """ + _ensure_initialized() + + operation_type = params.get("operation_type") + minutes = params.get("minutes", 60) + include_percentiles = params.get("include_percentiles", True) + + # Get operation metrics + metrics = performance_monitor.get_operation_metrics(operation_type) + + # Get current snapshot + current = performance_monitor.get_current_snapshot() + + # Get performance history + history = performance_monitor.get_performance_history(minutes) + + result = { + "current_snapshot": { + "timestamp": current.timestamp.isoformat(), + "operations_per_second": current.operations_per_second, + "avg_response_time_ms": current.avg_response_time_ms, + "error_rate": current.error_rate, + "active_operations": current.active_operations + }, + "operations": {} + } + + # Add operation-specific metrics + for op_type, op_metrics in metrics.items(): + op_data = { + "count": op_metrics.count, + "avg_duration_ms": op_metrics.avg_duration_ms, + "min_duration_ms": op_metrics.min_duration_ms, + "max_duration_ms": op_metrics.max_duration_ms, + "error_rate": op_metrics.error_rate, + "success_rate": op_metrics.success_rate + } + + if include_percentiles and op_metrics.count > 0: + percentiles = performance_monitor.get_response_percentiles( + op_type, + [0.5, 0.75, 0.90, 0.95, 0.99] + ) + op_data["percentiles"] = { + f"p{int(p*100)}": value + for p, value in percentiles.items() + } + + result["operations"][op_type] = op_data + + # Add historical trend + if history: + result["history"] = [ + { + "timestamp": snap.timestamp.isoformat(), + "ops_per_second": snap.operations_per_second, + "avg_response_ms": snap.avg_response_time_ms, + "error_rate": snap.error_rate + } + for snap in history[-20:] # Last 20 snapshots + ] + + return result + + +@tool_registry.register( + name="get_cost_report", + description="Get cost attribution report for MCP operations", + input_schema={ + "type": "object", + "properties": { + "start_time": { + "type": "string", + "format": "date-time", + "description": "Start time (ISO format). Defaults to 24 hours ago" + }, + "end_time": { + "type": "string", + "format": "date-time", + "description": "End time (ISO format). Defaults to now" + }, + "group_by": { + "type": "string", + "enum": ["agent", "operation", "provider"], + "description": "How to group costs", + "default": "agent" + }, + "include_projections": { + "type": "boolean", + "description": "Include monthly cost projections", + "default": True + } + } + } +) +async def get_cost_report(params: dict[str, Any]) -> dict[str, Any]: + """Get cost attribution report. + + Returns: + - Total costs broken down by type + - Costs grouped by agent/operation/provider + - Daily cost breakdown + - Optimization recommendations + - Monthly projections + """ + _ensure_initialized() + + # Parse parameters + end_time = datetime.now(timezone.utc) + start_time = end_time - timedelta(days=1) + + if "start_time" in params: + start_time = datetime.fromisoformat(params["start_time"].replace("Z", "+00:00")) + if "end_time" in params: + end_time = datetime.fromisoformat(params["end_time"].replace("Z", "+00:00")) + + group_by = params.get("group_by", "agent") + include_projections = params.get("include_projections", True) + + # Get cost report + report = await cost_calculator.get_cost_report(start_time, end_time, group_by) + + result = { + "period": { + "start": report.summary.period_start.isoformat(), + "end": report.summary.period_end.isoformat() + }, + "total_cost": round(report.summary.total_cost, 4), + "breakdown": { + "llm": round(report.summary.llm_cost, 4), + "storage": round(report.summary.storage_cost, 4), + "bandwidth": round(report.summary.bandwidth_cost, 4) + }, + "costs_by_" + group_by: { + k: round(v, 4) + for k, v in getattr(report.summary, f"costs_by_{group_by}").items() + } + } + + # Add daily breakdown + if report.daily_breakdown: + result["daily_breakdown"] = [ + { + "date": day.period_start.date().isoformat(), + "total": round(day.total_cost, 4), + "llm": round(day.llm_cost, 4), + "storage": round(day.storage_cost, 4), + "bandwidth": round(day.bandwidth_cost, 4) + } + for day in report.daily_breakdown[:7] # Last 7 days + ] + + # Add recommendations + if report.recommendations: + result["recommendations"] = report.recommendations + + # Add projections + if include_projections: + result["projections"] = { + "monthly_cost": round(report.projected_monthly_cost, 2), + "annual_cost": round(report.projected_monthly_cost * 12, 2) + } + + return result + + +@tool_registry.register( + name="get_monitoring_status", + description="Get overall monitoring system status", + input_schema={ + "type": "object", + "properties": {} + } +) +async def get_monitoring_status(params: dict[str, Any]) -> dict[str, Any]: + """Get overall monitoring system status. + + Returns: + - Monitoring system health + - Configuration status + - Buffer sizes and memory usage + - Collection statistics + """ + _ensure_initialized() + + # Get buffer sizes + usage_buffer_size = len(metrics_collector._usage_buffer) + performance_buffer_size = len(metrics_collector._performance_buffer) + cost_buffer_size = len(metrics_collector._cost_buffer) + + # Get collection stats + total_metrics = usage_buffer_size + performance_buffer_size + cost_buffer_size + + # Get active operations + active_operations = len(performance_monitor._active_operations) + + return { + "status": "healthy" if metrics_collector.config.enabled else "disabled", + "configuration": { + "enabled": metrics_collector.config.enabled, + "retention_days": metrics_collector.config.retention_days, + "flush_interval_seconds": metrics_collector.config.flush_interval_seconds, + "max_memory_metrics": metrics_collector.config.max_memory_metrics + }, + "buffers": { + "usage": usage_buffer_size, + "performance": performance_buffer_size, + "cost": cost_buffer_size, + "total": total_metrics + }, + "activity": { + "active_operations": active_operations, + "tracked_queries": len(usage_tracker._query_cache), + "tracked_documents": len(usage_tracker._document_cache), + "tracked_agents": len(usage_tracker._agent_activity) + } + } + + +@tool_registry.register( + name="export_metrics", + description="Export metrics to various formats", + input_schema={ + "type": "object", + "properties": { + "format": { + "type": "string", + "enum": ["prometheus", "json", "csv"], + "description": "Export format", + "default": "json" + }, + "metric_types": { + "type": "array", + "items": { + "type": "string", + "enum": ["usage", "performance", "cost", "all"] + }, + "description": "Which metrics to export", + "default": ["all"] + }, + "include_raw": { + "type": "boolean", + "description": "Include raw metric data", + "default": False + } + } + } +) +async def export_metrics(params: dict[str, Any]) -> dict[str, Any]: + """Export metrics to external monitoring systems. + + Supports formats: + - Prometheus text format + - JSON for custom processing + - CSV for analysis + """ + _ensure_initialized() + + format_type = params.get("format", "json") + metric_types = params.get("metric_types", ["all"]) + include_raw = params.get("include_raw", False) + + # Determine which metrics to include + include_all = "all" in metric_types + include_usage = include_all or "usage" in metric_types + include_performance = include_all or "performance" in metric_types + include_cost = include_all or "cost" in metric_types + + if format_type == "prometheus": + # Generate Prometheus text format + lines = [] + + if include_usage: + # Usage metrics + usage_stats = await usage_tracker.get_usage_stats( + datetime.now(timezone.utc) - timedelta(hours=1), + datetime.now(timezone.utc) + ) + + lines.extend([ + "# HELP contextframe_queries_total Total number of queries", + "# TYPE contextframe_queries_total counter", + f"contextframe_queries_total {usage_stats.total_queries}", + "", + "# HELP contextframe_document_accesses_total Total document accesses", + "# TYPE contextframe_document_accesses_total counter", + f"contextframe_document_accesses_total {usage_stats.total_document_accesses}", + "" + ]) + + if include_performance: + # Performance metrics + metrics = performance_monitor.get_operation_metrics() + + for op_type, op_metrics in metrics.items(): + safe_op_type = op_type.replace("/", "_").replace("-", "_") + lines.extend([ + f"# HELP contextframe_operation_{safe_op_type}_total Total {op_type} operations", + f"# TYPE contextframe_operation_{safe_op_type}_total counter", + f"contextframe_operation_{safe_op_type}_total {op_metrics.count}", + "", + f"# HELP contextframe_operation_{safe_op_type}_duration_ms {op_type} duration", + f"# TYPE contextframe_operation_{safe_op_type}_duration_ms histogram", + f"contextframe_operation_{safe_op_type}_duration_ms_sum {op_metrics.total_duration_ms}", + f"contextframe_operation_{safe_op_type}_duration_ms_count {op_metrics.count}", + "" + ]) + + return { + "format": "prometheus", + "content": "\n".join(lines), + "content_type": "text/plain" + } + + elif format_type == "json": + # Generate JSON format + result = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "metrics": {} + } + + if include_usage: + result["metrics"]["usage"] = await get_usage_metrics({}) + + if include_performance: + result["metrics"]["performance"] = await get_performance_metrics({}) + + if include_cost: + result["metrics"]["cost"] = await get_cost_report({}) + + return { + "format": "json", + "content": result, + "content_type": "application/json" + } + + else: + raise InvalidParams(f"Unsupported export format: {format_type}") \ No newline at end of file diff --git a/contextframe/mcp/monitoring/usage.py b/contextframe/mcp/monitoring/usage.py new file mode 100644 index 0000000..c8cf4c1 --- /dev/null +++ b/contextframe/mcp/monitoring/usage.py @@ -0,0 +1,321 @@ +"""Usage tracking for documents and queries.""" + +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List, Optional + +from .collector import MetricsCollector + + +@dataclass +class QueryStats: + """Statistics for a single query.""" + + query: str + query_type: str + count: int = 0 + total_results: int = 0 + avg_execution_time_ms: float = 0.0 + success_rate: float = 1.0 + + +@dataclass +class DocumentStats: + """Statistics for document usage.""" + + document_id: str + access_count: int = 0 + search_appearances: int = 0 + last_accessed: datetime | None = None + access_by_operation: dict[str, int] = field(default_factory=dict) + + +@dataclass +class UsageStats: + """Aggregated usage statistics.""" + + period_start: datetime + period_end: datetime + total_queries: int = 0 + total_document_accesses: int = 0 + unique_documents_accessed: int = 0 + unique_agents: int = 0 + queries_by_type: dict[str, int] = field(default_factory=dict) + top_documents: list[DocumentStats] = field(default_factory=list) + top_queries: list[QueryStats] = field(default_factory=list) + access_patterns: dict[str, Any] = field(default_factory=dict) + + +class UsageTracker: + """Track document access patterns and query statistics.""" + + def __init__(self, metrics_collector: MetricsCollector): + self.metrics = metrics_collector + + # Local caches for fast lookups + self._query_cache: dict[str, QueryStats] = {} + self._document_cache: dict[str, DocumentStats] = {} + self._agent_activity: dict[str, datetime] = {} + + async def track_document_access( + self, + document_id: str, + operation: str, + agent_id: str | None = None, + metadata: dict[str, Any] | None = None + ) -> None: + """Record document access event. + + Args: + document_id: ID of the accessed document + operation: Type of operation (read, search_hit, update, delete) + agent_id: Optional agent identifier + metadata: Additional context about the access + """ + # Update local cache + if document_id not in self._document_cache: + self._document_cache[document_id] = DocumentStats(document_id=document_id) + + doc_stats = self._document_cache[document_id] + doc_stats.access_count += 1 + doc_stats.last_accessed = datetime.now(timezone.utc) + + if operation not in doc_stats.access_by_operation: + doc_stats.access_by_operation[operation] = 0 + doc_stats.access_by_operation[operation] += 1 + + if operation == "search_hit": + doc_stats.search_appearances += 1 + + # Track agent activity + if agent_id: + self._agent_activity[agent_id] = datetime.now(timezone.utc) + + # Record metric + await self.metrics.record_usage( + metric_type="document_access", + resource_id=document_id, + operation=operation, + value=1.0, + agent_id=agent_id, + metadata=metadata + ) + + async def track_query( + self, + query: str, + query_type: str, + result_count: int, + execution_time_ms: float, + agent_id: str | None = None, + success: bool = True, + metadata: dict[str, Any] | None = None + ) -> None: + """Record query execution. + + Args: + query: The query string + query_type: Type of query (vector, text, hybrid, sql) + result_count: Number of results returned + execution_time_ms: Query execution time in milliseconds + agent_id: Optional agent identifier + success: Whether the query succeeded + metadata: Additional query context + """ + # Update query cache + query_key = f"{query_type}:{query[:100]}" # Truncate long queries + + if query_key not in self._query_cache: + self._query_cache[query_key] = QueryStats( + query=query[:100], + query_type=query_type + ) + + q_stats = self._query_cache[query_key] + q_stats.count += 1 + q_stats.total_results += result_count + + # Update average execution time + prev_total_time = q_stats.avg_execution_time_ms * (q_stats.count - 1) + q_stats.avg_execution_time_ms = (prev_total_time + execution_time_ms) / q_stats.count + + # Update success rate + if not success: + prev_successes = q_stats.success_rate * (q_stats.count - 1) + q_stats.success_rate = prev_successes / q_stats.count + + # Track agent activity + if agent_id: + self._agent_activity[agent_id] = datetime.now(timezone.utc) + + # Record metric + await self.metrics.record_usage( + metric_type="query", + resource_id=query_type, + operation="execute", + value=float(result_count), + agent_id=agent_id, + metadata={ + "query": query[:100], + "execution_time_ms": execution_time_ms, + "success": success, + **(metadata or {}) + } + ) + + async def get_usage_stats( + self, + start_time: datetime, + end_time: datetime, + group_by: str = "hour" + ) -> UsageStats: + """Get aggregated usage statistics. + + Args: + start_time: Start of the period + end_time: End of the period + group_by: Aggregation interval (hour, day, week) + + Returns: + Aggregated usage statistics + """ + stats = UsageStats( + period_start=start_time, + period_end=end_time + ) + + # Get metrics from collector + usage_metrics = await self.metrics.get_aggregated_metrics( + "usage", + interval="1h", + lookback_hours=int((end_time - start_time).total_seconds() / 3600) + ) + + # Aggregate from local caches + # Count unique documents accessed + accessed_docs = set() + for doc_id, doc_stats in self._document_cache.items(): + if doc_stats.last_accessed and start_time <= doc_stats.last_accessed <= end_time: + accessed_docs.add(doc_id) + stats.total_document_accesses += doc_stats.access_count + + stats.unique_documents_accessed = len(accessed_docs) + + # Get top documents + sorted_docs = sorted( + self._document_cache.values(), + key=lambda d: d.access_count, + reverse=True + ) + stats.top_documents = sorted_docs[:10] + + # Count queries by type + for query_key, q_stats in self._query_cache.items(): + stats.total_queries += q_stats.count + if q_stats.query_type not in stats.queries_by_type: + stats.queries_by_type[q_stats.query_type] = 0 + stats.queries_by_type[q_stats.query_type] += q_stats.count + + # Get top queries + sorted_queries = sorted( + self._query_cache.values(), + key=lambda q: q.count, + reverse=True + ) + stats.top_queries = sorted_queries[:10] + + # Count unique agents + active_agents = set() + for agent_id, last_active in self._agent_activity.items(): + if start_time <= last_active <= end_time: + active_agents.add(agent_id) + stats.unique_agents = len(active_agents) + + # Access patterns by time + if group_by == "hour": + stats.access_patterns = self._calculate_hourly_patterns(start_time, end_time) + elif group_by == "day": + stats.access_patterns = self._calculate_daily_patterns(start_time, end_time) + + return stats + + def _calculate_hourly_patterns( + self, + start_time: datetime, + end_time: datetime + ) -> dict[str, Any]: + """Calculate hourly access patterns.""" + patterns = {} + + # This would analyze the metrics to find patterns + # For now, return a simple structure + current = start_time.replace(minute=0, second=0, microsecond=0) + while current < end_time: + hour_key = current.strftime("%Y-%m-%d %H:00") + patterns[hour_key] = { + "queries": 0, + "document_accesses": 0 + } + current += timedelta(hours=1) + + return patterns + + def _calculate_daily_patterns( + self, + start_time: datetime, + end_time: datetime + ) -> dict[str, Any]: + """Calculate daily access patterns.""" + patterns = {} + + current = start_time.replace(hour=0, minute=0, second=0, microsecond=0) + while current < end_time: + day_key = current.strftime("%Y-%m-%d") + patterns[day_key] = { + "queries": 0, + "document_accesses": 0, + "peak_hour": None + } + current += timedelta(days=1) + + return patterns + + async def get_document_usage( + self, + document_id: str, + lookback_days: int = 30 + ) -> DocumentStats | None: + """Get usage statistics for a specific document. + + Args: + document_id: Document to get stats for + lookback_days: How far back to look + + Returns: + Document usage statistics or None if not found + """ + return self._document_cache.get(document_id) + + async def get_query_performance( + self, + query_type: str | None = None, + limit: int = 20 + ) -> list[QueryStats]: + """Get query performance statistics. + + Args: + query_type: Optional filter by query type + limit: Maximum number of results + + Returns: + List of query statistics + """ + queries = list(self._query_cache.values()) + + if query_type: + queries = [q for q in queries if q.query_type == query_type] + + # Sort by execution count + queries.sort(key=lambda q: q.count, reverse=True) + + return queries[:limit] \ No newline at end of file diff --git a/contextframe/mcp/security/__init__.py b/contextframe/mcp/security/__init__.py new file mode 100644 index 0000000..4d2a699 --- /dev/null +++ b/contextframe/mcp/security/__init__.py @@ -0,0 +1,56 @@ +"""Security components for MCP server. + +Provides authentication, authorization, rate limiting, and audit logging +for the MCP server to ensure secure access to ContextFrame datasets. +""" + +from .audit import AuditLogger, AuditEvent +from .auth import ( + AuthenticationError, + AuthProvider, + APIKeyAuth, + SecurityContext, +) +from .authorization import ( + AuthorizationError, + Permission, + Role, + AccessControl, +) +from .integration import SecurityMiddleware, SecuredMessageHandler +from .jwt import JWTHandler, JWTConfig +from .oauth import OAuth2Provider, OAuth2Config +from .rate_limiting import ( + RateLimiter, + RateLimitExceeded, + RateLimitConfig, +) + +__all__ = [ + # Authentication + "AuthenticationError", + "AuthProvider", + "APIKeyAuth", + "SecurityContext", + # Authorization + "AuthorizationError", + "Permission", + "Role", + "AccessControl", + # OAuth 2.1 + "OAuth2Provider", + "OAuth2Config", + # JWT + "JWTHandler", + "JWTConfig", + # Rate Limiting + "RateLimiter", + "RateLimitExceeded", + "RateLimitConfig", + # Audit + "AuditLogger", + "AuditEvent", + # Integration + "SecurityMiddleware", + "SecuredMessageHandler", +] \ No newline at end of file diff --git a/contextframe/mcp/security/audit.py b/contextframe/mcp/security/audit.py new file mode 100644 index 0000000..e1df5bf --- /dev/null +++ b/contextframe/mcp/security/audit.py @@ -0,0 +1,488 @@ +"""Audit logging for MCP server security events.""" + +import asyncio +import json +import logging +from dataclasses import dataclass, field, asdict +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional +from collections import deque + +from contextframe.frame import FrameDataset, FrameRecord + + +class AuditEventType(str, Enum): + """Types of audit events.""" + + # Authentication events + AUTH_SUCCESS = "auth.success" + AUTH_FAILURE = "auth.failure" + AUTH_TOKEN_CREATED = "auth.token_created" + AUTH_TOKEN_REVOKED = "auth.token_revoked" + + # Authorization events + AUTHZ_GRANTED = "authz.granted" + AUTHZ_DENIED = "authz.denied" + + # Rate limiting events + RATE_LIMIT_EXCEEDED = "rate_limit.exceeded" + RATE_LIMIT_RESET = "rate_limit.reset" + + # Resource access events + RESOURCE_READ = "resource.read" + RESOURCE_WRITE = "resource.write" + RESOURCE_DELETE = "resource.delete" + + # Tool execution events + TOOL_EXECUTED = "tool.executed" + TOOL_FAILED = "tool.failed" + + # Security configuration events + SECURITY_CONFIG_CHANGED = "security.config_changed" + ROLE_CREATED = "security.role_created" + ROLE_MODIFIED = "security.role_modified" + ROLE_DELETED = "security.role_deleted" + POLICY_CREATED = "security.policy_created" + POLICY_MODIFIED = "security.policy_modified" + POLICY_DELETED = "security.policy_deleted" + + # System events + SYSTEM_START = "system.start" + SYSTEM_STOP = "system.stop" + SYSTEM_ERROR = "system.error" + + +@dataclass +class AuditEvent: + """Audit event record.""" + + # Event metadata + event_id: str + timestamp: datetime + event_type: AuditEventType + + # Principal information + principal_id: Optional[str] = None + principal_type: Optional[str] = None + principal_name: Optional[str] = None + auth_method: Optional[str] = None + + # Request context + operation: Optional[str] = None + resource_type: Optional[str] = None + resource_id: Optional[str] = None + request_id: Optional[str] = None + session_id: Optional[str] = None + + # Network context + client_ip: Optional[str] = None + user_agent: Optional[str] = None + + # Event details + success: bool = True + error_code: Optional[int] = None + error_message: Optional[str] = None + details: Dict[str, Any] = field(default_factory=dict) + + # Computed fields + severity: str = field(init=False) + + def __post_init__(self): + # Compute severity based on event type and success + if not self.success: + if self.event_type in [ + AuditEventType.AUTH_FAILURE, + AuditEventType.AUTHZ_DENIED, + AuditEventType.RATE_LIMIT_EXCEEDED + ]: + self.severity = "warning" + else: + self.severity = "error" + else: + if self.event_type in [ + AuditEventType.SECURITY_CONFIG_CHANGED, + AuditEventType.ROLE_CREATED, + AuditEventType.ROLE_MODIFIED, + AuditEventType.ROLE_DELETED, + AuditEventType.POLICY_CREATED, + AuditEventType.POLICY_MODIFIED, + AuditEventType.POLICY_DELETED, + ]: + self.severity = "warning" + else: + self.severity = "info" + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + data = asdict(self) + # Convert datetime to ISO format + data["timestamp"] = self.timestamp.isoformat() + return data + + def to_json(self) -> str: + """Convert to JSON string.""" + return json.dumps(self.to_dict()) + + +@dataclass +class AuditConfig: + """Audit logging configuration.""" + + # Storage settings + storage_backend: str = "memory" # "memory", "file", "dataset" + file_path: Optional[str] = None + dataset_path: Optional[str] = None + + # Retention settings + max_events_memory: int = 10000 + retention_days: int = 90 + + # Filtering + enabled_event_types: Optional[List[AuditEventType]] = None + disabled_event_types: Optional[List[AuditEventType]] = None + + # Performance + buffer_size: int = 1000 + flush_interval: int = 60 # seconds + + # Security + include_request_details: bool = True + include_response_details: bool = False + redact_sensitive_data: bool = True + + def should_log_event(self, event_type: AuditEventType) -> bool: + """Check if event type should be logged.""" + if self.disabled_event_types and event_type in self.disabled_event_types: + return False + + if self.enabled_event_types: + return event_type in self.enabled_event_types + + return True + + +class AuditLogger: + """Audit logger for security events.""" + + def __init__(self, config: AuditConfig): + self.config = config + self._buffer: deque = deque(maxlen=config.buffer_size) + self._memory_store: deque = deque(maxlen=config.max_events_memory) + self._logger = logging.getLogger(__name__) + + # Storage backend + self._file_handle = None + self._dataset: Optional[FrameDataset] = None + + # Background tasks + self._flush_task = None + self._cleanup_task = None + + # Event counter + self._event_counter = 0 + + async def start(self): + """Start the audit logger.""" + # Initialize storage backend + if self.config.storage_backend == "file" and self.config.file_path: + Path(self.config.file_path).parent.mkdir(parents=True, exist_ok=True) + self._file_handle = open(self.config.file_path, "a") + elif self.config.storage_backend == "dataset" and self.config.dataset_path: + try: + self._dataset = FrameDataset.open(self.config.dataset_path) + except Exception: + # Create new dataset for audit logs + self._dataset = FrameDataset.create(self.config.dataset_path) + + # Start background tasks + self._flush_task = asyncio.create_task(self._flush_loop()) + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + + async def stop(self): + """Stop the audit logger.""" + # Stop background tasks + if self._flush_task: + self._flush_task.cancel() + try: + await self._flush_task + except asyncio.CancelledError: + pass + + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + + # Flush remaining events + await self._flush_buffer() + + # Close storage + if self._file_handle: + self._file_handle.close() + + async def _flush_loop(self): + """Periodically flush buffered events.""" + while True: + try: + await asyncio.sleep(self.config.flush_interval) + await self._flush_buffer() + except asyncio.CancelledError: + break + + async def _cleanup_loop(self): + """Periodically clean up old events.""" + while True: + try: + await asyncio.sleep(86400) # Daily cleanup + await self._cleanup_old_events() + except asyncio.CancelledError: + break + + async def _flush_buffer(self): + """Flush buffered events to storage.""" + if not self._buffer: + return + + events = [] + while self._buffer: + events.append(self._buffer.popleft()) + + # Store based on backend + if self.config.storage_backend == "memory": + self._memory_store.extend(events) + + elif self.config.storage_backend == "file" and self._file_handle: + for event in events: + self._file_handle.write(event.to_json() + "\n") + self._file_handle.flush() + + elif self.config.storage_backend == "dataset" and self._dataset: + # Convert events to FrameRecords + records = [] + for event in events: + record = FrameRecord( + uuid=event.event_id, + type="audit_event", + title=f"{event.event_type}: {event.operation or 'Unknown'}", + content=event.to_json(), + metadata={ + "event_type": event.event_type, + "principal_id": event.principal_id, + "success": event.success, + "severity": event.severity, + "timestamp": event.timestamp.isoformat(), + }, + created_at=event.timestamp, + updated_at=event.timestamp, + ) + records.append(record) + + # Batch insert + self._dataset.add_records(records) + + async def _cleanup_old_events(self): + """Clean up events older than retention period.""" + if self.config.retention_days <= 0: + return + + cutoff = datetime.now(timezone.utc) - timedelta(days=self.config.retention_days) + + if self.config.storage_backend == "memory": + # Filter memory store + self._memory_store = deque( + (e for e in self._memory_store if e.timestamp > cutoff), + maxlen=self.config.max_events_memory + ) + + elif self.config.storage_backend == "dataset" and self._dataset: + # Delete old records from dataset + self._dataset.delete_where(f"created_at < '{cutoff.isoformat()}'") + + def _generate_event_id(self) -> str: + """Generate unique event ID.""" + import uuid + return str(uuid.uuid4()) + + def _redact_sensitive_data(self, data: Dict[str, Any]) -> Dict[str, Any]: + """Redact sensitive data from event details.""" + if not self.config.redact_sensitive_data: + return data + + sensitive_keys = { + "password", "secret", "token", "api_key", "private_key", + "client_secret", "access_token", "refresh_token" + } + + redacted = {} + for key, value in data.items(): + if any(s in key.lower() for s in sensitive_keys): + redacted[key] = "[REDACTED]" + elif isinstance(value, dict): + redacted[key] = self._redact_sensitive_data(value) + else: + redacted[key] = value + + return redacted + + async def log_event( + self, + event_type: AuditEventType, + success: bool = True, + principal_id: Optional[str] = None, + principal_type: Optional[str] = None, + principal_name: Optional[str] = None, + auth_method: Optional[str] = None, + operation: Optional[str] = None, + resource_type: Optional[str] = None, + resource_id: Optional[str] = None, + request_id: Optional[str] = None, + session_id: Optional[str] = None, + client_ip: Optional[str] = None, + user_agent: Optional[str] = None, + error_code: Optional[int] = None, + error_message: Optional[str] = None, + details: Optional[Dict[str, Any]] = None, + ) -> None: + """Log an audit event. + + Args: + event_type: Type of event + success: Whether operation succeeded + principal_id: ID of principal performing action + principal_type: Type of principal (user, agent, service) + principal_name: Human-readable name + auth_method: Authentication method used + operation: Operation being performed + resource_type: Type of resource accessed + resource_id: ID of resource accessed + request_id: Request correlation ID + session_id: Session ID + client_ip: Client IP address + user_agent: User agent string + error_code: Error code if failed + error_message: Error message if failed + details: Additional event details + """ + # Check if event should be logged + if not self.config.should_log_event(event_type): + return + + # Create event + event = AuditEvent( + event_id=self._generate_event_id(), + timestamp=datetime.now(timezone.utc), + event_type=event_type, + success=success, + principal_id=principal_id, + principal_type=principal_type, + principal_name=principal_name, + auth_method=auth_method, + operation=operation, + resource_type=resource_type, + resource_id=resource_id, + request_id=request_id, + session_id=session_id, + client_ip=client_ip, + user_agent=user_agent, + error_code=error_code, + error_message=error_message, + details=self._redact_sensitive_data(details or {}), + ) + + # Add to buffer + self._buffer.append(event) + + # Log to standard logger as well + log_msg = ( + f"Audit: {event.event_type} - " + f"Principal: {principal_id or 'anonymous'} - " + f"Operation: {operation or 'unknown'} - " + f"Success: {success}" + ) + + if event.severity == "error": + self._logger.error(log_msg) + elif event.severity == "warning": + self._logger.warning(log_msg) + else: + self._logger.info(log_msg) + + # Increment counter + self._event_counter += 1 + + # Flush if buffer is full + if len(self._buffer) >= self.config.buffer_size: + asyncio.create_task(self._flush_buffer()) + + async def search_events( + self, + event_types: Optional[List[AuditEventType]] = None, + principal_id: Optional[str] = None, + resource_type: Optional[str] = None, + resource_id: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + success: Optional[bool] = None, + limit: int = 100 + ) -> List[AuditEvent]: + """Search audit events. + + Args: + event_types: Filter by event types + principal_id: Filter by principal + resource_type: Filter by resource type + resource_id: Filter by resource ID + start_time: Start time filter + end_time: End time filter + success: Filter by success/failure + limit: Maximum events to return + + Returns: + List of matching audit events + """ + # For memory backend, search in-memory store + if self.config.storage_backend == "memory": + results = [] + + for event in reversed(self._memory_store): + # Apply filters + if event_types and event.event_type not in event_types: + continue + if principal_id and event.principal_id != principal_id: + continue + if resource_type and event.resource_type != resource_type: + continue + if resource_id and event.resource_id != resource_id: + continue + if start_time and event.timestamp < start_time: + continue + if end_time and event.timestamp > end_time: + continue + if success is not None and event.success != success: + continue + + results.append(event) + if len(results) >= limit: + break + + return results + + # For other backends, would implement appropriate search + return [] + + def get_statistics(self) -> Dict[str, Any]: + """Get audit logging statistics.""" + stats = { + "total_events": self._event_counter, + "buffer_size": len(self._buffer), + "storage_backend": self.config.storage_backend, + } + + if self.config.storage_backend == "memory": + stats["memory_events"] = len(self._memory_store) + + return stats \ No newline at end of file diff --git a/contextframe/mcp/security/auth.py b/contextframe/mcp/security/auth.py new file mode 100644 index 0000000..a83fb46 --- /dev/null +++ b/contextframe/mcp/security/auth.py @@ -0,0 +1,189 @@ +"""Authentication components for MCP server.""" + +import hashlib +import secrets +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, Optional, Set + +from contextframe.mcp.errors import MCPError + + +class AuthenticationError(MCPError): + """Authentication failed error.""" + + def __init__(self, message: str = "Authentication failed"): + super().__init__(code=-32001, message=message) + + +@dataclass +class SecurityContext: + """Security context for authenticated requests.""" + + # Authentication info + authenticated: bool = False + auth_method: Optional[str] = None # "api_key", "oauth", "jwt" + + # Principal identity + principal_id: Optional[str] = None # User/agent ID + principal_type: Optional[str] = None # "user", "agent", "service" + principal_name: Optional[str] = None + + # Permissions and roles + permissions: Set[str] = None + roles: Set[str] = None + + # Session info + session_id: Optional[str] = None + expires_at: Optional[datetime] = None + + # Additional claims/attributes + attributes: Dict[str, Any] = None + + def __post_init__(self): + if self.permissions is None: + self.permissions = set() + if self.roles is None: + self.roles = set() + if self.attributes is None: + self.attributes = {} + + def has_permission(self, permission: str) -> bool: + """Check if context has a specific permission.""" + return permission in self.permissions + + def has_role(self, role: str) -> bool: + """Check if context has a specific role.""" + return role in self.roles + + def is_expired(self) -> bool: + """Check if the security context has expired.""" + if not self.expires_at: + return False + return datetime.now(timezone.utc) > self.expires_at + + +class AuthProvider(ABC): + """Base class for authentication providers.""" + + @abstractmethod + async def authenticate(self, credentials: Dict[str, Any]) -> SecurityContext: + """Authenticate using the provided credentials. + + Args: + credentials: Provider-specific credentials + + Returns: + SecurityContext if authentication successful + + Raises: + AuthenticationError: If authentication fails + """ + pass + + @abstractmethod + def get_auth_method(self) -> str: + """Get the authentication method name.""" + pass + + +class APIKeyAuth(AuthProvider): + """API key authentication provider.""" + + def __init__(self, api_keys: Dict[str, Dict[str, Any]] = None): + """Initialize API key auth provider. + + Args: + api_keys: Mapping of API key -> metadata + Metadata should include: + - principal_id: ID of the principal + - principal_name: Name of the principal + - permissions: Set of permissions + - roles: Set of roles + - expires_at: Optional expiration datetime + """ + self.api_keys = api_keys or {} + # Hash API keys for secure storage + self._hashed_keys = { + self._hash_key(key): metadata + for key, metadata in self.api_keys.items() + } + + def _hash_key(self, api_key: str) -> str: + """Hash an API key for secure comparison.""" + return hashlib.sha256(api_key.encode()).hexdigest() + + async def authenticate(self, credentials: Dict[str, Any]) -> SecurityContext: + """Authenticate using API key.""" + api_key = credentials.get("api_key") + if not api_key: + raise AuthenticationError("Missing API key") + + # Hash the provided key + hashed_key = self._hash_key(api_key) + + # Look up key metadata + metadata = self._hashed_keys.get(hashed_key) + if not metadata: + raise AuthenticationError("Invalid API key") + + # Check expiration + expires_at = metadata.get("expires_at") + if expires_at and datetime.now(timezone.utc) > expires_at: + raise AuthenticationError("API key expired") + + # Build security context + return SecurityContext( + authenticated=True, + auth_method="api_key", + principal_id=metadata.get("principal_id"), + principal_type=metadata.get("principal_type", "agent"), + principal_name=metadata.get("principal_name"), + permissions=set(metadata.get("permissions", [])), + roles=set(metadata.get("roles", [])), + expires_at=expires_at, + attributes=metadata.get("attributes", {}) + ) + + def get_auth_method(self) -> str: + """Get authentication method name.""" + return "api_key" + + @staticmethod + def generate_api_key() -> str: + """Generate a secure API key.""" + return secrets.token_urlsafe(32) + + +class MultiAuthProvider(AuthProvider): + """Combines multiple authentication providers.""" + + def __init__(self, providers: list[AuthProvider]): + """Initialize with multiple providers. + + Args: + providers: List of auth providers to try in order + """ + self.providers = providers + + async def authenticate(self, credentials: Dict[str, Any]) -> SecurityContext: + """Try each provider in order until one succeeds.""" + errors = [] + + for provider in self.providers: + try: + return await provider.authenticate(credentials) + except AuthenticationError as e: + errors.append(f"{provider.get_auth_method()}: {str(e)}") + continue + + # All providers failed + raise AuthenticationError( + f"All authentication methods failed: {'; '.join(errors)}" + ) + + def get_auth_method(self) -> str: + """Get authentication method name.""" + methods = [p.get_auth_method() for p in self.providers] + return f"multi[{','.join(methods)}]" \ No newline at end of file diff --git a/contextframe/mcp/security/authorization.py b/contextframe/mcp/security/authorization.py new file mode 100644 index 0000000..8a5f30e --- /dev/null +++ b/contextframe/mcp/security/authorization.py @@ -0,0 +1,381 @@ +"""Authorization and access control for MCP server.""" + +import fnmatch +import re +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Set + +from contextframe.mcp.errors import MCPError + +from .auth import SecurityContext + + +class AuthorizationError(MCPError): + """Authorization failed error.""" + + def __init__(self, message: str = "Authorization failed"): + super().__init__(code=-32002, message=message) + + +class Permission(str, Enum): + """Standard MCP permissions.""" + + # Document permissions + DOCUMENTS_READ = "documents.read" + DOCUMENTS_WRITE = "documents.write" + DOCUMENTS_DELETE = "documents.delete" + DOCUMENTS_ADMIN = "documents.admin" + + # Collection permissions + COLLECTIONS_READ = "collections.read" + COLLECTIONS_WRITE = "collections.write" + COLLECTIONS_DELETE = "collections.delete" + COLLECTIONS_ADMIN = "collections.admin" + + # Tool permissions + TOOLS_EXECUTE = "tools.execute" + TOOLS_ADMIN = "tools.admin" + + # System permissions + SYSTEM_READ = "system.read" + SYSTEM_ADMIN = "system.admin" + + # Monitoring permissions + MONITORING_READ = "monitoring.read" + MONITORING_EXPORT = "monitoring.export" + MONITORING_ADMIN = "monitoring.admin" + + # Special permissions + ALL = "*" # Superuser permission + + +@dataclass +class Role: + """Role definition with permissions.""" + + name: str + description: str + permissions: Set[str] = field(default_factory=set) + + def has_permission(self, permission: str) -> bool: + """Check if role has a specific permission.""" + # Check for superuser + if Permission.ALL in self.permissions: + return True + + # Check exact match + if permission in self.permissions: + return True + + # Check wildcard patterns + for perm in self.permissions: + if fnmatch.fnmatch(permission, perm): + return True + + return False + + +# Standard roles +STANDARD_ROLES = { + "viewer": Role( + name="viewer", + description="Read-only access to documents and collections", + permissions={ + Permission.DOCUMENTS_READ, + Permission.COLLECTIONS_READ, + } + ), + + "editor": Role( + name="editor", + description="Read and write access to documents and collections", + permissions={ + Permission.DOCUMENTS_READ, + Permission.DOCUMENTS_WRITE, + Permission.COLLECTIONS_READ, + Permission.COLLECTIONS_WRITE, + Permission.TOOLS_EXECUTE, + } + ), + + "admin": Role( + name="admin", + description="Full access to all resources", + permissions={ + "documents.*", + "collections.*", + "tools.*", + "system.*", + "monitoring.*", + } + ), + + "monitor": Role( + name="monitor", + description="Access to monitoring and metrics", + permissions={ + Permission.MONITORING_READ, + Permission.MONITORING_EXPORT, + Permission.SYSTEM_READ, + } + ), + + "service": Role( + name="service", + description="Service account with limited permissions", + permissions={ + Permission.DOCUMENTS_READ, + Permission.COLLECTIONS_READ, + Permission.TOOLS_EXECUTE, + } + ), +} + + +@dataclass +class ResourcePolicy: + """Policy for resource-level access control.""" + + resource_type: str # "document", "collection", "tool", etc. + resource_id: Optional[str] = None # Specific resource ID or pattern + permissions: Set[str] = field(default_factory=set) + conditions: Dict[str, Any] = field(default_factory=dict) + + def matches_resource(self, resource_type: str, resource_id: str) -> bool: + """Check if policy applies to a resource.""" + # Check resource type + if self.resource_type != resource_type and self.resource_type != "*": + return False + + # Check resource ID + if self.resource_id: + if "*" in self.resource_id or "?" in self.resource_id: + # Wildcard pattern + return fnmatch.fnmatch(resource_id, self.resource_id) + else: + # Exact match + return resource_id == self.resource_id + + return True + + def evaluate_conditions(self, context: Dict[str, Any]) -> bool: + """Evaluate policy conditions.""" + for key, expected in self.conditions.items(): + actual = context.get(key) + + # Handle different condition types + if isinstance(expected, dict): + # Complex condition (e.g., {"$in": ["value1", "value2"]}) + operator = list(expected.keys())[0] + value = expected[operator] + + if operator == "$in" and actual not in value: + return False + elif operator == "$eq" and actual != value: + return False + elif operator == "$ne" and actual == value: + return False + elif operator == "$regex": + if not re.match(value, str(actual)): + return False + else: + # Simple equality + if actual != expected: + return False + + return True + + +class AccessControl: + """Access control manager for authorization decisions.""" + + def __init__( + self, + roles: Optional[Dict[str, Role]] = None, + policies: Optional[List[ResourcePolicy]] = None, + default_allow: bool = False + ): + """Initialize access control. + + Args: + roles: Role definitions (defaults to STANDARD_ROLES) + policies: Resource-level policies + default_allow: Default authorization decision + """ + self.roles = roles or STANDARD_ROLES.copy() + self.policies = policies or [] + self.default_allow = default_allow + + def authorize( + self, + context: SecurityContext, + permission: str, + resource_type: Optional[str] = None, + resource_id: Optional[str] = None, + request_context: Optional[Dict[str, Any]] = None + ) -> bool: + """Make authorization decision. + + Args: + context: Security context from authentication + permission: Required permission + resource_type: Type of resource being accessed + resource_id: ID of specific resource + request_context: Additional context for conditions + + Returns: + True if authorized, False otherwise + """ + # Unauthenticated users have no permissions + if not context.authenticated: + return False + + # Check direct permissions + if context.has_permission(permission) or context.has_permission(Permission.ALL): + return True + + # Check role-based permissions + for role_name in context.roles: + role = self.roles.get(role_name) + if role and role.has_permission(permission): + return True + + # Check resource-level policies if resource specified + if resource_type and resource_id: + for policy in self.policies: + if not policy.matches_resource(resource_type, resource_id): + continue + + # Check if policy grants the permission + if permission not in policy.permissions and Permission.ALL not in policy.permissions: + continue + + # Evaluate conditions + eval_context = { + "principal_id": context.principal_id, + "principal_type": context.principal_type, + "auth_method": context.auth_method, + **(request_context or {}), + **(context.attributes or {}) + } + + if policy.evaluate_conditions(eval_context): + return True + + # Default decision + return self.default_allow + + def require_permission( + self, + context: SecurityContext, + permission: str, + resource_type: Optional[str] = None, + resource_id: Optional[str] = None, + request_context: Optional[Dict[str, Any]] = None + ) -> None: + """Require a permission, raising error if not authorized. + + Args: + context: Security context from authentication + permission: Required permission + resource_type: Type of resource being accessed + resource_id: ID of specific resource + request_context: Additional context for conditions + + Raises: + AuthorizationError: If not authorized + """ + if not self.authorize( + context, permission, resource_type, resource_id, request_context + ): + resource_info = "" + if resource_type and resource_id: + resource_info = f" for {resource_type}/{resource_id}" + + raise AuthorizationError( + f"Permission '{permission}' required{resource_info}" + ) + + def filter_permitted_resources( + self, + context: SecurityContext, + permission: str, + resources: List[Dict[str, Any]], + resource_type: str, + id_field: str = "id" + ) -> List[Dict[str, Any]]: + """Filter list of resources to only those permitted. + + Args: + context: Security context + permission: Required permission + resources: List of resources to filter + resource_type: Type of resources + id_field: Field containing resource ID + + Returns: + Filtered list of permitted resources + """ + permitted = [] + + for resource in resources: + resource_id = resource.get(id_field) + if resource_id and self.authorize( + context, permission, resource_type, resource_id + ): + permitted.append(resource) + + return permitted + + def add_role(self, role: Role) -> None: + """Add or update a role definition.""" + self.roles[role.name] = role + + def add_policy(self, policy: ResourcePolicy) -> None: + """Add a resource policy.""" + self.policies.append(policy) + + def get_effective_permissions( + self, + context: SecurityContext, + resource_type: Optional[str] = None, + resource_id: Optional[str] = None + ) -> Set[str]: + """Get all effective permissions for a context. + + Args: + context: Security context + resource_type: Optional resource type filter + resource_id: Optional resource ID filter + + Returns: + Set of effective permissions + """ + permissions = set() + + # Add direct permissions + permissions.update(context.permissions) + + # Add role permissions + for role_name in context.roles: + role = self.roles.get(role_name) + if role: + permissions.update(role.permissions) + + # Add policy permissions if resource specified + if resource_type and resource_id: + for policy in self.policies: + if policy.matches_resource(resource_type, resource_id): + # Check conditions + eval_context = { + "principal_id": context.principal_id, + "principal_type": context.principal_type, + "auth_method": context.auth_method, + **(context.attributes or {}) + } + + if policy.evaluate_conditions(eval_context): + permissions.update(policy.permissions) + + return permissions \ No newline at end of file diff --git a/contextframe/mcp/security/integration.py b/contextframe/mcp/security/integration.py new file mode 100644 index 0000000..c7b1682 --- /dev/null +++ b/contextframe/mcp/security/integration.py @@ -0,0 +1,418 @@ +"""Integration of security components with MCP server.""" + +import uuid +from typing import Any, Dict, Optional + +from contextframe.mcp.handlers import MessageHandler as BaseMessageHandler +from contextframe.mcp.errors import MCPError + +from .audit import AuditEventType, AuditLogger +from .auth import AuthenticationError, SecurityContext, AuthProvider +from .authorization import AuthorizationError, AccessControl, Permission +from .rate_limiting import RateLimitExceeded, RateLimiter + + +class SecurityMiddleware: + """Security middleware for MCP server. + + Provides authentication, authorization, rate limiting, and audit logging + for all MCP operations. + """ + + def __init__( + self, + auth_provider: Optional[AuthProvider] = None, + access_control: Optional[AccessControl] = None, + rate_limiter: Optional[RateLimiter] = None, + audit_logger: Optional[AuditLogger] = None, + anonymous_allowed: bool = False, + anonymous_permissions: Optional[set] = None + ): + """Initialize security middleware. + + Args: + auth_provider: Authentication provider + access_control: Access control manager + rate_limiter: Rate limiter + audit_logger: Audit logger + anonymous_allowed: Allow anonymous access + anonymous_permissions: Permissions for anonymous users + """ + self.auth_provider = auth_provider + self.access_control = access_control + self.rate_limiter = rate_limiter + self.audit_logger = audit_logger + self.anonymous_allowed = anonymous_allowed + self.anonymous_permissions = anonymous_permissions or set() + + async def start(self): + """Start security components.""" + if self.rate_limiter: + await self.rate_limiter.start() + if self.audit_logger: + await self.audit_logger.start() + + async def stop(self): + """Stop security components.""" + if self.rate_limiter: + await self.rate_limiter.stop() + if self.audit_logger: + await self.audit_logger.stop() + + def _get_request_metadata(self, message: Dict[str, Any]) -> Dict[str, Any]: + """Extract request metadata from message.""" + return { + "request_id": message.get("id"), + "method": message.get("method"), + "params": message.get("params", {}), + "client_ip": message.get("_client_ip"), + "user_agent": message.get("_user_agent"), + } + + async def authenticate(self, message: Dict[str, Any]) -> SecurityContext: + """Authenticate the request. + + Args: + message: JSON-RPC message + + Returns: + Security context + + Raises: + AuthenticationError: If authentication fails + """ + # Extract credentials from message + credentials = {} + + # Check for API key in params + params = message.get("params", {}) + if isinstance(params, dict): + if "api_key" in params: + credentials["api_key"] = params["api_key"] + elif "_api_key" in params: + credentials["api_key"] = params["_api_key"] + + # Check for bearer token in metadata + if "_authorization" in message: + auth_header = message["_authorization"] + if auth_header.startswith("Bearer "): + token = auth_header[7:] + credentials["token"] = token + + # Check for OAuth code + if "code" in params: + credentials["code"] = params["code"] + credentials["code_verifier"] = params.get("code_verifier") + credentials["redirect_uri"] = params.get("redirect_uri") + + # Try authentication + if credentials and self.auth_provider: + try: + context = await self.auth_provider.authenticate(credentials) + + # Log successful authentication + if self.audit_logger: + metadata = self._get_request_metadata(message) + await self.audit_logger.log_event( + event_type=AuditEventType.AUTH_SUCCESS, + success=True, + principal_id=context.principal_id, + principal_type=context.principal_type, + principal_name=context.principal_name, + auth_method=context.auth_method, + operation=metadata["method"], + request_id=metadata["request_id"], + client_ip=metadata["client_ip"], + user_agent=metadata["user_agent"], + ) + + return context + + except AuthenticationError as e: + # Log failed authentication + if self.audit_logger: + metadata = self._get_request_metadata(message) + await self.audit_logger.log_event( + event_type=AuditEventType.AUTH_FAILURE, + success=False, + operation=metadata["method"], + request_id=metadata["request_id"], + client_ip=metadata["client_ip"], + user_agent=metadata["user_agent"], + error_code=e.code, + error_message=str(e), + details={"credentials_type": list(credentials.keys())} + ) + raise + + # Check if anonymous access is allowed + if self.anonymous_allowed: + return SecurityContext( + authenticated=False, + auth_method="anonymous", + principal_id="anonymous", + principal_type="anonymous", + permissions=self.anonymous_permissions.copy(), + ) + + # No credentials and anonymous not allowed + raise AuthenticationError("Authentication required") + + async def authorize( + self, + context: SecurityContext, + operation: str, + params: Dict[str, Any] + ) -> None: + """Authorize the operation. + + Args: + context: Security context from authentication + operation: Operation being performed + params: Operation parameters + + Raises: + AuthorizationError: If not authorized + """ + if not self.access_control: + # No access control configured, allow all + return + + # Map operations to permissions + permission_map = { + # Document operations + "get_document": Permission.DOCUMENTS_READ, + "search_documents": Permission.DOCUMENTS_READ, + "add_document": Permission.DOCUMENTS_WRITE, + "update_document": Permission.DOCUMENTS_WRITE, + "delete_document": Permission.DOCUMENTS_DELETE, + + # Collection operations + "get_collection": Permission.COLLECTIONS_READ, + "list_collections": Permission.COLLECTIONS_READ, + "create_collection": Permission.COLLECTIONS_WRITE, + "update_collection": Permission.COLLECTIONS_WRITE, + "delete_collection": Permission.COLLECTIONS_DELETE, + + # Tool operations + "tools/list": Permission.TOOLS_EXECUTE, + "tools/call": Permission.TOOLS_EXECUTE, + + # System operations + "resources/list": Permission.SYSTEM_READ, + "resources/read": Permission.SYSTEM_READ, + + # Monitoring operations + "get_usage_metrics": Permission.MONITORING_READ, + "get_performance_metrics": Permission.MONITORING_READ, + "get_cost_report": Permission.MONITORING_READ, + "export_metrics": Permission.MONITORING_EXPORT, + } + + # Get required permission + permission = permission_map.get(operation) + if not permission: + # Unknown operation, check for wildcards + if operation.startswith("monitoring/"): + permission = Permission.MONITORING_READ + elif operation.startswith("tools/"): + permission = Permission.TOOLS_EXECUTE + else: + # Default to system read + permission = Permission.SYSTEM_READ + + # Extract resource info from params + resource_type = None + resource_id = None + + if "document_id" in params: + resource_type = "document" + resource_id = params["document_id"] + elif "collection_id" in params: + resource_type = "collection" + resource_id = params["collection_id"] + elif "name" in params and operation == "tools/call": + resource_type = "tool" + resource_id = params["name"] + elif "uri" in params and operation == "resources/read": + resource_type = "resource" + resource_id = params["uri"] + + try: + # Check authorization + self.access_control.require_permission( + context, + permission, + resource_type, + resource_id, + params + ) + + # Log successful authorization + if self.audit_logger: + await self.audit_logger.log_event( + event_type=AuditEventType.AUTHZ_GRANTED, + success=True, + principal_id=context.principal_id, + principal_type=context.principal_type, + auth_method=context.auth_method, + operation=operation, + resource_type=resource_type, + resource_id=resource_id, + details={"permission": permission} + ) + + except AuthorizationError as e: + # Log authorization denial + if self.audit_logger: + await self.audit_logger.log_event( + event_type=AuditEventType.AUTHZ_DENIED, + success=False, + principal_id=context.principal_id, + principal_type=context.principal_type, + auth_method=context.auth_method, + operation=operation, + resource_type=resource_type, + resource_id=resource_id, + error_code=e.code, + error_message=str(e), + details={"permission": permission} + ) + raise + + async def check_rate_limit( + self, + context: SecurityContext, + operation: str + ) -> None: + """Check rate limits. + + Args: + context: Security context + operation: Operation being performed + + Raises: + RateLimitExceeded: If rate limit exceeded + """ + if not self.rate_limiter: + return + + try: + await self.rate_limiter.check_rate_limit( + client_id=context.principal_id, + operation=operation + ) + except RateLimitExceeded as e: + # Log rate limit exceeded + if self.audit_logger: + await self.audit_logger.log_event( + event_type=AuditEventType.RATE_LIMIT_EXCEEDED, + success=False, + principal_id=context.principal_id, + principal_type=context.principal_type, + operation=operation, + error_code=e.code, + error_message=str(e), + details={"retry_after": e.retry_after} + ) + raise + + +class SecuredMessageHandler(BaseMessageHandler): + """Message handler with integrated security.""" + + def __init__( + self, + server: Any, + security: SecurityMiddleware + ): + super().__init__(server) + self.security = security + + # Override method handlers to add security + self._secured_handlers = {} + for method, handler in self._method_handlers.items(): + self._secured_handlers[method] = self._wrap_handler(handler, method) + self._method_handlers = self._secured_handlers + + def _wrap_handler(self, handler, method: str): + """Wrap handler with security checks.""" + async def secured_handler(params: Dict[str, Any]) -> Any: + # Get current message context + message = getattr(self, "_current_message", {}) + + # Authenticate + context = await self.security.authenticate(message) + + # Store context for use in handler + self._current_context = context + + # Check rate limit + await self.security.check_rate_limit(context, method) + + # Authorize + await self.security.authorize(context, method, params) + + # Call original handler + try: + result = await handler(params) + + # Log successful operation + if self.security.audit_logger: + resource_type = None + resource_id = None + + if method.startswith("tools/"): + event_type = AuditEventType.TOOL_EXECUTED + resource_type = "tool" + resource_id = params.get("name") + elif method.startswith("resources/"): + event_type = AuditEventType.RESOURCE_READ + resource_type = "resource" + resource_id = params.get("uri") + else: + event_type = AuditEventType.RESOURCE_READ + + await self.security.audit_logger.log_event( + event_type=event_type, + success=True, + principal_id=context.principal_id, + principal_type=context.principal_type, + auth_method=context.auth_method, + operation=method, + resource_type=resource_type, + resource_id=resource_id, + request_id=message.get("id"), + session_id=context.session_id, + ) + + return result + + except Exception as e: + # Log operation failure + if self.security.audit_logger: + await self.security.audit_logger.log_event( + event_type=AuditEventType.TOOL_FAILED if method.startswith("tools/") else AuditEventType.SYSTEM_ERROR, + success=False, + principal_id=context.principal_id, + principal_type=context.principal_type, + auth_method=context.auth_method, + operation=method, + request_id=message.get("id"), + error_message=str(e), + ) + raise + + return secured_handler + + async def handle(self, message: dict[str, Any]) -> dict[str, Any]: + """Handle message with security context.""" + # Store message for security checks + self._current_message = message + + try: + return await super().handle(message) + finally: + # Clean up + self._current_message = None + self._current_context = None \ No newline at end of file diff --git a/contextframe/mcp/security/jwt.py b/contextframe/mcp/security/jwt.py new file mode 100644 index 0000000..1191aac --- /dev/null +++ b/contextframe/mcp/security/jwt.py @@ -0,0 +1,280 @@ +"""JWT authentication and token handling.""" + +import json +import time +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, Optional, Set + +import jwt +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.backends import default_backend + +from .auth import AuthProvider, AuthenticationError, SecurityContext + + +@dataclass +class JWTConfig: + """JWT configuration.""" + + # Signing configuration + algorithm: str = "RS256" # RS256, HS256, etc. + secret_key: Optional[str] = None # For HMAC algorithms + private_key: Optional[str] = None # For RSA/ECDSA + public_key: Optional[str] = None # For RSA/ECDSA + + # Token settings + issuer: str = "contextframe-mcp" + audience: Optional[str] = None + token_lifetime: int = 3600 # seconds + refresh_token_lifetime: int = 86400 * 7 # 7 days + + # Validation settings + verify_signature: bool = True + verify_exp: bool = True + verify_nbf: bool = True + verify_iat: bool = True + verify_aud: bool = True + verify_iss: bool = True + require_exp: bool = True + require_nbf: bool = False + require_iat: bool = True + + # Claims mapping + principal_id_claim: str = "sub" + principal_name_claim: str = "name" + principal_type_claim: str = "type" + permissions_claim: str = "permissions" + roles_claim: str = "roles" + + +class JWTHandler(AuthProvider): + """JWT token handler for authentication.""" + + def __init__(self, config: JWTConfig): + self.config = config + + # Initialize keys if not provided + if self.config.algorithm.startswith("RS") and not self.config.private_key: + # Generate RSA key pair for testing + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + backend=default_backend() + ) + self.config.private_key = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() + ).decode() + self.config.public_key = private_key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ).decode() + + async def authenticate(self, credentials: Dict[str, Any]) -> SecurityContext: + """Authenticate using JWT token.""" + token = credentials.get("token") or credentials.get("jwt") + if not token: + raise AuthenticationError("Missing JWT token") + + try: + # Decode and verify token + payload = self._verify_token(token) + + # Build security context from claims + return self._build_security_context(payload) + + except jwt.ExpiredSignatureError: + raise AuthenticationError("JWT token expired") + except jwt.InvalidTokenError as e: + raise AuthenticationError(f"Invalid JWT token: {str(e)}") + + def _verify_token(self, token: str) -> Dict[str, Any]: + """Verify and decode JWT token.""" + # Determine verification key + if self.config.algorithm.startswith("HS"): + # HMAC algorithms use secret key + key = self.config.secret_key + else: + # RSA/ECDSA algorithms use public key + key = self.config.public_key + + if not key: + raise AuthenticationError("Missing verification key") + + # Build verification options + options = { + "verify_signature": self.config.verify_signature, + "verify_exp": self.config.verify_exp, + "verify_nbf": self.config.verify_nbf, + "verify_iat": self.config.verify_iat, + "verify_aud": self.config.verify_aud if self.config.audience else False, + "verify_iss": self.config.verify_iss, + "require_exp": self.config.require_exp, + "require_nbf": self.config.require_nbf, + "require_iat": self.config.require_iat, + } + + # Decode token + payload = jwt.decode( + token, + key, + algorithms=[self.config.algorithm], + options=options, + audience=self.config.audience, + issuer=self.config.issuer + ) + + return payload + + def _build_security_context(self, payload: Dict[str, Any]) -> SecurityContext: + """Build security context from JWT claims.""" + # Extract principal information + principal_id = payload.get(self.config.principal_id_claim) + principal_name = payload.get(self.config.principal_name_claim) + principal_type = payload.get(self.config.principal_type_claim, "user") + + # Extract permissions and roles + permissions = set(payload.get(self.config.permissions_claim, [])) + roles = set(payload.get(self.config.roles_claim, [])) + + # Calculate expiration + expires_at = None + if "exp" in payload: + expires_at = datetime.fromtimestamp(payload["exp"], tz=timezone.utc) + + # Extract additional claims + standard_claims = { + "sub", "name", "type", "permissions", "roles", + "exp", "nbf", "iat", "iss", "aud", "jti" + } + attributes = { + k: v for k, v in payload.items() + if k not in standard_claims + } + + return SecurityContext( + authenticated=True, + auth_method="jwt", + principal_id=principal_id, + principal_type=principal_type, + principal_name=principal_name, + permissions=permissions, + roles=roles, + session_id=payload.get("jti"), # JWT ID as session ID + expires_at=expires_at, + attributes=attributes + ) + + def get_auth_method(self) -> str: + """Get authentication method name.""" + return "jwt" + + def create_token( + self, + principal_id: str, + principal_name: Optional[str] = None, + principal_type: str = "user", + permissions: Optional[Set[str]] = None, + roles: Optional[Set[str]] = None, + additional_claims: Optional[Dict[str, Any]] = None, + lifetime: Optional[int] = None + ) -> str: + """Create a new JWT token. + + Args: + principal_id: Subject/principal ID + principal_name: Human-readable name + principal_type: Type of principal (user, agent, service) + permissions: Set of permissions + roles: Set of roles + additional_claims: Extra claims to include + lifetime: Token lifetime in seconds + + Returns: + Signed JWT token + """ + now = datetime.now(timezone.utc) + lifetime = lifetime or self.config.token_lifetime + + # Build payload + payload = { + # Standard claims + "iss": self.config.issuer, + "sub": principal_id, + "iat": int(now.timestamp()), + "exp": int((now + timedelta(seconds=lifetime)).timestamp()), + + # Custom claims + self.config.principal_name_claim: principal_name, + self.config.principal_type_claim: principal_type, + self.config.permissions_claim: list(permissions or []), + self.config.roles_claim: list(roles or []), + } + + # Add audience if configured + if self.config.audience: + payload["aud"] = self.config.audience + + # Add additional claims + if additional_claims: + payload.update(additional_claims) + + # Determine signing key + if self.config.algorithm.startswith("HS"): + key = self.config.secret_key + else: + key = self.config.private_key + + if not key: + raise ValueError("Missing signing key") + + # Create token + token = jwt.encode( + payload, + key, + algorithm=self.config.algorithm + ) + + return token + + def create_refresh_token( + self, + principal_id: str, + token_id: str, + lifetime: Optional[int] = None + ) -> str: + """Create a refresh token. + + Args: + principal_id: Subject/principal ID + token_id: ID of the access token this refreshes + lifetime: Token lifetime in seconds + + Returns: + Signed refresh token + """ + lifetime = lifetime or self.config.refresh_token_lifetime + + payload = { + "iss": self.config.issuer, + "sub": principal_id, + "iat": int(time.time()), + "exp": int(time.time() + lifetime), + "token_type": "refresh", + "token_id": token_id, + } + + # Determine signing key + if self.config.algorithm.startswith("HS"): + key = self.config.secret_key + else: + key = self.config.private_key + + return jwt.encode(payload, key, algorithm=self.config.algorithm) + + def decode_token_unsafe(self, token: str) -> Dict[str, Any]: + """Decode token without verification (for debugging).""" + return jwt.decode(token, options={"verify_signature": False}) \ No newline at end of file diff --git a/contextframe/mcp/security/oauth.py b/contextframe/mcp/security/oauth.py new file mode 100644 index 0000000..4b11d3e --- /dev/null +++ b/contextframe/mcp/security/oauth.py @@ -0,0 +1,399 @@ +"""OAuth 2.1 authentication provider.""" + +import base64 +import json +import secrets +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, Optional, Set +from urllib.parse import urlencode, urlparse, parse_qs + +import httpx + +from .auth import AuthProvider, AuthenticationError, SecurityContext + + +@dataclass +class OAuth2Config: + """OAuth 2.1 configuration.""" + + # OAuth endpoints + authorization_endpoint: str + token_endpoint: str + client_id: str # Required field moved before optional fields + + # Optional endpoints + userinfo_endpoint: Optional[str] = None + jwks_uri: Optional[str] = None + + # Client credentials + client_secret: Optional[str] = None + redirect_uri: str = "urn:ietf:wg:oauth:2.0:oob" + + # Scopes and claims + scopes: list[str] = None + required_claims: list[str] = None + + # Token settings + access_token_lifetime: int = 3600 # seconds + refresh_token_lifetime: int = 86400 * 30 # 30 days + + # Security settings + require_pkce: bool = True + require_state: bool = True + allowed_redirect_uris: list[str] = None + + def __post_init__(self): + if self.scopes is None: + self.scopes = ["openid", "profile", "email"] + if self.required_claims is None: + self.required_claims = ["sub"] + if self.allowed_redirect_uris is None: + self.allowed_redirect_uris = [self.redirect_uri] + + +@dataclass +class OAuth2Token: + """OAuth 2.1 token response.""" + + access_token: str + token_type: str = "Bearer" + expires_in: Optional[int] = None + refresh_token: Optional[str] = None + scope: Optional[str] = None + id_token: Optional[str] = None + + # Computed fields + issued_at: datetime = None + expires_at: Optional[datetime] = None + + def __post_init__(self): + if self.issued_at is None: + self.issued_at = datetime.now(timezone.utc) + if self.expires_in and not self.expires_at: + self.expires_at = self.issued_at + timedelta(seconds=self.expires_in) + + def is_expired(self) -> bool: + """Check if token is expired.""" + if not self.expires_at: + return False + return datetime.now(timezone.utc) > self.expires_at + + +class OAuth2Provider(AuthProvider): + """OAuth 2.1 authentication provider. + + Implements OAuth 2.1 with: + - Authorization Code flow with PKCE + - Client Credentials flow + - Token introspection + - JWT validation (if JWKS provided) + """ + + def __init__(self, config: OAuth2Config): + self.config = config + self._http_client = httpx.AsyncClient() + + # Cache for authorization codes + self._auth_codes: Dict[str, Dict[str, Any]] = {} + + # Cache for access tokens (for introspection) + self._access_tokens: Dict[str, OAuth2Token] = {} + + async def authenticate(self, credentials: Dict[str, Any]) -> SecurityContext: + """Authenticate using OAuth 2.1. + + Supports: + - Authorization code exchange + - Client credentials + - Access token validation + """ + # Check for authorization code + if "code" in credentials: + return await self._handle_authorization_code(credentials) + + # Check for access token + elif "access_token" in credentials: + return await self._validate_access_token(credentials["access_token"]) + + # Check for client credentials + elif "client_id" in credentials and "client_secret" in credentials: + return await self._handle_client_credentials(credentials) + + else: + raise AuthenticationError("Missing OAuth credentials") + + async def _handle_authorization_code(self, credentials: Dict[str, Any]) -> SecurityContext: + """Exchange authorization code for tokens.""" + code = credentials.get("code") + code_verifier = credentials.get("code_verifier") + redirect_uri = credentials.get("redirect_uri", self.config.redirect_uri) + + if not code: + raise AuthenticationError("Missing authorization code") + + # Validate redirect URI + if redirect_uri not in self.config.allowed_redirect_uris: + raise AuthenticationError("Invalid redirect URI") + + # Build token request + token_data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "client_id": self.config.client_id, + } + + # Add PKCE verifier if required + if self.config.require_pkce: + if not code_verifier: + raise AuthenticationError("PKCE code verifier required") + token_data["code_verifier"] = code_verifier + + # Add client secret if available + if self.config.client_secret: + token_data["client_secret"] = self.config.client_secret + + try: + # Exchange code for token + response = await self._http_client.post( + self.config.token_endpoint, + data=token_data, + headers={"Accept": "application/json"} + ) + response.raise_for_status() + + token_response = response.json() + token = OAuth2Token(**token_response) + + # Store token for later validation + self._access_tokens[token.access_token] = token + + # Get user info if available + userinfo = None + if self.config.userinfo_endpoint and token.access_token: + userinfo = await self._get_userinfo(token.access_token) + + # Build security context + return self._build_security_context(token, userinfo) + + except httpx.HTTPError as e: + raise AuthenticationError(f"Token exchange failed: {str(e)}") + + async def _handle_client_credentials(self, credentials: Dict[str, Any]) -> SecurityContext: + """Authenticate using client credentials.""" + client_id = credentials.get("client_id") + client_secret = credentials.get("client_secret") + scope = credentials.get("scope", " ".join(self.config.scopes)) + + if not client_id or not client_secret: + raise AuthenticationError("Missing client credentials") + + # Verify client credentials match config + if client_id != self.config.client_id: + raise AuthenticationError("Invalid client ID") + if client_secret != self.config.client_secret: + raise AuthenticationError("Invalid client secret") + + try: + # Request token + response = await self._http_client.post( + self.config.token_endpoint, + data={ + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": client_secret, + "scope": scope + }, + headers={"Accept": "application/json"} + ) + response.raise_for_status() + + token_response = response.json() + token = OAuth2Token(**token_response) + + # Store token + self._access_tokens[token.access_token] = token + + # Build security context for service account + return SecurityContext( + authenticated=True, + auth_method="oauth2_client", + principal_id=client_id, + principal_type="service", + principal_name=f"Service: {client_id}", + permissions=self._parse_scopes_to_permissions(scope), + roles={"service"}, + expires_at=token.expires_at, + attributes={ + "grant_type": "client_credentials", + "scope": scope + } + ) + + except httpx.HTTPError as e: + raise AuthenticationError(f"Client credentials flow failed: {str(e)}") + + async def _validate_access_token(self, access_token: str) -> SecurityContext: + """Validate an access token.""" + # Check local cache first + if access_token in self._access_tokens: + token = self._access_tokens[access_token] + if not token.is_expired(): + # Get fresh user info + userinfo = None + if self.config.userinfo_endpoint: + userinfo = await self._get_userinfo(access_token) + return self._build_security_context(token, userinfo) + + # Token not in cache or expired, validate with server + # This would typically use token introspection endpoint + raise AuthenticationError("Token validation not implemented") + + async def _get_userinfo(self, access_token: str) -> Dict[str, Any]: + """Get user info from userinfo endpoint.""" + try: + response = await self._http_client.get( + self.config.userinfo_endpoint, + headers={ + "Authorization": f"Bearer {access_token}", + "Accept": "application/json" + } + ) + response.raise_for_status() + return response.json() + except httpx.HTTPError: + # Userinfo fetch failed, not critical + return {} + + def _build_security_context( + self, + token: OAuth2Token, + userinfo: Optional[Dict[str, Any]] = None + ) -> SecurityContext: + """Build security context from token and user info.""" + # Extract principal info + principal_id = None + principal_name = None + email = None + + if userinfo: + principal_id = userinfo.get("sub") + principal_name = userinfo.get("name") or userinfo.get("preferred_username") + email = userinfo.get("email") + + # Parse scopes into permissions + permissions = set() + if token.scope: + permissions = self._parse_scopes_to_permissions(token.scope) + + return SecurityContext( + authenticated=True, + auth_method="oauth2", + principal_id=principal_id, + principal_type="user", + principal_name=principal_name, + permissions=permissions, + roles={"user"}, # Default role + expires_at=token.expires_at, + attributes={ + "email": email, + "scope": token.scope, + "token_type": token.token_type + } + ) + + def _parse_scopes_to_permissions(self, scope: str) -> Set[str]: + """Convert OAuth scopes to permissions.""" + permissions = set() + scopes = scope.split() if scope else [] + + # Map common scopes to permissions + scope_mapping = { + "read": ["documents.read", "collections.read"], + "write": ["documents.write", "collections.write"], + "admin": ["documents.*", "collections.*", "system.*"], + } + + for s in scopes: + if s in scope_mapping: + permissions.update(scope_mapping[s]) + else: + # Use scope as-is for custom scopes + permissions.add(s) + + return permissions + + def get_auth_method(self) -> str: + """Get authentication method name.""" + return "oauth2" + + def generate_authorization_url( + self, + state: Optional[str] = None, + code_challenge: Optional[str] = None, + scope: Optional[str] = None, + **kwargs + ) -> str: + """Generate OAuth authorization URL. + + Args: + state: CSRF protection state + code_challenge: PKCE challenge + scope: OAuth scopes + **kwargs: Additional OAuth parameters + + Returns: + Authorization URL + """ + params = { + "response_type": "code", + "client_id": self.config.client_id, + "redirect_uri": self.config.redirect_uri, + } + + if state: + params["state"] = state + elif self.config.require_state: + params["state"] = secrets.token_urlsafe(32) + + if code_challenge: + params["code_challenge"] = code_challenge + params["code_challenge_method"] = "S256" + elif self.config.require_pkce: + raise ValueError("PKCE code challenge required") + + if scope: + params["scope"] = scope + else: + params["scope"] = " ".join(self.config.scopes) + + # Add any additional parameters + params.update(kwargs) + + return f"{self.config.authorization_endpoint}?{urlencode(params)}" + + @staticmethod + def generate_pkce_pair() -> tuple[str, str]: + """Generate PKCE code verifier and challenge. + + Returns: + Tuple of (code_verifier, code_challenge) + """ + # Generate code verifier + code_verifier = base64.urlsafe_b64encode( + secrets.token_bytes(32) + ).decode("utf-8").rstrip("=") + + # Generate code challenge (S256) + import hashlib + challenge_bytes = hashlib.sha256(code_verifier.encode()).digest() + code_challenge = base64.urlsafe_b64encode( + challenge_bytes + ).decode("utf-8").rstrip("=") + + return code_verifier, code_challenge + + async def close(self): + """Clean up resources.""" + await self._http_client.aclose() \ No newline at end of file diff --git a/contextframe/mcp/security/rate_limiting.py b/contextframe/mcp/security/rate_limiting.py new file mode 100644 index 0000000..bb55817 --- /dev/null +++ b/contextframe/mcp/security/rate_limiting.py @@ -0,0 +1,366 @@ +"""Rate limiting for MCP server.""" + +import asyncio +import time +from collections import defaultdict, deque +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Any, Dict, Optional, Tuple + +from contextframe.mcp.errors import MCPError + + +class RateLimitExceeded(MCPError): + """Rate limit exceeded error.""" + + def __init__( + self, + message: str = "Rate limit exceeded", + retry_after: Optional[int] = None + ): + super().__init__(code=-32003, message=message) + self.retry_after = retry_after + + +@dataclass +class RateLimitConfig: + """Rate limiting configuration.""" + + # Global limits + global_requests_per_minute: int = 600 + global_burst_size: int = 100 + + # Per-client limits + client_requests_per_minute: int = 60 + client_burst_size: int = 10 + + # Per-operation limits + operation_limits: Dict[str, Tuple[int, int]] = field(default_factory=dict) + # Format: {"operation": (requests_per_minute, burst_size)} + + # Advanced settings + cleanup_interval: int = 60 # seconds + use_sliding_window: bool = True + track_by: str = "principal_id" # or "ip_address" + + def __post_init__(self): + # Set default operation limits + if not self.operation_limits: + self.operation_limits = { + # Expensive operations have lower limits + "tools/call": (30, 5), + "batch/*": (10, 2), + "export/*": (5, 1), + + # Read operations have higher limits + "resources/read": (120, 20), + "resources/list": (120, 20), + + # Monitoring operations + "monitoring/*": (60, 10), + } + + +class TokenBucket: + """Token bucket rate limiter implementation.""" + + def __init__(self, capacity: int, refill_rate: float): + """Initialize token bucket. + + Args: + capacity: Maximum number of tokens + refill_rate: Tokens added per second + """ + self.capacity = capacity + self.refill_rate = refill_rate + self.tokens = float(capacity) + self.last_refill = time.monotonic() + self._lock = asyncio.Lock() + + async def consume(self, tokens: int = 1) -> Tuple[bool, Optional[float]]: + """Try to consume tokens. + + Args: + tokens: Number of tokens to consume + + Returns: + Tuple of (success, wait_time_seconds) + """ + async with self._lock: + # Refill tokens based on elapsed time + now = time.monotonic() + elapsed = now - self.last_refill + self.tokens = min( + self.capacity, + self.tokens + (elapsed * self.refill_rate) + ) + self.last_refill = now + + # Check if we have enough tokens + if self.tokens >= tokens: + self.tokens -= tokens + return True, None + else: + # Calculate wait time + deficit = tokens - self.tokens + wait_time = deficit / self.refill_rate + return False, wait_time + + async def reset(self): + """Reset bucket to full capacity.""" + async with self._lock: + self.tokens = float(self.capacity) + self.last_refill = time.monotonic() + + +class SlidingWindowCounter: + """Sliding window counter for rate limiting.""" + + def __init__(self, window_size: int, max_requests: int): + """Initialize sliding window counter. + + Args: + window_size: Window size in seconds + max_requests: Maximum requests in window + """ + self.window_size = window_size + self.max_requests = max_requests + self.requests: deque = deque() + self._lock = asyncio.Lock() + + async def add_request(self) -> Tuple[bool, Optional[float]]: + """Add a request and check if within limit. + + Returns: + Tuple of (allowed, wait_time_seconds) + """ + async with self._lock: + now = time.time() + + # Remove old requests outside window + cutoff = now - self.window_size + while self.requests and self.requests[0] < cutoff: + self.requests.popleft() + + # Check if we're at limit + if len(self.requests) >= self.max_requests: + # Calculate when oldest request expires + oldest = self.requests[0] + wait_time = oldest + self.window_size - now + return False, wait_time + + # Add request + self.requests.append(now) + return True, None + + async def reset(self): + """Reset counter.""" + async with self._lock: + self.requests.clear() + + +class RateLimiter: + """Rate limiter for MCP server.""" + + def __init__(self, config: RateLimitConfig): + self.config = config + + # Global rate limiter + self.global_limiter = TokenBucket( + capacity=config.global_burst_size, + refill_rate=config.global_requests_per_minute / 60.0 + ) + + # Per-client limiters + self.client_limiters: Dict[str, Any] = {} + + # Per-operation limiters + self.operation_limiters: Dict[str, Any] = {} + + # Initialize operation limiters + for operation, (rpm, burst) in config.operation_limits.items(): + if config.use_sliding_window: + self.operation_limiters[operation] = SlidingWindowCounter( + window_size=60, + max_requests=rpm + ) + else: + self.operation_limiters[operation] = TokenBucket( + capacity=burst, + refill_rate=rpm / 60.0 + ) + + # Cleanup task + self._cleanup_task = None + + async def start(self): + """Start the rate limiter.""" + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + + async def stop(self): + """Stop the rate limiter.""" + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + + async def _cleanup_loop(self): + """Periodic cleanup of old limiters.""" + while True: + try: + await asyncio.sleep(self.config.cleanup_interval) + + # Clean up inactive client limiters + # This is a simplified version - in production you'd track + # last access time and remove old entries + + except asyncio.CancelledError: + break + + def _get_client_limiter(self, client_id: str) -> Any: + """Get or create client limiter.""" + if client_id not in self.client_limiters: + if self.config.use_sliding_window: + self.client_limiters[client_id] = SlidingWindowCounter( + window_size=60, + max_requests=self.config.client_requests_per_minute + ) + else: + self.client_limiters[client_id] = TokenBucket( + capacity=self.config.client_burst_size, + refill_rate=self.config.client_requests_per_minute / 60.0 + ) + + return self.client_limiters[client_id] + + def _match_operation(self, operation: str) -> Optional[str]: + """Find matching operation pattern.""" + # Exact match + if operation in self.operation_limiters: + return operation + + # Wildcard match + for pattern in self.operation_limiters: + if pattern.endswith("/*"): + prefix = pattern[:-2] + if operation.startswith(prefix): + return pattern + + return None + + async def check_rate_limit( + self, + client_id: Optional[str] = None, + operation: Optional[str] = None, + tokens: int = 1 + ) -> None: + """Check rate limits, raising error if exceeded. + + Args: + client_id: Client identifier + operation: Operation being performed + tokens: Number of tokens to consume + + Raises: + RateLimitExceeded: If any rate limit is exceeded + """ + retry_after = None + + # Check global limit + if isinstance(self.global_limiter, TokenBucket): + allowed, wait_time = await self.global_limiter.consume(tokens) + else: + allowed, wait_time = await self.global_limiter.add_request() + + if not allowed: + retry_after = int(wait_time + 1) if wait_time else 60 + raise RateLimitExceeded( + "Global rate limit exceeded", + retry_after=retry_after + ) + + # Check client limit + if client_id: + client_limiter = self._get_client_limiter(client_id) + + if isinstance(client_limiter, TokenBucket): + allowed, wait_time = await client_limiter.consume(tokens) + else: + allowed, wait_time = await client_limiter.add_request() + + if not allowed: + retry_after = int(wait_time + 1) if wait_time else 60 + raise RateLimitExceeded( + f"Client rate limit exceeded for {client_id}", + retry_after=retry_after + ) + + # Check operation limit + if operation: + pattern = self._match_operation(operation) + if pattern: + op_limiter = self.operation_limiters[pattern] + + if isinstance(op_limiter, TokenBucket): + allowed, wait_time = await op_limiter.consume(tokens) + else: + allowed, wait_time = await op_limiter.add_request() + + if not allowed: + retry_after = int(wait_time + 1) if wait_time else 60 + raise RateLimitExceeded( + f"Operation rate limit exceeded for {operation}", + retry_after=retry_after + ) + + async def reset_client_limit(self, client_id: str): + """Reset rate limit for a specific client.""" + if client_id in self.client_limiters: + await self.client_limiters[client_id].reset() + + async def reset_all_limits(self): + """Reset all rate limits.""" + await self.global_limiter.reset() + + for limiter in self.client_limiters.values(): + await limiter.reset() + + for limiter in self.operation_limiters.values(): + await limiter.reset() + + def get_limit_status( + self, + client_id: Optional[str] = None, + operation: Optional[str] = None + ) -> Dict[str, Any]: + """Get current rate limit status. + + Returns: + Dictionary with limit information + """ + status = { + "global": { + "requests_per_minute": self.config.global_requests_per_minute, + "burst_size": self.config.global_burst_size, + } + } + + if client_id: + status["client"] = { + "requests_per_minute": self.config.client_requests_per_minute, + "burst_size": self.config.client_burst_size, + } + + if operation: + pattern = self._match_operation(operation) + if pattern: + rpm, burst = self.config.operation_limits[pattern] + status["operation"] = { + "pattern": pattern, + "requests_per_minute": rpm, + "burst_size": burst, + } + + return status \ No newline at end of file diff --git a/contextframe/mcp/server.py b/contextframe/mcp/server.py index 7b7faa6..7225838 100644 --- a/contextframe/mcp/server.py +++ b/contextframe/mcp/server.py @@ -11,7 +11,11 @@ from contextframe.mcp.tools import ToolRegistry from contextframe.mcp.transports.stdio import StdioAdapter from dataclasses import dataclass -from typing import Any, Dict, Literal, Optional +from typing import Any, Dict, Literal, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from contextframe.mcp.monitoring.integration import MonitoringSystem + from contextframe.mcp.security.integration import SecurityMiddleware logger = logging.getLogger(__name__) @@ -38,6 +42,23 @@ class MCPConfig: http_rate_limit: dict[str, int] = None http_ssl_cert: str | None = None http_ssl_key: str | None = None + + # Monitoring configuration + monitoring_enabled: bool = True + monitoring_retention_days: int = 30 + monitoring_flush_interval: int = 60 + pricing_config_path: str | None = None + + # Security configuration + security_enabled: bool = True + auth_providers: list[str] = None # ["api_key", "oauth", "jwt"] + anonymous_allowed: bool = False + anonymous_permissions: list[str] = None + api_keys_file: str | None = None + oauth_config_file: str | None = None + jwt_config_file: str | None = None + audit_log_file: str | None = None + audit_retention_days: int = 90 class ContextFrameMCPServer: @@ -68,6 +89,8 @@ def __init__(self, dataset_path: str, config: MCPConfig | None = None): self.handler: MessageHandler | None = None self.tools: ToolRegistry | None = None self.resources: ResourceRegistry | None = None + self.monitoring: "MonitoringSystem" | None = None + self.security: "SecurityMiddleware" | None = None async def setup(self): """Set up server components.""" @@ -77,11 +100,62 @@ async def setup(self): except Exception as e: raise DatasetNotFound(self.dataset_path) from e + # Initialize monitoring if enabled + if self.config.monitoring_enabled: + from contextframe.mcp.monitoring.collector import MetricsConfig + from contextframe.mcp.monitoring.cost import PricingConfig + from contextframe.mcp.monitoring.integration import ( + MonitoredMessageHandler, + MonitoredToolRegistry, + MonitoringSystem, + ) + + # Create monitoring config + metrics_config = MetricsConfig( + enabled=True, + retention_days=self.config.monitoring_retention_days, + flush_interval_seconds=self.config.monitoring_flush_interval + ) + + # Load pricing config if provided + pricing_config = None + if self.config.pricing_config_path: + pricing_config = PricingConfig.from_file(self.config.pricing_config_path) + else: + pricing_config = PricingConfig() + + # Initialize monitoring system + self.monitoring = MonitoringSystem( + self.dataset, + metrics_config, + pricing_config + ) + await self.monitoring.start() + + # Initialize security if enabled + if self.config.security_enabled: + await self._setup_security() + # Initialize transport based on configuration if self.config.transport == "stdio": self.transport = StdioAdapter() - self.handler = MessageHandler(self) - self.tools = ToolRegistry(self.dataset, self.transport) + + # Use secured/monitored versions based on configuration + if self.security: + from contextframe.mcp.security.integration import SecuredMessageHandler + self.handler = SecuredMessageHandler(self, self.security) + elif self.monitoring: + from contextframe.mcp.monitoring.integration import MonitoredMessageHandler + self.handler = MonitoredMessageHandler(self, self.monitoring) + else: + self.handler = MessageHandler(self) + + if self.monitoring: + from contextframe.mcp.monitoring.integration import MonitoredToolRegistry + self.tools = MonitoredToolRegistry(self.dataset, self.transport, self.monitoring) + else: + self.tools = ToolRegistry(self.dataset, self.transport) + self.resources = ResourceRegistry(self.dataset) await self.transport.initialize() elif self.config.transport == "http": @@ -92,6 +166,100 @@ async def setup(self): logger.info( f"MCP server initialized for dataset: {self.dataset_path} with {self.config.transport} transport" + + (f" and monitoring enabled" if self.config.monitoring_enabled else "") + ) + + async def _setup_security(self): + """Set up security components.""" + import json + from contextframe.mcp.security import ( + APIKeyAuth, + OAuth2Provider, + OAuth2Config, + JWTHandler, + JWTConfig, + MultiAuthProvider, + AccessControl, + RateLimiter, + RateLimitConfig, + AuditLogger, + AuditConfig, + ) + from contextframe.mcp.security.integration import SecurityMiddleware + + # Build auth providers + auth_providers = [] + + if not self.config.auth_providers: + # Default to API key auth + self.config.auth_providers = ["api_key"] + + for provider_type in self.config.auth_providers: + if provider_type == "api_key" and self.config.api_keys_file: + # Load API keys from file + try: + with open(self.config.api_keys_file, "r") as f: + api_keys = json.load(f) + auth_providers.append(APIKeyAuth(api_keys)) + except Exception as e: + logger.warning(f"Failed to load API keys: {e}") + + elif provider_type == "oauth" and self.config.oauth_config_file: + # Load OAuth config + try: + with open(self.config.oauth_config_file, "r") as f: + oauth_data = json.load(f) + oauth_config = OAuth2Config(**oauth_data) + auth_providers.append(OAuth2Provider(oauth_config)) + except Exception as e: + logger.warning(f"Failed to load OAuth config: {e}") + + elif provider_type == "jwt" and self.config.jwt_config_file: + # Load JWT config + try: + with open(self.config.jwt_config_file, "r") as f: + jwt_data = json.load(f) + jwt_config = JWTConfig(**jwt_data) + auth_providers.append(JWTHandler(jwt_config)) + except Exception as e: + logger.warning(f"Failed to load JWT config: {e}") + + # Create multi-auth provider if multiple providers + auth_provider = None + if len(auth_providers) > 1: + auth_provider = MultiAuthProvider(auth_providers) + elif len(auth_providers) == 1: + auth_provider = auth_providers[0] + + # Create access control + access_control = AccessControl() + + # Create rate limiter + rate_limiter = RateLimiter(RateLimitConfig()) + + # Create audit logger + audit_config = AuditConfig( + storage_backend="file" if self.config.audit_log_file else "memory", + file_path=self.config.audit_log_file, + retention_days=self.config.audit_retention_days + ) + audit_logger = AuditLogger(audit_config) + + # Create security middleware + self.security = SecurityMiddleware( + auth_provider=auth_provider, + access_control=access_control, + rate_limiter=rate_limiter, + audit_logger=audit_logger, + anonymous_allowed=self.config.anonymous_allowed, + anonymous_permissions=set(self.config.anonymous_permissions or []) + ) + + await self.security.start() + + logger.info( + f"Security initialized with providers: {self.config.auth_providers}" + + f", anonymous_allowed: {self.config.anonymous_allowed}" ) async def _setup_http_transport(self): @@ -125,8 +293,23 @@ async def _setup_http_transport(self): # For compatibility, set transport to the HTTP adapter self.transport = self.http_server.adapter - self.handler = self.http_server.handler - self.tools = ToolRegistry(self.dataset, self.transport) + + # Use secured/monitored versions based on configuration + if self.security: + from contextframe.mcp.security.integration import SecuredMessageHandler + self.handler = SecuredMessageHandler(self, self.security) + elif self.monitoring: + from contextframe.mcp.monitoring.integration import MonitoredMessageHandler + self.handler = MonitoredMessageHandler(self, self.monitoring) + else: + self.handler = self.http_server.handler + + if self.monitoring: + from contextframe.mcp.monitoring.integration import MonitoredToolRegistry + self.tools = MonitoredToolRegistry(self.dataset, self.transport, self.monitoring) + else: + self.tools = ToolRegistry(self.dataset, self.transport) + self.resources = ResourceRegistry(self.dataset) async def run(self): @@ -225,6 +408,16 @@ async def cleanup(self): if self.transport: await self.transport.shutdown() + # Stop monitoring if enabled + if self.monitoring: + logger.info("Stopping monitoring system") + await self.monitoring.stop() + + # Stop security if enabled + if self.security: + logger.info("Stopping security system") + await self.security.stop() + # Dataset cleanup if needed if self.dataset: # FrameDataset doesn't require explicit cleanup diff --git a/contextframe/tests/test_mcp/test_monitoring.py b/contextframe/tests/test_mcp/test_monitoring.py new file mode 100644 index 0000000..435cd1d --- /dev/null +++ b/contextframe/tests/test_mcp/test_monitoring.py @@ -0,0 +1,509 @@ +"""Test monitoring system for MCP server.""" + +import asyncio +import json +import pytest +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +from contextframe.mcp.monitoring.collector import MetricsCollector, MetricsConfig +from contextframe.mcp.monitoring.cost import CostCalculator, LLMPricing, PricingConfig +from contextframe.mcp.monitoring.integration import ( + MonitoredMessageHandler, + MonitoringSystem, +) +from contextframe.mcp.monitoring.performance import PerformanceMonitor +from contextframe.mcp.monitoring.usage import UsageTracker + + +@pytest.fixture +def metrics_config(): + """Create test metrics configuration.""" + return MetricsConfig( + enabled=True, + retention_days=7, + flush_interval_seconds=10, + max_memory_metrics=1000 + ) + + +@pytest.fixture +def pricing_config(): + """Create test pricing configuration.""" + return PricingConfig( + llm_pricing={ + "openai:gpt-4": LLMPricing("openai", "gpt-4", 0.03, 0.06), + "openai:gpt-3.5-turbo": LLMPricing("openai", "gpt-3.5-turbo", 0.0005, 0.0015), + }, + bandwidth_cost_per_gb=0.09 + ) + + +@pytest.fixture +def mock_dataset(): + """Create mock dataset.""" + return MagicMock() + + +@pytest.fixture +async def monitoring_system(mock_dataset, metrics_config, pricing_config): + """Create monitoring system for testing.""" + system = MonitoringSystem(mock_dataset, metrics_config, pricing_config) + await system.start() + yield system + await system.stop() + + +class TestMetricsCollector: + """Test metrics collection functionality.""" + + @pytest.mark.asyncio + async def test_record_usage_metric(self, mock_dataset, metrics_config): + """Test recording usage metrics.""" + collector = MetricsCollector(mock_dataset, metrics_config) + + await collector.record_usage( + metric_type="document_access", + resource_id="doc-123", + operation="read", + value=1.0, + agent_id="agent-1", + metadata={"source": "test"} + ) + + # Check metric was buffered + assert len(collector._usage_buffer) == 1 + metric = collector._usage_buffer[0] + assert metric["metric_type"] == "document_access" + assert metric["resource_id"] == "doc-123" + assert metric["operation"] == "read" + assert metric["agent_id"] == "agent-1" + + @pytest.mark.asyncio + async def test_record_performance_metric(self, mock_dataset, metrics_config): + """Test recording performance metrics.""" + collector = MetricsCollector(mock_dataset, metrics_config) + + await collector.record_performance( + operation_id="op-123", + operation_type="tool_call", + duration_ms=150.5, + status="success", + agent_id="agent-1", + result_size=1024 + ) + + # Check metric was buffered + assert len(collector._performance_buffer) == 1 + metric = collector._performance_buffer[0] + assert metric["operation_id"] == "op-123" + assert metric["operation_type"] == "tool_call" + assert metric["duration_ms"] == 150.5 + assert metric["status"] == "success" + + @pytest.mark.asyncio + async def test_record_cost_metric(self, mock_dataset, metrics_config): + """Test recording cost metrics.""" + collector = MetricsCollector(mock_dataset, metrics_config) + + await collector.record_cost( + operation_id="op-123", + cost_type="llm", + provider="openai", + amount_usd=0.015, + units=500, + agent_id="agent-1" + ) + + # Check metric was buffered + assert len(collector._cost_buffer) == 1 + metric = collector._cost_buffer[0] + assert metric["cost_type"] == "llm" + assert metric["provider"] == "openai" + assert metric["amount_usd"] == 0.015 + assert metric["units"] == 500 + + @pytest.mark.asyncio + async def test_metrics_disabled(self, mock_dataset): + """Test that metrics are not recorded when disabled.""" + config = MetricsConfig(enabled=False) + collector = MetricsCollector(mock_dataset, config) + + await collector.record_usage( + metric_type="test", + resource_id="test", + operation="test" + ) + + # No metrics should be recorded + assert len(collector._usage_buffer) == 0 + + +class TestUsageTracker: + """Test usage tracking functionality.""" + + @pytest.mark.asyncio + async def test_track_document_access(self, monitoring_system): + """Test tracking document access.""" + usage = monitoring_system.usage_tracker + + # Track multiple accesses + await usage.track_document_access("doc-1", "read", "agent-1") + await usage.track_document_access("doc-1", "search_hit", "agent-2") + await usage.track_document_access("doc-2", "update", "agent-1") + + # Check document cache + assert len(usage._document_cache) == 2 + + doc1_stats = usage._document_cache["doc-1"] + assert doc1_stats.access_count == 2 + assert doc1_stats.search_appearances == 1 + assert doc1_stats.access_by_operation["read"] == 1 + assert doc1_stats.access_by_operation["search_hit"] == 1 + + @pytest.mark.asyncio + async def test_track_query(self, monitoring_system): + """Test tracking query execution.""" + usage = monitoring_system.usage_tracker + + # Track queries + await usage.track_query( + query="test query", + query_type="vector", + result_count=10, + execution_time_ms=50.0, + agent_id="agent-1", + success=True + ) + + await usage.track_query( + query="test query", + query_type="vector", + result_count=5, + execution_time_ms=40.0, + agent_id="agent-2", + success=True + ) + + # Check query cache + query_key = "vector:test query" + assert query_key in usage._query_cache + + stats = usage._query_cache[query_key] + assert stats.count == 2 + assert stats.total_results == 15 + assert stats.avg_execution_time_ms == 45.0 # (50 + 40) / 2 + assert stats.success_rate == 1.0 + + @pytest.mark.asyncio + async def test_get_usage_stats(self, monitoring_system): + """Test getting aggregated usage statistics.""" + usage = monitoring_system.usage_tracker + + # Track some activity + await usage.track_document_access("doc-1", "read", "agent-1") + await usage.track_document_access("doc-2", "read", "agent-2") + await usage.track_query("test", "vector", 5, 10.0, "agent-1") + + # Get stats + end_time = datetime.now(timezone.utc) + start_time = end_time - timedelta(hours=1) + stats = await usage.get_usage_stats(start_time, end_time) + + assert stats.total_document_accesses >= 2 + assert stats.unique_documents_accessed == 2 + assert stats.total_queries >= 1 + assert stats.unique_agents >= 2 + + +class TestPerformanceMonitor: + """Test performance monitoring functionality.""" + + @pytest.mark.asyncio + async def test_operation_tracking(self, monitoring_system): + """Test tracking operation performance.""" + perf = monitoring_system.performance_monitor + + # Start operation + context = await perf.start_operation( + operation_id="op-1", + operation_type="tool_call", + agent_id="agent-1" + ) + + assert "op-1" in perf._active_operations + assert context.operation_type == "tool_call" + + # Simulate some work + await asyncio.sleep(0.01) + + # End operation + await perf.end_operation( + operation_id="op-1", + status="success", + result_size=100 + ) + + assert "op-1" not in perf._active_operations + + # Check metrics + metrics = perf.get_operation_metrics("tool_call") + assert "tool_call" in metrics + op_metrics = metrics["tool_call"] + assert op_metrics.count == 1 + assert op_metrics.error_count == 0 + assert op_metrics.avg_duration_ms > 0 + + @pytest.mark.asyncio + async def test_operation_context_manager(self, monitoring_system): + """Test operation tracking with context manager.""" + perf = monitoring_system.performance_monitor + + # Track successful operation + async with perf.track_operation("test_op", "agent-1") as ctx: + assert ctx.operation_type == "test_op" + await asyncio.sleep(0.01) + + # Track failed operation + with pytest.raises(ValueError): + async with perf.track_operation("test_op", "agent-1"): + raise ValueError("Test error") + + # Check metrics + metrics = perf.get_operation_metrics("test_op") + op_metrics = metrics["test_op"] + assert op_metrics.count == 2 + assert op_metrics.error_count == 1 + assert op_metrics.success_rate == 50.0 + + @pytest.mark.asyncio + async def test_response_percentiles(self, monitoring_system): + """Test response time percentile calculation.""" + perf = monitoring_system.performance_monitor + + # Track multiple operations with different durations + for i in range(10): + ctx = await perf.start_operation(f"op-{i}", "test_op") + await asyncio.sleep(0.001 * i) # Variable delays + await perf.end_operation(f"op-{i}", "success") + + # Get percentiles + percentiles = perf.get_response_percentiles( + "test_op", + [0.5, 0.9, 0.99] + ) + + assert 0.5 in percentiles + assert 0.9 in percentiles + assert percentiles[0.5] < percentiles[0.9] + + +class TestCostCalculator: + """Test cost calculation functionality.""" + + @pytest.mark.asyncio + async def test_llm_cost_tracking(self, monitoring_system): + """Test LLM usage cost tracking.""" + cost_calc = monitoring_system.cost_calculator + + # Track GPT-4 usage + cost = await cost_calc.track_llm_usage( + provider="openai", + model="gpt-4", + input_tokens=1000, + output_tokens=500, + operation_id="op-1", + agent_id="agent-1" + ) + + # GPT-4: $0.03/1k input, $0.06/1k output + expected_cost = (1000/1000 * 0.03) + (500/1000 * 0.06) + assert cost == expected_cost + assert cost == 0.06 # $0.03 + $0.03 + + # Track GPT-3.5 usage + cost = await cost_calc.track_llm_usage( + provider="openai", + model="gpt-3.5-turbo", + input_tokens=2000, + output_tokens=1000, + operation_id="op-2", + agent_id="agent-1" + ) + + # GPT-3.5: $0.0005/1k input, $0.0015/1k output + expected_cost = (2000/1000 * 0.0005) + (1000/1000 * 0.0015) + assert cost == expected_cost + assert cost == 0.0025 # $0.001 + $0.0015 + + @pytest.mark.asyncio + async def test_storage_cost_tracking(self, monitoring_system): + """Test storage operation cost tracking.""" + cost_calc = monitoring_system.cost_calculator + + # Track read operation (1GB) + cost = await cost_calc.track_storage_usage( + operation="read", + size_bytes=1024 * 1024 * 1024, # 1GB + agent_id="agent-1" + ) + + assert cost == 0.01 # $0.01 per GB read + + # Track write operation (2GB) + cost = await cost_calc.track_storage_usage( + operation="write", + size_bytes=2 * 1024 * 1024 * 1024, # 2GB + agent_id="agent-1" + ) + + assert cost == 0.04 # $0.02 per GB write * 2GB + + @pytest.mark.asyncio + async def test_bandwidth_cost_tracking(self, monitoring_system): + """Test bandwidth cost tracking.""" + cost_calc = monitoring_system.cost_calculator + + # Track egress bandwidth + cost = await cost_calc.track_bandwidth_usage( + size_bytes=10 * 1024 * 1024 * 1024, # 10GB + direction="egress", + agent_id="agent-1" + ) + + assert cost == 0.9 # $0.09 per GB * 10GB + + # Track ingress (should be free) + cost = await cost_calc.track_bandwidth_usage( + size_bytes=10 * 1024 * 1024 * 1024, + direction="ingress", + agent_id="agent-1" + ) + + assert cost == 0.0 + + @pytest.mark.asyncio + async def test_cost_report(self, monitoring_system): + """Test cost report generation.""" + cost_calc = monitoring_system.cost_calculator + + # Track some costs + await cost_calc.track_llm_usage( + "openai", "gpt-4", 1000, 500, "op-1", "agent-1" + ) + await cost_calc.track_storage_usage( + "read", 1024 * 1024 * 1024, "agent-1" + ) + + # Get report + end_time = datetime.now(timezone.utc) + start_time = end_time - timedelta(days=1) + report = await cost_calc.get_cost_report(start_time, end_time, "agent") + + assert report.summary.total_cost > 0 + assert report.summary.llm_cost > 0 + assert report.summary.storage_cost > 0 + assert report.projected_monthly_cost > 0 + + +class TestMonitoredMessageHandler: + """Test monitored message handler.""" + + @pytest.mark.asyncio + async def test_message_handling_with_monitoring(self, monitoring_system): + """Test that messages are tracked by monitoring.""" + # Create mock server and handler + mock_server = MagicMock() + mock_server.tools = MagicMock() + mock_server.resources = MagicMock() + + # Create base handler mock + with patch('contextframe.mcp.handlers.MessageHandler') as MockHandler: + mock_base_handler = AsyncMock() + mock_base_handler.handle.return_value = { + "jsonrpc": "2.0", + "result": {"success": True}, + "id": 1 + } + MockHandler.return_value = mock_base_handler + + # Create monitored handler + handler = MonitoredMessageHandler(mock_server, monitoring_system) + + # Handle a message + message = { + "jsonrpc": "2.0", + "method": "tools/call", + "params": {"name": "test_tool"}, + "id": 1, + "agent_id": "test-agent" + } + + result = await handler.handle(message) + + # Check that base handler was called + mock_base_handler.handle.assert_called_once() + + # Check monitoring was updated + perf_metrics = monitoring_system.performance_monitor.get_operation_metrics() + assert "tools/call" in perf_metrics + assert perf_metrics["tools/call"].count == 1 + assert perf_metrics["tools/call"].error_count == 0 + + +class TestMonitoringTools: + """Test monitoring MCP tools.""" + + @pytest.mark.asyncio + async def test_get_usage_metrics_tool(self, monitoring_system): + """Test get_usage_metrics tool.""" + from contextframe.mcp.monitoring.tools import get_usage_metrics + + # Track some usage + await monitoring_system.usage_tracker.track_document_access( + "doc-1", "read", "agent-1" + ) + + # Get metrics + result = await get_usage_metrics({}) + + assert "summary" in result + assert "queries_by_type" in result + assert result["summary"]["total_document_accesses"] >= 1 + + @pytest.mark.asyncio + async def test_get_performance_metrics_tool(self, monitoring_system): + """Test get_performance_metrics tool.""" + from contextframe.mcp.monitoring.tools import get_performance_metrics + + # Track some operations + await monitoring_system.performance_monitor.start_operation( + "op-1", "test_op", "agent-1" + ) + await monitoring_system.performance_monitor.end_operation( + "op-1", "success" + ) + + # Get metrics + result = await get_performance_metrics({}) + + assert "operations" in result + assert "current_snapshot" in result + assert "test_op" in result["operations"] + + @pytest.mark.asyncio + async def test_get_cost_report_tool(self, monitoring_system): + """Test get_cost_report tool.""" + from contextframe.mcp.monitoring.tools import get_cost_report + + # Track some costs + await monitoring_system.cost_calculator.track_llm_usage( + "openai", "gpt-3.5-turbo", 1000, 500, "op-1", "agent-1" + ) + + # Get report + result = await get_cost_report({}) + + assert "total_cost" in result + assert "breakdown" in result + assert result["total_cost"] > 0 \ No newline at end of file diff --git a/contextframe/tests/test_mcp/test_security.py b/contextframe/tests/test_mcp/test_security.py new file mode 100644 index 0000000..e021023 --- /dev/null +++ b/contextframe/tests/test_mcp/test_security.py @@ -0,0 +1,607 @@ +"""Test security components for MCP server.""" + +import asyncio +import json +import time +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from contextframe.mcp.security.audit import ( + AuditConfig, + AuditEvent, + AuditEventType, + AuditLogger, +) +from contextframe.mcp.security.auth import ( + APIKeyAuth, + AuthenticationError, + MultiAuthProvider, + SecurityContext, +) +from contextframe.mcp.security.authorization import ( + AccessControl, + AuthorizationError, + Permission, + ResourcePolicy, + Role, + STANDARD_ROLES, +) +from contextframe.mcp.security.jwt import JWTConfig, JWTHandler +from contextframe.mcp.security.oauth import OAuth2Config, OAuth2Provider +from contextframe.mcp.security.rate_limiting import ( + RateLimitConfig, + RateLimiter, + RateLimitExceeded, +) + + +class TestAPIKeyAuth: + """Test API key authentication.""" + + @pytest.fixture + def api_keys(self): + """Sample API keys.""" + return { + "test-key-1": { + "principal_id": "user-1", + "principal_name": "Test User 1", + "permissions": ["documents.read", "collections.read"], + "roles": ["viewer"], + }, + "test-key-2": { + "principal_id": "service-1", + "principal_type": "service", + "principal_name": "Test Service", + "permissions": ["documents.*", "collections.*"], + "roles": ["admin"], + "expires_at": datetime.now(timezone.utc) + timedelta(days=30), + }, + } + + @pytest.mark.asyncio + async def test_api_key_authentication_success(self, api_keys): + """Test successful API key authentication.""" + auth = APIKeyAuth(api_keys) + + context = await auth.authenticate({"api_key": "test-key-1"}) + + assert context.authenticated + assert context.auth_method == "api_key" + assert context.principal_id == "user-1" + assert context.principal_name == "Test User 1" + assert "documents.read" in context.permissions + assert "viewer" in context.roles + + @pytest.mark.asyncio + async def test_api_key_authentication_failure(self, api_keys): + """Test failed API key authentication.""" + auth = APIKeyAuth(api_keys) + + with pytest.raises(AuthenticationError) as exc_info: + await auth.authenticate({"api_key": "invalid-key"}) + + assert "Invalid API key" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_api_key_expiration(self, api_keys): + """Test expired API key rejection.""" + # Add expired key + api_keys["expired-key"] = { + "principal_id": "user-2", + "expires_at": datetime.now(timezone.utc) - timedelta(days=1), + } + + auth = APIKeyAuth(api_keys) + + with pytest.raises(AuthenticationError) as exc_info: + await auth.authenticate({"api_key": "expired-key"}) + + assert "API key expired" in str(exc_info.value) + + def test_generate_api_key(self): + """Test API key generation.""" + key1 = APIKeyAuth.generate_api_key() + key2 = APIKeyAuth.generate_api_key() + + assert len(key1) > 20 + assert key1 != key2 + + +class TestJWTHandler: + """Test JWT authentication.""" + + @pytest.fixture + def jwt_config(self): + """JWT configuration.""" + return JWTConfig( + algorithm="HS256", + secret_key="test-secret-key-for-testing-only", + issuer="test-issuer", + audience="test-audience", + token_lifetime=3600, + ) + + @pytest.mark.asyncio + async def test_jwt_create_and_verify(self, jwt_config): + """Test JWT creation and verification.""" + handler = JWTHandler(jwt_config) + + # Create token + token = handler.create_token( + principal_id="user-123", + principal_name="Test User", + principal_type="user", + permissions={"documents.read", "collections.read"}, + roles={"viewer"}, + ) + + # Verify token + context = await handler.authenticate({"token": token}) + + assert context.authenticated + assert context.auth_method == "jwt" + assert context.principal_id == "user-123" + assert context.principal_name == "Test User" + assert "documents.read" in context.permissions + assert "viewer" in context.roles + + @pytest.mark.asyncio + async def test_jwt_expired_token(self, jwt_config): + """Test expired JWT rejection.""" + jwt_config.token_lifetime = -1 # Immediate expiration + handler = JWTHandler(jwt_config) + + token = handler.create_token(principal_id="user-123") + + with pytest.raises(AuthenticationError) as exc_info: + await handler.authenticate({"token": token}) + + assert "expired" in str(exc_info.value).lower() + + @pytest.mark.asyncio + async def test_jwt_invalid_signature(self, jwt_config): + """Test JWT with invalid signature.""" + handler = JWTHandler(jwt_config) + + # Create token with different secret + jwt_config.secret_key = "different-secret" + bad_handler = JWTHandler(jwt_config) + token = bad_handler.create_token(principal_id="user-123") + + with pytest.raises(AuthenticationError) as exc_info: + await handler.authenticate({"token": token}) + + assert "Invalid JWT token" in str(exc_info.value) + + +class TestOAuth2Provider: + """Test OAuth 2.1 authentication.""" + + @pytest.fixture + def oauth_config(self): + """OAuth configuration.""" + return OAuth2Config( + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + userinfo_endpoint="https://auth.example.com/userinfo", + client_id="test-client-id", + client_secret="test-client-secret", + redirect_uri="http://localhost:8080/callback", + scopes=["openid", "profile", "email"], + ) + + def test_generate_authorization_url(self, oauth_config): + """Test authorization URL generation.""" + provider = OAuth2Provider(oauth_config) + + url = provider.generate_authorization_url( + state="test-state", + code_challenge="test-challenge", + ) + + assert oauth_config.authorization_endpoint in url + assert "client_id=test-client-id" in url + assert "state=test-state" in url + assert "code_challenge=test-challenge" in url + + def test_generate_pkce_pair(self): + """Test PKCE verifier and challenge generation.""" + verifier, challenge = OAuth2Provider.generate_pkce_pair() + + assert len(verifier) > 40 + assert len(challenge) > 40 + assert verifier != challenge + + +class TestAuthorization: + """Test authorization and access control.""" + + @pytest.fixture + def security_context(self): + """Sample security context.""" + return SecurityContext( + authenticated=True, + auth_method="api_key", + principal_id="user-1", + principal_type="user", + permissions={"documents.read"}, + roles={"viewer"}, + ) + + def test_role_permissions(self): + """Test role permission checks.""" + viewer_role = STANDARD_ROLES["viewer"] + admin_role = STANDARD_ROLES["admin"] + + assert viewer_role.has_permission(Permission.DOCUMENTS_READ) + assert not viewer_role.has_permission(Permission.DOCUMENTS_WRITE) + + assert admin_role.has_permission(Permission.DOCUMENTS_READ) + assert admin_role.has_permission(Permission.DOCUMENTS_WRITE) + assert admin_role.has_permission("documents.custom") # Wildcard + + def test_access_control_direct_permission(self, security_context): + """Test authorization with direct permissions.""" + access_control = AccessControl() + + # Should allow - has direct permission + assert access_control.authorize( + security_context, + Permission.DOCUMENTS_READ + ) + + # Should deny - no permission + assert not access_control.authorize( + security_context, + Permission.DOCUMENTS_WRITE + ) + + def test_access_control_role_permission(self, security_context): + """Test authorization with role-based permissions.""" + access_control = AccessControl() + + # Viewer role allows collections.read + assert access_control.authorize( + security_context, + Permission.COLLECTIONS_READ + ) + + # Viewer role doesn't allow collections.write + assert not access_control.authorize( + security_context, + Permission.COLLECTIONS_WRITE + ) + + def test_access_control_resource_policy(self, security_context): + """Test resource-level access control.""" + access_control = AccessControl() + + # Add policy allowing write to specific document + policy = ResourcePolicy( + resource_type="document", + resource_id="doc-123", + permissions={Permission.DOCUMENTS_WRITE}, + conditions={"principal_id": "user-1"} + ) + access_control.add_policy(policy) + + # Should allow - policy matches + assert access_control.authorize( + security_context, + Permission.DOCUMENTS_WRITE, + resource_type="document", + resource_id="doc-123" + ) + + # Should deny - different document + assert not access_control.authorize( + security_context, + Permission.DOCUMENTS_WRITE, + resource_type="document", + resource_id="doc-456" + ) + + def test_require_permission_raises(self, security_context): + """Test require_permission raises on denial.""" + access_control = AccessControl() + + with pytest.raises(AuthorizationError) as exc_info: + access_control.require_permission( + security_context, + Permission.DOCUMENTS_DELETE + ) + + assert "Permission 'documents.delete' required" in str(exc_info.value) + + +class TestRateLimiting: + """Test rate limiting functionality.""" + + @pytest.fixture + def rate_limiter(self): + """Create rate limiter with test config.""" + config = RateLimitConfig( + global_requests_per_minute=60, + global_burst_size=10, + client_requests_per_minute=30, + client_burst_size=5, + use_sliding_window=False, # Use token bucket for testing + ) + return RateLimiter(config) + + @pytest.mark.asyncio + async def test_rate_limit_allows_normal_traffic(self, rate_limiter): + """Test rate limiter allows normal traffic.""" + # Should allow first few requests + for _ in range(5): + await rate_limiter.check_rate_limit(client_id="user-1") + + # No exception means allowed + + @pytest.mark.asyncio + async def test_client_rate_limit_exceeded(self, rate_limiter): + """Test client rate limit enforcement.""" + # Exhaust client burst + for _ in range(5): + await rate_limiter.check_rate_limit(client_id="user-1") + + # Next request should fail + with pytest.raises(RateLimitExceeded) as exc_info: + await rate_limiter.check_rate_limit(client_id="user-1") + + assert "Client rate limit exceeded" in str(exc_info.value) + assert exc_info.value.retry_after > 0 + + @pytest.mark.asyncio + async def test_operation_rate_limit(self, rate_limiter): + """Test operation-specific rate limits.""" + # Tools have lower limits + for _ in range(5): + await rate_limiter.check_rate_limit( + client_id="user-1", + operation="tools/call" + ) + + with pytest.raises(RateLimitExceeded) as exc_info: + await rate_limiter.check_rate_limit( + client_id="user-1", + operation="tools/call" + ) + + assert "Operation rate limit exceeded" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_rate_limit_reset(self, rate_limiter): + """Test rate limit reset.""" + # Exhaust limit + for _ in range(5): + await rate_limiter.check_rate_limit(client_id="user-1") + + # Reset + await rate_limiter.reset_client_limit("user-1") + + # Should allow again + await rate_limiter.check_rate_limit(client_id="user-1") + + +class TestAuditLogging: + """Test audit logging functionality.""" + + @pytest.fixture + def audit_logger(self): + """Create audit logger with memory backend.""" + config = AuditConfig( + storage_backend="memory", + max_events_memory=100, + buffer_size=10, + ) + return AuditLogger(config) + + @pytest.mark.asyncio + async def test_log_authentication_event(self, audit_logger): + """Test logging authentication events.""" + await audit_logger.start() + + await audit_logger.log_event( + event_type=AuditEventType.AUTH_SUCCESS, + success=True, + principal_id="user-1", + principal_type="user", + auth_method="api_key", + operation="initialize", + client_ip="127.0.0.1", + ) + + # Force flush + await audit_logger._flush_buffer() + + # Check event was logged + events = await audit_logger.search_events( + event_types=[AuditEventType.AUTH_SUCCESS] + ) + + assert len(events) == 1 + assert events[0].principal_id == "user-1" + assert events[0].auth_method == "api_key" + + await audit_logger.stop() + + @pytest.mark.asyncio + async def test_log_authorization_event(self, audit_logger): + """Test logging authorization events.""" + await audit_logger.start() + + await audit_logger.log_event( + event_type=AuditEventType.AUTHZ_DENIED, + success=False, + principal_id="user-1", + operation="delete_document", + resource_type="document", + resource_id="doc-123", + error_message="Permission denied", + ) + + await audit_logger._flush_buffer() + + events = await audit_logger.search_events( + event_types=[AuditEventType.AUTHZ_DENIED] + ) + + assert len(events) == 1 + assert events[0].resource_id == "doc-123" + assert events[0].severity == "warning" + + await audit_logger.stop() + + @pytest.mark.asyncio + async def test_search_events_with_filters(self, audit_logger): + """Test searching events with filters.""" + await audit_logger.start() + + # Log various events + await audit_logger.log_event( + AuditEventType.AUTH_SUCCESS, + principal_id="user-1", + ) + await audit_logger.log_event( + AuditEventType.AUTH_SUCCESS, + principal_id="user-2", + ) + await audit_logger.log_event( + AuditEventType.TOOL_EXECUTED, + principal_id="user-1", + resource_type="tool", + resource_id="test_tool", + ) + + await audit_logger._flush_buffer() + + # Search by principal + events = await audit_logger.search_events(principal_id="user-1") + assert len(events) == 2 + + # Search by event type + events = await audit_logger.search_events( + event_types=[AuditEventType.TOOL_EXECUTED] + ) + assert len(events) == 1 + assert events[0].resource_id == "test_tool" + + await audit_logger.stop() + + def test_sensitive_data_redaction(self, audit_logger): + """Test sensitive data is redacted.""" + details = { + "username": "test", + "password": "secret123", + "api_key": "key-12345", + "other_data": "visible", + } + + redacted = audit_logger._redact_sensitive_data(details) + + assert redacted["username"] == "test" + assert redacted["password"] == "[REDACTED]" + assert redacted["api_key"] == "[REDACTED]" + assert redacted["other_data"] == "visible" + + +class TestSecurityIntegration: + """Test security middleware integration.""" + + @pytest.mark.asyncio + async def test_security_middleware_authentication(self): + """Test security middleware authentication flow.""" + from contextframe.mcp.security.integration import SecurityMiddleware + + # Create components + api_keys = { + "test-key": { + "principal_id": "user-1", + "permissions": ["documents.read"], + "roles": ["viewer"], + } + } + auth_provider = APIKeyAuth(api_keys) + + middleware = SecurityMiddleware( + auth_provider=auth_provider, + anonymous_allowed=False, + ) + + # Test successful auth + message = { + "method": "test", + "params": {"api_key": "test-key"}, + "id": 1, + } + + context = await middleware.authenticate(message) + assert context.authenticated + assert context.principal_id == "user-1" + + # Test failed auth + message["params"]["api_key"] = "wrong-key" + + with pytest.raises(AuthenticationError): + await middleware.authenticate(message) + + @pytest.mark.asyncio + async def test_security_middleware_full_flow(self): + """Test complete security flow.""" + from contextframe.mcp.security.integration import SecurityMiddleware + + # Create all components + auth_provider = APIKeyAuth({ + "test-key": { + "principal_id": "user-1", + "permissions": ["documents.read"], + "roles": ["viewer"], + } + }) + + access_control = AccessControl() + rate_limiter = RateLimiter(RateLimitConfig()) + audit_logger = AuditLogger(AuditConfig(storage_backend="memory")) + + middleware = SecurityMiddleware( + auth_provider=auth_provider, + access_control=access_control, + rate_limiter=rate_limiter, + audit_logger=audit_logger, + ) + + await middleware.start() + + # Test full security check + message = { + "method": "get_document", + "params": { + "api_key": "test-key", + "document_id": "doc-123", + }, + "id": 1, + } + + # Authenticate + context = await middleware.authenticate(message) + + # Check rate limit + await middleware.check_rate_limit(context, "get_document") + + # Authorize + await middleware.authorize( + context, + "get_document", + message["params"] + ) + + # Verify audit log + await audit_logger._flush_buffer() + events = await audit_logger.search_events() + assert len(events) > 0 + assert events[0].event_type == AuditEventType.AUTH_SUCCESS + + await middleware.stop() \ No newline at end of file