diff --git a/agent/agent.go b/agent/agent.go index 6bf3a87..fad6a19 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -327,6 +327,69 @@ func mergeConfig(rc *resolvedConfig, cc types.ConfigContent) { } } +// persistCompacted forks a new branch off the tree root and adds the compacted +// messages (skipping the first, which is the system message already on root). +// Returns the new branch ID. +func (a *Agent) persistCompacted(ctx context.Context, tr *tree.Tree, compacted []types.Message) (types.BranchID, error) { + root := tr.Root() + if len(compacted) < 2 { + return "", fmt.Errorf("compacted history too short to branch") + } + + // First compacted message is the system prompt (same as root) — skip it. + // Branch from root with the second message (the summary). + branchID, _, err := tr.Branch(ctx, root.ID, "compact", compacted[1]) + if err != nil { + return "", fmt.Errorf("branch from root: %w", err) + } + + // Add remaining compacted messages (the preserved recent context). + for _, msg := range compacted[2:] { + tip, err := tr.Tip(branchID) + if err != nil { + return "", fmt.Errorf("tip lookup: %w", err) + } + if _, err := tr.AddChild(ctx, tip.ID, msg); err != nil { + return "", fmt.Errorf("add compacted child: %w", err) + } + } + + // Set the compacted branch as active so future Invoke calls use it. + if err := tr.SetActive(branchID); err != nil { + return "", fmt.Errorf("set active: %w", err) + } + + return branchID, nil +} + +// tryCompact attempts compaction if configured. Returns the new branch ID and +// true if compaction succeeded and the caller should re-flatten from the new branch. +func (a *Agent) tryCompact(ctx context.Context, log *slog.Logger, resolved resolvedConfig, llmMessages []types.Message, tr *tree.Tree) (types.BranchID, bool) { + if resolved.compactor == nil { + if resolved.compactNow { + log.Warn("compactNow requested but no compactor configured, ignoring") + } + return "", false + } + + compacted, err := resolved.compactor.Compact(ctx, llmMessages, a.cfg.Provider) + if err != nil { + log.Warn("compaction failed, continuing with full history", "error", err) + return "", false + } + if len(compacted) >= len(llmMessages) { + return "", false + } + + newBranch, err := a.persistCompacted(ctx, tr, compacted) + if err != nil { + log.Warn("failed to persist compacted branch", "error", err) + return "", false + } + log.Debug("compacted to new branch", "branch", newBranch) + return newBranch, true +} + // callProvider invokes the LLM, using structured output when available. func (a *Agent) callProvider(ctx context.Context, messages []types.Message, tools []types.ToolDef) (<-chan types.Delta, error) { if a.cfg.ResponseSchema != nil && len(tools) == 0 { @@ -479,14 +542,10 @@ func (a *Agent) runLoop(ctx context.Context, stream *EventStream, input []types. // Resolve file URIs to data. llmMessages = a.resolveFiles(ctx, llmMessages) - // Compact if configured. - if resolved.compactNow || resolved.compactor != nil { - if resolved.compactor != nil { - compacted, err := resolved.compactor.Compact(ctx, llmMessages, a.cfg.Provider) - if err == nil { - llmMessages = compacted - } - } + // Compact if configured: summarize, fork a new branch off root, continue. + if newBranch, compacted := a.tryCompact(ctx, log, resolved, llmMessages, tr); compacted { + branch = newBranch + continue // re-flatten from the new branch } // Call LLM + timing. diff --git a/agent/integration_test.go b/agent/integration_test.go index 5581be3..53f71e5 100644 --- a/agent/integration_test.go +++ b/agent/integration_test.go @@ -981,7 +981,7 @@ func TestSlidingWindowCompactorBelowWindow(t *testing.T) { } func TestSummarizeCompactorBelowThreshold(t *testing.T) { - compactor := types.NewSummarizeCompactor(10) + compactor := types.NewSummarizeCompactor(10, 0) msgs := []types.Message{ types.NewSystemMessage("sys"), types.NewUserMessage("one"), @@ -998,7 +998,7 @@ func TestSummarizeCompactorBelowThreshold(t *testing.T) { func TestSummarizeCompactorAboveThreshold(t *testing.T) { provider := &mockProvider{response: "conversation summary"} - compactor := types.NewSummarizeCompactor(3) + compactor := types.NewSummarizeCompactor(3, 0) msgs := []types.Message{ types.NewSystemMessage("sys"), @@ -1039,7 +1039,7 @@ func TestSummarizeCompactorAboveThreshold(t *testing.T) { func TestSummarizeCompactorProviderError(t *testing.T) { provider := &errorProvider{err: errors.New("provider down")} - compactor := types.NewSummarizeCompactor(2) + compactor := types.NewSummarizeCompactor(2, 0) msgs := []types.Message{ types.NewSystemMessage("sys"), @@ -2799,18 +2799,32 @@ func TestAgentWithSummarizeCompactor(t *testing.T) { Tree: tr, }) - // Multiple turns + // Multiple turns — each Invoke uses tr.Active(), which may change after compaction. for i := 0; i < 4; i++ { stream := agent.Invoke(context.Background(), []types.Message{types.NewUserMessage(fmt.Sprintf("turn-%d", i))}) collectDeltas(stream) stream.Wait() } - // Tree should have all messages - msgs, _ := tr.FlattenBranch("main") - // sys + 4*(user + asst) = 9 - if len(msgs) != 9 { - t.Errorf("tree messages = %d, want 9", len(msgs)) + // After compaction, the active branch may have been forked from root. + // The original main branch is preserved; the active branch has compacted history. + activeBranch := tr.Active() + msgs, err := tr.FlattenBranch(activeBranch) + if err != nil { + t.Fatalf("FlattenBranch(%s): %v", activeBranch, err) + } + // Active branch should have messages (at least sys + summary + recent context). + if len(msgs) < 3 { + t.Errorf("active branch messages = %d, want >= 3", len(msgs)) + } + + // Original main branch should still exist and be untouched. + mainMsgs, err := tr.FlattenBranch("main") + if err != nil { + t.Fatalf("FlattenBranch(main): %v", err) + } + if len(mainMsgs) == 0 { + t.Error("main branch should not be empty") } } @@ -2990,7 +3004,7 @@ func TestSlidingWindowPreservesSystem(t *testing.T) { // =================================================================== func TestSummarizeCompactorFewMessages(t *testing.T) { - compactor := types.NewSummarizeCompactor(2) + compactor := types.NewSummarizeCompactor(2, 0) provider := &mockProvider{response: "summary"} msgs := []types.Message{ @@ -3331,3 +3345,307 @@ func TestFeedbackReplay(t *testing.T) { t.Fatal("expected FeedbackDelta during replay") } } + +// =================================================================== +// persistCompacted tests +// =================================================================== + +func TestPersistCompactedCreatesBranch(t *testing.T) { + tr, _ := tree.New(types.NewSystemMessage("sys")) + agent := NewAgent(AgentConfig{ + Provider: &mockProvider{response: "ok"}, + SystemPrompt: "sys", + Tree: tr, + }) + + compacted := []types.Message{ + types.NewSystemMessage("sys"), + types.NewUserMessage("Previous conversation summary: stuff happened"), + types.NewUserMessage("recent-1"), + types.AssistantMessage{Content: []types.AssistantContent{types.TextContent{Text: "recent-2"}}}, + } + + branchID, err := agent.persistCompacted(context.Background(), tr, compacted) + if err != nil { + t.Fatalf("persistCompacted: %v", err) + } + + // The new branch should be active. + if tr.Active() != branchID { + t.Errorf("active = %s, want %s", tr.Active(), branchID) + } + + // Flatten should have sys + 3 compacted messages. + msgs, err := tr.FlattenBranch(branchID) + if err != nil { + t.Fatalf("FlattenBranch: %v", err) + } + if len(msgs) != 4 { + t.Errorf("messages = %d, want 4", len(msgs)) + } +} + +func TestPersistCompactedTooShortReturnsError(t *testing.T) { + tr, _ := tree.New(types.NewSystemMessage("sys")) + agent := NewAgent(AgentConfig{ + Provider: &mockProvider{response: "ok"}, + SystemPrompt: "sys", + Tree: tr, + }) + + // Only one message — too short to branch. + _, err := agent.persistCompacted(context.Background(), tr, []types.Message{ + types.NewSystemMessage("sys"), + }) + if err == nil { + t.Fatal("expected error for short compacted history") + } +} + +// =================================================================== +// tryCompact tests +// =================================================================== + +func TestTryCompactNoCompactorNoOp(t *testing.T) { + agent := NewAgent(AgentConfig{ + Provider: &mockProvider{response: "ok"}, + SystemPrompt: "sys", + }) + + rc := resolvedConfig{maxIter: 10} + msgs := []types.Message{types.NewSystemMessage("sys"), types.NewUserMessage("hi")} + + _, compacted := agent.tryCompact(context.Background(), agent.cfg.Logger, rc, msgs, agent.cfg.Tree) + if compacted { + t.Error("should not compact without compactor") + } +} + +func TestTryCompactBelowThresholdNoOp(t *testing.T) { + agent := NewAgent(AgentConfig{ + Provider: &mockProvider{response: "ok"}, + SystemPrompt: "sys", + }) + + // Compactor with high threshold — won't trigger. + rc := resolvedConfig{ + maxIter: 10, + compactor: types.NewSummarizeCompactor(100, 4), + } + msgs := []types.Message{types.NewSystemMessage("sys"), types.NewUserMessage("hi")} + + _, compacted := agent.tryCompact(context.Background(), agent.cfg.Logger, rc, msgs, agent.cfg.Tree) + if compacted { + t.Error("should not compact below threshold") + } +} + +func TestTryCompactSuccessReturnsTrueAndNewBranch(t *testing.T) { + provider := &mockProvider{response: "summary of conversation"} + tr, _ := tree.New(types.NewSystemMessage("sys")) + + // Build up enough messages on main branch to trigger summarize. + current := tr.Root() + for i := 0; i < 6; i++ { + var msg types.Message + if i%2 == 0 { + msg = types.NewUserMessage(fmt.Sprintf("user-%d", i)) + } else { + msg = types.AssistantMessage{Content: []types.AssistantContent{types.TextContent{Text: fmt.Sprintf("asst-%d", i)}}} + } + node, _ := tr.AddChild(context.Background(), current.ID, msg) + current = node + } + + agent := NewAgent(AgentConfig{ + Provider: provider, + SystemPrompt: "sys", + Tree: tr, + }) + + msgs, _ := tr.FlattenBranch("main") + rc := resolvedConfig{ + maxIter: 10, + compactor: types.NewSummarizeCompactor(3, 2), // threshold=3, keepLast=2 + } + + newBranch, compacted := agent.tryCompact(context.Background(), agent.cfg.Logger, rc, msgs, tr) + if !compacted { + t.Fatal("expected compaction to succeed") + } + if newBranch == "" { + t.Fatal("expected non-empty branch ID") + } + if tr.Active() != newBranch { + t.Errorf("active = %s, want %s", tr.Active(), newBranch) + } +} + +// =================================================================== +// SlidingWindowCompactor tool-pair integrity +// =================================================================== + +func TestSlidingWindowPreservesToolPair(t *testing.T) { + compactor := types.NewSlidingWindowCompactor(2) + + // The cut point would land on a tool result — compactor should back up one. + msgs := []types.Message{ + types.NewSystemMessage("sys"), + types.NewUserMessage("old"), + types.AssistantMessage{Content: []types.AssistantContent{ + types.ToolUseContent{ID: "tc-1", Name: "search", Arguments: map[string]any{"q": "x"}}, + }}, + types.NewToolResultMessage(types.ToolResultContent{ToolCallID: "tc-1", Text: "result"}), + types.NewUserMessage("recent"), + } + + result, err := compactor.Compact(context.Background(), msgs, nil) + if err != nil { + t.Fatalf("Compact: %v", err) + } + + // Should keep system + assistant(tool-call) + tool-result + recent = 4 + // (backed up to include the tool-call with its result) + if len(result) != 4 { + t.Errorf("messages = %d, want 4 (tool pair preserved)", len(result)) + } + + // First must be system. + if result[0].Role() != types.RoleSystem { + t.Error("first message should be system") + } +} + +func TestSlidingWindowNoCutWhenToolPairAtStart(t *testing.T) { + compactor := types.NewSlidingWindowCompactor(2) + + // Tool result is right after system — backing up would make cut <= 0. + msgs := []types.Message{ + types.NewSystemMessage("sys"), + types.NewToolResultMessage(types.ToolResultContent{ToolCallID: "tc-1", Text: "result"}), + types.NewUserMessage("recent-1"), + types.NewUserMessage("recent-2"), + } + + result, err := compactor.Compact(context.Background(), msgs, nil) + if err != nil { + t.Fatalf("Compact: %v", err) + } + + // cut would be 2, messages[2] is not a tool result, so normal windowing. + if len(result) != 3 { + t.Errorf("messages = %d, want 3", len(result)) + } +} + +// =================================================================== +// SummarizeCompactor configurable KeepLast +// =================================================================== + +func TestSummarizeCompactorCustomKeepLast(t *testing.T) { + provider := &mockProvider{response: "summary"} + compactor := types.NewSummarizeCompactor(3, 2) + + msgs := []types.Message{ + types.NewSystemMessage("sys"), + types.NewUserMessage("old-1"), + types.NewUserMessage("old-2"), + types.NewUserMessage("recent-1"), + types.NewUserMessage("recent-2"), + } + + result, err := compactor.Compact(context.Background(), msgs, provider) + if err != nil { + t.Fatalf("Compact: %v", err) + } + + // sys + summary + 2 recent = 4 + if len(result) != 4 { + t.Errorf("messages = %d, want 4", len(result)) + } +} + +func TestSummarizeCompactorDefaultKeepLast(t *testing.T) { + compactor := types.NewSummarizeCompactor(3, 0) + if compactor.KeepLast != 4 { + t.Errorf("KeepLast = %d, want 4 (default)", compactor.KeepLast) + } +} + +// =================================================================== +// CompactConfig.ToCompactor with KeepLast +// =================================================================== + +func TestCompactConfigKeepLastPassthrough(t *testing.T) { + cfg := types.CompactConfig{ + Strategy: types.CompactSummarize, + Threshold: 10, + KeepLast: 6, + } + c := cfg.ToCompactor() + sc, ok := c.(*types.SummarizeCompactor) + if !ok { + t.Fatalf("expected *SummarizeCompactor, got %T", c) + } + if sc.KeepLast != 6 { + t.Errorf("KeepLast = %d, want 6", sc.KeepLast) + } +} + +// =================================================================== +// MessagesToText with ToolUseContent +// =================================================================== + +func TestMessagesToTextToolUseContent(t *testing.T) { + msgs := []types.Message{ + types.AssistantMessage{Content: []types.AssistantContent{ + types.ToolUseContent{ID: "tc-42", Name: "web_search", Arguments: map[string]any{"q": "test"}}, + }}, + } + + text := types.MessagesToText(msgs) + if !strings.Contains(text, "Tool Call [tc-42]: web_search") { + t.Errorf("expected tool call text, got: %s", text) + } +} + +func TestMessagesToTextMixedAssistantContent(t *testing.T) { + msgs := []types.Message{ + types.AssistantMessage{Content: []types.AssistantContent{ + types.TextContent{Text: "Let me search for that."}, + types.ToolUseContent{ID: "tc-1", Name: "search", Arguments: map[string]any{}}, + }}, + } + + text := types.MessagesToText(msgs) + if !strings.Contains(text, "Assistant: Let me search for that.") { + t.Error("missing assistant text") + } + if !strings.Contains(text, "Tool Call [tc-1]: search") { + t.Error("missing tool call") + } +} + +// =================================================================== +// hasToolResult helper (tested via SlidingWindowCompactor) +// =================================================================== + +func TestHasToolResultInSystemMessage(t *testing.T) { + toolResult := types.NewToolResultMessage(types.ToolResultContent{ToolCallID: "tc-1", Text: "result"}) + compactor := types.NewSlidingWindowCompactor(2) + + // cut = len(5) - 2 = 3 → messages[3] is the tool result → should back up. + msgs := []types.Message{ + types.NewSystemMessage("sys"), + types.NewUserMessage("old-1"), + types.NewUserMessage("old-2"), + toolResult, + types.NewUserMessage("recent"), + } + + result, _ := compactor.Compact(context.Background(), msgs, nil) + // Backed up: system + old-2 + tool-result + recent = 4 (instead of 3). + if len(result) != 4 { + t.Errorf("messages = %d, want 4 (tool result preserved with preceding message)", len(result)) + } +} diff --git a/agent/tree/tree.go b/agent/tree/tree.go index 76aa95c..3d8bedf 100644 --- a/agent/tree/tree.go +++ b/agent/tree/tree.go @@ -17,11 +17,6 @@ func WithWAL(wal types.WAL) Option { return func(t *Tree) { t.wal = wal } } -// WithStore sets the persistence store for the tree. -func WithStore(store types.Store) Option { - return func(t *Tree) { t.store = store } -} - // Tree is a branching conversation graph rooted at a system message. type Tree struct { mu sync.RWMutex @@ -32,7 +27,6 @@ type Tree struct { active types.BranchID // the branch Invoke reads from checkpoints map[types.CheckpointID]types.Checkpoint wal types.WAL - store types.Store } // New creates a new conversation tree rooted at the given system message. diff --git a/agent/types/compactor.go b/agent/types/compactor.go index c1e297d..5ba8d18 100644 --- a/agent/types/compactor.go +++ b/agent/types/compactor.go @@ -26,6 +26,7 @@ type CompactConfig struct { Strategy CompactStrategy WindowSize int // for sliding_window Threshold int // for summarize + KeepLast int // recent messages to preserve during summarize (default 4) } // ToCompactor converts the config into a Compactor implementation. @@ -34,7 +35,7 @@ func (cc CompactConfig) ToCompactor() Compactor { case CompactSlidingWindow: return NewSlidingWindowCompactor(cc.WindowSize) case CompactSummarize: - return NewSummarizeCompactor(cc.Threshold) + return NewSummarizeCompactor(cc.Threshold, cc.KeepLast) default: return NoopCompactor{} } @@ -60,20 +61,50 @@ func (c *SlidingWindowCompactor) Compact(_ context.Context, messages []Message, if len(messages) <= c.WindowSize+1 { return messages, nil } - // Keep first (system) + last N - result := make([]Message, 0, c.WindowSize+1) + // Keep first (system) + last N, but don't split a tool-result from its tool-call. + cut := len(messages) - c.WindowSize + if cut > 0 && cut < len(messages) && hasToolResult(messages[cut]) { + cut-- // include the preceding assistant message with the tool call + } + if cut <= 0 { + return messages, nil + } + result := make([]Message, 0, len(messages)-cut+1) result = append(result, messages[0]) - result = append(result, messages[len(messages)-c.WindowSize:]...) + result = append(result, messages[cut:]...) return result, nil } +// hasToolResult reports whether a message contains a ToolResultContent block. +func hasToolResult(msg Message) bool { + switch v := msg.(type) { + case SystemMessage: + for _, c := range v.Content { + if _, ok := c.(ToolResultContent); ok { + return true + } + } + case UserMessage: + for _, c := range v.Content { + if _, ok := c.(ToolResultContent); ok { + return true + } + } + } + return false +} + // SummarizeCompactor summarizes older messages when history exceeds a threshold. type SummarizeCompactor struct { Threshold int + KeepLast int } -func NewSummarizeCompactor(threshold int) *SummarizeCompactor { - return &SummarizeCompactor{Threshold: threshold} +func NewSummarizeCompactor(threshold, keepLast int) *SummarizeCompactor { + if keepLast <= 0 { + keepLast = 4 + } + return &SummarizeCompactor{Threshold: threshold, KeepLast: keepLast} } func (c *SummarizeCompactor) Compact(ctx context.Context, messages []Message, provider Provider) ([]Message, error) { @@ -81,11 +112,7 @@ func (c *SummarizeCompactor) Compact(ctx context.Context, messages []Message, pr return messages, nil } - // Summarize all but last 4 messages using the provider - keepLast := 4 - if keepLast > len(messages)-1 { - keepLast = len(messages) - 1 - } + keepLast := min(c.KeepLast, len(messages)-1) toSummarize := messages[1 : len(messages)-keepLast] if len(toSummarize) == 0 { @@ -161,9 +188,16 @@ func MessagesToText(msgs []Message) string { } case AssistantMessage: for _, c := range v.Content { - if tc, ok := c.(TextContent); ok { + switch bc := c.(type) { + case TextContent: b.WriteString("Assistant: ") - b.WriteString(tc.Text) + b.WriteString(bc.Text) + b.WriteByte('\n') + case ToolUseContent: + b.WriteString("Tool Call [") + b.WriteString(bc.ID) + b.WriteString("]: ") + b.WriteString(bc.Name) b.WriteByte('\n') } } diff --git a/go.mod b/go.mod index d8af5bd..1f44ab8 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/pkoukk/tiktoken-go v0.1.8 golang.org/x/net v0.52.0 golang.org/x/sync v0.20.0 - google.golang.org/genai v1.50.0 + google.golang.org/genai v1.51.0 ) require ( diff --git a/go.sum b/go.sum index 59a8341..5bac9ab 100644 --- a/go.sum +++ b/go.sum @@ -247,8 +247,8 @@ golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genai v1.50.0 h1:yHKV/vjoeN9PJ3iF0ur4cBZco4N3Kl7j09rMq7XSoWk= -google.golang.org/genai v1.50.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= +google.golang.org/genai v1.51.0 h1:IZGuUqgfx40INv3hLFGCbOSGp0qFqm7LVmDghzNIYqg= +google.golang.org/genai v1.51.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= diff --git a/rag/contextassembler/compressing.go b/rag/contextassembler/compressing.go index 6fd4c17..a53f4dc 100644 --- a/rag/contextassembler/compressing.go +++ b/rag/contextassembler/compressing.go @@ -4,11 +4,12 @@ package contextassembler import ( "context" "fmt" + "strings" "golang.org/x/sync/errgroup" - "github.com/urmzd/saige/rag/types" "github.com/urmzd/saige/rag/tokenizer" + "github.com/urmzd/saige/rag/types" ) // CompressingAssembler uses an LLM to extract query-relevant sentences from each hit @@ -79,7 +80,7 @@ func (a *CompressingAssembler) Assemble(ctx context.Context, query string, hits parts = append(parts, fmt.Sprintf("%s %s (Source: %s)", citation, compressed, source)) } - promptText := fmt.Sprintf("Context for query %q:\n\n%s", query, joinStrings(parts, "\n\n")) + promptText := fmt.Sprintf("Context for query %q:\n\n%s", query, strings.Join(parts, "\n\n")) return &types.AssembledContext{ Prompt: promptText, @@ -87,14 +88,3 @@ func (a *CompressingAssembler) Assemble(ctx context.Context, query string, hits TokenCount: tokenCount, }, nil } - -func joinStrings(parts []string, sep string) string { - if len(parts) == 0 { - return "" - } - result := parts[0] - for _, p := range parts[1:] { - result += sep + p - } - return result -}