From 36a979959a81f2467fe8fcb6e2040d1368662646 Mon Sep 17 00:00:00 2001 From: Urmzd Mukhammadnaim Date: Sun, 29 Mar 2026 18:03:16 -0500 Subject: [PATCH 1/2] fix(agent): redesign compaction to use tree branching and fix related bugs Compaction now forks a new branch off root instead of only modifying the in-memory message slice. This preserves the original branch, prevents the SummarizeCompactor from re-firing every iteration, and sets the compacted branch as active for subsequent Invoke calls. Additional fixes: configurable KeepLast, ToolUseContent serialization, SlidingWindow tool-pair integrity, compactNow warning, remove unused tree.Store, replace joinStrings with strings.Join. --- agent/agent.go | 53 +++++++++++++++++++++++-- agent/integration_test.go | 34 +++++++++++----- agent/tree/tree.go | 6 --- agent/types/compactor.go | 60 ++++++++++++++++++++++------- go.mod | 2 +- go.sum | 4 +- rag/contextassembler/compressing.go | 16 ++------ 7 files changed, 127 insertions(+), 48 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 6bf3a87..fbb1ad4 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -327,6 +327,41 @@ func mergeConfig(rc *resolvedConfig, cc types.ConfigContent) { } } +// persistCompacted forks a new branch off the tree root and adds the compacted +// messages (skipping the first, which is the system message already on root). +// Returns the new branch ID. +func (a *Agent) persistCompacted(ctx context.Context, tr *tree.Tree, compacted []types.Message) (types.BranchID, error) { + root := tr.Root() + if len(compacted) < 2 { + return "", fmt.Errorf("compacted history too short to branch") + } + + // First compacted message is the system prompt (same as root) — skip it. + // Branch from root with the second message (the summary). + branchID, _, err := tr.Branch(ctx, root.ID, "compact", compacted[1]) + if err != nil { + return "", fmt.Errorf("branch from root: %w", err) + } + + // Add remaining compacted messages (the preserved recent context). + for _, msg := range compacted[2:] { + tip, err := tr.Tip(branchID) + if err != nil { + return "", fmt.Errorf("tip lookup: %w", err) + } + if _, err := tr.AddChild(ctx, tip.ID, msg); err != nil { + return "", fmt.Errorf("add compacted child: %w", err) + } + } + + // Set the compacted branch as active so future Invoke calls use it. + if err := tr.SetActive(branchID); err != nil { + return "", fmt.Errorf("set active: %w", err) + } + + return branchID, nil +} + // callProvider invokes the LLM, using structured output when available. func (a *Agent) callProvider(ctx context.Context, messages []types.Message, tools []types.ToolDef) (<-chan types.Delta, error) { if a.cfg.ResponseSchema != nil && len(tools) == 0 { @@ -479,13 +514,25 @@ func (a *Agent) runLoop(ctx context.Context, stream *EventStream, input []types. // Resolve file URIs to data. llmMessages = a.resolveFiles(ctx, llmMessages) - // Compact if configured. + // Compact if configured: summarize, fork a new branch off root, continue. if resolved.compactNow || resolved.compactor != nil { if resolved.compactor != nil { compacted, err := resolved.compactor.Compact(ctx, llmMessages, a.cfg.Provider) - if err == nil { - llmMessages = compacted + if err != nil { + log.Warn("compaction failed, continuing with full history", "error", err) + } else if len(compacted) < len(llmMessages) { + // Compaction produced a shorter history — persist it on a new branch. + newBranch, err := a.persistCompacted(ctx, tr, compacted) + if err != nil { + log.Warn("failed to persist compacted branch", "error", err) + } else { + branch = newBranch + log.Debug("compacted to new branch", "branch", branch) + continue // re-flatten from the new branch + } } + } else if resolved.compactNow { + log.Warn("compactNow requested but no compactor configured, ignoring") } } diff --git a/agent/integration_test.go b/agent/integration_test.go index 5581be3..dd0bcbe 100644 --- a/agent/integration_test.go +++ b/agent/integration_test.go @@ -981,7 +981,7 @@ func TestSlidingWindowCompactorBelowWindow(t *testing.T) { } func TestSummarizeCompactorBelowThreshold(t *testing.T) { - compactor := types.NewSummarizeCompactor(10) + compactor := types.NewSummarizeCompactor(10, 0) msgs := []types.Message{ types.NewSystemMessage("sys"), types.NewUserMessage("one"), @@ -998,7 +998,7 @@ func TestSummarizeCompactorBelowThreshold(t *testing.T) { func TestSummarizeCompactorAboveThreshold(t *testing.T) { provider := &mockProvider{response: "conversation summary"} - compactor := types.NewSummarizeCompactor(3) + compactor := types.NewSummarizeCompactor(3, 0) msgs := []types.Message{ types.NewSystemMessage("sys"), @@ -1039,7 +1039,7 @@ func TestSummarizeCompactorAboveThreshold(t *testing.T) { func TestSummarizeCompactorProviderError(t *testing.T) { provider := &errorProvider{err: errors.New("provider down")} - compactor := types.NewSummarizeCompactor(2) + compactor := types.NewSummarizeCompactor(2, 0) msgs := []types.Message{ types.NewSystemMessage("sys"), @@ -2799,18 +2799,32 @@ func TestAgentWithSummarizeCompactor(t *testing.T) { Tree: tr, }) - // Multiple turns + // Multiple turns — each Invoke uses tr.Active(), which may change after compaction. for i := 0; i < 4; i++ { stream := agent.Invoke(context.Background(), []types.Message{types.NewUserMessage(fmt.Sprintf("turn-%d", i))}) collectDeltas(stream) stream.Wait() } - // Tree should have all messages - msgs, _ := tr.FlattenBranch("main") - // sys + 4*(user + asst) = 9 - if len(msgs) != 9 { - t.Errorf("tree messages = %d, want 9", len(msgs)) + // After compaction, the active branch may have been forked from root. + // The original main branch is preserved; the active branch has compacted history. + activeBranch := tr.Active() + msgs, err := tr.FlattenBranch(activeBranch) + if err != nil { + t.Fatalf("FlattenBranch(%s): %v", activeBranch, err) + } + // Active branch should have messages (at least sys + summary + recent context). + if len(msgs) < 3 { + t.Errorf("active branch messages = %d, want >= 3", len(msgs)) + } + + // Original main branch should still exist and be untouched. + mainMsgs, err := tr.FlattenBranch("main") + if err != nil { + t.Fatalf("FlattenBranch(main): %v", err) + } + if len(mainMsgs) == 0 { + t.Error("main branch should not be empty") } } @@ -2990,7 +3004,7 @@ func TestSlidingWindowPreservesSystem(t *testing.T) { // =================================================================== func TestSummarizeCompactorFewMessages(t *testing.T) { - compactor := types.NewSummarizeCompactor(2) + compactor := types.NewSummarizeCompactor(2, 0) provider := &mockProvider{response: "summary"} msgs := []types.Message{ diff --git a/agent/tree/tree.go b/agent/tree/tree.go index 76aa95c..3d8bedf 100644 --- a/agent/tree/tree.go +++ b/agent/tree/tree.go @@ -17,11 +17,6 @@ func WithWAL(wal types.WAL) Option { return func(t *Tree) { t.wal = wal } } -// WithStore sets the persistence store for the tree. -func WithStore(store types.Store) Option { - return func(t *Tree) { t.store = store } -} - // Tree is a branching conversation graph rooted at a system message. type Tree struct { mu sync.RWMutex @@ -32,7 +27,6 @@ type Tree struct { active types.BranchID // the branch Invoke reads from checkpoints map[types.CheckpointID]types.Checkpoint wal types.WAL - store types.Store } // New creates a new conversation tree rooted at the given system message. diff --git a/agent/types/compactor.go b/agent/types/compactor.go index c1e297d..5ba8d18 100644 --- a/agent/types/compactor.go +++ b/agent/types/compactor.go @@ -26,6 +26,7 @@ type CompactConfig struct { Strategy CompactStrategy WindowSize int // for sliding_window Threshold int // for summarize + KeepLast int // recent messages to preserve during summarize (default 4) } // ToCompactor converts the config into a Compactor implementation. @@ -34,7 +35,7 @@ func (cc CompactConfig) ToCompactor() Compactor { case CompactSlidingWindow: return NewSlidingWindowCompactor(cc.WindowSize) case CompactSummarize: - return NewSummarizeCompactor(cc.Threshold) + return NewSummarizeCompactor(cc.Threshold, cc.KeepLast) default: return NoopCompactor{} } @@ -60,20 +61,50 @@ func (c *SlidingWindowCompactor) Compact(_ context.Context, messages []Message, if len(messages) <= c.WindowSize+1 { return messages, nil } - // Keep first (system) + last N - result := make([]Message, 0, c.WindowSize+1) + // Keep first (system) + last N, but don't split a tool-result from its tool-call. + cut := len(messages) - c.WindowSize + if cut > 0 && cut < len(messages) && hasToolResult(messages[cut]) { + cut-- // include the preceding assistant message with the tool call + } + if cut <= 0 { + return messages, nil + } + result := make([]Message, 0, len(messages)-cut+1) result = append(result, messages[0]) - result = append(result, messages[len(messages)-c.WindowSize:]...) + result = append(result, messages[cut:]...) return result, nil } +// hasToolResult reports whether a message contains a ToolResultContent block. +func hasToolResult(msg Message) bool { + switch v := msg.(type) { + case SystemMessage: + for _, c := range v.Content { + if _, ok := c.(ToolResultContent); ok { + return true + } + } + case UserMessage: + for _, c := range v.Content { + if _, ok := c.(ToolResultContent); ok { + return true + } + } + } + return false +} + // SummarizeCompactor summarizes older messages when history exceeds a threshold. type SummarizeCompactor struct { Threshold int + KeepLast int } -func NewSummarizeCompactor(threshold int) *SummarizeCompactor { - return &SummarizeCompactor{Threshold: threshold} +func NewSummarizeCompactor(threshold, keepLast int) *SummarizeCompactor { + if keepLast <= 0 { + keepLast = 4 + } + return &SummarizeCompactor{Threshold: threshold, KeepLast: keepLast} } func (c *SummarizeCompactor) Compact(ctx context.Context, messages []Message, provider Provider) ([]Message, error) { @@ -81,11 +112,7 @@ func (c *SummarizeCompactor) Compact(ctx context.Context, messages []Message, pr return messages, nil } - // Summarize all but last 4 messages using the provider - keepLast := 4 - if keepLast > len(messages)-1 { - keepLast = len(messages) - 1 - } + keepLast := min(c.KeepLast, len(messages)-1) toSummarize := messages[1 : len(messages)-keepLast] if len(toSummarize) == 0 { @@ -161,9 +188,16 @@ func MessagesToText(msgs []Message) string { } case AssistantMessage: for _, c := range v.Content { - if tc, ok := c.(TextContent); ok { + switch bc := c.(type) { + case TextContent: b.WriteString("Assistant: ") - b.WriteString(tc.Text) + b.WriteString(bc.Text) + b.WriteByte('\n') + case ToolUseContent: + b.WriteString("Tool Call [") + b.WriteString(bc.ID) + b.WriteString("]: ") + b.WriteString(bc.Name) b.WriteByte('\n') } } diff --git a/go.mod b/go.mod index d8af5bd..1f44ab8 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/pkoukk/tiktoken-go v0.1.8 golang.org/x/net v0.52.0 golang.org/x/sync v0.20.0 - google.golang.org/genai v1.50.0 + google.golang.org/genai v1.51.0 ) require ( diff --git a/go.sum b/go.sum index 59a8341..5bac9ab 100644 --- a/go.sum +++ b/go.sum @@ -247,8 +247,8 @@ golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genai v1.50.0 h1:yHKV/vjoeN9PJ3iF0ur4cBZco4N3Kl7j09rMq7XSoWk= -google.golang.org/genai v1.50.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= +google.golang.org/genai v1.51.0 h1:IZGuUqgfx40INv3hLFGCbOSGp0qFqm7LVmDghzNIYqg= +google.golang.org/genai v1.51.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= diff --git a/rag/contextassembler/compressing.go b/rag/contextassembler/compressing.go index 6fd4c17..a53f4dc 100644 --- a/rag/contextassembler/compressing.go +++ b/rag/contextassembler/compressing.go @@ -4,11 +4,12 @@ package contextassembler import ( "context" "fmt" + "strings" "golang.org/x/sync/errgroup" - "github.com/urmzd/saige/rag/types" "github.com/urmzd/saige/rag/tokenizer" + "github.com/urmzd/saige/rag/types" ) // CompressingAssembler uses an LLM to extract query-relevant sentences from each hit @@ -79,7 +80,7 @@ func (a *CompressingAssembler) Assemble(ctx context.Context, query string, hits parts = append(parts, fmt.Sprintf("%s %s (Source: %s)", citation, compressed, source)) } - promptText := fmt.Sprintf("Context for query %q:\n\n%s", query, joinStrings(parts, "\n\n")) + promptText := fmt.Sprintf("Context for query %q:\n\n%s", query, strings.Join(parts, "\n\n")) return &types.AssembledContext{ Prompt: promptText, @@ -87,14 +88,3 @@ func (a *CompressingAssembler) Assemble(ctx context.Context, query string, hits TokenCount: tokenCount, }, nil } - -func joinStrings(parts []string, sep string) string { - if len(parts) == 0 { - return "" - } - result := parts[0] - for _, p := range parts[1:] { - result += sep + p - } - return result -} From 2523714cae68481183e648cacd65c44c1c9ee377 Mon Sep 17 00:00:00 2001 From: Urmzd Mukhammadnaim Date: Sun, 29 Mar 2026 18:51:01 -0500 Subject: [PATCH 2/2] fix(agent): reduce runLoop complexity and add compaction tests Extract compaction logic from runLoop into tryCompact helper to bring cyclomatic complexity from 31 to under 30 (fixes CI lint failure). Add 13 tests covering persistCompacted, tryCompact, SlidingWindow tool-pair integrity, configurable KeepLast, MessagesToText with ToolUseContent, and CompactConfig round-trip. --- agent/agent.go | 50 ++++--- agent/integration_test.go | 304 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 335 insertions(+), 19 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index fbb1ad4..fad6a19 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -362,6 +362,34 @@ func (a *Agent) persistCompacted(ctx context.Context, tr *tree.Tree, compacted [ return branchID, nil } +// tryCompact attempts compaction if configured. Returns the new branch ID and +// true if compaction succeeded and the caller should re-flatten from the new branch. +func (a *Agent) tryCompact(ctx context.Context, log *slog.Logger, resolved resolvedConfig, llmMessages []types.Message, tr *tree.Tree) (types.BranchID, bool) { + if resolved.compactor == nil { + if resolved.compactNow { + log.Warn("compactNow requested but no compactor configured, ignoring") + } + return "", false + } + + compacted, err := resolved.compactor.Compact(ctx, llmMessages, a.cfg.Provider) + if err != nil { + log.Warn("compaction failed, continuing with full history", "error", err) + return "", false + } + if len(compacted) >= len(llmMessages) { + return "", false + } + + newBranch, err := a.persistCompacted(ctx, tr, compacted) + if err != nil { + log.Warn("failed to persist compacted branch", "error", err) + return "", false + } + log.Debug("compacted to new branch", "branch", newBranch) + return newBranch, true +} + // callProvider invokes the LLM, using structured output when available. func (a *Agent) callProvider(ctx context.Context, messages []types.Message, tools []types.ToolDef) (<-chan types.Delta, error) { if a.cfg.ResponseSchema != nil && len(tools) == 0 { @@ -515,25 +543,9 @@ func (a *Agent) runLoop(ctx context.Context, stream *EventStream, input []types. llmMessages = a.resolveFiles(ctx, llmMessages) // Compact if configured: summarize, fork a new branch off root, continue. - if resolved.compactNow || resolved.compactor != nil { - if resolved.compactor != nil { - compacted, err := resolved.compactor.Compact(ctx, llmMessages, a.cfg.Provider) - if err != nil { - log.Warn("compaction failed, continuing with full history", "error", err) - } else if len(compacted) < len(llmMessages) { - // Compaction produced a shorter history — persist it on a new branch. - newBranch, err := a.persistCompacted(ctx, tr, compacted) - if err != nil { - log.Warn("failed to persist compacted branch", "error", err) - } else { - branch = newBranch - log.Debug("compacted to new branch", "branch", branch) - continue // re-flatten from the new branch - } - } - } else if resolved.compactNow { - log.Warn("compactNow requested but no compactor configured, ignoring") - } + if newBranch, compacted := a.tryCompact(ctx, log, resolved, llmMessages, tr); compacted { + branch = newBranch + continue // re-flatten from the new branch } // Call LLM + timing. diff --git a/agent/integration_test.go b/agent/integration_test.go index dd0bcbe..53f71e5 100644 --- a/agent/integration_test.go +++ b/agent/integration_test.go @@ -3345,3 +3345,307 @@ func TestFeedbackReplay(t *testing.T) { t.Fatal("expected FeedbackDelta during replay") } } + +// =================================================================== +// persistCompacted tests +// =================================================================== + +func TestPersistCompactedCreatesBranch(t *testing.T) { + tr, _ := tree.New(types.NewSystemMessage("sys")) + agent := NewAgent(AgentConfig{ + Provider: &mockProvider{response: "ok"}, + SystemPrompt: "sys", + Tree: tr, + }) + + compacted := []types.Message{ + types.NewSystemMessage("sys"), + types.NewUserMessage("Previous conversation summary: stuff happened"), + types.NewUserMessage("recent-1"), + types.AssistantMessage{Content: []types.AssistantContent{types.TextContent{Text: "recent-2"}}}, + } + + branchID, err := agent.persistCompacted(context.Background(), tr, compacted) + if err != nil { + t.Fatalf("persistCompacted: %v", err) + } + + // The new branch should be active. + if tr.Active() != branchID { + t.Errorf("active = %s, want %s", tr.Active(), branchID) + } + + // Flatten should have sys + 3 compacted messages. + msgs, err := tr.FlattenBranch(branchID) + if err != nil { + t.Fatalf("FlattenBranch: %v", err) + } + if len(msgs) != 4 { + t.Errorf("messages = %d, want 4", len(msgs)) + } +} + +func TestPersistCompactedTooShortReturnsError(t *testing.T) { + tr, _ := tree.New(types.NewSystemMessage("sys")) + agent := NewAgent(AgentConfig{ + Provider: &mockProvider{response: "ok"}, + SystemPrompt: "sys", + Tree: tr, + }) + + // Only one message — too short to branch. + _, err := agent.persistCompacted(context.Background(), tr, []types.Message{ + types.NewSystemMessage("sys"), + }) + if err == nil { + t.Fatal("expected error for short compacted history") + } +} + +// =================================================================== +// tryCompact tests +// =================================================================== + +func TestTryCompactNoCompactorNoOp(t *testing.T) { + agent := NewAgent(AgentConfig{ + Provider: &mockProvider{response: "ok"}, + SystemPrompt: "sys", + }) + + rc := resolvedConfig{maxIter: 10} + msgs := []types.Message{types.NewSystemMessage("sys"), types.NewUserMessage("hi")} + + _, compacted := agent.tryCompact(context.Background(), agent.cfg.Logger, rc, msgs, agent.cfg.Tree) + if compacted { + t.Error("should not compact without compactor") + } +} + +func TestTryCompactBelowThresholdNoOp(t *testing.T) { + agent := NewAgent(AgentConfig{ + Provider: &mockProvider{response: "ok"}, + SystemPrompt: "sys", + }) + + // Compactor with high threshold — won't trigger. + rc := resolvedConfig{ + maxIter: 10, + compactor: types.NewSummarizeCompactor(100, 4), + } + msgs := []types.Message{types.NewSystemMessage("sys"), types.NewUserMessage("hi")} + + _, compacted := agent.tryCompact(context.Background(), agent.cfg.Logger, rc, msgs, agent.cfg.Tree) + if compacted { + t.Error("should not compact below threshold") + } +} + +func TestTryCompactSuccessReturnsTrueAndNewBranch(t *testing.T) { + provider := &mockProvider{response: "summary of conversation"} + tr, _ := tree.New(types.NewSystemMessage("sys")) + + // Build up enough messages on main branch to trigger summarize. + current := tr.Root() + for i := 0; i < 6; i++ { + var msg types.Message + if i%2 == 0 { + msg = types.NewUserMessage(fmt.Sprintf("user-%d", i)) + } else { + msg = types.AssistantMessage{Content: []types.AssistantContent{types.TextContent{Text: fmt.Sprintf("asst-%d", i)}}} + } + node, _ := tr.AddChild(context.Background(), current.ID, msg) + current = node + } + + agent := NewAgent(AgentConfig{ + Provider: provider, + SystemPrompt: "sys", + Tree: tr, + }) + + msgs, _ := tr.FlattenBranch("main") + rc := resolvedConfig{ + maxIter: 10, + compactor: types.NewSummarizeCompactor(3, 2), // threshold=3, keepLast=2 + } + + newBranch, compacted := agent.tryCompact(context.Background(), agent.cfg.Logger, rc, msgs, tr) + if !compacted { + t.Fatal("expected compaction to succeed") + } + if newBranch == "" { + t.Fatal("expected non-empty branch ID") + } + if tr.Active() != newBranch { + t.Errorf("active = %s, want %s", tr.Active(), newBranch) + } +} + +// =================================================================== +// SlidingWindowCompactor tool-pair integrity +// =================================================================== + +func TestSlidingWindowPreservesToolPair(t *testing.T) { + compactor := types.NewSlidingWindowCompactor(2) + + // The cut point would land on a tool result — compactor should back up one. + msgs := []types.Message{ + types.NewSystemMessage("sys"), + types.NewUserMessage("old"), + types.AssistantMessage{Content: []types.AssistantContent{ + types.ToolUseContent{ID: "tc-1", Name: "search", Arguments: map[string]any{"q": "x"}}, + }}, + types.NewToolResultMessage(types.ToolResultContent{ToolCallID: "tc-1", Text: "result"}), + types.NewUserMessage("recent"), + } + + result, err := compactor.Compact(context.Background(), msgs, nil) + if err != nil { + t.Fatalf("Compact: %v", err) + } + + // Should keep system + assistant(tool-call) + tool-result + recent = 4 + // (backed up to include the tool-call with its result) + if len(result) != 4 { + t.Errorf("messages = %d, want 4 (tool pair preserved)", len(result)) + } + + // First must be system. + if result[0].Role() != types.RoleSystem { + t.Error("first message should be system") + } +} + +func TestSlidingWindowNoCutWhenToolPairAtStart(t *testing.T) { + compactor := types.NewSlidingWindowCompactor(2) + + // Tool result is right after system — backing up would make cut <= 0. + msgs := []types.Message{ + types.NewSystemMessage("sys"), + types.NewToolResultMessage(types.ToolResultContent{ToolCallID: "tc-1", Text: "result"}), + types.NewUserMessage("recent-1"), + types.NewUserMessage("recent-2"), + } + + result, err := compactor.Compact(context.Background(), msgs, nil) + if err != nil { + t.Fatalf("Compact: %v", err) + } + + // cut would be 2, messages[2] is not a tool result, so normal windowing. + if len(result) != 3 { + t.Errorf("messages = %d, want 3", len(result)) + } +} + +// =================================================================== +// SummarizeCompactor configurable KeepLast +// =================================================================== + +func TestSummarizeCompactorCustomKeepLast(t *testing.T) { + provider := &mockProvider{response: "summary"} + compactor := types.NewSummarizeCompactor(3, 2) + + msgs := []types.Message{ + types.NewSystemMessage("sys"), + types.NewUserMessage("old-1"), + types.NewUserMessage("old-2"), + types.NewUserMessage("recent-1"), + types.NewUserMessage("recent-2"), + } + + result, err := compactor.Compact(context.Background(), msgs, provider) + if err != nil { + t.Fatalf("Compact: %v", err) + } + + // sys + summary + 2 recent = 4 + if len(result) != 4 { + t.Errorf("messages = %d, want 4", len(result)) + } +} + +func TestSummarizeCompactorDefaultKeepLast(t *testing.T) { + compactor := types.NewSummarizeCompactor(3, 0) + if compactor.KeepLast != 4 { + t.Errorf("KeepLast = %d, want 4 (default)", compactor.KeepLast) + } +} + +// =================================================================== +// CompactConfig.ToCompactor with KeepLast +// =================================================================== + +func TestCompactConfigKeepLastPassthrough(t *testing.T) { + cfg := types.CompactConfig{ + Strategy: types.CompactSummarize, + Threshold: 10, + KeepLast: 6, + } + c := cfg.ToCompactor() + sc, ok := c.(*types.SummarizeCompactor) + if !ok { + t.Fatalf("expected *SummarizeCompactor, got %T", c) + } + if sc.KeepLast != 6 { + t.Errorf("KeepLast = %d, want 6", sc.KeepLast) + } +} + +// =================================================================== +// MessagesToText with ToolUseContent +// =================================================================== + +func TestMessagesToTextToolUseContent(t *testing.T) { + msgs := []types.Message{ + types.AssistantMessage{Content: []types.AssistantContent{ + types.ToolUseContent{ID: "tc-42", Name: "web_search", Arguments: map[string]any{"q": "test"}}, + }}, + } + + text := types.MessagesToText(msgs) + if !strings.Contains(text, "Tool Call [tc-42]: web_search") { + t.Errorf("expected tool call text, got: %s", text) + } +} + +func TestMessagesToTextMixedAssistantContent(t *testing.T) { + msgs := []types.Message{ + types.AssistantMessage{Content: []types.AssistantContent{ + types.TextContent{Text: "Let me search for that."}, + types.ToolUseContent{ID: "tc-1", Name: "search", Arguments: map[string]any{}}, + }}, + } + + text := types.MessagesToText(msgs) + if !strings.Contains(text, "Assistant: Let me search for that.") { + t.Error("missing assistant text") + } + if !strings.Contains(text, "Tool Call [tc-1]: search") { + t.Error("missing tool call") + } +} + +// =================================================================== +// hasToolResult helper (tested via SlidingWindowCompactor) +// =================================================================== + +func TestHasToolResultInSystemMessage(t *testing.T) { + toolResult := types.NewToolResultMessage(types.ToolResultContent{ToolCallID: "tc-1", Text: "result"}) + compactor := types.NewSlidingWindowCompactor(2) + + // cut = len(5) - 2 = 3 → messages[3] is the tool result → should back up. + msgs := []types.Message{ + types.NewSystemMessage("sys"), + types.NewUserMessage("old-1"), + types.NewUserMessage("old-2"), + toolResult, + types.NewUserMessage("recent"), + } + + result, _ := compactor.Compact(context.Background(), msgs, nil) + // Backed up: system + old-2 + tool-result + recent = 4 (instead of 3). + if len(result) != 4 { + t.Errorf("messages = %d, want 4 (tool result preserved with preceding message)", len(result)) + } +}