From 552e7b6f641f6d3fe56c16dab252133b88638fde Mon Sep 17 00:00:00 2001 From: Siddhant Khare Date: Sun, 22 Feb 2026 18:30:29 +0000 Subject: [PATCH 1/6] feat: add persistent context memory store with write-time dedup Adds pkg/memory with SQLite-backed persistent storage for context that accumulates across agent sessions. Core features: - Write-time dedup via cosine distance on embeddings - Tag-based recall with relevance + recency ranking - Token budget support for recall - Hierarchical decay: full text -> summary -> keywords -> evicted - Background decay worker with configurable age thresholds Integration: - CLI: distill memory store/recall/forget/stats - API: POST /v1/memory/store, /recall, /forget, GET /stats - MCP: store_memory, recall_memory, forget_memory, memory_stats tools Uses modernc.org/sqlite (pure Go, no CGO) for zero-dependency local storage. Closes #29 Co-authored-by: Ona --- cmd/api.go | 18 ++ cmd/api_memory.go | 160 ++++++++++++++ cmd/mcp.go | 78 ++++++- cmd/mcp_memory.go | 155 +++++++++++++ cmd/memory.go | 262 ++++++++++++++++++++++ go.mod | 10 +- go.sum | 17 ++ pkg/memory/decay.go | 226 +++++++++++++++++++ pkg/memory/helpers.go | 53 +++++ pkg/memory/memory_test.go | 371 +++++++++++++++++++++++++++++++ pkg/memory/sqlite.go | 454 ++++++++++++++++++++++++++++++++++++++ pkg/memory/store.go | 181 +++++++++++++++ 12 files changed, 1983 insertions(+), 2 deletions(-) create mode 100644 cmd/api_memory.go create mode 100644 cmd/mcp_memory.go create mode 100644 cmd/memory.go create mode 100644 pkg/memory/decay.go create mode 100644 pkg/memory/helpers.go create mode 100644 pkg/memory/memory_test.go create mode 100644 pkg/memory/sqlite.go create mode 100644 pkg/memory/store.go diff --git a/cmd/api.go b/cmd/api.go index 995437a..7516ae4 100644 --- a/cmd/api.go +++ b/cmd/api.go @@ -173,10 +173,28 @@ func runAPI(cmd *cobra.Command, args []string) error { tracing: tp, } + // Setup memory store + memDBPath := viper.GetString("memory.db_path") + if memDBPath == "" { + memDBPath = "distill-memory.db" + } + memThreshold := viper.GetFloat64("memory.dedup_threshold") + if memThreshold == 0 { + memThreshold = 0.15 + } + memStore, err := memoryStoreFromConfig(memDBPath, memThreshold) + if err != nil { + return fmt.Errorf("failed to create memory store: %w", err) + } + defer memStore.Close() + + memAPI := &MemoryAPI{store: memStore, embedder: embedder} + // Setup routes mux := http.NewServeMux() mux.HandleFunc("/v1/dedupe", m.Middleware("/v1/dedupe", server.handleDedupe)) mux.HandleFunc("/v1/dedupe/stream", m.Middleware("/v1/dedupe/stream", server.handleDedupeStream)) + memAPI.RegisterMemoryRoutes(mux, m.Middleware) mux.HandleFunc("/health", server.handleHealth) mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) { m.Handler().ServeHTTP(w, r) diff --git a/cmd/api_memory.go b/cmd/api_memory.go new file mode 100644 index 0000000..7caa3b8 --- /dev/null +++ b/cmd/api_memory.go @@ -0,0 +1,160 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/Siddhant-K-code/distill/pkg/memory" +) + +// MemoryAPI handles memory-related HTTP endpoints. +type MemoryAPI struct { + store *memory.SQLiteStore + embedder embeddingProvider +} + +// embeddingProvider is a minimal interface for generating embeddings. +type embeddingProvider interface { + Embed(ctx context.Context, text string) ([]float32, error) + EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) +} + +// RegisterMemoryRoutes adds memory endpoints to the given mux. +func (m *MemoryAPI) RegisterMemoryRoutes(mux *http.ServeMux, mw func(string, http.HandlerFunc) http.HandlerFunc) { + mux.HandleFunc("/v1/memory/store", mw("/v1/memory/store", m.handleStore)) + mux.HandleFunc("/v1/memory/recall", mw("/v1/memory/recall", m.handleRecall)) + mux.HandleFunc("/v1/memory/forget", mw("/v1/memory/forget", m.handleForget)) + mux.HandleFunc("/v1/memory/stats", mw("/v1/memory/stats", m.handleStats)) +} + +func (m *MemoryAPI) handleStore(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var req memory.StoreRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSONError(w, "invalid request body", http.StatusBadRequest) + return + } + + // Generate embeddings for entries that don't have them + if m.embedder != nil { + var textsToEmbed []string + var indices []int + for i, e := range req.Entries { + if len(e.Embedding) == 0 && e.Text != "" { + textsToEmbed = append(textsToEmbed, e.Text) + indices = append(indices, i) + } + } + if len(textsToEmbed) > 0 { + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + embeddings, err := m.embedder.EmbedBatch(ctx, textsToEmbed) + if err != nil { + writeJSONError(w, fmt.Sprintf("embedding error: %v", err), http.StatusInternalServerError) + return + } + for i, idx := range indices { + req.Entries[idx].Embedding = embeddings[i] + } + } + } + + result, err := m.store.Store(r.Context(), req) + if err != nil { + writeJSONError(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(result) +} + +func (m *MemoryAPI) handleRecall(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var req memory.RecallRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSONError(w, "invalid request body", http.StatusBadRequest) + return + } + + if req.Query == "" && len(req.QueryEmbedding) == 0 { + writeJSONError(w, "query or query_embedding is required", http.StatusBadRequest) + return + } + + // Generate query embedding if not provided + if len(req.QueryEmbedding) == 0 && m.embedder != nil && req.Query != "" { + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + emb, err := m.embedder.Embed(ctx, req.Query) + if err != nil { + writeJSONError(w, fmt.Sprintf("embedding error: %v", err), http.StatusInternalServerError) + return + } + req.QueryEmbedding = emb + } + + result, err := m.store.Recall(r.Context(), req) + if err != nil { + writeJSONError(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(result) +} + +func (m *MemoryAPI) handleForget(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete && r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var req memory.ForgetRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSONError(w, "invalid request body", http.StatusBadRequest) + return + } + + result, err := m.store.Forget(r.Context(), req) + if err != nil { + writeJSONError(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(result) +} + +func (m *MemoryAPI) handleStats(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + stats, err := m.store.Stats(r.Context()) + if err != nil { + writeJSONError(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(stats) +} + +func writeJSONError(w http.ResponseWriter, msg string, code int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + json.NewEncoder(w).Encode(map[string]string{"error": msg}) +} diff --git a/cmd/mcp.go b/cmd/mcp.go index 6395536..3d220b3 100644 --- a/cmd/mcp.go +++ b/cmd/mcp.go @@ -9,6 +9,7 @@ import ( "github.com/Siddhant-K-code/distill/pkg/contextlab" "github.com/Siddhant-K-code/distill/pkg/embedding/openai" + "github.com/Siddhant-K-code/distill/pkg/memory" "github.com/Siddhant-K-code/distill/pkg/retriever" pcretriever "github.com/Siddhant-K-code/distill/pkg/retriever/pinecone" qdretriever "github.com/Siddhant-K-code/distill/pkg/retriever/qdrant" @@ -100,6 +101,7 @@ type MCPServer struct { broker *contextlab.Broker embedder retriever.EmbeddingProvider cfg contextlab.BrokerConfig + memStore *memory.SQLiteStore } func runMCP(cmd *cobra.Command, args []string) error { @@ -141,9 +143,19 @@ func runMCP(cmd *cobra.Command, args []string) error { IncludeMetadata: true, } + // Create memory store + memCfg := memory.DefaultConfig() + memCfg.DedupThreshold = threshold + memStore, err := memory.NewSQLiteStore("distill-memory.db", memCfg) + if err != nil { + return fmt.Errorf("failed to create memory store: %w", err) + } + defer memStore.Close() + // Create MCP server wrapper mcpSrv := &MCPServer{ - cfg: brokerCfg, + cfg: brokerCfg, + memStore: memStore, } // Create embedding provider if OpenAI key is provided @@ -347,6 +359,69 @@ whether to deduplicate. Returns cluster information and redundancy percentage.`) ) s.AddTool(analyzeTool, m.handleAnalyzeRedundancy) + + // Memory tools + if m.memStore != nil { + storeMemoryTool := mcp.NewTool("store_memory", + mcp.WithDescription(`Store context into persistent memory with automatic deduplication. + +Use this to save important context that should persist across sessions. +Duplicate information is automatically detected and merged.`), + mcp.WithString("text", + mcp.Required(), + mcp.Description("Text content to store"), + ), + mcp.WithString("source", + mcp.Description("Source of the memory (e.g., code_review, docs, conversation)"), + ), + mcp.WithArray("tags", + mcp.Description("Tags for categorizing the memory"), + ), + mcp.WithString("session_id", + mcp.Description("Session ID to associate with this memory"), + ), + ) + s.AddTool(storeMemoryTool, m.handleStoreMemory) + + recallMemoryTool := mcp.NewTool("recall_memory", + mcp.WithDescription(`Recall relevant memories from persistent storage. + +Retrieves memories ranked by semantic relevance and recency. +Use this to access context from previous sessions.`), + mcp.WithString("query", + mcp.Required(), + mcp.Description("Query text to search memories"), + ), + mcp.WithArray("tags", + mcp.Description("Filter by tags"), + ), + mcp.WithNumber("max_results", + mcp.Description("Maximum number of memories to return (default: 10)"), + ), + mcp.WithNumber("max_tokens", + mcp.Description("Maximum token budget for returned memories (0 = unlimited)"), + ), + ) + s.AddTool(recallMemoryTool, m.handleRecallMemory) + + forgetMemoryTool := mcp.NewTool("forget_memory", + mcp.WithDescription(`Remove memories matching the given criteria. + +Use this to clean up outdated or incorrect memories.`), + mcp.WithArray("ids", + mcp.Description("Memory IDs to remove"), + ), + mcp.WithArray("tags", + mcp.Description("Remove all memories with these tags"), + ), + ) + s.AddTool(forgetMemoryTool, m.handleForgetMemory) + + memoryStatsTool := mcp.NewTool("memory_stats", + mcp.WithDescription("Get statistics about the persistent memory store."), + ) + s.AddTool(memoryStatsTool, m.handleMemoryStats) + } } // System prompt that guides AI assistants to use deduplication @@ -742,3 +817,4 @@ func formatChunksForResponse(chunks []types.Chunk) []map[string]interface{} { } return result } + diff --git a/cmd/mcp_memory.go b/cmd/mcp_memory.go new file mode 100644 index 0000000..a1f3d94 --- /dev/null +++ b/cmd/mcp_memory.go @@ -0,0 +1,155 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/Siddhant-K-code/distill/pkg/memory" + "github.com/mark3labs/mcp-go/mcp" +) + +func (m *MCPServer) handleStoreMemory(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args := request.GetArguments() + + text, _ := args["text"].(string) + if text == "" { + return mcp.NewToolResultError("text is required"), nil + } + + source, _ := args["source"].(string) + sessionID, _ := args["session_id"].(string) + + var tags []string + if tagsRaw, ok := args["tags"].([]interface{}); ok { + for _, t := range tagsRaw { + if s, ok := t.(string); ok { + tags = append(tags, s) + } + } + } + + entry := memory.StoreEntry{ + Text: text, + Source: source, + Tags: tags, + } + + if m.embedder != nil { + emb, err := m.embedder.Embed(ctx, text) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("embedding error: %v", err)), nil + } + entry.Embedding = emb + } + + result, err := m.memStore.Store(ctx, memory.StoreRequest{ + SessionID: sessionID, + Entries: []memory.StoreEntry{entry}, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("store error: %v", err)), nil + } + + out, _ := json.MarshalIndent(result, "", " ") + return mcp.NewToolResultText(string(out)), nil +} + +func (m *MCPServer) handleRecallMemory(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args := request.GetArguments() + + query, _ := args["query"].(string) + if query == "" { + return mcp.NewToolResultError("query is required"), nil + } + + var tags []string + if tagsRaw, ok := args["tags"].([]interface{}); ok { + for _, t := range tagsRaw { + if s, ok := t.(string); ok { + tags = append(tags, s) + } + } + } + + maxResults := 10 + if v, ok := args["max_results"].(float64); ok && v > 0 { + maxResults = int(v) + } + + maxTokens := 0 + if v, ok := args["max_tokens"].(float64); ok && v > 0 { + maxTokens = int(v) + } + + req := memory.RecallRequest{ + Query: query, + Tags: tags, + MaxResults: maxResults, + MaxTokens: maxTokens, + RecencyWeight: 0.3, + } + + if m.embedder != nil { + emb, err := m.embedder.Embed(ctx, query) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("embedding error: %v", err)), nil + } + req.QueryEmbedding = emb + } + + result, err := m.memStore.Recall(ctx, req) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("recall error: %v", err)), nil + } + + out, _ := json.MarshalIndent(result, "", " ") + return mcp.NewToolResultText(string(out)), nil +} + +func (m *MCPServer) handleForgetMemory(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args := request.GetArguments() + + var ids []string + if idsRaw, ok := args["ids"].([]interface{}); ok { + for _, id := range idsRaw { + if s, ok := id.(string); ok { + ids = append(ids, s) + } + } + } + + var tags []string + if tagsRaw, ok := args["tags"].([]interface{}); ok { + for _, t := range tagsRaw { + if s, ok := t.(string); ok { + tags = append(tags, s) + } + } + } + + if len(ids) == 0 && len(tags) == 0 { + return mcp.NewToolResultError("at least one of ids or tags is required"), nil + } + + result, err := m.memStore.Forget(ctx, memory.ForgetRequest{ + IDs: ids, + Tags: tags, + }) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("forget error: %v", err)), nil + } + + out, _ := json.MarshalIndent(result, "", " ") + return mcp.NewToolResultText(string(out)), nil +} + +func (m *MCPServer) handleMemoryStats(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + stats, err := m.memStore.Stats(ctx) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("stats error: %v", err)), nil + } + + out, _ := json.MarshalIndent(stats, "", " ") + return mcp.NewToolResultText(string(out)), nil +} diff --git a/cmd/memory.go b/cmd/memory.go new file mode 100644 index 0000000..77904db --- /dev/null +++ b/cmd/memory.go @@ -0,0 +1,262 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "os" + + "github.com/Siddhant-K-code/distill/pkg/embedding/openai" + "github.com/Siddhant-K-code/distill/pkg/memory" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +var memoryCmd = &cobra.Command{ + Use: "memory", + Short: "Manage the persistent context memory store", + Long: `Store, recall, and manage persistent context memories. + +Memories are deduplicated on write, compressed over time, and +ranked by relevance + recency on recall. + +Examples: + distill memory store --text "Auth uses JWT with RS256" --tags auth + distill memory recall --query "How does auth work?" --max-results 5 + distill memory forget --tags deprecated + distill memory stats`, +} + +var memoryStoreCmd = &cobra.Command{ + Use: "store", + Short: "Store a memory entry", + RunE: runMemoryStore, +} + +var memoryRecallCmd = &cobra.Command{ + Use: "recall", + Short: "Recall memories matching a query", + RunE: runMemoryRecall, +} + +var memoryForgetCmd = &cobra.Command{ + Use: "forget", + Short: "Remove memories matching criteria", + RunE: runMemoryForget, +} + +var memoryStatsCmd = &cobra.Command{ + Use: "stats", + Short: "Show memory store statistics", + RunE: runMemoryStats, +} + +func init() { + rootCmd.AddCommand(memoryCmd) + memoryCmd.AddCommand(memoryStoreCmd) + memoryCmd.AddCommand(memoryRecallCmd) + memoryCmd.AddCommand(memoryForgetCmd) + memoryCmd.AddCommand(memoryStatsCmd) + + // Shared flags + memoryCmd.PersistentFlags().String("db", "distill-memory.db", "SQLite database path") + memoryCmd.PersistentFlags().Float64("dedup-threshold", 0.15, "Cosine distance threshold for dedup") + + // Store flags + memoryStoreCmd.Flags().String("text", "", "Text to store") + memoryStoreCmd.Flags().String("source", "", "Source of the memory (e.g., code_review, docs)") + memoryStoreCmd.Flags().StringSlice("tags", nil, "Tags for the memory") + memoryStoreCmd.Flags().String("session-id", "", "Session ID") + memoryStoreCmd.Flags().String("openai-key", "", "OpenAI API key for embeddings (or OPENAI_API_KEY)") + + // Recall flags + memoryRecallCmd.Flags().String("query", "", "Query text") + memoryRecallCmd.Flags().StringSlice("tags", nil, "Filter by tags") + memoryRecallCmd.Flags().Int("max-results", 10, "Maximum results to return") + memoryRecallCmd.Flags().Int("max-tokens", 0, "Maximum token budget (0 = unlimited)") + memoryRecallCmd.Flags().Float64("recency-weight", 0.3, "Weight for recency vs relevance (0-1)") + memoryRecallCmd.Flags().String("openai-key", "", "OpenAI API key for embeddings (or OPENAI_API_KEY)") + + // Forget flags + memoryForgetCmd.Flags().StringSlice("tags", nil, "Remove memories with these tags") + memoryForgetCmd.Flags().StringSlice("ids", nil, "Remove memories with these IDs") +} + +func openMemoryStore(cmd *cobra.Command) (*memory.SQLiteStore, error) { + dbPath, _ := cmd.Flags().GetString("db") + threshold, _ := cmd.Flags().GetFloat64("dedup-threshold") + + cfg := memory.DefaultConfig() + cfg.DedupThreshold = threshold + + return memory.NewSQLiteStore(dbPath, cfg) +} + +func runMemoryStore(cmd *cobra.Command, args []string) error { + text, _ := cmd.Flags().GetString("text") + if text == "" { + return fmt.Errorf("--text is required") + } + + source, _ := cmd.Flags().GetString("source") + tags, _ := cmd.Flags().GetStringSlice("tags") + sessionID, _ := cmd.Flags().GetString("session-id") + openaiKey, _ := cmd.Flags().GetString("openai-key") + if openaiKey == "" { + openaiKey = os.Getenv("OPENAI_API_KEY") + } + + store, err := openMemoryStore(cmd) + if err != nil { + return err + } + defer store.Close() + + entry := memory.StoreEntry{ + Text: text, + Source: source, + Tags: tags, + } + + // Generate embedding if OpenAI key is available + if openaiKey != "" { + model := viper.GetString("embedding.model") + if model == "" { + model = "text-embedding-3-small" + } + embedder, err := openai.NewClient(openai.Config{APIKey: openaiKey, Model: model}) + if err != nil { + return fmt.Errorf("create embedder: %w", err) + } + emb, err := embedder.Embed(context.Background(), text) + if err != nil { + return fmt.Errorf("embed text: %w", err) + } + entry.Embedding = emb + } + + result, err := store.Store(context.Background(), memory.StoreRequest{ + SessionID: sessionID, + Entries: []memory.StoreEntry{entry}, + }) + if err != nil { + return err + } + + out, _ := json.MarshalIndent(result, "", " ") + fmt.Println(string(out)) + return nil +} + +func runMemoryRecall(cmd *cobra.Command, args []string) error { + query, _ := cmd.Flags().GetString("query") + if query == "" { + return fmt.Errorf("--query is required") + } + + tags, _ := cmd.Flags().GetStringSlice("tags") + maxResults, _ := cmd.Flags().GetInt("max-results") + maxTokens, _ := cmd.Flags().GetInt("max-tokens") + recencyWeight, _ := cmd.Flags().GetFloat64("recency-weight") + openaiKey, _ := cmd.Flags().GetString("openai-key") + if openaiKey == "" { + openaiKey = os.Getenv("OPENAI_API_KEY") + } + + store, err := openMemoryStore(cmd) + if err != nil { + return err + } + defer store.Close() + + req := memory.RecallRequest{ + Query: query, + Tags: tags, + MaxResults: maxResults, + MaxTokens: maxTokens, + RecencyWeight: recencyWeight, + } + + // Generate query embedding if OpenAI key is available + if openaiKey != "" { + model := viper.GetString("embedding.model") + if model == "" { + model = "text-embedding-3-small" + } + embedder, err := openai.NewClient(openai.Config{APIKey: openaiKey, Model: model}) + if err != nil { + return fmt.Errorf("create embedder: %w", err) + } + emb, err := embedder.Embed(context.Background(), query) + if err != nil { + return fmt.Errorf("embed query: %w", err) + } + req.QueryEmbedding = emb + } + + result, err := store.Recall(context.Background(), req) + if err != nil { + return err + } + + out, _ := json.MarshalIndent(result, "", " ") + fmt.Println(string(out)) + return nil +} + +func runMemoryForget(cmd *cobra.Command, args []string) error { + tags, _ := cmd.Flags().GetStringSlice("tags") + ids, _ := cmd.Flags().GetStringSlice("ids") + + if len(tags) == 0 && len(ids) == 0 { + return fmt.Errorf("at least one of --tags or --ids is required") + } + + store, err := openMemoryStore(cmd) + if err != nil { + return err + } + defer store.Close() + + result, err := store.Forget(context.Background(), memory.ForgetRequest{ + Tags: tags, + IDs: ids, + }) + if err != nil { + return err + } + + out, _ := json.MarshalIndent(result, "", " ") + fmt.Println(string(out)) + return nil +} + +func runMemoryStats(cmd *cobra.Command, args []string) error { + store, err := openMemoryStore(cmd) + if err != nil { + return err + } + defer store.Close() + + stats, err := store.Stats(context.Background()) + if err != nil { + return err + } + + out, _ := json.MarshalIndent(stats, "", " ") + fmt.Println(string(out)) + return nil +} + +// memoryStoreFromConfig creates a memory store from the API server config. +// Used by the API server and MCP server. +func memoryStoreFromConfig(dbPath string, threshold float64) (*memory.SQLiteStore, error) { + if dbPath == "" { + dbPath = "distill-memory.db" + } + cfg := memory.DefaultConfig() + cfg.DedupThreshold = threshold + return memory.NewSQLiteStore(dbPath, cfg) +} + + diff --git a/go.mod b/go.mod index 7331e95..a2fe129 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -40,14 +41,17 @@ require ( github.com/invopop/jsonschema v0.13.0 // indirect github.com/magiconair/properties v1.8.10 // indirect github.com/mailru/easyjson v0.7.7 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect github.com/oapi-codegen/runtime v1.1.1 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect @@ -66,7 +70,7 @@ require ( go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect - golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect + golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/net v0.49.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/term v0.39.0 // indirect @@ -75,4 +79,8 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + modernc.org/libc v1.67.6 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect + modernc.org/sqlite v1.46.1 // indirect ) diff --git a/go.sum b/go.sum index 1e81c20..35bac8e 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= @@ -57,6 +59,7 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0 github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I= github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= @@ -64,6 +67,8 @@ github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyua github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/oapi-codegen/runtime v1.1.1 h1:EXLHh0DXIJnWhdRPN2w4MXAzFyE4CskzhNLUmtpMYro= github.com/oapi-codegen/runtime v1.1.1/go.mod h1:SK9X900oXmPWilYR5/WKPzt3Kqxn/uS/+lbpREv+eCg= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= @@ -83,6 +88,8 @@ github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzM github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/qdrant/go-client v1.15.2 h1:3NSyxpHrfQTP6JLDAwqNUShz6V9tuRBKz0G7hSOxrac= github.com/qdrant/go-client v1.15.2/go.mod h1:iO8ts78jL4x6LDHFOViyYWELVtIBDTjOykBmiOTHLnQ= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= @@ -154,6 +161,8 @@ go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -183,3 +192,11 @@ gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI= +modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU= +modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= diff --git a/pkg/memory/decay.go b/pkg/memory/decay.go new file mode 100644 index 0000000..4102030 --- /dev/null +++ b/pkg/memory/decay.go @@ -0,0 +1,226 @@ +package memory + +import ( + "context" + "fmt" + "strings" + "time" +) + +// DecayWorker runs periodic compression of aging memories. +// Memories progress through decay levels based on age and access patterns: +// +// full text -> summary -> keywords -> evicted +type DecayWorker struct { + store *SQLiteStore + cfg Config + stopCh chan struct{} +} + +// NewDecayWorker creates a decay worker for the given store. +func NewDecayWorker(store *SQLiteStore, cfg Config) *DecayWorker { + return &DecayWorker{ + store: store, + cfg: cfg, + stopCh: make(chan struct{}), + } +} + +// Start begins the periodic decay loop. Call Stop() to terminate. +func (w *DecayWorker) Start() { + go w.run() +} + +// Stop terminates the decay worker. +func (w *DecayWorker) Stop() { + close(w.stopCh) +} + +func (w *DecayWorker) run() { + ticker := time.NewTicker(w.cfg.DecayInterval) + defer ticker.Stop() + + for { + select { + case <-w.stopCh: + return + case <-ticker.C: + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + _ = w.runOnce(ctx) + cancel() + } + } +} + +// RunOnce executes a single decay pass. Exported for testing. +func (w *DecayWorker) RunOnce(ctx context.Context) error { + return w.runOnce(ctx) +} + +func (w *DecayWorker) runOnce(ctx context.Context) error { + w.store.mu.Lock() + defer w.store.mu.Unlock() + + now := time.Now().UTC() + + // Evict: remove very old, unreferenced memories + if w.cfg.EvictAge > 0 { + evictBefore := now.Add(-w.cfg.EvictAge).Format(time.RFC3339Nano) + _, _ = w.store.db.ExecContext(ctx, + "DELETE FROM memories WHERE last_referenced < ? AND decay_level >= ?", + evictBefore, int(DecayKeywords), + ) + } + + // Decay to keywords: compress old summaries + if w.cfg.KeywordsAge > 0 { + keywordsBefore := now.Add(-w.cfg.KeywordsAge).Format(time.RFC3339Nano) + if err := w.decayRows(ctx, keywordsBefore, DecaySummary, DecayKeywords, extractKeywords); err != nil { + return err + } + } + + // Decay to summary: compress old full-text memories + if w.cfg.SummaryAge > 0 { + summaryBefore := now.Add(-w.cfg.SummaryAge).Format(time.RFC3339Nano) + if err := w.decayRows(ctx, summaryBefore, DecayFull, DecaySummary, extractSummary); err != nil { + return err + } + } + + return nil +} + +// decayRows queries for memories at fromLevel older than cutoff, +// applies the transform function, and updates them to toLevel. +func (w *DecayWorker) decayRows(ctx context.Context, cutoff string, fromLevel, toLevel DecayLevel, transform func(string) string) error { + rows, err := w.store.db.QueryContext(ctx, + "SELECT id, text FROM memories WHERE last_referenced < ? AND decay_level = ?", + cutoff, int(fromLevel), + ) + if err != nil { + return fmt.Errorf("query for decay level %d: %w", fromLevel, err) + } + + type entry struct { + id, text string + } + var entries []entry + for rows.Next() { + var e entry + if err := rows.Scan(&e.id, &e.text); err != nil { + continue + } + entries = append(entries, e) + } + rows.Close() + + for _, e := range entries { + compressed := transform(e.text) + _, _ = w.store.db.ExecContext(ctx, + "UPDATE memories SET text = ?, decay_level = ? WHERE id = ?", + compressed, int(toLevel), e.id, + ) + } + + return nil +} + +// extractSummary produces a shortened version of the text. +// Keeps the first and last sentences, targeting ~20% of original length. +func extractSummary(text string) string { + sentences := splitSentences(text) + if len(sentences) <= 2 { + return text + } + + // Target ~20% of sentences, minimum 2 + target := len(sentences) / 5 + if target < 2 { + target = 2 + } + + // Keep first half from the beginning, second half from the end + firstHalf := target / 2 + if firstHalf < 1 { + firstHalf = 1 + } + secondHalf := target - firstHalf + + var parts []string + parts = append(parts, sentences[:firstHalf]...) + if secondHalf > 0 && len(sentences)-secondHalf > firstHalf { + parts = append(parts, sentences[len(sentences)-secondHalf:]...) + } + + return strings.Join(parts, " ") +} + +// extractKeywords produces a keyword-only representation. +// Keeps words that appear significant (longer words, capitalized, numbers). +func extractKeywords(text string) string { + words := strings.Fields(text) + seen := make(map[string]bool) + var keywords []string + + for _, w := range words { + lower := strings.ToLower(strings.Trim(w, ".,;:!?\"'()[]{}")) + if lower == "" || len(lower) < 4 { + continue + } + if isStopWord(lower) { + continue + } + if seen[lower] { + continue + } + seen[lower] = true + keywords = append(keywords, lower) + } + + // Limit to ~20 keywords + if len(keywords) > 20 { + keywords = keywords[:20] + } + + return strings.Join(keywords, ", ") +} + +// splitSentences splits text on sentence boundaries. +func splitSentences(text string) []string { + var sentences []string + var current strings.Builder + + for _, r := range text { + current.WriteRune(r) + if r == '.' || r == '!' || r == '?' { + s := strings.TrimSpace(current.String()) + if s != "" { + sentences = append(sentences, s) + } + current.Reset() + } + } + + // Remaining text + if s := strings.TrimSpace(current.String()); s != "" { + sentences = append(sentences, s) + } + + return sentences +} + +// isStopWord returns true for common English stop words. +func isStopWord(w string) bool { + stops := map[string]bool{ + "that": true, "this": true, "with": true, "from": true, + "have": true, "been": true, "were": true, "they": true, + "their": true, "which": true, "would": true, "there": true, + "about": true, "could": true, "other": true, "into": true, + "more": true, "some": true, "than": true, "them": true, + "very": true, "when": true, "what": true, "your": true, + "also": true, "each": true, "does": true, "will": true, + "just": true, "should": true, "because": true, "these": true, + } + return stops[w] +} diff --git a/pkg/memory/helpers.go b/pkg/memory/helpers.go new file mode 100644 index 0000000..6a1d5a1 --- /dev/null +++ b/pkg/memory/helpers.go @@ -0,0 +1,53 @@ +package memory + +import ( + "crypto/rand" + "encoding/binary" + "encoding/hex" + "math" + "time" +) + +// generateID creates a random 16-char hex ID with a time prefix for ordering. +func generateID() string { + b := make([]byte, 12) + // First 4 bytes: unix timestamp for natural ordering + ts := uint32(time.Now().Unix()) + b[0] = byte(ts >> 24) + b[1] = byte(ts >> 16) + b[2] = byte(ts >> 8) + b[3] = byte(ts) + // Remaining 8 bytes: random + _, _ = rand.Read(b[4:]) + return hex.EncodeToString(b) +} + +// encodeEmbedding converts a float32 slice to a byte slice for SQLite BLOB storage. +func encodeEmbedding(emb []float32) []byte { + if len(emb) == 0 { + return nil + } + buf := make([]byte, len(emb)*4) + for i, v := range emb { + binary.LittleEndian.PutUint32(buf[i*4:], math.Float32bits(v)) + } + return buf +} + +// decodeEmbedding converts a byte slice back to a float32 slice. +func decodeEmbedding(buf []byte) []float32 { + if len(buf) == 0 || len(buf)%4 != 0 { + return nil + } + emb := make([]float32, len(buf)/4) + for i := range emb { + emb[i] = math.Float32frombits(binary.LittleEndian.Uint32(buf[i*4:])) + } + return emb +} + +// estimateTokens returns a rough token count for a text string. +// Uses the same heuristic as pkg/compress: ~4 chars per token. +func estimateTokens(text string) int { + return (len(text) + 3) / 4 +} diff --git a/pkg/memory/memory_test.go b/pkg/memory/memory_test.go new file mode 100644 index 0000000..6cb0adb --- /dev/null +++ b/pkg/memory/memory_test.go @@ -0,0 +1,371 @@ +package memory + +import ( + "context" + "math" + "testing" + "time" +) + +// makeEmbedding creates a simple unit vector for testing. +// angle controls the direction in the first two dimensions. +func makeEmbedding(angle float64, dim int) []float32 { + emb := make([]float32, dim) + emb[0] = float32(math.Cos(angle)) + emb[1] = float32(math.Sin(angle)) + return emb +} + +func newTestStore(t *testing.T) *SQLiteStore { + t.Helper() + cfg := DefaultConfig() + cfg.DedupThreshold = 0.15 + s, err := NewSQLiteStore(":memory:", cfg) + if err != nil { + t.Fatalf("NewSQLiteStore: %v", err) + } + t.Cleanup(func() { s.Close() }) + return s +} + +func TestStoreAndRecall(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + // Store two distinct entries + result, err := s.Store(ctx, StoreRequest{ + SessionID: "test-session", + Entries: []StoreEntry{ + {Text: "The auth service uses JWT with RS256", Embedding: makeEmbedding(0, 8), Source: "code_review", Tags: []string{"auth"}}, + {Text: "The payment service uses Stripe API", Embedding: makeEmbedding(math.Pi/2, 8), Source: "docs", Tags: []string{"payments"}}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + if result.Stored != 2 { + t.Errorf("expected 2 stored, got %d", result.Stored) + } + if result.TotalMemories != 2 { + t.Errorf("expected 2 total, got %d", result.TotalMemories) + } + + // Recall with embedding similar to auth entry + recall, err := s.Recall(ctx, RecallRequest{ + Query: "How does authentication work?", + QueryEmbedding: makeEmbedding(0.05, 8), // Very close to auth entry + MaxResults: 5, + }) + if err != nil { + t.Fatalf("Recall: %v", err) + } + if len(recall.Memories) == 0 { + t.Fatal("expected at least 1 memory") + } + // Auth entry should be most relevant (closest embedding) + if recall.Memories[0].Source != "code_review" { + t.Errorf("expected auth entry first, got source=%s", recall.Memories[0].Source) + } +} + +func TestWriteTimeDedup(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + emb := makeEmbedding(0, 8) + + // Store first entry + r1, err := s.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "JWT uses RS256 for signing", Embedding: emb, Source: "docs"}, + }, + }) + if err != nil { + t.Fatalf("Store 1: %v", err) + } + if r1.Stored != 1 { + t.Errorf("expected 1 stored, got %d", r1.Stored) + } + + // Store near-duplicate (same embedding) + r2, err := s.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Auth tokens are signed with RS256", Embedding: emb, Source: "code"}, + }, + }) + if err != nil { + t.Fatalf("Store 2: %v", err) + } + if r2.Deduplicated != 1 { + t.Errorf("expected 1 deduplicated, got %d", r2.Deduplicated) + } + if r2.Stored != 0 { + t.Errorf("expected 0 stored, got %d", r2.Stored) + } + if r2.TotalMemories != 1 { + t.Errorf("expected 1 total, got %d", r2.TotalMemories) + } +} + +func TestForget(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + _, err := s.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Old deprecated info", Tags: []string{"deprecated"}}, + {Text: "Current auth info", Tags: []string{"auth"}}, + {Text: "Another deprecated item", Tags: []string{"deprecated"}}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + // Forget by tag + result, err := s.Forget(ctx, ForgetRequest{Tags: []string{"deprecated"}}) + if err != nil { + t.Fatalf("Forget: %v", err) + } + if result.Removed != 2 { + t.Errorf("expected 2 removed, got %d", result.Removed) + } + if result.TotalMemories != 1 { + t.Errorf("expected 1 remaining, got %d", result.TotalMemories) + } +} + +func TestForgetByAge(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + // Insert an entry with a manually backdated created_at + now := time.Now().UTC() + old := now.Add(-48 * time.Hour).Format(time.RFC3339Nano) + _, err := s.db.ExecContext(ctx, + `INSERT INTO memories (id, text, source, tags, metadata, decay_level, created_at, last_referenced, access_count) + VALUES (?, ?, '', '[]', '{}', 0, ?, ?, 0)`, + "old-1", "Old memory", old, old, + ) + if err != nil { + t.Fatalf("insert old: %v", err) + } + + _, err = s.Store(ctx, StoreRequest{ + Entries: []StoreEntry{{Text: "Recent memory"}}, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + // Forget entries older than 24h + result, err := s.Forget(ctx, ForgetRequest{ + OlderThan: now.Add(-24 * time.Hour), + }) + if err != nil { + t.Fatalf("Forget: %v", err) + } + if result.Removed != 1 { + t.Errorf("expected 1 removed, got %d", result.Removed) + } +} + +func TestStats(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + _, err := s.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Entry from code review", Source: "code_review"}, + {Text: "Entry from docs", Source: "docs"}, + {Text: "Another code review entry", Source: "code_review"}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + stats, err := s.Stats(ctx) + if err != nil { + t.Fatalf("Stats: %v", err) + } + if stats.TotalMemories != 3 { + t.Errorf("expected 3 total, got %d", stats.TotalMemories) + } + if stats.BySource["code_review"] != 2 { + t.Errorf("expected 2 code_review, got %d", stats.BySource["code_review"]) + } + if stats.BySource["docs"] != 1 { + t.Errorf("expected 1 docs, got %d", stats.BySource["docs"]) + } +} + +func TestRecallWithTokenBudget(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + // Store entries with embeddings at different angles + _, err := s.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Short entry about auth", Embedding: makeEmbedding(0, 8)}, + {Text: "This is a much longer entry about authentication that contains many more tokens and details about how the JWT system works with RS256 signing", Embedding: makeEmbedding(0.1, 8)}, + {Text: "Another auth entry", Embedding: makeEmbedding(0.2, 8)}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + // Recall with a tight token budget + recall, err := s.Recall(ctx, RecallRequest{ + Query: "auth", + QueryEmbedding: makeEmbedding(0, 8), + MaxTokens: 20, // Very tight budget + MaxResults: 10, + }) + if err != nil { + t.Fatalf("Recall: %v", err) + } + if recall.Stats.TokenCount > 20 { + t.Errorf("expected token count <= 20, got %d", recall.Stats.TokenCount) + } +} + +func TestRecallWithTagFilter(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + _, err := s.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Auth uses JWT", Embedding: makeEmbedding(0, 8), Tags: []string{"auth"}}, + {Text: "Payments use Stripe", Embedding: makeEmbedding(math.Pi/2, 8), Tags: []string{"payments"}}, + {Text: "Auth also uses OAuth", Embedding: makeEmbedding(math.Pi, 8), Tags: []string{"auth"}}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + recall, err := s.Recall(ctx, RecallRequest{ + Query: "how does it work", + QueryEmbedding: makeEmbedding(0, 8), + Tags: []string{"auth"}, + MaxResults: 10, + }) + if err != nil { + t.Fatalf("Recall: %v", err) + } + if len(recall.Memories) != 2 { + t.Errorf("expected 2 auth memories, got %d", len(recall.Memories)) + } + for _, m := range recall.Memories { + found := false + for _, tag := range m.Tags { + if tag == "auth" { + found = true + break + } + } + if !found { + t.Errorf("expected auth tag, got tags=%v", m.Tags) + } + } +} + +func TestDecayWorker(t *testing.T) { + cfg := DefaultConfig() + cfg.SummaryAge = 1 * time.Millisecond // Near-immediate decay for testing + cfg.KeywordsAge = 1 * time.Millisecond // Near-immediate decay for testing + cfg.EvictAge = 0 // Disable eviction for this test + cfg.DedupThreshold = 0.15 + + s, err := NewSQLiteStore(":memory:", cfg) + if err != nil { + t.Fatalf("NewSQLiteStore: %v", err) + } + defer s.Close() + + ctx := context.Background() + + // Store a multi-sentence entry + _, err = s.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "The authentication service uses JWT tokens with RS256 signing. It validates tokens on every request. The token expiry is set to 24 hours. Refresh tokens are stored in Redis with a 7-day TTL. The service also supports OAuth2 for third-party integrations."}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + // Backdate the entry so it qualifies for decay + past := time.Now().Add(-48 * time.Hour).UTC().Format(time.RFC3339Nano) + _, _ = s.db.ExecContext(ctx, "UPDATE memories SET last_referenced = ?", past) + + // Run decay + w := NewDecayWorker(s, cfg) + if err := w.RunOnce(ctx); err != nil { + t.Fatalf("RunOnce: %v", err) + } + + // Check that the entry was compressed to summary + stats, _ := s.Stats(ctx) + if stats.ByDecayLevel[int(DecaySummary)] != 1 { + t.Errorf("expected 1 summary-level memory, got decay levels: %v", stats.ByDecayLevel) + } + + // Run decay again - should compress to keywords + if err := w.RunOnce(ctx); err != nil { + t.Fatalf("RunOnce 2: %v", err) + } + + stats, _ = s.Stats(ctx) + if stats.ByDecayLevel[int(DecayKeywords)] != 1 { + t.Errorf("expected 1 keywords-level memory, got decay levels: %v", stats.ByDecayLevel) + } +} + +func TestEmbeddingRoundtrip(t *testing.T) { + original := []float32{0.1, 0.2, 0.3, -0.5, 1.0} + encoded := encodeEmbedding(original) + decoded := decodeEmbedding(encoded) + + if len(decoded) != len(original) { + t.Fatalf("length mismatch: %d vs %d", len(decoded), len(original)) + } + for i := range original { + if decoded[i] != original[i] { + t.Errorf("index %d: expected %f, got %f", i, original[i], decoded[i]) + } + } +} + +func TestEmptyStore(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + stats, err := s.Stats(ctx) + if err != nil { + t.Fatalf("Stats: %v", err) + } + if stats.TotalMemories != 0 { + t.Errorf("expected 0 total, got %d", stats.TotalMemories) + } +} + +func TestStoreEmptyText(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + result, err := s.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: ""}, + {Text: "Valid entry"}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + if result.Stored != 1 { + t.Errorf("expected 1 stored (empty skipped), got %d", result.Stored) + } +} diff --git a/pkg/memory/sqlite.go b/pkg/memory/sqlite.go new file mode 100644 index 0000000..2358a52 --- /dev/null +++ b/pkg/memory/sqlite.go @@ -0,0 +1,454 @@ +package memory + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + "sync" + "time" + + distillmath "github.com/Siddhant-K-code/distill/pkg/math" + _ "modernc.org/sqlite" +) + +// SQLiteStore implements Store using SQLite for local persistent storage. +type SQLiteStore struct { + db *sql.DB + cfg Config + mu sync.RWMutex +} + +// NewSQLiteStore creates a new SQLite-backed memory store. +// Use ":memory:" for in-memory storage or a file path for persistence. +func NewSQLiteStore(dsn string, cfg Config) (*SQLiteStore, error) { + if dsn == "" { + dsn = ":memory:" + } + + db, err := sql.Open("sqlite", dsn) + if err != nil { + return nil, fmt.Errorf("open sqlite: %w", err) + } + + // Enable WAL mode for better concurrent read performance + if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { + db.Close() + return nil, fmt.Errorf("set WAL mode: %w", err) + } + + s := &SQLiteStore{db: db, cfg: cfg} + if err := s.migrate(); err != nil { + db.Close() + return nil, fmt.Errorf("migrate: %w", err) + } + + return s, nil +} + +func (s *SQLiteStore) migrate() error { + schema := ` + CREATE TABLE IF NOT EXISTS memories ( + id TEXT PRIMARY KEY, + text TEXT NOT NULL, + embedding BLOB, + source TEXT DEFAULT '', + tags TEXT DEFAULT '[]', + session_id TEXT DEFAULT '', + metadata TEXT DEFAULT '{}', + decay_level INTEGER DEFAULT 0, + created_at TEXT NOT NULL, + last_referenced TEXT NOT NULL, + access_count INTEGER DEFAULT 0 + ); + CREATE INDEX IF NOT EXISTS idx_memories_tags ON memories(tags); + CREATE INDEX IF NOT EXISTS idx_memories_decay ON memories(decay_level); + CREATE INDEX IF NOT EXISTS idx_memories_created ON memories(created_at); + CREATE INDEX IF NOT EXISTS idx_memories_referenced ON memories(last_referenced); + ` + _, err := s.db.Exec(schema) + return err +} + +// Store adds entries with write-time deduplication. +func (s *SQLiteStore) Store(ctx context.Context, req StoreRequest) (*StoreResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + + result := &StoreResult{} + + for _, entry := range req.Entries { + if entry.Text == "" { + continue + } + + // Check for semantic duplicates if embedding is provided + if len(entry.Embedding) > 0 { + dupID, err := s.findDuplicate(ctx, entry.Embedding) + if err != nil { + return nil, fmt.Errorf("find duplicate: %w", err) + } + if dupID != "" { + // Update the existing memory's last_referenced and access_count + _, err := s.db.ExecContext(ctx, + `UPDATE memories SET last_referenced = ?, access_count = access_count + 1 WHERE id = ?`, + time.Now().UTC().Format(time.RFC3339Nano), dupID, + ) + if err != nil { + return nil, fmt.Errorf("update duplicate: %w", err) + } + result.Deduplicated++ + continue + } + } + + // Insert new memory + id := generateID() + now := time.Now().UTC().Format(time.RFC3339Nano) + + tagsJSON, _ := json.Marshal(entry.Tags) + metaJSON, _ := json.Marshal(entry.Metadata) + embBlob := encodeEmbedding(entry.Embedding) + + sessionID := req.SessionID + + _, err := s.db.ExecContext(ctx, + `INSERT INTO memories (id, text, embedding, source, tags, session_id, metadata, decay_level, created_at, last_referenced, access_count) + VALUES (?, ?, ?, ?, ?, ?, ?, 0, ?, ?, 0)`, + id, entry.Text, embBlob, entry.Source, string(tagsJSON), sessionID, string(metaJSON), now, now, + ) + if err != nil { + return nil, fmt.Errorf("insert memory: %w", err) + } + result.Stored++ + } + + // Get total count + var total int + if err := s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM memories").Scan(&total); err != nil { + return nil, err + } + result.TotalMemories = total + + return result, nil +} + +// findDuplicate scans existing embeddings and returns the ID of the first +// entry within the dedup threshold. Returns "" if no duplicate found. +func (s *SQLiteStore) findDuplicate(ctx context.Context, embedding []float32) (string, error) { + rows, err := s.db.QueryContext(ctx, "SELECT id, embedding FROM memories WHERE embedding IS NOT NULL") + if err != nil { + return "", err + } + defer rows.Close() + + for rows.Next() { + var id string + var embBlob []byte + if err := rows.Scan(&id, &embBlob); err != nil { + return "", err + } + + existing := decodeEmbedding(embBlob) + if len(existing) == 0 { + continue + } + + dist := distillmath.CosineDistance(embedding, existing) + if dist < s.cfg.DedupThreshold { + return id, nil + } + } + + return "", rows.Err() +} + +// Recall retrieves memories matching a query, ranked by relevance and recency. +func (s *SQLiteStore) Recall(ctx context.Context, req RecallRequest) (*RecallResult, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if req.Query == "" && len(req.QueryEmbedding) == 0 { + return nil, ErrInvalidQuery + } + + maxResults := req.MaxResults + if maxResults <= 0 { + maxResults = 10 + } + + recencyWeight := req.RecencyWeight + if recencyWeight < 0 { + recencyWeight = 0 + } + if recencyWeight > 1 { + recencyWeight = 1 + } + + // Build query with optional tag filter + query := "SELECT id, text, embedding, source, tags, decay_level, last_referenced FROM memories" + var args []interface{} + + if len(req.Tags) > 0 { + clauses := make([]string, len(req.Tags)) + for i, tag := range req.Tags { + clauses[i] = "tags LIKE ?" + args = append(args, "%\""+tag+"\"%") + } + query += " WHERE (" + strings.Join(clauses, " OR ") + ")" + } + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("query memories: %w", err) + } + defer rows.Close() + + var candidates []scored + now := time.Now() + + for rows.Next() { + var ( + id, text, source, tagsStr, refStr string + embBlob []byte + decayLevel int + ) + if err := rows.Scan(&id, &text, &embBlob, &source, &tagsStr, &decayLevel, &refStr); err != nil { + return nil, err + } + + var tags []string + _ = json.Unmarshal([]byte(tagsStr), &tags) + lastRef, _ := time.Parse(time.RFC3339Nano, refStr) + + // Compute relevance score from embedding similarity + var similarity float64 + if len(req.QueryEmbedding) > 0 { + existing := decodeEmbedding(embBlob) + if len(existing) > 0 { + dist := distillmath.CosineDistance(req.QueryEmbedding, existing) + similarity = 1.0 - dist // Convert distance to similarity + } + } + + // Compute recency score (exponential decay, half-life = 24h) + age := now.Sub(lastRef).Hours() + recency := 1.0 + if age > 0 { + recency = 1.0 / (1.0 + age/24.0) + } + + // Combined score + relevance := (1.0-recencyWeight)*similarity + recencyWeight*recency + + candidates = append(candidates, scored{ + memory: RecalledMemory{ + ID: id, + Text: text, + Source: source, + Tags: tags, + Relevance: relevance, + DecayLevel: DecayLevel(decayLevel), + LastReferenced: lastRef, + }, + relevance: relevance, + }) + } + if err := rows.Err(); err != nil { + return nil, err + } + + // Sort by relevance descending + sortByRelevance(candidates) + + // Apply token budget or max results limit + var results []RecalledMemory + tokenCount := 0 + for _, c := range candidates { + if len(results) >= maxResults { + break + } + tokens := estimateTokens(c.memory.Text) + if req.MaxTokens > 0 && tokenCount+tokens > req.MaxTokens { + break + } + results = append(results, c.memory) + tokenCount += tokens + } + + // Update last_referenced for returned memories + if len(results) > 0 { + ids := make([]string, len(results)) + for i, m := range results { + ids[i] = m.ID + } + s.touchMemories(ctx, ids) + } + + return &RecallResult{ + Memories: results, + Stats: RecallStats{ + Candidates: len(candidates), + Deduplicated: len(candidates) - len(results), + Returned: len(results), + TokenCount: tokenCount, + }, + }, nil +} + +// Forget removes memories matching the given criteria. +func (s *SQLiteStore) Forget(ctx context.Context, req ForgetRequest) (*ForgetResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + + var conditions []string + var args []interface{} + + if len(req.IDs) > 0 { + placeholders := make([]string, len(req.IDs)) + for i, id := range req.IDs { + placeholders[i] = "?" + args = append(args, id) + } + conditions = append(conditions, "id IN ("+strings.Join(placeholders, ",")+")") + } + + if len(req.Tags) > 0 { + tagClauses := make([]string, len(req.Tags)) + for i, tag := range req.Tags { + tagClauses[i] = "tags LIKE ?" + args = append(args, "%\""+tag+"\"%") + } + conditions = append(conditions, "("+strings.Join(tagClauses, " OR ")+")") + } + + if !req.OlderThan.IsZero() { + conditions = append(conditions, "created_at < ?") + args = append(args, req.OlderThan.UTC().Format(time.RFC3339Nano)) + } + + if len(conditions) == 0 { + return &ForgetResult{}, nil + } + + query := "DELETE FROM memories WHERE " + strings.Join(conditions, " AND ") + res, err := s.db.ExecContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("delete memories: %w", err) + } + + removed, _ := res.RowsAffected() + + var total int + if err := s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM memories").Scan(&total); err != nil { + return nil, err + } + + return &ForgetResult{ + Removed: int(removed), + TotalMemories: total, + }, nil +} + +// Stats returns memory store statistics. +func (s *SQLiteStore) Stats(ctx context.Context) (*Stats, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + stats := &Stats{ + ByDecayLevel: make(map[int]int), + BySource: make(map[string]int), + } + + // Total count + if err := s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM memories").Scan(&stats.TotalMemories); err != nil { + return nil, err + } + + // By decay level + rows, err := s.db.QueryContext(ctx, "SELECT decay_level, COUNT(*) FROM memories GROUP BY decay_level") + if err != nil { + return nil, err + } + defer rows.Close() + for rows.Next() { + var level, count int + if err := rows.Scan(&level, &count); err != nil { + return nil, err + } + stats.ByDecayLevel[level] = count + } + + // By source + rows2, err := s.db.QueryContext(ctx, "SELECT source, COUNT(*) FROM memories WHERE source != '' GROUP BY source") + if err != nil { + return nil, err + } + defer rows2.Close() + for rows2.Next() { + var source string + var count int + if err := rows2.Scan(&source, &count); err != nil { + return nil, err + } + stats.BySource[source] = count + } + + // Oldest and newest + var oldest, newest sql.NullString + _ = s.db.QueryRowContext(ctx, "SELECT MIN(created_at) FROM memories").Scan(&oldest) + _ = s.db.QueryRowContext(ctx, "SELECT MAX(created_at) FROM memories").Scan(&newest) + if oldest.Valid { + stats.OldestMemory, _ = time.Parse(time.RFC3339Nano, oldest.String) + } + if newest.Valid { + stats.NewestMemory, _ = time.Parse(time.RFC3339Nano, newest.String) + } + + return stats, nil +} + +// Close closes the database connection. +func (s *SQLiteStore) Close() error { + return s.db.Close() +} + +// touchMemories updates last_referenced and access_count for the given IDs. +// Called from Recall under a read lock, so we use a separate goroutine. +func (s *SQLiteStore) touchMemories(ctx context.Context, ids []string) { + go func() { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now().UTC().Format(time.RFC3339Nano) + placeholders := make([]string, len(ids)) + args := []interface{}{now} + for i, id := range ids { + placeholders[i] = "?" + args = append(args, id) + } + query := "UPDATE memories SET last_referenced = ?, access_count = access_count + 1 WHERE id IN (" + strings.Join(placeholders, ",") + ")" + _, _ = s.db.ExecContext(ctx, query, args...) + }() +} + +// scored pairs a recalled memory with its computed relevance. +type scored struct { + memory RecalledMemory + relevance float64 +} + +// sortByRelevance sorts scored candidates by relevance descending. +func sortByRelevance(candidates []scored) { + // Simple insertion sort - typically small N + for i := 1; i < len(candidates); i++ { + key := candidates[i] + j := i - 1 + for j >= 0 && candidates[j].relevance < key.relevance { + candidates[j+1] = candidates[j] + j-- + } + candidates[j+1] = key + } +} diff --git a/pkg/memory/store.go b/pkg/memory/store.go new file mode 100644 index 0000000..2ab9c8a --- /dev/null +++ b/pkg/memory/store.go @@ -0,0 +1,181 @@ +// Package memory provides a persistent context memory store with +// write-time deduplication, tag-based recall, and hierarchical decay. +package memory + +import ( + "context" + "errors" + "time" +) + +// Common errors returned by memory stores. +var ( + ErrNotFound = errors.New("memory not found") + ErrEmptyText = errors.New("entry text is empty") + ErrStoreClosed = errors.New("memory store is closed") + ErrInvalidQuery = errors.New("query text is empty") +) + +// DecayLevel represents how compressed a memory is. +// Memories decay over time: full text -> summary -> keywords -> evicted. +type DecayLevel int + +const ( + DecayFull DecayLevel = 0 // Original text, no compression + DecaySummary DecayLevel = 1 // Paragraph-level summary (~20% of original) + DecayKeywords DecayLevel = 2 // Keywords only (~5% of original) +) + +// Entry is a single memory stored in the system. +type Entry struct { + ID string `json:"id"` + Text string `json:"text"` + Embedding []float32 `json:"embedding,omitempty"` + Source string `json:"source,omitempty"` + Tags []string `json:"tags,omitempty"` + SessionID string `json:"session_id,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + DecayLevel DecayLevel `json:"decay_level"` + CreatedAt time.Time `json:"created_at"` + LastReferenced time.Time `json:"last_referenced"` + AccessCount int `json:"access_count"` +} + +// StoreRequest is the input for storing memories. +type StoreRequest struct { + SessionID string `json:"session_id,omitempty"` + Entries []StoreEntry `json:"entries"` +} + +// StoreEntry is a single entry in a store request. +type StoreEntry struct { + Text string `json:"text"` + Embedding []float32 `json:"embedding,omitempty"` + Source string `json:"source,omitempty"` + Tags []string `json:"tags,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// StoreResult is the output of a store operation. +type StoreResult struct { + Stored int `json:"stored"` + Merged int `json:"merged"` + Deduplicated int `json:"deduplicated"` + TotalMemories int `json:"total_memories"` +} + +// RecallRequest is the input for recalling memories. +type RecallRequest struct { + Query string `json:"query"` + QueryEmbedding []float32 `json:"query_embedding,omitempty"` + Tags []string `json:"tags,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + MaxResults int `json:"max_results,omitempty"` + RecencyWeight float64 `json:"recency_weight,omitempty"` +} + +// RecallResult is the output of a recall operation. +type RecallResult struct { + Memories []RecalledMemory `json:"memories"` + Stats RecallStats `json:"stats"` +} + +// RecalledMemory is a single memory returned from recall. +type RecalledMemory struct { + ID string `json:"id"` + Text string `json:"text"` + Source string `json:"source,omitempty"` + Tags []string `json:"tags,omitempty"` + Relevance float64 `json:"relevance"` + DecayLevel DecayLevel `json:"decay_level"` + LastReferenced time.Time `json:"last_referenced"` +} + +// RecallStats contains recall operation metrics. +type RecallStats struct { + Candidates int `json:"candidates"` + Deduplicated int `json:"deduplicated"` + Returned int `json:"returned"` + TokenCount int `json:"token_count"` +} + +// ForgetRequest specifies which memories to remove. +type ForgetRequest struct { + IDs []string `json:"ids,omitempty"` + Tags []string `json:"tags,omitempty"` + OlderThan time.Time `json:"older_than,omitempty"` +} + +// ForgetResult is the output of a forget operation. +type ForgetResult struct { + Removed int `json:"removed"` + TotalMemories int `json:"total_memories"` +} + +// Stats contains memory store statistics. +type Stats struct { + TotalMemories int `json:"total_memories"` + ByDecayLevel map[int]int `json:"by_decay_level"` + BySource map[string]int `json:"by_source"` + OldestMemory time.Time `json:"oldest_memory,omitempty"` + NewestMemory time.Time `json:"newest_memory,omitempty"` +} + +// Store is the interface for persistent memory backends. +type Store interface { + // Store adds entries to memory with write-time deduplication. + Store(ctx context.Context, req StoreRequest) (*StoreResult, error) + + // Recall retrieves memories matching a query, ranked by relevance and recency. + Recall(ctx context.Context, req RecallRequest) (*RecallResult, error) + + // Forget removes memories matching the given criteria. + Forget(ctx context.Context, req ForgetRequest) (*ForgetResult, error) + + // Stats returns memory store statistics. + Stats(ctx context.Context) (*Stats, error) + + // Close releases resources held by the store. + Close() error +} + +// Config holds memory store configuration. +type Config struct { + // DedupThreshold is the cosine distance below which entries are + // considered duplicates. Default: 0.15. + DedupThreshold float64 + + // DecayEnabled enables the background decay worker. + DecayEnabled bool + + // DecayInterval is how often the decay worker runs. Default: 1h. + DecayInterval time.Duration + + // SummaryAge is the age after which memories are compressed to summaries. + // Default: 24h. + SummaryAge time.Duration + + // KeywordsAge is the age after which memories are compressed to keywords. + // Default: 168h (7 days). + KeywordsAge time.Duration + + // EvictAge is the age after which unreferenced memories are evicted. + // Default: 720h (30 days). + EvictAge time.Duration + + // MaxMemories is the maximum number of memories to store. 0 = unlimited. + MaxMemories int +} + +// DefaultConfig returns sensible defaults. +func DefaultConfig() Config { + return Config{ + DedupThreshold: 0.15, + DecayEnabled: true, + DecayInterval: 1 * time.Hour, + SummaryAge: 24 * time.Hour, + KeywordsAge: 168 * time.Hour, + EvictAge: 720 * time.Hour, + MaxMemories: 0, + } +} From e74ee4e95e864b5faf0f8bb5b1789814730f86df Mon Sep 17 00:00:00 2001 From: Siddhant Khare Date: Sun, 22 Feb 2026 18:41:48 +0000 Subject: [PATCH 2/6] fix: handle all error returns to satisfy errcheck linter Co-authored-by: Ona --- cmd/api.go | 2 +- cmd/api_memory.go | 10 +++++----- cmd/mcp.go | 2 +- cmd/memory.go | 8 ++++---- pkg/memory/decay.go | 2 +- pkg/memory/sqlite.go | 12 ++++++------ 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/cmd/api.go b/cmd/api.go index 7516ae4..cd84beb 100644 --- a/cmd/api.go +++ b/cmd/api.go @@ -186,7 +186,7 @@ func runAPI(cmd *cobra.Command, args []string) error { if err != nil { return fmt.Errorf("failed to create memory store: %w", err) } - defer memStore.Close() + defer func() { _ = memStore.Close() }() memAPI := &MemoryAPI{store: memStore, embedder: embedder} diff --git a/cmd/api_memory.go b/cmd/api_memory.go index 7caa3b8..12022c3 100644 --- a/cmd/api_memory.go +++ b/cmd/api_memory.go @@ -73,7 +73,7 @@ func (m *MemoryAPI) handleStore(w http.ResponseWriter, r *http.Request) { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(result) + _ = json.NewEncoder(w).Encode(result) } func (m *MemoryAPI) handleRecall(w http.ResponseWriter, r *http.Request) { @@ -112,7 +112,7 @@ func (m *MemoryAPI) handleRecall(w http.ResponseWriter, r *http.Request) { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(result) + _ = json.NewEncoder(w).Encode(result) } func (m *MemoryAPI) handleForget(w http.ResponseWriter, r *http.Request) { @@ -134,7 +134,7 @@ func (m *MemoryAPI) handleForget(w http.ResponseWriter, r *http.Request) { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(result) + _ = json.NewEncoder(w).Encode(result) } func (m *MemoryAPI) handleStats(w http.ResponseWriter, r *http.Request) { @@ -150,11 +150,11 @@ func (m *MemoryAPI) handleStats(w http.ResponseWriter, r *http.Request) { } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(stats) + _ = json.NewEncoder(w).Encode(stats) } func writeJSONError(w http.ResponseWriter, msg string, code int) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(code) - json.NewEncoder(w).Encode(map[string]string{"error": msg}) + _ = json.NewEncoder(w).Encode(map[string]string{"error": msg}) } diff --git a/cmd/mcp.go b/cmd/mcp.go index 3d220b3..9fc5512 100644 --- a/cmd/mcp.go +++ b/cmd/mcp.go @@ -150,7 +150,7 @@ func runMCP(cmd *cobra.Command, args []string) error { if err != nil { return fmt.Errorf("failed to create memory store: %w", err) } - defer memStore.Close() + defer func() { _ = memStore.Close() }() // Create MCP server wrapper mcpSrv := &MCPServer{ diff --git a/cmd/memory.go b/cmd/memory.go index 77904db..9a91923 100644 --- a/cmd/memory.go +++ b/cmd/memory.go @@ -110,7 +110,7 @@ func runMemoryStore(cmd *cobra.Command, args []string) error { if err != nil { return err } - defer store.Close() + defer func() { _ = store.Close() }() entry := memory.StoreEntry{ Text: text, @@ -167,7 +167,7 @@ func runMemoryRecall(cmd *cobra.Command, args []string) error { if err != nil { return err } - defer store.Close() + defer func() { _ = store.Close() }() req := memory.RecallRequest{ Query: query, @@ -216,7 +216,7 @@ func runMemoryForget(cmd *cobra.Command, args []string) error { if err != nil { return err } - defer store.Close() + defer func() { _ = store.Close() }() result, err := store.Forget(context.Background(), memory.ForgetRequest{ Tags: tags, @@ -236,7 +236,7 @@ func runMemoryStats(cmd *cobra.Command, args []string) error { if err != nil { return err } - defer store.Close() + defer func() { _ = store.Close() }() stats, err := store.Stats(context.Background()) if err != nil { diff --git a/pkg/memory/decay.go b/pkg/memory/decay.go index 4102030..5773ab8 100644 --- a/pkg/memory/decay.go +++ b/pkg/memory/decay.go @@ -113,7 +113,7 @@ func (w *DecayWorker) decayRows(ctx context.Context, cutoff string, fromLevel, t } entries = append(entries, e) } - rows.Close() + _ = rows.Close() for _, e := range entries { compressed := transform(e.text) diff --git a/pkg/memory/sqlite.go b/pkg/memory/sqlite.go index 2358a52..b84d732 100644 --- a/pkg/memory/sqlite.go +++ b/pkg/memory/sqlite.go @@ -34,13 +34,13 @@ func NewSQLiteStore(dsn string, cfg Config) (*SQLiteStore, error) { // Enable WAL mode for better concurrent read performance if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { - db.Close() + _ = db.Close() return nil, fmt.Errorf("set WAL mode: %w", err) } s := &SQLiteStore{db: db, cfg: cfg} if err := s.migrate(); err != nil { - db.Close() + _ = db.Close() return nil, fmt.Errorf("migrate: %w", err) } @@ -141,7 +141,7 @@ func (s *SQLiteStore) findDuplicate(ctx context.Context, embedding []float32) (s if err != nil { return "", err } - defer rows.Close() + defer func() { _ = rows.Close() }() for rows.Next() { var id string @@ -203,7 +203,7 @@ func (s *SQLiteStore) Recall(ctx context.Context, req RecallRequest) (*RecallRes if err != nil { return nil, fmt.Errorf("query memories: %w", err) } - defer rows.Close() + defer func() { _ = rows.Close() }() var candidates []scored now := time.Now() @@ -371,7 +371,7 @@ func (s *SQLiteStore) Stats(ctx context.Context) (*Stats, error) { if err != nil { return nil, err } - defer rows.Close() + defer func() { _ = rows.Close() }() for rows.Next() { var level, count int if err := rows.Scan(&level, &count); err != nil { @@ -385,7 +385,7 @@ func (s *SQLiteStore) Stats(ctx context.Context) (*Stats, error) { if err != nil { return nil, err } - defer rows2.Close() + defer func() { _ = rows2.Close() }() for rows2.Next() { var source string var count int From da796daf443e971ef150563407009e70b4b29fa2 Mon Sep 17 00:00:00 2001 From: Siddhant Khare Date: Sun, 22 Feb 2026 18:42:52 +0000 Subject: [PATCH 3/6] fix: handle Close() error returns in tests Co-authored-by: Ona --- pkg/memory/memory_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/memory/memory_test.go b/pkg/memory/memory_test.go index 6cb0adb..7c82959 100644 --- a/pkg/memory/memory_test.go +++ b/pkg/memory/memory_test.go @@ -24,7 +24,7 @@ func newTestStore(t *testing.T) *SQLiteStore { if err != nil { t.Fatalf("NewSQLiteStore: %v", err) } - t.Cleanup(func() { s.Close() }) + t.Cleanup(func() { _ = s.Close() }) return s } @@ -283,7 +283,7 @@ func TestDecayWorker(t *testing.T) { if err != nil { t.Fatalf("NewSQLiteStore: %v", err) } - defer s.Close() + defer func() { _ = s.Close() }() ctx := context.Background() From e87a6ab7c9900b2196bdd9493a4355311453d22a Mon Sep 17 00:00:00 2001 From: Siddhant Khare Date: Sun, 22 Feb 2026 18:55:51 +0000 Subject: [PATCH 4/6] refactor: address code review findings for memory store - P1: Make touchMemories synchronous (no goroutine leak risk) - P2: Remove unused MaxMemories from config - P2: Use junction table (memory_tags) for exact tag matching - P3: Make memory store opt-in via --memory flag in api/mcp - P4: Replace local embeddingProvider with retriever.EmbeddingProvider - P4: Reuse pkg/compress extractive scorer in decay extractSummary - Fix scan-then-process pattern to avoid SQLite single-conn deadlocks Co-authored-by: Ona --- cmd/api.go | 41 +++++----- cmd/api_memory.go | 9 +-- cmd/mcp.go | 28 ++++--- pkg/memory/decay.go | 65 ++++----------- pkg/memory/memory_test.go | 4 +- pkg/memory/sqlite.go | 162 +++++++++++++++++++++++--------------- pkg/memory/store.go | 4 - 7 files changed, 157 insertions(+), 156 deletions(-) diff --git a/cmd/api.go b/cmd/api.go index cd84beb..3279951 100644 --- a/cmd/api.go +++ b/cmd/api.go @@ -45,6 +45,7 @@ func init() { apiCmd.Flags().String("openai-key", "", "OpenAI API key for embeddings (or use OPENAI_API_KEY)") apiCmd.Flags().String("embedding-model", "text-embedding-3-small", "OpenAI embedding model") apiCmd.Flags().String("api-keys", "", "Comma-separated list of valid API keys (or use DISTILL_API_KEYS)") + apiCmd.Flags().Bool("memory", false, "Enable persistent memory store") // Bind to viper for config file support _ = viper.BindPFlag("server.port", apiCmd.Flags().Lookup("port")) @@ -173,28 +174,31 @@ func runAPI(cmd *cobra.Command, args []string) error { tracing: tp, } - // Setup memory store - memDBPath := viper.GetString("memory.db_path") - if memDBPath == "" { - memDBPath = "distill-memory.db" - } - memThreshold := viper.GetFloat64("memory.dedup_threshold") - if memThreshold == 0 { - memThreshold = 0.15 - } - memStore, err := memoryStoreFromConfig(memDBPath, memThreshold) - if err != nil { - return fmt.Errorf("failed to create memory store: %w", err) - } - defer func() { _ = memStore.Close() }() - - memAPI := &MemoryAPI{store: memStore, embedder: embedder} - // Setup routes mux := http.NewServeMux() mux.HandleFunc("/v1/dedupe", m.Middleware("/v1/dedupe", server.handleDedupe)) mux.HandleFunc("/v1/dedupe/stream", m.Middleware("/v1/dedupe/stream", server.handleDedupeStream)) - memAPI.RegisterMemoryRoutes(mux, m.Middleware) + + // Setup memory store (opt-in) + enableMemory, _ := cmd.Flags().GetBool("memory") + if enableMemory { + memDBPath := viper.GetString("memory.db_path") + if memDBPath == "" { + memDBPath = "distill-memory.db" + } + memThreshold := viper.GetFloat64("memory.dedup_threshold") + if memThreshold == 0 { + memThreshold = 0.15 + } + memStore, err := memoryStoreFromConfig(memDBPath, memThreshold) + if err != nil { + return fmt.Errorf("failed to create memory store: %w", err) + } + defer func() { _ = memStore.Close() }() + + memAPI := &MemoryAPI{store: memStore, embedder: embedder} + memAPI.RegisterMemoryRoutes(mux, m.Middleware) + } mux.HandleFunc("/health", server.handleHealth) mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) { m.Handler().ServeHTTP(w, r) @@ -236,6 +240,7 @@ func runAPI(cmd *cobra.Command, args []string) error { fmt.Printf("Distill API server starting on %s\n", addr) fmt.Printf(" Embeddings: %v\n", embedder != nil) fmt.Printf(" Auth: %v (%d keys)\n", server.hasAuth, len(validKeys)) + fmt.Printf(" Memory: %v\n", enableMemory) fmt.Println() fmt.Println("Endpoints:") fmt.Printf(" POST http://%s/v1/dedupe\n", addr) diff --git a/cmd/api_memory.go b/cmd/api_memory.go index 12022c3..df95f5d 100644 --- a/cmd/api_memory.go +++ b/cmd/api_memory.go @@ -8,18 +8,13 @@ import ( "time" "github.com/Siddhant-K-code/distill/pkg/memory" + "github.com/Siddhant-K-code/distill/pkg/retriever" ) // MemoryAPI handles memory-related HTTP endpoints. type MemoryAPI struct { store *memory.SQLiteStore - embedder embeddingProvider -} - -// embeddingProvider is a minimal interface for generating embeddings. -type embeddingProvider interface { - Embed(ctx context.Context, text string) ([]float32, error) - EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) + embedder retriever.EmbeddingProvider } // RegisterMemoryRoutes adds memory endpoints to the given mux. diff --git a/cmd/mcp.go b/cmd/mcp.go index 9fc5512..1393f5e 100644 --- a/cmd/mcp.go +++ b/cmd/mcp.go @@ -89,6 +89,9 @@ func init() { mcpCmd.Flags().String("openai-key", "", "OpenAI API key for embeddings (or use OPENAI_API_KEY)") mcpCmd.Flags().String("embedding-model", "text-embedding-3-small", "OpenAI embedding model") + // Memory store + mcpCmd.Flags().Bool("memory", false, "Enable persistent memory store") + // Default deduplication settings mcpCmd.Flags().Int("over-fetch-k", 50, "Default over-fetch count") mcpCmd.Flags().Int("target-k", 8, "Default target chunk count") @@ -143,19 +146,22 @@ func runMCP(cmd *cobra.Command, args []string) error { IncludeMetadata: true, } - // Create memory store - memCfg := memory.DefaultConfig() - memCfg.DedupThreshold = threshold - memStore, err := memory.NewSQLiteStore("distill-memory.db", memCfg) - if err != nil { - return fmt.Errorf("failed to create memory store: %w", err) - } - defer func() { _ = memStore.Close() }() - // Create MCP server wrapper mcpSrv := &MCPServer{ - cfg: brokerCfg, - memStore: memStore, + cfg: brokerCfg, + } + + // Create memory store (opt-in) + enableMemory, _ := cmd.Flags().GetBool("memory") + if enableMemory { + memCfg := memory.DefaultConfig() + memCfg.DedupThreshold = threshold + memStore, err := memory.NewSQLiteStore("distill-memory.db", memCfg) + if err != nil { + return fmt.Errorf("failed to create memory store: %w", err) + } + defer func() { _ = memStore.Close() }() + mcpSrv.memStore = memStore } // Create embedding provider if OpenAI key is provided diff --git a/pkg/memory/decay.go b/pkg/memory/decay.go index 5773ab8..fdee338 100644 --- a/pkg/memory/decay.go +++ b/pkg/memory/decay.go @@ -5,6 +5,9 @@ import ( "fmt" "strings" "time" + + "github.com/Siddhant-K-code/distill/pkg/compress" + "github.com/Siddhant-K-code/distill/pkg/types" ) // DecayWorker runs periodic compression of aging memories. @@ -58,8 +61,6 @@ func (w *DecayWorker) RunOnce(ctx context.Context) error { } func (w *DecayWorker) runOnce(ctx context.Context) error { - w.store.mu.Lock() - defer w.store.mu.Unlock() now := time.Now().UTC() @@ -126,34 +127,20 @@ func (w *DecayWorker) decayRows(ctx context.Context, cutoff string, fromLevel, t return nil } -// extractSummary produces a shortened version of the text. -// Keeps the first and last sentences, targeting ~20% of original length. +// extractSummary produces a shortened version of the text using the +// extractive compressor's sentence scorer for better quality summaries. func extractSummary(text string) string { - sentences := splitSentences(text) - if len(sentences) <= 2 { - return text - } - - // Target ~20% of sentences, minimum 2 - target := len(sentences) / 5 - if target < 2 { - target = 2 + c := compress.NewExtractiveCompressor() + chunks := []types.Chunk{{ID: "decay", Text: text}} + opts := compress.Options{ + TargetReduction: 0.2, // keep ~20% of content + MinChunkLength: 20, } - - // Keep first half from the beginning, second half from the end - firstHalf := target / 2 - if firstHalf < 1 { - firstHalf = 1 + result, _, _ := c.Compress(context.Background(), chunks, opts) + if len(result) > 0 && result[0].Text != "" { + return result[0].Text } - secondHalf := target - firstHalf - - var parts []string - parts = append(parts, sentences[:firstHalf]...) - if secondHalf > 0 && len(sentences)-secondHalf > firstHalf { - parts = append(parts, sentences[len(sentences)-secondHalf:]...) - } - - return strings.Join(parts, " ") + return text } // extractKeywords produces a keyword-only representation. @@ -186,30 +173,6 @@ func extractKeywords(text string) string { return strings.Join(keywords, ", ") } -// splitSentences splits text on sentence boundaries. -func splitSentences(text string) []string { - var sentences []string - var current strings.Builder - - for _, r := range text { - current.WriteRune(r) - if r == '.' || r == '!' || r == '?' { - s := strings.TrimSpace(current.String()) - if s != "" { - sentences = append(sentences, s) - } - current.Reset() - } - } - - // Remaining text - if s := strings.TrimSpace(current.String()); s != "" { - sentences = append(sentences, s) - } - - return sentences -} - // isStopWord returns true for common English stop words. func isStopWord(w string) bool { stops := map[string]bool{ diff --git a/pkg/memory/memory_test.go b/pkg/memory/memory_test.go index 7c82959..a5f9448 100644 --- a/pkg/memory/memory_test.go +++ b/pkg/memory/memory_test.go @@ -143,8 +143,8 @@ func TestForgetByAge(t *testing.T) { now := time.Now().UTC() old := now.Add(-48 * time.Hour).Format(time.RFC3339Nano) _, err := s.db.ExecContext(ctx, - `INSERT INTO memories (id, text, source, tags, metadata, decay_level, created_at, last_referenced, access_count) - VALUES (?, ?, '', '[]', '{}', 0, ?, ?, 0)`, + `INSERT INTO memories (id, text, source, metadata, decay_level, created_at, last_referenced, access_count) + VALUES (?, ?, '', '{}', 0, ?, ?, 0)`, "old-1", "Old memory", old, old, ) if err != nil { diff --git a/pkg/memory/sqlite.go b/pkg/memory/sqlite.go index b84d732..5c9d11c 100644 --- a/pkg/memory/sqlite.go +++ b/pkg/memory/sqlite.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "strings" - "sync" "time" distillmath "github.com/Siddhant-K-code/distill/pkg/math" @@ -14,10 +13,11 @@ import ( ) // SQLiteStore implements Store using SQLite for local persistent storage. +// Uses a single connection (SetMaxOpenConns(1)) so SQLite's internal +// serialization handles concurrency. No application-level mutex needed. type SQLiteStore struct { db *sql.DB cfg Config - mu sync.RWMutex } // NewSQLiteStore creates a new SQLite-backed memory store. @@ -32,6 +32,10 @@ func NewSQLiteStore(dsn string, cfg Config) (*SQLiteStore, error) { return nil, fmt.Errorf("open sqlite: %w", err) } + // SQLite doesn't support concurrent connections well with in-memory DBs + // and PRAGMAs are per-connection, so pin to a single connection. + db.SetMaxOpenConns(1) + // Enable WAL mode for better concurrent read performance if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { _ = db.Close() @@ -54,7 +58,6 @@ func (s *SQLiteStore) migrate() error { text TEXT NOT NULL, embedding BLOB, source TEXT DEFAULT '', - tags TEXT DEFAULT '[]', session_id TEXT DEFAULT '', metadata TEXT DEFAULT '{}', decay_level INTEGER DEFAULT 0, @@ -62,19 +65,27 @@ func (s *SQLiteStore) migrate() error { last_referenced TEXT NOT NULL, access_count INTEGER DEFAULT 0 ); - CREATE INDEX IF NOT EXISTS idx_memories_tags ON memories(tags); + CREATE TABLE IF NOT EXISTS memory_tags ( + memory_id TEXT NOT NULL, + tag TEXT NOT NULL, + PRIMARY KEY (memory_id, tag), + FOREIGN KEY (memory_id) REFERENCES memories(id) ON DELETE CASCADE + ); + CREATE INDEX IF NOT EXISTS idx_memory_tags_tag ON memory_tags(tag); CREATE INDEX IF NOT EXISTS idx_memories_decay ON memories(decay_level); CREATE INDEX IF NOT EXISTS idx_memories_created ON memories(created_at); CREATE INDEX IF NOT EXISTS idx_memories_referenced ON memories(last_referenced); ` + // Enable foreign keys for CASCADE deletes + if _, err := s.db.Exec("PRAGMA foreign_keys = ON"); err != nil { + return fmt.Errorf("enable foreign keys: %w", err) + } _, err := s.db.Exec(schema) return err } // Store adds entries with write-time deduplication. func (s *SQLiteStore) Store(ctx context.Context, req StoreRequest) (*StoreResult, error) { - s.mu.Lock() - defer s.mu.Unlock() result := &StoreResult{} @@ -107,20 +118,31 @@ func (s *SQLiteStore) Store(ctx context.Context, req StoreRequest) (*StoreResult id := generateID() now := time.Now().UTC().Format(time.RFC3339Nano) - tagsJSON, _ := json.Marshal(entry.Tags) metaJSON, _ := json.Marshal(entry.Metadata) embBlob := encodeEmbedding(entry.Embedding) sessionID := req.SessionID _, err := s.db.ExecContext(ctx, - `INSERT INTO memories (id, text, embedding, source, tags, session_id, metadata, decay_level, created_at, last_referenced, access_count) - VALUES (?, ?, ?, ?, ?, ?, ?, 0, ?, ?, 0)`, - id, entry.Text, embBlob, entry.Source, string(tagsJSON), sessionID, string(metaJSON), now, now, + `INSERT INTO memories (id, text, embedding, source, session_id, metadata, decay_level, created_at, last_referenced, access_count) + VALUES (?, ?, ?, ?, ?, ?, 0, ?, ?, 0)`, + id, entry.Text, embBlob, entry.Source, sessionID, string(metaJSON), now, now, ) if err != nil { return nil, fmt.Errorf("insert memory: %w", err) } + + // Insert tags into junction table + for _, tag := range entry.Tags { + _, err := s.db.ExecContext(ctx, + "INSERT OR IGNORE INTO memory_tags (memory_id, tag) VALUES (?, ?)", + id, tag, + ) + if err != nil { + return nil, fmt.Errorf("insert tag: %w", err) + } + } + result.Stored++ } @@ -166,8 +188,6 @@ func (s *SQLiteStore) findDuplicate(ctx context.Context, embedding []float32) (s // Recall retrieves memories matching a query, ranked by relevance and recency. func (s *SQLiteStore) Recall(ctx context.Context, req RecallRequest) (*RecallResult, error) { - s.mu.RLock() - defer s.mu.RUnlock() if req.Query == "" && len(req.QueryEmbedding) == 0 { return nil, ErrInvalidQuery @@ -187,48 +207,59 @@ func (s *SQLiteStore) Recall(ctx context.Context, req RecallRequest) (*RecallRes } // Build query with optional tag filter - query := "SELECT id, text, embedding, source, tags, decay_level, last_referenced FROM memories" + query := "SELECT m.id, m.text, m.embedding, m.source, m.decay_level, m.last_referenced FROM memories m" var args []interface{} if len(req.Tags) > 0 { - clauses := make([]string, len(req.Tags)) + placeholders := make([]string, len(req.Tags)) for i, tag := range req.Tags { - clauses[i] = "tags LIKE ?" - args = append(args, "%\""+tag+"\"%") + placeholders[i] = "?" + args = append(args, tag) } - query += " WHERE (" + strings.Join(clauses, " OR ") + ")" + query += " WHERE m.id IN (SELECT memory_id FROM memory_tags WHERE tag IN (" + strings.Join(placeholders, ",") + "))" } rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("query memories: %w", err) } - defer func() { _ = rows.Close() }() - - var candidates []scored - now := time.Now() + // Scan all rows first, then close before issuing more queries. + // SQLite with MaxOpenConns(1) requires the connection to be free. + type rawRow struct { + id, text, source, refStr string + embBlob []byte + decayLevel int + } + var rawRows []rawRow for rows.Next() { - var ( - id, text, source, tagsStr, refStr string - embBlob []byte - decayLevel int - ) - if err := rows.Scan(&id, &text, &embBlob, &source, &tagsStr, &decayLevel, &refStr); err != nil { + var r rawRow + if err := rows.Scan(&r.id, &r.text, &r.embBlob, &r.source, &r.decayLevel, &r.refStr); err != nil { + _ = rows.Close() return nil, err } + rawRows = append(rawRows, r) + } + if err := rows.Err(); err != nil { + _ = rows.Close() + return nil, err + } + _ = rows.Close() + + var candidates []scored + now := time.Now() - var tags []string - _ = json.Unmarshal([]byte(tagsStr), &tags) - lastRef, _ := time.Parse(time.RFC3339Nano, refStr) + for _, r := range rawRows { + tags, _ := s.loadTags(ctx, r.id) + lastRef, _ := time.Parse(time.RFC3339Nano, r.refStr) // Compute relevance score from embedding similarity var similarity float64 if len(req.QueryEmbedding) > 0 { - existing := decodeEmbedding(embBlob) + existing := decodeEmbedding(r.embBlob) if len(existing) > 0 { dist := distillmath.CosineDistance(req.QueryEmbedding, existing) - similarity = 1.0 - dist // Convert distance to similarity + similarity = 1.0 - dist } } @@ -239,25 +270,21 @@ func (s *SQLiteStore) Recall(ctx context.Context, req RecallRequest) (*RecallRes recency = 1.0 / (1.0 + age/24.0) } - // Combined score relevance := (1.0-recencyWeight)*similarity + recencyWeight*recency candidates = append(candidates, scored{ memory: RecalledMemory{ - ID: id, - Text: text, - Source: source, + ID: r.id, + Text: r.text, + Source: r.source, Tags: tags, Relevance: relevance, - DecayLevel: DecayLevel(decayLevel), + DecayLevel: DecayLevel(r.decayLevel), LastReferenced: lastRef, }, relevance: relevance, }) } - if err := rows.Err(); err != nil { - return nil, err - } // Sort by relevance descending sortByRelevance(candidates) @@ -299,8 +326,6 @@ func (s *SQLiteStore) Recall(ctx context.Context, req RecallRequest) (*RecallRes // Forget removes memories matching the given criteria. func (s *SQLiteStore) Forget(ctx context.Context, req ForgetRequest) (*ForgetResult, error) { - s.mu.Lock() - defer s.mu.Unlock() var conditions []string var args []interface{} @@ -315,12 +340,12 @@ func (s *SQLiteStore) Forget(ctx context.Context, req ForgetRequest) (*ForgetRes } if len(req.Tags) > 0 { - tagClauses := make([]string, len(req.Tags)) + placeholders := make([]string, len(req.Tags)) for i, tag := range req.Tags { - tagClauses[i] = "tags LIKE ?" - args = append(args, "%\""+tag+"\"%") + placeholders[i] = "?" + args = append(args, tag) } - conditions = append(conditions, "("+strings.Join(tagClauses, " OR ")+")") + conditions = append(conditions, "id IN (SELECT memory_id FROM memory_tags WHERE tag IN ("+strings.Join(placeholders, ",")+"))") } if !req.OlderThan.IsZero() { @@ -353,8 +378,6 @@ func (s *SQLiteStore) Forget(ctx context.Context, req ForgetRequest) (*ForgetRes // Stats returns memory store statistics. func (s *SQLiteStore) Stats(ctx context.Context) (*Stats, error) { - s.mu.RLock() - defer s.mu.RUnlock() stats := &Stats{ ByDecayLevel: make(map[int]int), @@ -414,23 +437,36 @@ func (s *SQLiteStore) Close() error { return s.db.Close() } -// touchMemories updates last_referenced and access_count for the given IDs. -// Called from Recall under a read lock, so we use a separate goroutine. -func (s *SQLiteStore) touchMemories(ctx context.Context, ids []string) { - go func() { - s.mu.Lock() - defer s.mu.Unlock() +// loadTags returns the tags for a given memory ID. +func (s *SQLiteStore) loadTags(ctx context.Context, memoryID string) ([]string, error) { + rows, err := s.db.QueryContext(ctx, "SELECT tag FROM memory_tags WHERE memory_id = ?", memoryID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() - now := time.Now().UTC().Format(time.RFC3339Nano) - placeholders := make([]string, len(ids)) - args := []interface{}{now} - for i, id := range ids { - placeholders[i] = "?" - args = append(args, id) + var tags []string + for rows.Next() { + var tag string + if err := rows.Scan(&tag); err != nil { + return nil, err } - query := "UPDATE memories SET last_referenced = ?, access_count = access_count + 1 WHERE id IN (" + strings.Join(placeholders, ",") + ")" - _, _ = s.db.ExecContext(ctx, query, args...) - }() + tags = append(tags, tag) + } + return tags, rows.Err() +} + +// touchMemories updates last_referenced and access_count for the given IDs. +func (s *SQLiteStore) touchMemories(ctx context.Context, ids []string) { + now := time.Now().UTC().Format(time.RFC3339Nano) + placeholders := make([]string, len(ids)) + args := []interface{}{now} + for i, id := range ids { + placeholders[i] = "?" + args = append(args, id) + } + query := "UPDATE memories SET last_referenced = ?, access_count = access_count + 1 WHERE id IN (" + strings.Join(placeholders, ",") + ")" + _, _ = s.db.ExecContext(ctx, query, args...) } // scored pairs a recalled memory with its computed relevance. diff --git a/pkg/memory/store.go b/pkg/memory/store.go index 2ab9c8a..e212106 100644 --- a/pkg/memory/store.go +++ b/pkg/memory/store.go @@ -162,9 +162,6 @@ type Config struct { // EvictAge is the age after which unreferenced memories are evicted. // Default: 720h (30 days). EvictAge time.Duration - - // MaxMemories is the maximum number of memories to store. 0 = unlimited. - MaxMemories int } // DefaultConfig returns sensible defaults. @@ -176,6 +173,5 @@ func DefaultConfig() Config { SummaryAge: 24 * time.Hour, KeywordsAge: 168 * time.Hour, EvictAge: 720 * time.Hour, - MaxMemories: 0, } } From 84070d5d0de9047004500f25f4dec866877be737 Mon Sep 17 00:00:00 2001 From: Siddhant Khare Date: Sun, 22 Feb 2026 19:02:18 +0000 Subject: [PATCH 5/6] cleanup: address remaining review nits - Add TODO to findDuplicate noting O(n) full table scan scaling limit - Refactor Stats to scan-then-close each query (consistent with Recall) - Move extractSummary compressor to package-level var (avoid per-call alloc) - Add --memory-db flag to MCP command (was hardcoded) - Move PRAGMA foreign_keys to NewSQLiteStore alongside other PRAGMAs - Move isStopWord map to package-level var (avoid per-call alloc) - Remove trailing blank line in memory.go Co-authored-by: Ona --- cmd/mcp.go | 4 +++- cmd/memory.go | 2 -- pkg/memory/decay.go | 30 +++++++++++++++++------------- pkg/memory/sqlite.go | 42 ++++++++++++++++++++++++++++++------------ 4 files changed, 50 insertions(+), 28 deletions(-) diff --git a/cmd/mcp.go b/cmd/mcp.go index 1393f5e..0e43e1f 100644 --- a/cmd/mcp.go +++ b/cmd/mcp.go @@ -91,6 +91,7 @@ func init() { // Memory store mcpCmd.Flags().Bool("memory", false, "Enable persistent memory store") + mcpCmd.Flags().String("memory-db", "distill-memory.db", "SQLite database path for memory store") // Default deduplication settings mcpCmd.Flags().Int("over-fetch-k", 50, "Default over-fetch count") @@ -154,9 +155,10 @@ func runMCP(cmd *cobra.Command, args []string) error { // Create memory store (opt-in) enableMemory, _ := cmd.Flags().GetBool("memory") if enableMemory { + memDBPath, _ := cmd.Flags().GetString("memory-db") memCfg := memory.DefaultConfig() memCfg.DedupThreshold = threshold - memStore, err := memory.NewSQLiteStore("distill-memory.db", memCfg) + memStore, err := memory.NewSQLiteStore(memDBPath, memCfg) if err != nil { return fmt.Errorf("failed to create memory store: %w", err) } diff --git a/cmd/memory.go b/cmd/memory.go index 9a91923..5120c42 100644 --- a/cmd/memory.go +++ b/cmd/memory.go @@ -258,5 +258,3 @@ func memoryStoreFromConfig(dbPath string, threshold float64) (*memory.SQLiteStor cfg.DedupThreshold = threshold return memory.NewSQLiteStore(dbPath, cfg) } - - diff --git a/pkg/memory/decay.go b/pkg/memory/decay.go index fdee338..889587a 100644 --- a/pkg/memory/decay.go +++ b/pkg/memory/decay.go @@ -127,16 +127,18 @@ func (w *DecayWorker) decayRows(ctx context.Context, cutoff string, fromLevel, t return nil } +// summaryCompressor is reused across decay passes to avoid per-call allocation. +var summaryCompressor = compress.NewExtractiveCompressor() + // extractSummary produces a shortened version of the text using the // extractive compressor's sentence scorer for better quality summaries. func extractSummary(text string) string { - c := compress.NewExtractiveCompressor() chunks := []types.Chunk{{ID: "decay", Text: text}} opts := compress.Options{ TargetReduction: 0.2, // keep ~20% of content MinChunkLength: 20, } - result, _, _ := c.Compress(context.Background(), chunks, opts) + result, _, _ := summaryCompressor.Compress(context.Background(), chunks, opts) if len(result) > 0 && result[0].Text != "" { return result[0].Text } @@ -173,17 +175,19 @@ func extractKeywords(text string) string { return strings.Join(keywords, ", ") } +// stopWords is the set of common English stop words filtered during keyword extraction. +var stopWords = map[string]bool{ + "that": true, "this": true, "with": true, "from": true, + "have": true, "been": true, "were": true, "they": true, + "their": true, "which": true, "would": true, "there": true, + "about": true, "could": true, "other": true, "into": true, + "more": true, "some": true, "than": true, "them": true, + "very": true, "when": true, "what": true, "your": true, + "also": true, "each": true, "does": true, "will": true, + "just": true, "should": true, "because": true, "these": true, +} + // isStopWord returns true for common English stop words. func isStopWord(w string) bool { - stops := map[string]bool{ - "that": true, "this": true, "with": true, "from": true, - "have": true, "been": true, "were": true, "they": true, - "their": true, "which": true, "would": true, "there": true, - "about": true, "could": true, "other": true, "into": true, - "more": true, "some": true, "than": true, "them": true, - "very": true, "when": true, "what": true, "your": true, - "also": true, "each": true, "does": true, "will": true, - "just": true, "should": true, "because": true, "these": true, - } - return stops[w] + return stopWords[w] } diff --git a/pkg/memory/sqlite.go b/pkg/memory/sqlite.go index 5c9d11c..d82a8e6 100644 --- a/pkg/memory/sqlite.go +++ b/pkg/memory/sqlite.go @@ -36,12 +36,18 @@ func NewSQLiteStore(dsn string, cfg Config) (*SQLiteStore, error) { // and PRAGMAs are per-connection, so pin to a single connection. db.SetMaxOpenConns(1) - // Enable WAL mode for better concurrent read performance + // WAL mode for better read performance if pool size increases later. if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { _ = db.Close() return nil, fmt.Errorf("set WAL mode: %w", err) } + // Enable foreign keys for CASCADE deletes on memory_tags. + if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { + _ = db.Close() + return nil, fmt.Errorf("enable foreign keys: %w", err) + } + s := &SQLiteStore{db: db, cfg: cfg} if err := s.migrate(); err != nil { _ = db.Close() @@ -76,10 +82,6 @@ func (s *SQLiteStore) migrate() error { CREATE INDEX IF NOT EXISTS idx_memories_created ON memories(created_at); CREATE INDEX IF NOT EXISTS idx_memories_referenced ON memories(last_referenced); ` - // Enable foreign keys for CASCADE deletes - if _, err := s.db.Exec("PRAGMA foreign_keys = ON"); err != nil { - return fmt.Errorf("enable foreign keys: %w", err) - } _, err := s.db.Exec(schema) return err } @@ -158,6 +160,10 @@ func (s *SQLiteStore) Store(ctx context.Context, req StoreRequest) (*StoreResult // findDuplicate scans existing embeddings and returns the ID of the first // entry within the dedup threshold. Returns "" if no duplicate found. +// +// TODO: This does a full table scan (O(n) per insert). Fine for < 10K entries. +// At larger scale, consider an approximate nearest-neighbor index or caching +// embeddings in memory. func (s *SQLiteStore) findDuplicate(ctx context.Context, embedding []float32) (string, error) { rows, err := s.db.QueryContext(ctx, "SELECT id, embedding FROM memories WHERE embedding IS NOT NULL") if err != nil { @@ -377,6 +383,8 @@ func (s *SQLiteStore) Forget(ctx context.Context, req ForgetRequest) (*ForgetRes } // Stats returns memory store statistics. +// Each query is scanned and closed before the next to avoid holding +// the single SQLite connection across multiple result sets. func (s *SQLiteStore) Stats(ctx context.Context) (*Stats, error) { stats := &Stats{ @@ -389,34 +397,44 @@ func (s *SQLiteStore) Stats(ctx context.Context) (*Stats, error) { return nil, err } - // By decay level + // By decay level - scan and close before next query rows, err := s.db.QueryContext(ctx, "SELECT decay_level, COUNT(*) FROM memories GROUP BY decay_level") if err != nil { return nil, err } - defer func() { _ = rows.Close() }() for rows.Next() { var level, count int if err := rows.Scan(&level, &count); err != nil { + _ = rows.Close() return nil, err } stats.ByDecayLevel[level] = count } + if err := rows.Err(); err != nil { + _ = rows.Close() + return nil, err + } + _ = rows.Close() - // By source - rows2, err := s.db.QueryContext(ctx, "SELECT source, COUNT(*) FROM memories WHERE source != '' GROUP BY source") + // By source - scan and close before next query + rows, err = s.db.QueryContext(ctx, "SELECT source, COUNT(*) FROM memories WHERE source != '' GROUP BY source") if err != nil { return nil, err } - defer func() { _ = rows2.Close() }() - for rows2.Next() { + for rows.Next() { var source string var count int - if err := rows2.Scan(&source, &count); err != nil { + if err := rows.Scan(&source, &count); err != nil { + _ = rows.Close() return nil, err } stats.BySource[source] = count } + if err := rows.Err(); err != nil { + _ = rows.Close() + return nil, err + } + _ = rows.Close() // Oldest and newest var oldest, newest sql.NullString From 977a44236bb821fc23b63b3af2eaf5e99f8e5f76 Mon Sep 17 00:00:00 2001 From: Siddhant Khare Date: Sun, 22 Feb 2026 19:05:42 +0000 Subject: [PATCH 6/6] docs: add context memory section to README - Add Context Memory section with CLI, API, and MCP usage examples - Add memory endpoints to API Endpoints table - Add memory command to CLI Commands list - Update architecture diagram: Memory Store is shipped, not planned - Update roadmap: mark Context Memory Store as shipped - Update intro blurb to reflect memory is available Co-authored-by: Ona --- README.md | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 81 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index f0246cb..4bb413a 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ **Context intelligence layer for AI agents.** -Deduplicates, compresses, and manages context across sessions - so your agents produce reliable, deterministic outputs. Today: a dedup pipeline with ~12ms overhead. Next: persistent context memory, code change impact graphs, and session-aware context windows. +Deduplicates, compresses, and manages context across sessions - so your agents produce reliable, deterministic outputs. Includes a dedup pipeline with ~12ms overhead and persistent context memory with write-time dedup and hierarchical decay. Less redundant data. Lower costs. Faster responses. Deterministic results. @@ -201,12 +201,82 @@ Add to Claude Desktop (`~/Library/Application Support/Claude/claude_desktop_conf See [mcp/README.md](mcp/README.md) for more configuration options. +## Context Memory + +Persistent memory that accumulates knowledge across agent sessions. Memories are deduplicated on write, ranked by relevance + recency on recall, and compressed over time through hierarchical decay. + +Enable with the `--memory` flag on `api` or `mcp` commands. + +### CLI + +```bash +# Store a memory +distill memory store --text "Auth uses JWT with RS256 signing" --tags auth --source docs + +# Recall relevant memories +distill memory recall --query "How does authentication work?" --max-results 5 + +# Remove outdated memories +distill memory forget --tags deprecated + +# View statistics +distill memory stats +``` + +### API + +```bash +# Start API with memory enabled +distill api --port 8080 --memory + +# Store +curl -X POST http://localhost:8080/v1/memory/store \ + -H "Content-Type: application/json" \ + -d '{ + "session_id": "session-1", + "entries": [{"text": "Auth uses JWT with RS256", "tags": ["auth"], "source": "docs"}] + }' + +# Recall +curl -X POST http://localhost:8080/v1/memory/recall \ + -H "Content-Type: application/json" \ + -d '{"query": "How does auth work?", "max_results": 5}' +``` + +### MCP + +Memory tools are available in Claude Desktop, Cursor, and other MCP clients when `--memory` is enabled: + +```bash +distill mcp --memory +``` + +Tools exposed: `store_memory`, `recall_memory`, `forget_memory`, `memory_stats`. + +### How Decay Works + +Memories compress over time based on access patterns: + +``` +Full text → Summary (~20%) → Keywords (~5%) → Evicted + (24h) (7 days) (30 days) +``` + +Accessing a memory resets its decay clock. Configure ages via `distill.yaml`: + +```yaml +memory: + db_path: distill-memory.db + dedup_threshold: 0.15 +``` + ## CLI Commands ```bash distill api # Start standalone API server distill serve # Start server with vector DB connection distill mcp # Start MCP server for AI assistants +distill memory # Store, recall, and manage persistent context memories distill analyze # Analyze a file for duplicates distill sync # Upload vectors to Pinecone with dedup distill query # Test a query from command line @@ -220,6 +290,10 @@ distill config # Manage configuration files | POST | `/v1/dedupe` | Deduplicate chunks | | POST | `/v1/dedupe/stream` | SSE streaming dedup with per-stage progress | | POST | `/v1/retrieve` | Query vector DB with dedup (requires backend) | +| POST | `/v1/memory/store` | Store memories with write-time dedup (requires `--memory`) | +| POST | `/v1/memory/recall` | Recall memories by relevance + recency (requires `--memory`) | +| POST | `/v1/memory/forget` | Remove memories by ID, tag, or age (requires `--memory`) | +| GET | `/v1/memory/stats` | Memory store statistics (requires `--memory`) | | GET | `/health` | Health check | | GET | `/metrics` | Prometheus metrics | @@ -489,10 +563,10 @@ KV cache for repeated context patterns (system prompts, tool definitions, boiler │ └─────────┘ └─────────┘ └─────────┘ └──────────┘ └─────────┘ │ │ <1ms 6ms <1ms 2ms 3ms │ │ │ -│ Context Intelligence (planned) │ +│ Context Intelligence │ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────────────────┐ │ │ │ Memory Store │ │ Impact Graph │ │ Session Context Windows │ │ -│ │ (#29) │ │ (#30) │ │ (#31) │ │ +│ │ (shipped) │ │ (#30) │ │ (#31) │ │ │ └──────────────┘ └──────────────┘ └──────────────────────────┘ │ │ │ │ ┌──────────────────────────────────────────────────────────────┐ │ @@ -527,10 +601,10 @@ Distill is evolving from a dedup utility into a context intelligence layer. Here ### Context Memory -| Feature | Issue | Description | -|---------|-------|-------------| -| **Context Memory Store** | [#29](https://github.com/Siddhant-K-code/distill/issues/29) | Persistent, deduplicated memory across sessions. Write-time dedup, hierarchical decay (full text -> summary -> keywords -> evicted), token-budgeted recall. | -| **Session Management** | [#31](https://github.com/Siddhant-K-code/distill/issues/31) | Stateful context windows for long-running agents. Push context incrementally, Distill keeps it deduplicated and within budget. | +| Feature | Issue | Status | Description | +|---------|-------|--------|-------------| +| **Context Memory Store** | [#29](https://github.com/Siddhant-K-code/distill/issues/29) | Shipped | Persistent, deduplicated memory across sessions. Write-time dedup, hierarchical decay, token-budgeted recall. See [Context Memory](#context-memory). | +| **Session Management** | [#31](https://github.com/Siddhant-K-code/distill/issues/31) | Planned | Stateful context windows for long-running agents. Push context incrementally, Distill keeps it deduplicated and within budget. | ### Code Intelligence