diff --git a/CHANGELOG.md b/CHANGELOG.md index 006e353..174ae6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,23 @@ All notable changes to Distill are documented here. +## [Unreleased] + +### Added + +- **Session-based context window management** (`pkg/session`) — Token-budgeted context windows for long-running agent sessions. Entries are deduplicated on push, compressed through hierarchical levels (full text → summary → sentence → keywords), and evicted when the budget is exceeded. Lowest-importance entries are compressed first. ([#38](https://github.com/Siddhant-K-code/distill/pull/38), closes [#31](https://github.com/Siddhant-K-code/distill/issues/31)) +- **Session CLI** — `distill session create/push/context/delete` commands. ([#38](https://github.com/Siddhant-K-code/distill/pull/38)) +- **Session HTTP API** — `/v1/session/create`, `/push`, `/context`, `/delete`, `/get` endpoints. Opt-in via `--session` flag. ([#38](https://github.com/Siddhant-K-code/distill/pull/38)) +- **Session MCP tools** — `create_session`, `push_session`, `session_context`, `delete_session` for Claude Desktop, Cursor, and Amp. Opt-in via `--session` flag. ([#38](https://github.com/Siddhant-K-code/distill/pull/38)) + +### Stats + +- 9 files changed, 1,928 insertions, 6 deletions +- 1 new package: `pkg/session` +- 13 new tests + +--- + ## [v0.3.0] - 2026-02-23 Feature release adding persistent context memory, SSE streaming, OpenTelemetry tracing, and project documentation. diff --git a/FAQ.md b/FAQ.md index 6d719f1..e753cd3 100644 --- a/FAQ.md +++ b/FAQ.md @@ -20,6 +20,18 @@ LLMs are non-deterministic. The same input can produce different compressed outp --- +### What is Context Memory? + +Persistent memory that accumulates knowledge across agent sessions. Store context once, recall it later by semantic similarity + recency. Memories are deduplicated on write and compressed over time through hierarchical decay (full text → summary → keywords → evicted). Enable with `--memory` on the `api` or `mcp` commands. + +### What are Sessions? + +Token-budgeted context windows for long-running agent tasks. Push context incrementally as the agent works - Distill deduplicates entries, compresses aging ones, and evicts when the budget is exceeded. The `preserve_recent` setting keeps the N most recent entries at full fidelity. Enable with `--session` on the `api` or `mcp` commands. + +### How is Context Memory different from Sessions? + +Memory is cross-session: knowledge persists after a session ends and can be recalled in future sessions. Sessions are within-task: a bounded context window that tracks what the agent has seen during a single task, enforcing a token budget. Use memory for long-term knowledge, sessions for working context. + ## Algorithms ### Why agglomerative clustering instead of K-Means? @@ -108,6 +120,10 @@ Yes. The HTTP API is framework-agnostic. MCP works with any MCP-compatible clien LangChain's `search_type="mmr"` applies MMR at the vector DB level - a single re-ranking step. Distill runs a multi-stage pipeline: cache lookup, agglomerative clustering (groups similar chunks), representative selection (picks the best from each group), compression (reduces token count), then MMR (diversity re-ranking). The clustering step is the key difference - it understands group structure, not just pairwise similarity. +### What MCP tools does Distill expose? + +The base MCP server exposes `deduplicate_context` and `analyze_redundancy`. With `--memory`, it adds `store_memory`, `recall_memory`, `forget_memory`, `memory_stats`. With `--session`, it adds `create_session`, `push_session`, `session_context`, `delete_session`. Enable both with `distill mcp --memory --session`. + ### Can I use Distill with local models (Ollama, vLLM)? The dedup pipeline itself doesn't call any LLM - it's pure math (cosine distance, clustering). The only external dependency is for embedding generation when you send text without pre-computed embeddings. Multi-provider embedding support (Ollama, Azure, Cohere, HuggingFace) is planned in [#33](https://github.com/Siddhant-K-code/distill/issues/33). @@ -180,8 +196,10 @@ Yes, AGPL-3.0. The full pipeline, CLI, API server, MCP server, and all algorithm ### What's on the roadmap? -Three pillars: +**Shipped:** +- **Context Memory** - Persistent deduplicated memory across sessions with hierarchical decay ([#29](https://github.com/Siddhant-K-code/distill/issues/29)) +- **Session Management** - Token-budgeted context windows with compression and eviction ([#31](https://github.com/Siddhant-K-code/distill/issues/31)) -1. **Context Memory** - Persistent deduplicated memory across agent sessions with hierarchical decay ([#29](https://github.com/Siddhant-K-code/distill/issues/29), [#31](https://github.com/Siddhant-K-code/distill/issues/31)) -2. **Code Intelligence** - Dependency graphs, co-change patterns, blast radius analysis ([#30](https://github.com/Siddhant-K-code/distill/issues/30), [#32](https://github.com/Siddhant-K-code/distill/issues/32)) -3. **Platform** - Python SDK, multi-provider embeddings, batch API ([#5](https://github.com/Siddhant-K-code/distill/issues/5), [#33](https://github.com/Siddhant-K-code/distill/issues/33), [#11](https://github.com/Siddhant-K-code/distill/issues/11)) +**Upcoming:** +1. **Code Intelligence** - Dependency graphs, co-change patterns, blast radius analysis ([#30](https://github.com/Siddhant-K-code/distill/issues/30), [#32](https://github.com/Siddhant-K-code/distill/issues/32)) +2. **Platform** - Python SDK, multi-provider embeddings, batch API ([#5](https://github.com/Siddhant-K-code/distill/issues/5), [#33](https://github.com/Siddhant-K-code/distill/issues/33), [#11](https://github.com/Siddhant-K-code/distill/issues/11)) diff --git a/README.md b/README.md index 4bb413a..59586fa 100644 --- a/README.md +++ b/README.md @@ -183,7 +183,11 @@ curl -X POST http://localhost:8080/v1/retrieve \ Works with Claude, Cursor, Amp, and other MCP-compatible assistants: ```bash +# Dedup only distill mcp + +# With memory and sessions +distill mcp --memory --session ``` Add to Claude Desktop (`~/Library/Application Support/Claude/claude_desktop_config.json`): @@ -193,7 +197,10 @@ Add to Claude Desktop (`~/Library/Application Support/Claude/claude_desktop_conf "mcpServers": { "distill": { "command": "/path/to/distill", - "args": ["mcp"] + "args": ["mcp", "--memory", "--session"], + "env": { + "OPENAI_API_KEY": "your-key" + } } } } @@ -270,6 +277,77 @@ memory: dedup_threshold: 0.15 ``` +## Session Management + +Token-budgeted context windows for long-running agent sessions. Push context incrementally - Distill deduplicates, compresses aging entries, and evicts when the budget is exceeded. + +Enable with the `--session` flag on `api` or `mcp` commands. + +### CLI + +```bash +# Create a session with 128K token budget +distill session create --session-id task-42 --max-tokens 128000 + +# Push context as the agent works +distill session push --session-id task-42 --role user --content "Fix the JWT validation bug" +distill session push --session-id task-42 --role tool --content "$(cat auth/jwt.go)" --source file_read --importance 0.8 + +# Read the current context window +distill session context --session-id task-42 + +# Clean up when done +distill session delete --session-id task-42 +``` + +### API + +```bash +# Start API with sessions enabled +distill api --port 8080 --session + +# Create session +curl -X POST http://localhost:8080/v1/session/create \ + -H "Content-Type: application/json" \ + -d '{"session_id": "task-42", "max_tokens": 128000}' + +# Push entries +curl -X POST http://localhost:8080/v1/session/push \ + -H "Content-Type: application/json" \ + -d '{ + "session_id": "task-42", + "entries": [ + {"role": "tool", "content": "file contents...", "source": "file_read", "importance": 0.8} + ] + }' + +# Read context window +curl -X POST http://localhost:8080/v1/session/context \ + -H "Content-Type: application/json" \ + -d '{"session_id": "task-42"}' +``` + +### MCP + +Session tools are available when `--session` is enabled: + +```bash +distill mcp --session +``` + +Tools exposed: `create_session`, `push_session`, `session_context`, `delete_session`. + +### How Budget Enforcement Works + +When a push exceeds the token budget: + +1. **Compress** oldest entries (outside the `preserve_recent` window) through levels: + - Full text → Summary (~20%) → Single sentence (~5%) → Keywords (~1%) +2. **Evict** entries that are already at keyword level +3. Lowest-importance entries are compressed/evicted first + +The `preserve_recent` setting (default: 10) keeps the most recent entries at full fidelity. + ## CLI Commands ```bash @@ -277,6 +355,7 @@ distill api # Start standalone API server distill serve # Start server with vector DB connection distill mcp # Start MCP server for AI assistants distill memory # Store, recall, and manage persistent context memories +distill session # Manage token-budgeted context windows for agent sessions distill analyze # Analyze a file for duplicates distill sync # Upload vectors to Pinecone with dedup distill query # Test a query from command line @@ -294,6 +373,11 @@ distill config # Manage configuration files | POST | `/v1/memory/recall` | Recall memories by relevance + recency (requires `--memory`) | | POST | `/v1/memory/forget` | Remove memories by ID, tag, or age (requires `--memory`) | | GET | `/v1/memory/stats` | Memory store statistics (requires `--memory`) | +| POST | `/v1/session/create` | Create a session with token budget (requires `--session`) | +| POST | `/v1/session/push` | Push entries with dedup + budget enforcement (requires `--session`) | +| POST | `/v1/session/context` | Read current context window (requires `--session`) | +| POST | `/v1/session/delete` | Delete a session (requires `--session`) | +| GET | `/v1/session/get` | Get session metadata (requires `--session`) | | GET | `/health` | Health check | | GET | `/metrics` | Prometheus metrics | @@ -345,6 +429,15 @@ retriever: auth: api_keys: - ${DISTILL_API_KEY} + +memory: + db_path: distill-memory.db + dedup_threshold: 0.15 + +session: + db_path: distill-sessions.db + dedup_threshold: 0.15 + max_tokens: 128000 ``` Environment variables can be referenced using `${VAR}` or `${VAR:-default}` syntax. @@ -537,6 +630,14 @@ Reduces token count while preserving meaning. Three strategies: Strategies can be chained via `compress.Pipeline`. Configure with target reduction ratio (e.g., 0.3 = keep 30% of original). +### Memory (`pkg/memory`) + +Persistent context memory across agent sessions. SQLite-backed with write-time deduplication via cosine similarity. Memories decay over time: full text → summary → keywords → evicted. Recall ranked by `(1-w)*similarity + w*recency`. Enable with `--memory` flag. + +### Session (`pkg/session`) + +Token-budgeted context windows for long-running tasks. Entries are deduplicated on push, compressed through hierarchical levels when the budget is exceeded, and evicted by importance. The `preserve_recent` setting keeps the N most recent entries at full fidelity. Enable with `--session` flag. + ### Cache (`pkg/cache`) KV cache for repeated context patterns (system prompts, tool definitions, boilerplate). Sub-millisecond retrieval for cache hits. @@ -566,7 +667,7 @@ KV cache for repeated context patterns (system prompts, tool definitions, boiler │ Context Intelligence │ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────────────────┐ │ │ │ Memory Store │ │ Impact Graph │ │ Session Context Windows │ │ -│ │ (shipped) │ │ (#30) │ │ (#31) │ │ +│ │ (shipped) │ │ (#30) │ │ (shipped) │ │ │ └──────────────┘ └──────────────┘ └──────────────────────────┘ │ │ │ │ ┌──────────────────────────────────────────────────────────────┐ │ @@ -604,7 +705,7 @@ Distill is evolving from a dedup utility into a context intelligence layer. Here | Feature | Issue | Status | Description | |---------|-------|--------|-------------| | **Context Memory Store** | [#29](https://github.com/Siddhant-K-code/distill/issues/29) | Shipped | Persistent, deduplicated memory across sessions. Write-time dedup, hierarchical decay, token-budgeted recall. See [Context Memory](#context-memory). | -| **Session Management** | [#31](https://github.com/Siddhant-K-code/distill/issues/31) | Planned | Stateful context windows for long-running agents. Push context incrementally, Distill keeps it deduplicated and within budget. | +| **Session Management** | [#31](https://github.com/Siddhant-K-code/distill/issues/31) | Shipped | Stateful context windows with token budgets, hierarchical compression, and importance-based eviction. See [Session Management](#session-management). | ### Code Intelligence diff --git a/cmd/api.go b/cmd/api.go index 3279951..6e221e5 100644 --- a/cmd/api.go +++ b/cmd/api.go @@ -46,6 +46,8 @@ func init() { apiCmd.Flags().String("embedding-model", "text-embedding-3-small", "OpenAI embedding model") apiCmd.Flags().String("api-keys", "", "Comma-separated list of valid API keys (or use DISTILL_API_KEYS)") apiCmd.Flags().Bool("memory", false, "Enable persistent memory store") + apiCmd.Flags().Bool("session", false, "Enable session management") + apiCmd.Flags().String("session-db", "distill-sessions.db", "SQLite database path for session store") // Bind to viper for config file support _ = viper.BindPFlag("server.port", apiCmd.Flags().Lookup("port")) @@ -199,6 +201,24 @@ func runAPI(cmd *cobra.Command, args []string) error { memAPI := &MemoryAPI{store: memStore, embedder: embedder} memAPI.RegisterMemoryRoutes(mux, m.Middleware) } + + // Setup session store (opt-in) + enableSession, _ := cmd.Flags().GetBool("session") + if enableSession { + sessDBPath, _ := cmd.Flags().GetString("session-db") + if sessDBPath == "" { + sessDBPath = "distill-sessions.db" + } + sessStore, err := newSessionStore(sessDBPath) + if err != nil { + return fmt.Errorf("failed to create session store: %w", err) + } + defer func() { _ = sessStore.Close() }() + + sessAPI := &SessionAPI{store: sessStore} + sessAPI.RegisterSessionRoutes(mux, m.Middleware) + } + mux.HandleFunc("/health", server.handleHealth) mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) { m.Handler().ServeHTTP(w, r) @@ -241,6 +261,7 @@ func runAPI(cmd *cobra.Command, args []string) error { fmt.Printf(" Embeddings: %v\n", embedder != nil) fmt.Printf(" Auth: %v (%d keys)\n", server.hasAuth, len(validKeys)) fmt.Printf(" Memory: %v\n", enableMemory) + fmt.Printf(" Sessions: %v\n", enableSession) fmt.Println() fmt.Println("Endpoints:") fmt.Printf(" POST http://%s/v1/dedupe\n", addr) diff --git a/cmd/api_session.go b/cmd/api_session.go new file mode 100644 index 0000000..d291196 --- /dev/null +++ b/cmd/api_session.go @@ -0,0 +1,188 @@ +package cmd + +import ( + "encoding/json" + "net/http" + + "github.com/Siddhant-K-code/distill/pkg/session" +) + +// SessionAPI handles session-related HTTP endpoints. +type SessionAPI struct { + store *session.SQLiteStore +} + +// RegisterSessionRoutes adds session endpoints to the given mux. +func (s *SessionAPI) RegisterSessionRoutes(mux *http.ServeMux, mw func(string, http.HandlerFunc) http.HandlerFunc) { + mux.HandleFunc("/v1/session/create", mw("/v1/session/create", s.handleCreate)) + mux.HandleFunc("/v1/session/push", mw("/v1/session/push", s.handlePush)) + mux.HandleFunc("/v1/session/context", mw("/v1/session/context", s.handleContext)) + mux.HandleFunc("/v1/session/delete", mw("/v1/session/delete", s.handleDelete)) + mux.HandleFunc("/v1/session/get", mw("/v1/session/get", s.handleGet)) +} + +func (s *SessionAPI) handleCreate(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var req session.CreateRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid JSON: "+err.Error(), http.StatusBadRequest) + return + } + + sess, err := s.store.Create(r.Context(), req) + if err != nil { + if err == session.ErrSessionExists { + http.Error(w, err.Error(), http.StatusConflict) + return + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(sess) +} + +func (s *SessionAPI) handlePush(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var req session.PushRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid JSON: "+err.Error(), http.StatusBadRequest) + return + } + + if req.SessionID == "" { + http.Error(w, "session_id is required", http.StatusBadRequest) + return + } + + result, err := s.store.Push(r.Context(), req) + if err != nil { + if err == session.ErrSessionNotFound { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + if err == session.ErrOverBudget { + http.Error(w, err.Error(), http.StatusRequestEntityTooLarge) + return + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(result) +} + +func (s *SessionAPI) handleContext(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost && r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var req session.ContextRequest + + if r.Method == http.MethodPost { + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid JSON: "+err.Error(), http.StatusBadRequest) + return + } + } else { + req.SessionID = r.URL.Query().Get("session_id") + req.Role = r.URL.Query().Get("role") + } + + if req.SessionID == "" { + http.Error(w, "session_id is required", http.StatusBadRequest) + return + } + + result, err := s.store.Context(r.Context(), req) + if err != nil { + if err == session.ErrSessionNotFound { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(result) +} + +func (s *SessionAPI) handleGet(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet && r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + sessionID := r.URL.Query().Get("session_id") + if sessionID == "" && r.Method == http.MethodPost { + var body struct { + SessionID string `json:"session_id"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err == nil { + sessionID = body.SessionID + } + } + + if sessionID == "" { + http.Error(w, "session_id is required", http.StatusBadRequest) + return + } + + sess, err := s.store.Get(r.Context(), sessionID) + if err != nil { + if err == session.ErrSessionNotFound { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(sess) +} + +func (s *SessionAPI) handleDelete(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost && r.Method != http.MethodDelete { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var req struct { + SessionID string `json:"session_id"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid JSON: "+err.Error(), http.StatusBadRequest) + return + } + + if req.SessionID == "" { + http.Error(w, "session_id is required", http.StatusBadRequest) + return + } + + result, err := s.store.Delete(r.Context(), req.SessionID) + if err != nil { + if err == session.ErrSessionNotFound { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(result) +} diff --git a/cmd/mcp.go b/cmd/mcp.go index 0e43e1f..659f4ac 100644 --- a/cmd/mcp.go +++ b/cmd/mcp.go @@ -11,6 +11,7 @@ import ( "github.com/Siddhant-K-code/distill/pkg/embedding/openai" "github.com/Siddhant-K-code/distill/pkg/memory" "github.com/Siddhant-K-code/distill/pkg/retriever" + "github.com/Siddhant-K-code/distill/pkg/session" pcretriever "github.com/Siddhant-K-code/distill/pkg/retriever/pinecone" qdretriever "github.com/Siddhant-K-code/distill/pkg/retriever/qdrant" "github.com/Siddhant-K-code/distill/pkg/types" @@ -92,6 +93,8 @@ func init() { // Memory store mcpCmd.Flags().Bool("memory", false, "Enable persistent memory store") mcpCmd.Flags().String("memory-db", "distill-memory.db", "SQLite database path for memory store") + mcpCmd.Flags().Bool("session", false, "Enable session management") + mcpCmd.Flags().String("session-db", "distill-sessions.db", "SQLite database path for session store") // Default deduplication settings mcpCmd.Flags().Int("over-fetch-k", 50, "Default over-fetch count") @@ -102,10 +105,11 @@ func init() { // MCPServer wraps the MCP server with Distill capabilities type MCPServer struct { - broker *contextlab.Broker - embedder retriever.EmbeddingProvider - cfg contextlab.BrokerConfig - memStore *memory.SQLiteStore + broker *contextlab.Broker + embedder retriever.EmbeddingProvider + cfg contextlab.BrokerConfig + memStore *memory.SQLiteStore + sessStore *session.SQLiteStore } func runMCP(cmd *cobra.Command, args []string) error { @@ -166,6 +170,20 @@ func runMCP(cmd *cobra.Command, args []string) error { mcpSrv.memStore = memStore } + // Create session store (opt-in) + enableSession, _ := cmd.Flags().GetBool("session") + if enableSession { + sessDBPath, _ := cmd.Flags().GetString("session-db") + sessCfg := session.DefaultConfig() + sessCfg.DefaultDedupThreshold = threshold + sessStore, err := session.NewSQLiteStore(sessDBPath, sessCfg) + if err != nil { + return fmt.Errorf("failed to create session store: %w", err) + } + defer func() { _ = sessStore.Close() }() + mcpSrv.sessStore = sessStore + } + // Create embedding provider if OpenAI key is provided if openaiKey != "" { embedder, err := openai.NewClient(openai.Config{ @@ -430,6 +448,70 @@ Use this to clean up outdated or incorrect memories.`), ) s.AddTool(memoryStatsTool, m.handleMemoryStats) } + + // Session tools (opt-in via --session) + if m.sessStore != nil { + createSessionTool := mcp.NewTool("create_session", + mcp.WithDescription(`Create a new context window session with a token budget. + +Use this at the start of a task to track context incrementally.`), + mcp.WithString("session_id", + mcp.Description("Session ID (auto-generated if empty)"), + ), + mcp.WithNumber("max_tokens", + mcp.Description("Token budget for the session (default: 128000)"), + ), + ) + s.AddTool(createSessionTool, m.handleCreateSession) + + pushSessionTool := mcp.NewTool("push_session", + mcp.WithDescription(`Push context entries to a session. Entries are deduplicated +and the token budget is enforced via compression and eviction.`), + mcp.WithString("session_id", + mcp.Description("Session ID"), + mcp.Required(), + ), + mcp.WithString("content", + mcp.Description("Entry content"), + mcp.Required(), + ), + mcp.WithString("role", + mcp.Description("Entry role: user, assistant, tool, system (default: tool)"), + ), + mcp.WithString("source", + mcp.Description("Entry source (e.g. file_read, search)"), + ), + mcp.WithNumber("importance", + mcp.Description("Entry importance 0-1 (default: 0.5, higher = harder to evict)"), + ), + ) + s.AddTool(pushSessionTool, m.handlePushSession) + + sessionContextTool := mcp.NewTool("session_context", + mcp.WithDescription(`Read the current context window for a session. +Returns entries in push order with compression levels and token counts.`), + mcp.WithString("session_id", + mcp.Description("Session ID"), + mcp.Required(), + ), + mcp.WithNumber("max_tokens", + mcp.Description("Max tokens to return (0 = all)"), + ), + mcp.WithString("role", + mcp.Description("Filter by role"), + ), + ) + s.AddTool(sessionContextTool, m.handleSessionContext) + + deleteSessionTool := mcp.NewTool("delete_session", + mcp.WithDescription("Delete a session and all its entries."), + mcp.WithString("session_id", + mcp.Description("Session ID"), + mcp.Required(), + ), + ) + s.AddTool(deleteSessionTool, m.handleDeleteSession) + } } // System prompt that guides AI assistants to use deduplication diff --git a/cmd/mcp_session.go b/cmd/mcp_session.go new file mode 100644 index 0000000..6c54a10 --- /dev/null +++ b/cmd/mcp_session.go @@ -0,0 +1,120 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/Siddhant-K-code/distill/pkg/session" + "github.com/mark3labs/mcp-go/mcp" +) + +func (m *MCPServer) handleCreateSession(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args := request.GetArguments() + + sessionID, _ := args["session_id"].(string) + maxTokens := 128000 + if v, ok := args["max_tokens"].(float64); ok && v > 0 { + maxTokens = int(v) + } + + sess, err := m.sessStore.Create(ctx, session.CreateRequest{ + SessionID: sessionID, + MaxTokens: maxTokens, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("create session: %v", err)), nil + } + + data, _ := json.MarshalIndent(sess, "", " ") + return mcp.NewToolResultText(string(data)), nil +} + +func (m *MCPServer) handlePushSession(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args := request.GetArguments() + + sessionID, _ := args["session_id"].(string) + if sessionID == "" { + return mcp.NewToolResultError("session_id is required"), nil + } + + role, _ := args["role"].(string) + if role == "" { + role = "tool" + } + + content, _ := args["content"].(string) + if content == "" { + return mcp.NewToolResultError("content is required"), nil + } + + source, _ := args["source"].(string) + + importance := 0.5 + if v, ok := args["importance"].(float64); ok && v > 0 { + importance = v + } + + result, err := m.sessStore.Push(ctx, session.PushRequest{ + SessionID: sessionID, + Entries: []session.PushEntry{ + { + Role: role, + Content: content, + Source: source, + Importance: importance, + }, + }, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("push: %v", err)), nil + } + + data, _ := json.MarshalIndent(result, "", " ") + return mcp.NewToolResultText(string(data)), nil +} + +func (m *MCPServer) handleSessionContext(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args := request.GetArguments() + + sessionID, _ := args["session_id"].(string) + if sessionID == "" { + return mcp.NewToolResultError("session_id is required"), nil + } + + maxTokens := 0 + if v, ok := args["max_tokens"].(float64); ok && v > 0 { + maxTokens = int(v) + } + + role, _ := args["role"].(string) + + result, err := m.sessStore.Context(ctx, session.ContextRequest{ + SessionID: sessionID, + MaxTokens: maxTokens, + Role: role, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("context: %v", err)), nil + } + + data, _ := json.MarshalIndent(result, "", " ") + return mcp.NewToolResultText(string(data)), nil +} + +func (m *MCPServer) handleDeleteSession(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args := request.GetArguments() + + sessionID, _ := args["session_id"].(string) + if sessionID == "" { + return mcp.NewToolResultError("session_id is required"), nil + } + + result, err := m.sessStore.Delete(ctx, sessionID) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("delete: %v", err)), nil + } + + data, _ := json.MarshalIndent(result, "", " ") + return mcp.NewToolResultText(string(data)), nil +} diff --git a/cmd/session.go b/cmd/session.go new file mode 100644 index 0000000..33080e5 --- /dev/null +++ b/cmd/session.go @@ -0,0 +1,209 @@ +package cmd + +import ( + "context" + "encoding/json" + "os" + + "github.com/Siddhant-K-code/distill/pkg/session" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +var sessionCmd = &cobra.Command{ + Use: "session", + Short: "Manage stateful context windows", + Long: `Create and manage token-budgeted context windows for AI agent sessions. + +Entries are deduplicated on push, compressed as they age, and evicted +when the token budget is exceeded. + +Examples: + distill session create --max-tokens 128000 + distill session push --session-id abc --role user --content "Fix the bug" + distill session context --session-id abc + distill session delete --session-id abc`, +} + +var sessionCreateCmd = &cobra.Command{ + Use: "create", + Short: "Create a new session", + RunE: runSessionCreate, +} + +var sessionPushCmd = &cobra.Command{ + Use: "push", + Short: "Push entries to a session", + RunE: runSessionPush, +} + +var sessionContextCmd = &cobra.Command{ + Use: "context", + Short: "Read the current context window", + RunE: runSessionContext, +} + +var sessionDeleteCmd = &cobra.Command{ + Use: "delete", + Short: "Delete a session and all its entries", + RunE: runSessionDelete, +} + +func init() { + rootCmd.AddCommand(sessionCmd) + sessionCmd.AddCommand(sessionCreateCmd, sessionPushCmd, sessionContextCmd, sessionDeleteCmd) + + // Shared flags + sessionCmd.PersistentFlags().String("db", "", "SQLite database path (default: distill-sessions.db)") + + // Create flags + sessionCreateCmd.Flags().String("session-id", "", "Session ID (auto-generated if empty)") + sessionCreateCmd.Flags().Int("max-tokens", 128000, "Token budget for the session") + sessionCreateCmd.Flags().Float64("dedup-threshold", 0.15, "Cosine distance threshold for dedup") + sessionCreateCmd.Flags().Int("preserve-recent", 10, "Always keep last N entries uncompressed") + + // Push flags + sessionPushCmd.Flags().String("session-id", "", "Session ID") + sessionPushCmd.Flags().String("role", "user", "Entry role (user, assistant, tool, system)") + sessionPushCmd.Flags().String("content", "", "Entry content") + sessionPushCmd.Flags().String("source", "", "Entry source (e.g. file_read, search)") + sessionPushCmd.Flags().Float64("importance", 0.5, "Entry importance (0-1)") + _ = sessionPushCmd.MarkFlagRequired("session-id") + _ = sessionPushCmd.MarkFlagRequired("content") + + // Context flags + sessionContextCmd.Flags().String("session-id", "", "Session ID") + sessionContextCmd.Flags().Int("max-tokens", 0, "Max tokens to return (0 = all)") + sessionContextCmd.Flags().String("role", "", "Filter by role") + _ = sessionContextCmd.MarkFlagRequired("session-id") + + // Delete flags + sessionDeleteCmd.Flags().String("session-id", "", "Session ID") + _ = sessionDeleteCmd.MarkFlagRequired("session-id") +} + +func runSessionCreate(cmd *cobra.Command, _ []string) error { + store, err := sessionStoreFromFlags(cmd) + if err != nil { + return err + } + defer func() { _ = store.Close() }() + + sessionID, _ := cmd.Flags().GetString("session-id") + maxTokens, _ := cmd.Flags().GetInt("max-tokens") + threshold, _ := cmd.Flags().GetFloat64("dedup-threshold") + preserveRecent, _ := cmd.Flags().GetInt("preserve-recent") + + sess, err := store.Create(context.Background(), session.CreateRequest{ + SessionID: sessionID, + MaxTokens: maxTokens, + DedupThreshold: threshold, + PreserveRecent: preserveRecent, + }) + if err != nil { + return err + } + + return json.NewEncoder(os.Stdout).Encode(sess) +} + +func runSessionPush(cmd *cobra.Command, _ []string) error { + store, err := sessionStoreFromFlags(cmd) + if err != nil { + return err + } + defer func() { _ = store.Close() }() + + sessionID, _ := cmd.Flags().GetString("session-id") + role, _ := cmd.Flags().GetString("role") + content, _ := cmd.Flags().GetString("content") + source, _ := cmd.Flags().GetString("source") + importance, _ := cmd.Flags().GetFloat64("importance") + + result, err := store.Push(context.Background(), session.PushRequest{ + SessionID: sessionID, + Entries: []session.PushEntry{ + { + Role: role, + Content: content, + Source: source, + Importance: importance, + }, + }, + }) + if err != nil { + return err + } + + return json.NewEncoder(os.Stdout).Encode(result) +} + +func runSessionContext(cmd *cobra.Command, _ []string) error { + store, err := sessionStoreFromFlags(cmd) + if err != nil { + return err + } + defer func() { _ = store.Close() }() + + sessionID, _ := cmd.Flags().GetString("session-id") + maxTokens, _ := cmd.Flags().GetInt("max-tokens") + role, _ := cmd.Flags().GetString("role") + + result, err := store.Context(context.Background(), session.ContextRequest{ + SessionID: sessionID, + MaxTokens: maxTokens, + Role: role, + }) + if err != nil { + return err + } + + return json.NewEncoder(os.Stdout).Encode(result) +} + +func runSessionDelete(cmd *cobra.Command, _ []string) error { + store, err := sessionStoreFromFlags(cmd) + if err != nil { + return err + } + defer func() { _ = store.Close() }() + + sessionID, _ := cmd.Flags().GetString("session-id") + + result, err := store.Delete(context.Background(), sessionID) + if err != nil { + return err + } + + return json.NewEncoder(os.Stdout).Encode(result) +} + +// sessionStoreFromFlags creates a session store from CLI flags. +func sessionStoreFromFlags(cmd *cobra.Command) (*session.SQLiteStore, error) { + dbPath, _ := cmd.Flags().GetString("db") + if dbPath == "" { + dbPath = viper.GetString("session.db_path") + } + return newSessionStore(dbPath) +} + +// newSessionStore creates a session store with the given DB path, +// applying viper config overrides. Used by CLI, API, and MCP. +func newSessionStore(dbPath string) (*session.SQLiteStore, error) { + if dbPath == "" { + dbPath = "distill-sessions.db" + } + cfg := session.DefaultConfig() + + threshold := viper.GetFloat64("session.dedup_threshold") + if threshold > 0 { + cfg.DefaultDedupThreshold = threshold + } + + maxTokens := viper.GetInt("session.max_tokens") + if maxTokens > 0 { + cfg.DefaultMaxTokens = maxTokens + } + + return session.NewSQLiteStore(dbPath, cfg) +} diff --git a/examples/session_api.sh b/examples/session_api.sh new file mode 100755 index 0000000..3e39903 --- /dev/null +++ b/examples/session_api.sh @@ -0,0 +1,78 @@ +#!/bin/bash +# Example: Session-based context window management via the API +# +# Start the server (in another terminal): +# distill api --port 8080 --session +# +# Sessions track context for long-running agent tasks with a token budget. +# Entries are deduplicated on push, compressed as they age, and evicted +# when the budget is exceeded. + +BASE="http://localhost:8080" + +echo "=== Create session ===" +curl -s -X POST "$BASE/v1/session/create" \ + -H "Content-Type: application/json" \ + -d '{ + "session_id": "demo-task", + "max_tokens": 50000 + }' | jq . + +echo "" +echo "=== Push context entries ===" +curl -s -X POST "$BASE/v1/session/push" \ + -H "Content-Type: application/json" \ + -d '{ + "session_id": "demo-task", + "entries": [ + { + "role": "user", + "content": "Fix the JWT validation bug in the auth service", + "importance": 1.0 + }, + { + "role": "tool", + "content": "File: auth/jwt.go\n\npackage auth\n\nimport (\n\t\"crypto/rsa\"\n\t\"time\"\n)\n\nfunc ValidateToken(token string, key *rsa.PublicKey) error {\n\t// BUG: not checking expiry\n\treturn nil\n}", + "source": "file_read", + "importance": 0.8 + }, + { + "role": "tool", + "content": "File: auth/jwt_test.go\n\npackage auth\n\nimport \"testing\"\n\nfunc TestValidateToken(t *testing.T) {\n\t// No expiry test\n}", + "source": "file_read", + "importance": 0.6 + }, + { + "role": "assistant", + "content": "The ValidateToken function is missing expiry checks. I will add time.Now().After(claims.ExpiresAt) validation.", + "importance": 0.9 + } + ] + }' | jq . + +echo "" +echo "=== Read context window ===" +curl -s -X POST "$BASE/v1/session/context" \ + -H "Content-Type: application/json" \ + -d '{ + "session_id": "demo-task" + }' | jq . + +echo "" +echo "=== Read only tool entries ===" +curl -s -X POST "$BASE/v1/session/context" \ + -H "Content-Type: application/json" \ + -d '{ + "session_id": "demo-task", + "role": "tool" + }' | jq . + +echo "" +echo "=== Get session metadata ===" +curl -s "$BASE/v1/session/get?session_id=demo-task" | jq . + +echo "" +echo "=== Clean up ===" +curl -s -X DELETE "$BASE/v1/session/delete" \ + -H "Content-Type: application/json" \ + -d '{"session_id": "demo-task"}' | jq . diff --git a/mcp/README.md b/mcp/README.md index 1898b2a..7d71abe 100644 --- a/mcp/README.md +++ b/mcp/README.md @@ -24,8 +24,11 @@ It clusters semantically similar chunks, picks the best representative from each # Build go build -o distill . -# Start MCP server +# Start MCP server (dedup only) ./distill mcp + +# With memory and sessions enabled +./distill mcp --memory --session ``` ### Remote (HTTP) - Hosted deployment @@ -34,6 +37,9 @@ go build -o distill . # Start HTTP server ./distill mcp --transport http --port 8081 +# With all features +./distill mcp --transport http --port 8081 --memory --session + # Or deploy to Fly.io fly deploy -c fly.mcp.toml ``` @@ -74,6 +80,78 @@ Query a vector database with automatic deduplication. Requires `--backend` flag. Analyze chunks for redundancy without removing any. Use to understand overlap before deduplicating. +### `store_memory` (requires `--memory`) + +Store context that should persist across sessions. Memories are deduplicated on write. + +```json +{ + "text": "Auth service uses JWT with RS256 signing", + "tags": ["auth", "jwt"], + "source": "code_review" +} +``` + +### `recall_memory` (requires `--memory`) + +Recall relevant memories by semantic similarity + recency. + +```json +{ + "query": "How does authentication work?", + "max_results": 5, + "tags": ["auth"] +} +``` + +### `forget_memory` (requires `--memory`) + +Remove memories by tag or age. + +### `memory_stats` (requires `--memory`) + +Get memory store statistics (total count, by decay level, by source). + +### `create_session` (requires `--session`) + +Create a token-budgeted context window for a task. + +```json +{ + "session_id": "fix-auth-bug", + "max_tokens": 128000 +} +``` + +### `push_session` (requires `--session`) + +Push context entries to a session. Entries are deduplicated and the token budget is enforced via compression and eviction. + +```json +{ + "session_id": "fix-auth-bug", + "content": "File: auth/jwt.go\n...", + "role": "tool", + "source": "file_read", + "importance": 0.8 +} +``` + +### `session_context` (requires `--session`) + +Read the current context window. Returns entries in push order with compression levels and token counts. + +```json +{ + "session_id": "fix-auth-bug", + "max_tokens": 50000 +} +``` + +### `delete_session` (requires `--session`) + +Delete a session and all its entries. + ## Resources ### `distill://system-prompt` @@ -102,7 +180,7 @@ Arguments: Add to `~/Library/Application Support/Claude/claude_desktop_config.json`: -**Local (stdio):** +**Local (stdio) - dedup only:** ```json { "mcpServers": { @@ -114,6 +192,21 @@ Add to `~/Library/Application Support/Claude/claude_desktop_config.json`: } ``` +**With memory and sessions:** +```json +{ + "mcpServers": { + "distill": { + "command": "/path/to/distill", + "args": ["mcp", "--memory", "--session"], + "env": { + "OPENAI_API_KEY": "your-openai-key" + } + } + } +} +``` + **Remote (HTTP):** ```json { @@ -131,7 +224,7 @@ Add to `~/Library/Application Support/Claude/claude_desktop_config.json`: "mcpServers": { "distill": { "command": "/path/to/distill", - "args": ["mcp", "--backend", "pinecone", "--index", "my-index"], + "args": ["mcp", "--backend", "pinecone", "--index", "my-index", "--memory", "--session"], "env": { "PINECONE_API_KEY": "your-api-key", "OPENAI_API_KEY": "your-openai-key" @@ -219,7 +312,30 @@ AI: [calls analyze_redundancy] AI: "Found 40% redundancy across 3 clusters. Want me to deduplicate?" ``` -### Pattern 4: Direct Vector DB Query +### Pattern 4: Session-Based Context Tracking + +Track context across a multi-step task: + +``` +1. AI creates a session: create_session("fix-auth-bug", 128000) +2. AI reads files: push_session(role="tool", content=file, source="file_read") +3. AI reads tests: push_session(role="tool", content=tests, source="file_read") +4. Budget exceeded → oldest low-importance entries compressed automatically +5. AI reads context: session_context() → deduplicated, budget-aware window +6. Task done: delete_session() +``` + +### Pattern 5: Cross-Session Memory + +Persist knowledge that should survive across sessions: + +``` +1. AI discovers a pattern: store_memory("Auth uses JWT with RS256", tags=["auth"]) +2. Next session, different task: recall_memory("How does auth work?") +3. AI gets relevant memories without re-reading files +``` + +### Pattern 6: Direct Vector DB Query If backend is configured, query with automatic deduplication: diff --git a/mcp/claude_desktop_config_full.example.json b/mcp/claude_desktop_config_full.example.json new file mode 100644 index 0000000..b669838 --- /dev/null +++ b/mcp/claude_desktop_config_full.example.json @@ -0,0 +1,18 @@ +{ + "mcpServers": { + "distill": { + "command": "/path/to/distill", + "args": [ + "mcp", + "--memory", + "--session", + "--backend", "pinecone", + "--index", "your-index-name" + ], + "env": { + "PINECONE_API_KEY": "your-pinecone-api-key", + "OPENAI_API_KEY": "your-openai-api-key" + } + } + } +} diff --git a/pkg/session/session.go b/pkg/session/session.go new file mode 100644 index 0000000..dc8078b --- /dev/null +++ b/pkg/session/session.go @@ -0,0 +1,157 @@ +// Package session provides stateful context window management for AI agent +// sessions. Entries are deduplicated on push, compressed as they age, and +// evicted when the token budget is exceeded. +package session + +import ( + "context" + "errors" + "time" +) + +// Common errors. +var ( + ErrSessionNotFound = errors.New("session not found") + ErrSessionExists = errors.New("session already exists") + ErrOverBudget = errors.New("single entry exceeds token budget") +) + +// CompressionLevel indicates how compressed an entry is. +type CompressionLevel int + +const ( + LevelFull CompressionLevel = 0 // Original content + LevelSummary CompressionLevel = 1 // Paragraph summary (~20%) + LevelSentence CompressionLevel = 2 // Single sentence (~5%) + LevelKeywords CompressionLevel = 3 // Keywords only (~1%) +) + +// Entry is a single item in a session's context window. +type Entry struct { + ID string `json:"id"` + Role string `json:"role"` // user, assistant, tool, system + Content string `json:"content"` // current (possibly compressed) text + OriginalContent string `json:"-"` // full text, kept for re-compression + Source string `json:"source,omitempty"` // e.g. file_read, search, conversation + Embedding []float32 `json:"-"` // for dedup comparison + Importance float64 `json:"importance"` // 0-1, higher = harder to evict + Level CompressionLevel `json:"level"` // current compression level + Tokens int `json:"tokens"` // token count of current content + CreatedAt time.Time `json:"created_at"` + CompressedAt time.Time `json:"compressed_at,omitempty"` +} + +// Session holds the state of a single context window. +type Session struct { + ID string `json:"session_id"` + MaxTokens int `json:"max_tokens"` + CurrentTokens int `json:"current_tokens"` + EntryCount int `json:"entry_count"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// CreateRequest is the input for creating a session. +type CreateRequest struct { + SessionID string `json:"session_id,omitempty"` // auto-generated if empty + MaxTokens int `json:"max_tokens"` + DedupThreshold float64 `json:"dedup_threshold,omitempty"` + PreserveRecent int `json:"preserve_recent,omitempty"` // always keep last N at full fidelity +} + +// PushRequest is the input for adding entries to a session. +type PushRequest struct { + SessionID string `json:"session_id"` + Entries []PushEntry `json:"entries"` +} + +// PushEntry is a single entry in a push request. +type PushEntry struct { + Role string `json:"role"` + Content string `json:"content"` + Source string `json:"source,omitempty"` + Embedding []float32 `json:"embedding,omitempty"` + Importance float64 `json:"importance,omitempty"` // default 0.5 +} + +// PushResult is the output of a push operation. +type PushResult struct { + SessionID string `json:"session_id"` + Accepted int `json:"accepted"` + Deduplicated int `json:"deduplicated"` + Compressed int `json:"compressed"` + Evicted int `json:"evicted"` + CurrentTokens int `json:"current_tokens"` + BudgetRemaining int `json:"budget_remaining"` +} + +// ContextRequest is the input for reading a session's context window. +type ContextRequest struct { + SessionID string `json:"session_id"` + MaxTokens int `json:"max_tokens,omitempty"` // 0 = return full window + Role string `json:"role,omitempty"` // filter by role +} + +// ContextResult is the output of a context read. +type ContextResult struct { + Entries []ContextEntry `json:"entries"` + Stats ContextStats `json:"stats"` +} + +// ContextEntry is a single entry returned from a context read. +type ContextEntry struct { + ID string `json:"id"` + Role string `json:"role"` + Content string `json:"content"` + Source string `json:"source,omitempty"` + Level CompressionLevel `json:"level"` + Tokens int `json:"tokens"` + Age string `json:"age"` +} + +// ContextStats contains context window metrics. +type ContextStats struct { + TotalEntries int `json:"total_entries"` + TotalTokens int `json:"total_tokens"` + CompressionLevels map[int]int `json:"compression_levels"` + CompressionSavings int `json:"compression_savings"` // tokens saved by compression +} + +// DeleteResult is the output of deleting a session. +type DeleteResult struct { + SessionID string `json:"session_id"` + EntriesRemoved int `json:"entries_removed"` +} + +// Store is the interface for session backends. +type Store interface { + Create(ctx context.Context, req CreateRequest) (*Session, error) + Push(ctx context.Context, req PushRequest) (*PushResult, error) + Context(ctx context.Context, req ContextRequest) (*ContextResult, error) + Get(ctx context.Context, sessionID string) (*Session, error) + Delete(ctx context.Context, sessionID string) (*DeleteResult, error) + Close() error +} + +// Config holds session store configuration. +type Config struct { + // DefaultMaxTokens is the default token budget for new sessions. + DefaultMaxTokens int + + // DefaultDedupThreshold is the cosine distance below which entries + // are considered duplicates. Default: 0.15. + DefaultDedupThreshold float64 + + // DefaultPreserveRecent is how many recent entries to keep uncompressed. + // Default: 10. + DefaultPreserveRecent int +} + +// DefaultConfig returns sensible defaults. +func DefaultConfig() Config { + return Config{ + DefaultMaxTokens: 128000, + DefaultDedupThreshold: 0.15, + DefaultPreserveRecent: 10, + } +} diff --git a/pkg/session/session_test.go b/pkg/session/session_test.go new file mode 100644 index 0000000..9a919f6 --- /dev/null +++ b/pkg/session/session_test.go @@ -0,0 +1,348 @@ +package session + +import ( + "context" + "math" + "strings" + "testing" +) + +func makeEmbedding(angle float64, dim int) []float32 { + emb := make([]float32, dim) + emb[0] = float32(math.Cos(angle)) + emb[1] = float32(math.Sin(angle)) + return emb +} + +func newTestStore(t *testing.T) *SQLiteStore { + t.Helper() + cfg := DefaultConfig() + cfg.DefaultMaxTokens = 1000 // small budget for testing + cfg.DefaultPreserveRecent = 2 + s, err := NewSQLiteStore(":memory:", cfg) + if err != nil { + t.Fatalf("NewSQLiteStore: %v", err) + } + t.Cleanup(func() { _ = s.Close() }) + return s +} + +func TestCreateAndGet(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + sess, err := s.Create(ctx, CreateRequest{ + SessionID: "test-1", + MaxTokens: 5000, + }) + if err != nil { + t.Fatalf("Create: %v", err) + } + if sess.ID != "test-1" { + t.Errorf("expected id test-1, got %s", sess.ID) + } + if sess.MaxTokens != 5000 { + t.Errorf("expected 5000 max tokens, got %d", sess.MaxTokens) + } + + got, err := s.Get(ctx, "test-1") + if err != nil { + t.Fatalf("Get: %v", err) + } + if got.MaxTokens != 5000 { + t.Errorf("expected 5000, got %d", got.MaxTokens) + } +} + +func TestCreateAutoID(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + sess, err := s.Create(ctx, CreateRequest{MaxTokens: 1000}) + if err != nil { + t.Fatalf("Create: %v", err) + } + if sess.ID == "" { + t.Error("expected auto-generated ID") + } +} + +func TestCreateDuplicate(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + _, err := s.Create(ctx, CreateRequest{SessionID: "dup"}) + if err != nil { + t.Fatalf("Create: %v", err) + } + + _, err = s.Create(ctx, CreateRequest{SessionID: "dup"}) + if err != ErrSessionExists { + t.Errorf("expected ErrSessionExists, got %v", err) + } +} + +func TestPushAndContext(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + _, _ = s.Create(ctx, CreateRequest{SessionID: "s1", MaxTokens: 50000}) + + result, err := s.Push(ctx, PushRequest{ + SessionID: "s1", + Entries: []PushEntry{ + {Role: "user", Content: "Fix the JWT validation bug", Importance: 1.0}, + {Role: "tool", Content: "File: auth/jwt.go\nfunc ValidateToken()...", Source: "file_read"}, + }, + }) + if err != nil { + t.Fatalf("Push: %v", err) + } + if result.Accepted != 2 { + t.Errorf("expected 2 accepted, got %d", result.Accepted) + } + if result.CurrentTokens <= 0 { + t.Error("expected positive token count") + } + + // Read context + ctxResult, err := s.Context(ctx, ContextRequest{SessionID: "s1"}) + if err != nil { + t.Fatalf("Context: %v", err) + } + if len(ctxResult.Entries) != 2 { + t.Errorf("expected 2 entries, got %d", len(ctxResult.Entries)) + } + // Entries should be in push order + if ctxResult.Entries[0].Role != "user" { + t.Errorf("expected first entry role=user, got %s", ctxResult.Entries[0].Role) + } +} + +func TestPushDedup(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + _, _ = s.Create(ctx, CreateRequest{SessionID: "s1", MaxTokens: 50000}) + + emb := makeEmbedding(0, 8) + + r1, err := s.Push(ctx, PushRequest{ + SessionID: "s1", + Entries: []PushEntry{ + {Role: "tool", Content: "File: auth/jwt.go contents...", Embedding: emb}, + }, + }) + if err != nil { + t.Fatalf("Push 1: %v", err) + } + if r1.Accepted != 1 { + t.Errorf("expected 1 accepted, got %d", r1.Accepted) + } + + // Push same embedding again - should be deduped + r2, err := s.Push(ctx, PushRequest{ + SessionID: "s1", + Entries: []PushEntry{ + {Role: "tool", Content: "File: auth/jwt.go (re-read)", Embedding: emb}, + }, + }) + if err != nil { + t.Fatalf("Push 2: %v", err) + } + if r2.Deduplicated != 1 { + t.Errorf("expected 1 deduplicated, got %d", r2.Deduplicated) + } + if r2.Accepted != 0 { + t.Errorf("expected 0 accepted, got %d", r2.Accepted) + } +} + +func TestBudgetEnforcement(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + // Create session with very tight budget + _, _ = s.Create(ctx, CreateRequest{ + SessionID: "tight", + MaxTokens: 50, // ~200 chars + PreserveRecent: 1, + }) + + // Push entries that exceed budget + _, err := s.Push(ctx, PushRequest{ + SessionID: "tight", + Entries: []PushEntry{ + {Role: "user", Content: "First message about authentication and JWT tokens.", Importance: 0.3}, + {Role: "tool", Content: "Second message with file contents from the auth module.", Importance: 0.5}, + {Role: "user", Content: "Third message asking about the bug fix.", Importance: 1.0}, + }, + }) + if err != nil { + t.Fatalf("Push: %v", err) + } + + // Check that budget is enforced + sess, _ := s.Get(ctx, "tight") + if sess.CurrentTokens > 50 { + t.Errorf("expected tokens <= 50, got %d", sess.CurrentTokens) + } +} + +func TestContextWithRoleFilter(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + _, _ = s.Create(ctx, CreateRequest{SessionID: "s1", MaxTokens: 50000}) + + _, _ = s.Push(ctx, PushRequest{ + SessionID: "s1", + Entries: []PushEntry{ + {Role: "user", Content: "Fix the bug"}, + {Role: "tool", Content: "File contents..."}, + {Role: "assistant", Content: "I'll look at that"}, + {Role: "tool", Content: "Test results..."}, + }, + }) + + // Filter by tool role + result, err := s.Context(ctx, ContextRequest{SessionID: "s1", Role: "tool"}) + if err != nil { + t.Fatalf("Context: %v", err) + } + if len(result.Entries) != 2 { + t.Errorf("expected 2 tool entries, got %d", len(result.Entries)) + } + for _, e := range result.Entries { + if e.Role != "tool" { + t.Errorf("expected role=tool, got %s", e.Role) + } + } +} + +func TestContextWithTokenLimit(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + _, _ = s.Create(ctx, CreateRequest{SessionID: "s1", MaxTokens: 50000}) + + _, _ = s.Push(ctx, PushRequest{ + SessionID: "s1", + Entries: []PushEntry{ + {Role: "user", Content: "Short message"}, + {Role: "tool", Content: "This is a much longer message that contains many more tokens and should push us over a small token limit when combined with the first entry"}, + }, + }) + + // Request with tight token limit + result, err := s.Context(ctx, ContextRequest{SessionID: "s1", MaxTokens: 10}) + if err != nil { + t.Fatalf("Context: %v", err) + } + if result.Stats.TotalTokens > 10 { + t.Errorf("expected tokens <= 10, got %d", result.Stats.TotalTokens) + } +} + +func TestDelete(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + _, _ = s.Create(ctx, CreateRequest{SessionID: "del"}) + _, _ = s.Push(ctx, PushRequest{ + SessionID: "del", + Entries: []PushEntry{{Role: "user", Content: "test"}}, + }) + + result, err := s.Delete(ctx, "del") + if err != nil { + t.Fatalf("Delete: %v", err) + } + if result.EntriesRemoved != 1 { + t.Errorf("expected 1 removed, got %d", result.EntriesRemoved) + } + + _, err = s.Get(ctx, "del") + if err != ErrSessionNotFound { + t.Errorf("expected ErrSessionNotFound, got %v", err) + } +} + +func TestDeleteNotFound(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + _, err := s.Delete(ctx, "nonexistent") + if err != ErrSessionNotFound { + t.Errorf("expected ErrSessionNotFound, got %v", err) + } +} + +func TestPushToNonexistentSession(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + _, err := s.Push(ctx, PushRequest{ + SessionID: "nope", + Entries: []PushEntry{{Role: "user", Content: "test"}}, + }) + if err != ErrSessionNotFound { + t.Errorf("expected ErrSessionNotFound, got %v", err) + } +} + +func TestPushEmptyContent(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + _, _ = s.Create(ctx, CreateRequest{SessionID: "s1"}) + + result, err := s.Push(ctx, PushRequest{ + SessionID: "s1", + Entries: []PushEntry{ + {Role: "user", Content: ""}, + {Role: "user", Content: "Valid"}, + }, + }) + if err != nil { + t.Fatalf("Push: %v", err) + } + if result.Accepted != 1 { + t.Errorf("expected 1 accepted (empty skipped), got %d", result.Accepted) + } +} + +func TestCompressToLevel(t *testing.T) { + text := "The authentication service uses JWT tokens with RS256 signing. It validates tokens on every request. The token expiry is set to 24 hours. Refresh tokens are stored in Redis with a 7-day TTL. The service also supports OAuth2 for third-party integrations." + + summary := compressToLevel(text, LevelSummary) + if len(summary) >= len(text) { + t.Errorf("summary should be shorter than original: %d >= %d", len(summary), len(text)) + } + if summary == "" { + t.Error("summary should not be empty") + } + + sentence := compressToLevel(text, LevelSentence) + if len(sentence) >= len(text) { + t.Errorf("sentence should be shorter than original: %d >= %d", len(sentence), len(text)) + } + // Should end with a sentence delimiter + last := sentence[len(sentence)-1] + if last != '.' && last != '!' && last != '?' { + t.Errorf("sentence should end with delimiter, got %q", string(last)) + } + + keywords := compressToLevel(text, LevelKeywords) + if keywords == "" { + t.Error("keywords should not be empty") + } + // Keywords should be comma-separated + if !strings.Contains(keywords, ",") { + t.Errorf("keywords should be comma-separated, got %q", keywords) + } + // Keywords should differ from full text + if keywords == text { + t.Error("keywords should differ from original text") + } +} diff --git a/pkg/session/sqlite.go b/pkg/session/sqlite.go new file mode 100644 index 0000000..0217cba --- /dev/null +++ b/pkg/session/sqlite.go @@ -0,0 +1,707 @@ +package session + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/binary" + "encoding/hex" + "fmt" + "math" + "sort" + "strings" + "time" + + "github.com/Siddhant-K-code/distill/pkg/compress" + distillmath "github.com/Siddhant-K-code/distill/pkg/math" + "github.com/Siddhant-K-code/distill/pkg/types" + _ "modernc.org/sqlite" +) + +// SQLiteStore implements Store using SQLite. +// Single connection (SetMaxOpenConns(1)) - SQLite handles serialization. +type SQLiteStore struct { + db *sql.DB + cfg Config +} + +// NewSQLiteStore creates a new SQLite-backed session store. +func NewSQLiteStore(dsn string, cfg Config) (*SQLiteStore, error) { + if dsn == "" { + dsn = ":memory:" + } + + db, err := sql.Open("sqlite", dsn) + if err != nil { + return nil, fmt.Errorf("open sqlite: %w", err) + } + + db.SetMaxOpenConns(1) + + if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { + _ = db.Close() + return nil, fmt.Errorf("set WAL mode: %w", err) + } + if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { + _ = db.Close() + return nil, fmt.Errorf("enable foreign keys: %w", err) + } + + s := &SQLiteStore{db: db, cfg: cfg} + if err := s.migrate(); err != nil { + _ = db.Close() + return nil, fmt.Errorf("migrate: %w", err) + } + + return s, nil +} + +func (s *SQLiteStore) migrate() error { + schema := ` + CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + max_tokens INTEGER NOT NULL, + dedup_threshold REAL NOT NULL DEFAULT 0.15, + preserve_recent INTEGER NOT NULL DEFAULT 10, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS session_entries ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + role TEXT NOT NULL DEFAULT '', + content TEXT NOT NULL, + original_content TEXT NOT NULL, + source TEXT DEFAULT '', + embedding BLOB, + importance REAL NOT NULL DEFAULT 0.5, + compression_level INTEGER NOT NULL DEFAULT 0, + tokens INTEGER NOT NULL DEFAULT 0, + seq INTEGER NOT NULL, + created_at TEXT NOT NULL, + compressed_at TEXT DEFAULT '', + FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE + ); + CREATE INDEX IF NOT EXISTS idx_entries_session ON session_entries(session_id); + CREATE INDEX IF NOT EXISTS idx_entries_seq ON session_entries(session_id, seq); + ` + _, err := s.db.Exec(schema) + return err +} + +// Create creates a new session. +func (s *SQLiteStore) Create(ctx context.Context, req CreateRequest) (*Session, error) { + id := req.SessionID + if id == "" { + id = generateID() + } + + maxTokens := req.MaxTokens + if maxTokens <= 0 { + maxTokens = s.cfg.DefaultMaxTokens + } + + threshold := req.DedupThreshold + if threshold <= 0 { + threshold = s.cfg.DefaultDedupThreshold + } + + preserveRecent := req.PreserveRecent + if preserveRecent <= 0 { + preserveRecent = s.cfg.DefaultPreserveRecent + } + + nowTime := time.Now().UTC() + now := nowTime.Format(time.RFC3339Nano) + + _, err := s.db.ExecContext(ctx, + `INSERT INTO sessions (id, max_tokens, dedup_threshold, preserve_recent, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?)`, + id, maxTokens, threshold, preserveRecent, now, now, + ) + if err != nil { + if strings.Contains(err.Error(), "UNIQUE constraint") { + return nil, ErrSessionExists + } + return nil, fmt.Errorf("insert session: %w", err) + } + + return &Session{ + ID: id, + MaxTokens: maxTokens, + CurrentTokens: 0, + EntryCount: 0, + CreatedAt: nowTime, + UpdatedAt: nowTime, + }, nil +} + +// Push adds entries to a session with dedup and budget enforcement. +func (s *SQLiteStore) Push(ctx context.Context, req PushRequest) (*PushResult, error) { + // Load session config + sess, err := s.loadSessionConfig(ctx, req.SessionID) + if err != nil { + return nil, err + } + + result := &PushResult{SessionID: req.SessionID} + + // Get current max seq + var maxSeq int + _ = s.db.QueryRowContext(ctx, + "SELECT COALESCE(MAX(seq), 0) FROM session_entries WHERE session_id = ?", + req.SessionID, + ).Scan(&maxSeq) + + for _, entry := range req.Entries { + if entry.Content == "" { + continue + } + + importance := entry.Importance + if importance <= 0 { + importance = 0.5 + } + + // Check for duplicates + if len(entry.Embedding) > 0 { + isDup, err := s.isDuplicate(ctx, req.SessionID, entry.Embedding, sess.dedupThreshold) + if err != nil { + return nil, fmt.Errorf("dedup check: %w", err) + } + if isDup { + result.Deduplicated++ + continue + } + } + + tokens := estimateTokens(entry.Content) + + // Reject single entries that exceed the entire budget + if tokens > sess.maxTokens { + return nil, ErrOverBudget + } + + maxSeq++ + id := generateID() + now := time.Now().UTC().Format(time.RFC3339Nano) + + _, err := s.db.ExecContext(ctx, + `INSERT INTO session_entries + (id, session_id, role, content, original_content, source, embedding, importance, compression_level, tokens, seq, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0, ?, ?, ?)`, + id, req.SessionID, entry.Role, entry.Content, entry.Content, + entry.Source, encodeEmbedding(entry.Embedding), importance, + tokens, maxSeq, now, + ) + if err != nil { + return nil, fmt.Errorf("insert entry: %w", err) + } + + result.Accepted++ + } + + // Enforce token budget - loop until within budget or no progress + for { + c, e, err := s.enforceBudget(ctx, req.SessionID, sess) + if err != nil { + return nil, fmt.Errorf("enforce budget: %w", err) + } + result.Compressed += c + result.Evicted += e + if c == 0 && e == 0 { + break // no progress possible + } + } + + // Update session timestamp + now := time.Now().UTC().Format(time.RFC3339Nano) + _, _ = s.db.ExecContext(ctx, + "UPDATE sessions SET updated_at = ? WHERE id = ?", + now, req.SessionID, + ) + + // Compute current tokens + var currentTokens int + _ = s.db.QueryRowContext(ctx, + "SELECT COALESCE(SUM(tokens), 0) FROM session_entries WHERE session_id = ?", + req.SessionID, + ).Scan(¤tTokens) + + result.CurrentTokens = currentTokens + result.BudgetRemaining = sess.maxTokens - currentTokens + + return result, nil +} + +// Context returns the current context window for a session. +func (s *SQLiteStore) Context(ctx context.Context, req ContextRequest) (*ContextResult, error) { + // Verify session exists + var exists int + err := s.db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM sessions WHERE id = ?", req.SessionID, + ).Scan(&exists) + if err != nil || exists == 0 { + return nil, ErrSessionNotFound + } + + query := "SELECT id, role, content, source, compression_level, tokens, created_at FROM session_entries WHERE session_id = ?" + args := []interface{}{req.SessionID} + + if req.Role != "" { + query += " AND role = ?" + args = append(args, req.Role) + } + + query += " ORDER BY seq ASC" + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("query entries: %w", err) + } + + type rawEntry struct { + id, role, content, source, createdAt string + level, tokens int + } + var raw []rawEntry + for rows.Next() { + var r rawEntry + if err := rows.Scan(&r.id, &r.role, &r.content, &r.source, &r.level, &r.tokens, &r.createdAt); err != nil { + _ = rows.Close() + return nil, err + } + raw = append(raw, r) + } + if err := rows.Err(); err != nil { + _ = rows.Close() + return nil, err + } + _ = rows.Close() + + now := time.Now() + levels := make(map[int]int) + var entries []ContextEntry + tokenCount := 0 + + for _, r := range raw { + if req.MaxTokens > 0 && tokenCount+r.tokens > req.MaxTokens { + break + } + + created, _ := time.Parse(time.RFC3339Nano, r.createdAt) + age := formatAge(now.Sub(created)) + + entries = append(entries, ContextEntry{ + ID: r.id, + Role: r.role, + Content: r.content, + Source: r.source, + Level: CompressionLevel(r.level), + Tokens: r.tokens, + Age: age, + }) + tokenCount += r.tokens + levels[r.level]++ + } + + // Compute compression savings (original tokens - current tokens) + var totalOriginalTokens int + _ = s.db.QueryRowContext(ctx, + "SELECT COALESCE(SUM(LENGTH(original_content)+3)/4, 0) FROM session_entries WHERE session_id = ?", + req.SessionID, + ).Scan(&totalOriginalTokens) + + return &ContextResult{ + Entries: entries, + Stats: ContextStats{ + TotalEntries: len(entries), + TotalTokens: tokenCount, + CompressionLevels: levels, + CompressionSavings: totalOriginalTokens - tokenCount, + }, + }, nil +} + +// Get returns session metadata. +func (s *SQLiteStore) Get(ctx context.Context, sessionID string) (*Session, error) { + var sess Session + var createdStr, updatedStr string + + err := s.db.QueryRowContext(ctx, + "SELECT id, max_tokens, created_at, updated_at FROM sessions WHERE id = ?", + sessionID, + ).Scan(&sess.ID, &sess.MaxTokens, &createdStr, &updatedStr) + if err != nil { + if err == sql.ErrNoRows { + return nil, ErrSessionNotFound + } + return nil, err + } + + sess.CreatedAt, _ = time.Parse(time.RFC3339Nano, createdStr) + sess.UpdatedAt, _ = time.Parse(time.RFC3339Nano, updatedStr) + + _ = s.db.QueryRowContext(ctx, + "SELECT COALESCE(SUM(tokens), 0), COUNT(*) FROM session_entries WHERE session_id = ?", + sessionID, + ).Scan(&sess.CurrentTokens, &sess.EntryCount) + + return &sess, nil +} + +// Delete removes a session and all its entries. +func (s *SQLiteStore) Delete(ctx context.Context, sessionID string) (*DeleteResult, error) { + var count int + _ = s.db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM session_entries WHERE session_id = ?", + sessionID, + ).Scan(&count) + + res, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE id = ?", sessionID) + if err != nil { + return nil, fmt.Errorf("delete session: %w", err) + } + + affected, _ := res.RowsAffected() + if affected == 0 { + return nil, ErrSessionNotFound + } + + return &DeleteResult{ + SessionID: sessionID, + EntriesRemoved: count, + }, nil +} + +// Close closes the database connection. +func (s *SQLiteStore) Close() error { + return s.db.Close() +} + +// --- internal --- + +type sessionConfig struct { + maxTokens int + dedupThreshold float64 + preserveRecent int +} + +func (s *SQLiteStore) loadSessionConfig(ctx context.Context, sessionID string) (*sessionConfig, error) { + var cfg sessionConfig + err := s.db.QueryRowContext(ctx, + "SELECT max_tokens, dedup_threshold, preserve_recent FROM sessions WHERE id = ?", + sessionID, + ).Scan(&cfg.maxTokens, &cfg.dedupThreshold, &cfg.preserveRecent) + if err != nil { + if err == sql.ErrNoRows { + return nil, ErrSessionNotFound + } + return nil, err + } + return &cfg, nil +} + +// isDuplicate checks if an embedding is within threshold of any existing entry. +// +// TODO: Full table scan (O(n) per entry). Fine for typical session sizes +// (< 1K entries). For larger sessions, consider caching embeddings in memory. +func (s *SQLiteStore) isDuplicate(ctx context.Context, sessionID string, embedding []float32, threshold float64) (bool, error) { + rows, err := s.db.QueryContext(ctx, + "SELECT embedding FROM session_entries WHERE session_id = ? AND embedding IS NOT NULL", + sessionID, + ) + if err != nil { + return false, err + } + + // Scan all then close - single connection pattern. + var blobs [][]byte + for rows.Next() { + var blob []byte + if err := rows.Scan(&blob); err != nil { + _ = rows.Close() + return false, err + } + blobs = append(blobs, blob) + } + if err := rows.Err(); err != nil { + _ = rows.Close() + return false, err + } + _ = rows.Close() + + for _, blob := range blobs { + existing := decodeEmbedding(blob) + if len(existing) == 0 { + continue + } + dist := distillmath.CosineDistance(embedding, existing) + if dist < threshold { + return true, nil + } + } + return false, nil +} + +// compressor reused across calls. +var compressor = compress.NewExtractiveCompressor() + +// enforceBudget compresses and evicts entries until within token budget. +// Returns (compressed count, evicted count). +func (s *SQLiteStore) enforceBudget(ctx context.Context, sessionID string, cfg *sessionConfig) (int, int, error) { + var currentTokens int + _ = s.db.QueryRowContext(ctx, + "SELECT COALESCE(SUM(tokens), 0) FROM session_entries WHERE session_id = ?", + sessionID, + ).Scan(¤tTokens) + + if currentTokens <= cfg.maxTokens { + return 0, 0, nil + } + + compressed := 0 + evicted := 0 + + // Get total entry count to determine which are "recent" + var totalEntries int + _ = s.db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM session_entries WHERE session_id = ?", + sessionID, + ).Scan(&totalEntries) + + // Load compressible entries (oldest first, skip the most recent N) + // We need entries ordered by seq, and we skip the last preserveRecent. + limit := totalEntries - cfg.preserveRecent + if limit <= 0 { + // All entries are "recent" - nothing to compress, but still over budget. + // Evict the oldest non-recent entry as a last resort. + return s.evictOldest(ctx, sessionID, cfg, currentTokens) + } + + rows, err := s.db.QueryContext(ctx, + `SELECT id, original_content, compression_level, importance, tokens + FROM session_entries WHERE session_id = ? + ORDER BY seq ASC LIMIT ?`, + sessionID, limit, + ) + if err != nil { + return 0, 0, err + } + + var candidates []compressCandidate + for rows.Next() { + var c compressCandidate + if err := rows.Scan(&c.id, &c.originalContent, &c.level, &c.importance, &c.tokens); err != nil { + _ = rows.Close() + return 0, 0, err + } + candidates = append(candidates, c) + } + if err := rows.Err(); err != nil { + _ = rows.Close() + return 0, 0, err + } + _ = rows.Close() + + // Process candidates from oldest, lowest importance first. + // Sort: by importance ASC, then by position (already ordered by seq ASC). + sortCandidates(candidates) + + for _, c := range candidates { + if currentTokens <= cfg.maxTokens { + break + } + + nextLevel := c.level + 1 + + if nextLevel > int(LevelKeywords) { + // Already at keywords - evict + _, err := s.db.ExecContext(ctx, + "DELETE FROM session_entries WHERE id = ?", c.id, + ) + if err != nil { + return compressed, evicted, err + } + currentTokens -= c.tokens + evicted++ + continue + } + + // Compress to next level + newContent := compressToLevel(c.originalContent, CompressionLevel(nextLevel)) + newTokens := estimateTokens(newContent) + now := time.Now().UTC().Format(time.RFC3339Nano) + + _, err := s.db.ExecContext(ctx, + `UPDATE session_entries SET content = ?, compression_level = ?, tokens = ?, compressed_at = ? WHERE id = ?`, + newContent, nextLevel, newTokens, now, c.id, + ) + if err != nil { + return compressed, evicted, err + } + + currentTokens -= (c.tokens - newTokens) + compressed++ + } + + return compressed, evicted, nil +} + +// evictOldest is a fallback when all entries are "recent" but still over budget. +func (s *SQLiteStore) evictOldest(ctx context.Context, sessionID string, cfg *sessionConfig, currentTokens int) (int, int, error) { + evicted := 0 + for currentTokens > cfg.maxTokens { + var id string + var tokens int + err := s.db.QueryRowContext(ctx, + "SELECT id, tokens FROM session_entries WHERE session_id = ? ORDER BY seq ASC LIMIT 1", + sessionID, + ).Scan(&id, &tokens) + if err != nil { + break + } + _, _ = s.db.ExecContext(ctx, "DELETE FROM session_entries WHERE id = ?", id) + currentTokens -= tokens + evicted++ + } + return 0, evicted, nil +} + +// compressToLevel applies compression for the given level. +func compressToLevel(text string, level CompressionLevel) string { + switch level { + case LevelSummary: + // Use extractive compressor to keep ~20% + chunks := []types.Chunk{{ID: "sess", Text: text}} + opts := compress.Options{TargetReduction: 0.2, MinChunkLength: 20} + result, _, _ := compressor.Compress(context.Background(), chunks, opts) + if len(result) > 0 && result[0].Text != "" { + return result[0].Text + } + return text + case LevelSentence: + // Keep first sentence only + for i, r := range text { + if r == '.' || r == '!' || r == '?' { + return text[:i+1] + } + } + // No sentence boundary - truncate at word boundary near 50 chars + if len(text) > 50 { + cut := 50 + for cut > 0 && text[cut] != ' ' { + cut-- + } + if cut == 0 { + cut = 50 // no space found, hard cut + } + return strings.TrimSpace(text[:cut]) + "..." + } + return text + case LevelKeywords: + return extractKeywords(text) + default: + return text + } +} + +// extractKeywords produces a keyword-only representation. +func extractKeywords(text string) string { + words := strings.Fields(text) + seen := make(map[string]bool) + var keywords []string + + for _, w := range words { + lower := strings.ToLower(strings.Trim(w, ".,;:!?\"'()[]{}")) + if lower == "" || len(lower) < 4 || stopWords[lower] || seen[lower] { + continue + } + seen[lower] = true + keywords = append(keywords, lower) + } + + if len(keywords) > 15 { + keywords = keywords[:15] + } + return strings.Join(keywords, ", ") +} + +var stopWords = map[string]bool{ + "that": true, "this": true, "with": true, "from": true, + "have": true, "been": true, "were": true, "they": true, + "their": true, "which": true, "would": true, "there": true, + "about": true, "could": true, "other": true, "into": true, + "more": true, "some": true, "than": true, "them": true, + "very": true, "when": true, "what": true, "your": true, + "also": true, "each": true, "does": true, "will": true, + "just": true, "should": true, "because": true, "these": true, +} + +// compressCandidate is an entry eligible for compression or eviction. +type compressCandidate struct { + id string + originalContent string + level int + importance float64 + tokens int +} + +// sortCandidates sorts by importance ASC (least important first). +func sortCandidates(c []compressCandidate) { + sort.Slice(c, func(i, j int) bool { + return c[i].importance < c[j].importance + }) +} + +// --- helpers --- + +func generateID() string { + b := make([]byte, 12) + ts := uint32(time.Now().Unix()) + b[0] = byte(ts >> 24) + b[1] = byte(ts >> 16) + b[2] = byte(ts >> 8) + b[3] = byte(ts) + _, _ = rand.Read(b[4:]) + return hex.EncodeToString(b) +} + +func encodeEmbedding(emb []float32) []byte { + if len(emb) == 0 { + return nil + } + buf := make([]byte, len(emb)*4) + for i, v := range emb { + binary.LittleEndian.PutUint32(buf[i*4:], math.Float32bits(v)) + } + return buf +} + +func decodeEmbedding(buf []byte) []float32 { + if len(buf) == 0 || len(buf)%4 != 0 { + return nil + } + emb := make([]float32, len(buf)/4) + for i := range emb { + emb[i] = math.Float32frombits(binary.LittleEndian.Uint32(buf[i*4:])) + } + return emb +} + +func estimateTokens(text string) int { + return (len(text) + 3) / 4 +} + +func formatAge(d time.Duration) string { + switch { + case d < time.Minute: + return fmt.Sprintf("%ds", int(d.Seconds())) + case d < time.Hour: + return fmt.Sprintf("%dm", int(d.Minutes())) + case d < 24*time.Hour: + return fmt.Sprintf("%dh", int(d.Hours())) + default: + return fmt.Sprintf("%dd", int(d.Hours()/24)) + } +}