From 83c68c1f0a63efcd2456d0a06782d605fd26f698 Mon Sep 17 00:00:00 2001 From: Wes Date: Sat, 21 Mar 2026 11:19:41 -0700 Subject: [PATCH 1/2] feat(acp): session rotation on context window exhaustion When an agent hits MaxTokens or MaxTurnRequests, the channel session is now invalidated so the next prompt creates a fresh one instead of reusing the bloated session. Also adds proactive turn-based rotation via SPROUT_ACP_MAX_TURNS_PER_SESSION (default 0 = disabled). When set, sessions are rotated after N successful turns before hitting the context limit. Refactored session invalidation into OwnedAgent::invalidate_session() and invalidate_all_sessions() helpers for consistency across all code paths. Changes: - pool.rs: Session rotation logic in run_prompt_task Ok(Ok()) arm, turn counters on OwnedAgent, invalidation helpers, PromptContext field - config.rs: New max_turns_per_session CLI arg and Config field - main.rs: Wire PromptContext and OwnedAgent construction sites --- crates/sprout-acp/src/config.rs | 13 ++++- crates/sprout-acp/src/main.rs | 7 +++ crates/sprout-acp/src/pool.rs | 90 ++++++++++++++++++++++++++++----- 3 files changed, 96 insertions(+), 14 deletions(-) diff --git a/crates/sprout-acp/src/config.rs b/crates/sprout-acp/src/config.rs index e088dc1d..f05aef8c 100644 --- a/crates/sprout-acp/src/config.rs +++ b/crates/sprout-acp/src/config.rs @@ -185,6 +185,12 @@ pub struct CliArgs { value_parser = clap::value_parser!(u32).range(0..=100))] pub context_message_limit: u32, + /// Maximum turns per session before proactive rotation. 0 = disabled + /// (rotate only on MaxTokens / MaxTurnRequests). + #[arg(long, env = "SPROUT_ACP_MAX_TURNS_PER_SESSION", default_value_t = 0, + value_parser = clap::value_parser!(u32))] + pub max_turns_per_session: u32, + /// Disable automatic presence (online/offline) status. #[arg(long, env = "SPROUT_ACP_NO_PRESENCE")] pub no_presence: bool, @@ -234,6 +240,8 @@ pub struct Config { pub no_mention_filter: bool, pub config_path: PathBuf, pub context_message_limit: u32, + /// Maximum turns per session before proactive rotation. 0 = disabled. + pub max_turns_per_session: u32, pub presence_enabled: bool, pub typing_enabled: bool, /// Desired LLM model ID. Applied after every `session_new_full()`. @@ -373,6 +381,7 @@ impl Config { no_mention_filter: args.no_mention_filter, config_path: args.config, context_message_limit: args.context_message_limit, + max_turns_per_session: args.max_turns_per_session, presence_enabled: !args.no_presence, typing_enabled: !args.no_typing, model: args.model, @@ -382,7 +391,7 @@ impl Config { /// Human-readable summary (no secrets). pub fn summary(&self) -> String { format!( - "relay={} pubkey={} agent_cmd={} {} mcp_cmd={} timeout={}s agents={} heartbeat={}s subscribe={:?} dedup={:?} ignore_self={} context_limit={} presence={} typing={} model={}", + "relay={} pubkey={} agent_cmd={} {} mcp_cmd={} timeout={}s agents={} heartbeat={}s subscribe={:?} dedup={:?} ignore_self={} context_limit={} max_turns_per_session={} presence={} typing={} model={}", self.relay_url, self.keys.public_key().to_hex(), self.agent_command, @@ -395,6 +404,7 @@ impl Config { self.dedup_mode, self.ignore_self, self.context_message_limit, + self.max_turns_per_session, self.presence_enabled, self.typing_enabled, self.model.as_deref().unwrap_or("(agent default)"), @@ -692,6 +702,7 @@ mod tests { no_mention_filter: false, config_path: PathBuf::from("./sprout-acp.toml"), context_message_limit: 12, + max_turns_per_session: 0, presence_enabled: true, typing_enabled: true, model: None, diff --git a/crates/sprout-acp/src/main.rs b/crates/sprout-acp/src/main.rs index a4b5afac..9ec13782 100644 --- a/crates/sprout-acp/src/main.rs +++ b/crates/sprout-acp/src/main.rs @@ -84,6 +84,8 @@ async fn main() -> Result<()> { acp, sessions: HashMap::new(), heartbeat_session: None, + turn_counts: HashMap::new(), + heartbeat_turn_count: 0, model_capabilities: None, desired_model: config.model.clone(), }); @@ -228,6 +230,7 @@ async fn main() -> Result<()> { rest_client: relay.rest_client(), channel_info: channel_info_map, context_message_limit: config.context_message_limit, + max_turns_per_session: config.max_turns_per_session, }); // ── Step 6: Heartbeat timer ─────────────────────────────────────────────── @@ -901,6 +904,8 @@ async fn recover_panicked_agent( acp, sessions: HashMap::new(), heartbeat_session: None, + turn_counts: HashMap::new(), + heartbeat_turn_count: 0, model_capabilities: None, desired_model: config.model.clone(), }); @@ -1013,6 +1018,8 @@ async fn respawn_agent_into(old_agent: OwnedAgent, config: &Config) -> Result, pub heartbeat_session: Option, + /// Per-channel turn counters for proactive session rotation. + /// Incremented on each successful prompt; reset when the session is rotated. + pub turn_counts: HashMap, + /// Turn counter for the heartbeat session. + pub heartbeat_turn_count: u32, /// Model catalog from first session/new. None until first session created. pub model_capabilities: Option, /// Desired model ID (from `Config.model`). Applied after every `session_new_full()`. pub desired_model: Option, } +impl OwnedAgent { + /// Invalidate the session (and turn counter) for a specific prompt source. + pub fn invalidate_session(&mut self, source: &PromptSource) { + match source { + PromptSource::Channel(cid) => { + self.sessions.remove(cid); + self.turn_counts.remove(cid); + } + PromptSource::Heartbeat => { + self.heartbeat_session = None; + self.heartbeat_turn_count = 0; + } + } + } + + /// Invalidate all sessions and turn counters (e.g. after agent exit). + pub fn invalidate_all_sessions(&mut self) { + self.sessions.clear(); + self.turn_counts.clear(); + self.heartbeat_session = None; + self.heartbeat_turn_count = 0; + } +} + /// Pool of agents with take-and-return ownership semantics. /// /// Agents are either idle (sitting in `agents[i]`) or checked out @@ -97,6 +126,7 @@ pub struct PromptResult { } /// Whether the prompt came from a channel event or a heartbeat. +#[derive(Debug)] pub enum PromptSource { Channel(Uuid), Heartbeat, @@ -129,6 +159,8 @@ pub struct PromptContext { pub channel_info: std::collections::HashMap, /// Max messages to include in thread/DM context. 0 = disabled. pub context_message_limit: u32, + /// Max turns per session before proactive rotation. 0 = disabled. + pub max_turns_per_session: u32, } // ── AgentPool impl ──────────────────────────────────────────────────────────── @@ -248,6 +280,7 @@ impl AgentPool { for slot in &mut self.agents { if let Some(agent) = slot.as_mut() { if agent.sessions.remove(&channel_id).is_some() { + agent.turn_counts.remove(&channel_id); count += 1; } } @@ -415,8 +448,7 @@ pub async fn run_prompt_task( (sid, true) } Err(AcpError::AgentExited) => { - agent.sessions.clear(); - agent.heartbeat_session = None; + agent.invalidate_all_sessions(); let _ = result_tx.send(PromptResult { agent, source, @@ -452,8 +484,7 @@ pub async fn run_prompt_task( (sid, true) } Err(AcpError::AgentExited) => { - agent.sessions.clear(); - agent.heartbeat_session = None; + agent.invalidate_all_sessions(); let _ = result_tx.send(PromptResult { agent, source, @@ -499,8 +530,7 @@ pub async fn run_prompt_task( ); } Ok(Err(AcpError::AgentExited)) => { - agent.sessions.clear(); - agent.heartbeat_session = None; + agent.invalidate_all_sessions(); let _ = result_tx.send(PromptResult { agent, source, @@ -514,7 +544,7 @@ pub async fn run_prompt_task( target: "pool::session", "initial_message failed for channel {cid}: {e} — invalidating session" ); - agent.sessions.remove(cid); + agent.invalidate_session(&source); let _ = result_tx.send(PromptResult { agent, source, @@ -530,11 +560,10 @@ pub async fn run_prompt_task( ); match agent.acp.cancel_with_cleanup(&session_id).await { Ok(_) => { - agent.sessions.remove(cid); + agent.invalidate_session(&source); } Err(AcpError::AgentExited) => { - agent.sessions.clear(); - agent.heartbeat_session = None; + agent.invalidate_all_sessions(); let _ = result_tx.send(PromptResult { agent, source, @@ -548,7 +577,7 @@ pub async fn run_prompt_task( target: "pool::session", "cancel_with_cleanup failed during initial_message timeout: {e}" ); - agent.sessions.remove(cid); + agent.invalidate_session(&source); } } let _ = result_tx.send(PromptResult { @@ -626,6 +655,41 @@ pub async fn run_prompt_task( match prompt_result { Ok(Ok(stop_reason)) => { log_stop_reason(&source, &stop_reason); + + // ── Session rotation on context exhaustion ──────────────── + let should_rotate = matches!( + stop_reason, + StopReason::MaxTokens | StopReason::MaxTurnRequests + ); + + // ── Proactive turn-based rotation ───────────────────────── + let should_rotate = should_rotate || { + let limit = ctx.max_turns_per_session; + if limit > 0 { + match &source { + PromptSource::Channel(cid) => { + let count = agent.turn_counts.entry(*cid).or_insert(0); + *count += 1; + *count >= limit + } + PromptSource::Heartbeat => { + agent.heartbeat_turn_count += 1; + agent.heartbeat_turn_count >= limit + } + } + } else { + false + } + }; + + if should_rotate { + tracing::info!( + target: "pool::session", + "rotating session for {source:?} after {stop_reason:?}", + ); + agent.invalidate_session(&source); + } + let _ = result_tx.send(PromptResult { agent, source, @@ -1005,10 +1069,10 @@ fn log_stop_reason(source: &PromptSource, stop_reason: &StopReason) { tracing::warn!(target: "pool::prompt", "turn cancelled for {label}"); } StopReason::MaxTokens => { - tracing::warn!(target: "pool::prompt", "turn hit max_tokens for {label}"); + tracing::warn!(target: "pool::prompt", "turn hit max_tokens for {label} — session will be rotated"); } StopReason::MaxTurnRequests => { - tracing::warn!(target: "pool::prompt", "turn hit max_turn_requests for {label}"); + tracing::warn!(target: "pool::prompt", "turn hit max_turn_requests for {label} — session will be rotated"); } StopReason::Refusal => { tracing::warn!(target: "pool::prompt", "turn refused for {label}"); From 0e55e3d912199d03fc40bfde219cc6e468a7eb56 Mon Sep 17 00:00:00 2001 From: Tyler Longwell Date: Sun, 22 Mar 2026 09:36:38 -0400 Subject: [PATCH 2/2] fix(acp): consistent turn counter cleanup + SessionState extraction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix 4 post-prompt error paths that cleared sessions without clearing turn counters, causing stale counts to leak across session lifetimes. Extract SessionState from OwnedAgent so the session/turn-counter state machine is testable without spawning a real agent subprocess. All session mutation now goes through SessionState methods — no raw field access outside the impl. Methods: - invalidate(source) — clear one channel or heartbeat - invalidate_channel(id) -> bool — clear one channel, return whether it existed - invalidate_all() — clear everything (agent exit) Also fix: handle_prompt_result's channel-removal now uses invalidate_channel() instead of raw retain(). 8 unit tests cover all SessionState methods and edge cases. --- crates/sprout-acp/src/main.rs | 26 ++-- crates/sprout-acp/src/pool.rs | 216 ++++++++++++++++++++++++++-------- 2 files changed, 172 insertions(+), 70 deletions(-) diff --git a/crates/sprout-acp/src/main.rs b/crates/sprout-acp/src/main.rs index 9ec13782..228905de 100644 --- a/crates/sprout-acp/src/main.rs +++ b/crates/sprout-acp/src/main.rs @@ -18,7 +18,9 @@ use config::{Config, DedupMode, ModelsArgs, SubscribeMode}; use filter::SubscriptionRule; use futures_util::FutureExt; use nostr::ToBech32; -use pool::{AgentPool, OwnedAgent, PromptContext, PromptOutcome, PromptResult, PromptSource}; +use pool::{ + AgentPool, OwnedAgent, PromptContext, PromptOutcome, PromptResult, PromptSource, SessionState, +}; use queue::{EventQueue, QueuedEvent}; use relay::HarnessRelay; use sprout_core::kind::{ @@ -82,10 +84,7 @@ async fn main() -> Result<()> { agents.push(OwnedAgent { index: i, acp, - sessions: HashMap::new(), - heartbeat_session: None, - turn_counts: HashMap::new(), - heartbeat_turn_count: 0, + state: SessionState::default(), model_capabilities: None, desired_model: config.model.clone(), }); @@ -811,11 +810,8 @@ async fn handle_prompt_result( // Strip sessions for channels the agent was removed from while this // agent was checked out. This covers the gap where invalidate_channel_sessions // only touches idle agents. - if !removed_channels.is_empty() { - result - .agent - .sessions - .retain(|ch, _| !removed_channels.contains(ch)); + for ch in removed_channels { + result.agent.state.invalidate_channel(ch); } let outcome_label = match &result.outcome { @@ -902,10 +898,7 @@ async fn recover_panicked_agent( pool.agents_mut()[i] = Some(OwnedAgent { index: i, acp, - sessions: HashMap::new(), - heartbeat_session: None, - turn_counts: HashMap::new(), - heartbeat_turn_count: 0, + state: SessionState::default(), model_capabilities: None, desired_model: config.model.clone(), }); @@ -1016,10 +1009,7 @@ async fn respawn_agent_into(old_agent: OwnedAgent, config: &Config) -> Result, } -/// An agent with its session state, owned by the pool or a running task. -pub struct OwnedAgent { - pub index: usize, - pub acp: AcpClient, +/// Per-channel session IDs and turn counters. +/// +/// Separated from `OwnedAgent` so the state machine is testable without +/// spawning a real agent subprocess. +#[derive(Default)] +pub struct SessionState { /// channel_id → session_id pub sessions: HashMap, pub heartbeat_session: Option, @@ -73,19 +75,14 @@ pub struct OwnedAgent { pub turn_counts: HashMap, /// Turn counter for the heartbeat session. pub heartbeat_turn_count: u32, - /// Model catalog from first session/new. None until first session created. - pub model_capabilities: Option, - /// Desired model ID (from `Config.model`). Applied after every `session_new_full()`. - pub desired_model: Option, } -impl OwnedAgent { +impl SessionState { /// Invalidate the session (and turn counter) for a specific prompt source. - pub fn invalidate_session(&mut self, source: &PromptSource) { + pub fn invalidate(&mut self, source: &PromptSource) { match source { PromptSource::Channel(cid) => { - self.sessions.remove(cid); - self.turn_counts.remove(cid); + self.invalidate_channel(cid); } PromptSource::Heartbeat => { self.heartbeat_session = None; @@ -94,8 +91,15 @@ impl OwnedAgent { } } + /// Invalidate a single channel's session and turn counter. + /// Returns `true` if the channel had an active session. + pub fn invalidate_channel(&mut self, channel_id: &Uuid) -> bool { + self.turn_counts.remove(channel_id); + self.sessions.remove(channel_id).is_some() + } + /// Invalidate all sessions and turn counters (e.g. after agent exit). - pub fn invalidate_all_sessions(&mut self) { + pub fn invalidate_all(&mut self) { self.sessions.clear(); self.turn_counts.clear(); self.heartbeat_session = None; @@ -103,6 +107,17 @@ impl OwnedAgent { } } +/// An agent with its session state, owned by the pool or a running task. +pub struct OwnedAgent { + pub index: usize, + pub acp: AcpClient, + pub state: SessionState, + /// Model catalog from first session/new. None until first session created. + pub model_capabilities: Option, + /// Desired model ID (from `Config.model`). Applied after every `session_new_full()`. + pub desired_model: Option, +} + /// Pool of agents with take-and-return ownership semantics. /// /// Agents are either idle (sitting in `agents[i]`) or checked out @@ -193,7 +208,7 @@ impl AgentPool { if let Some(cid) = channel_id { let idx = self.agents.iter().position(|slot| { slot.as_ref() - .map(|a| a.sessions.contains_key(&cid)) + .map(|a| a.state.sessions.contains_key(&cid)) .unwrap_or(false) }); if let Some(i) = idx { @@ -226,7 +241,7 @@ impl AgentPool { pub fn has_session_for(&self, channel_id: Uuid) -> bool { self.agents.iter().any(|slot| { slot.as_ref() - .map(|a| a.sessions.contains_key(&channel_id)) + .map(|a| a.state.sessions.contains_key(&channel_id)) .unwrap_or(false) }) } @@ -279,8 +294,7 @@ impl AgentPool { let mut count = 0; for slot in &mut self.agents { if let Some(agent) = slot.as_mut() { - if agent.sessions.remove(&channel_id).is_some() { - agent.turn_counts.remove(&channel_id); + if agent.state.invalidate_channel(&channel_id) { count += 1; } } @@ -434,7 +448,7 @@ pub async fn run_prompt_task( let (session_id, is_new_session) = match &source { PromptSource::Channel(cid) => { - if let Some(sid) = agent.sessions.get(cid) { + if let Some(sid) = agent.state.sessions.get(cid) { (sid.clone(), false) } else { // Create new session with model application. @@ -444,11 +458,11 @@ pub async fn run_prompt_task( target: "pool::session", "created session {sid} for channel {cid}" ); - agent.sessions.insert(*cid, sid.clone()); + agent.state.sessions.insert(*cid, sid.clone()); (sid, true) } Err(AcpError::AgentExited) => { - agent.invalidate_all_sessions(); + agent.state.invalidate_all(); let _ = result_tx.send(PromptResult { agent, source, @@ -470,7 +484,7 @@ pub async fn run_prompt_task( } } PromptSource::Heartbeat => { - if let Some(sid) = &agent.heartbeat_session { + if let Some(sid) = &agent.state.heartbeat_session { (sid.clone(), false) } else { match create_session_and_apply_model(&mut agent, &ctx).await { @@ -480,11 +494,11 @@ pub async fn run_prompt_task( "created heartbeat session {sid} for agent {}", agent.index ); - agent.heartbeat_session = Some(sid.clone()); + agent.state.heartbeat_session = Some(sid.clone()); (sid, true) } Err(AcpError::AgentExited) => { - agent.invalidate_all_sessions(); + agent.state.invalidate_all(); let _ = result_tx.send(PromptResult { agent, source, @@ -530,7 +544,7 @@ pub async fn run_prompt_task( ); } Ok(Err(AcpError::AgentExited)) => { - agent.invalidate_all_sessions(); + agent.state.invalidate_all(); let _ = result_tx.send(PromptResult { agent, source, @@ -544,7 +558,7 @@ pub async fn run_prompt_task( target: "pool::session", "initial_message failed for channel {cid}: {e} — invalidating session" ); - agent.invalidate_session(&source); + agent.state.invalidate(&source); let _ = result_tx.send(PromptResult { agent, source, @@ -560,10 +574,10 @@ pub async fn run_prompt_task( ); match agent.acp.cancel_with_cleanup(&session_id).await { Ok(_) => { - agent.invalidate_session(&source); + agent.state.invalidate(&source); } Err(AcpError::AgentExited) => { - agent.invalidate_all_sessions(); + agent.state.invalidate_all(); let _ = result_tx.send(PromptResult { agent, source, @@ -577,7 +591,7 @@ pub async fn run_prompt_task( target: "pool::session", "cancel_with_cleanup failed during initial_message timeout: {e}" ); - agent.invalidate_session(&source); + agent.state.invalidate(&source); } } let _ = result_tx.send(PromptResult { @@ -668,13 +682,13 @@ pub async fn run_prompt_task( if limit > 0 { match &source { PromptSource::Channel(cid) => { - let count = agent.turn_counts.entry(*cid).or_insert(0); + let count = agent.state.turn_counts.entry(*cid).or_insert(0); *count += 1; *count >= limit } PromptSource::Heartbeat => { - agent.heartbeat_turn_count += 1; - agent.heartbeat_turn_count >= limit + agent.state.heartbeat_turn_count += 1; + agent.state.heartbeat_turn_count >= limit } } } else { @@ -687,7 +701,7 @@ pub async fn run_prompt_task( target: "pool::session", "rotating session for {source:?} after {stop_reason:?}", ); - agent.invalidate_session(&source); + agent.state.invalidate(&source); } let _ = result_tx.send(PromptResult { @@ -699,8 +713,7 @@ pub async fn run_prompt_task( } Ok(Err(AcpError::AgentExited)) => { tracing::error!(target: "pool::prompt", "agent {} exited during prompt", agent.index); - agent.sessions.clear(); - agent.heartbeat_session = None; + agent.state.invalidate_all(); let _ = result_tx.send(PromptResult { agent, source, @@ -711,14 +724,7 @@ pub async fn run_prompt_task( Ok(Err(e)) => { tracing::error!(target: "pool::prompt", "session_prompt error: {e}"); // Invalidate only the affected session. - match &source { - PromptSource::Channel(cid) => { - agent.sessions.remove(cid); - } - PromptSource::Heartbeat => { - agent.heartbeat_session = None; - } - } + agent.state.invalidate(&source); let _ = result_tx.send(PromptResult { agent, source, @@ -749,8 +755,7 @@ pub async fn run_prompt_task( "agent {} exited during cancel_with_cleanup", agent.index ); - agent.sessions.clear(); - agent.heartbeat_session = None; + agent.state.invalidate_all(); let _ = result_tx.send(PromptResult { agent, source, @@ -763,14 +768,7 @@ pub async fn run_prompt_task( target: "pool::prompt", "cancel_with_cleanup error: {e} — invalidating session" ); - match &source { - PromptSource::Channel(cid) => { - agent.sessions.remove(cid); - } - PromptSource::Heartbeat => { - agent.heartbeat_session = None; - } - } + agent.state.invalidate(&source); let _ = result_tx.send(PromptResult { agent, source, @@ -1495,4 +1493,118 @@ mod tests { assert_eq!(pct_encode("+"), "%2B"); assert_eq!(pct_encode(" "), "%20"); } + + // ── SessionState tests ─────────────────────────────────────────────── + + fn make_state() -> (SessionState, Uuid, Uuid) { + let ch_a = Uuid::new_v4(); + let ch_b = Uuid::new_v4(); + let mut s = SessionState::default(); + s.sessions.insert(ch_a, "sess-a".into()); + s.sessions.insert(ch_b, "sess-b".into()); + s.turn_counts.insert(ch_a, 5); + s.turn_counts.insert(ch_b, 3); + s.heartbeat_session = Some("sess-hb".into()); + s.heartbeat_turn_count = 7; + (s, ch_a, ch_b) + } + + #[test] + fn test_invalidate_channel_clears_session_and_turn_count() { + let (mut s, ch_a, ch_b) = make_state(); + s.invalidate(&PromptSource::Channel(ch_a)); + + assert!(!s.sessions.contains_key(&ch_a)); + assert!(!s.turn_counts.contains_key(&ch_a)); + // ch_b untouched + assert_eq!(s.sessions.get(&ch_b).unwrap(), "sess-b"); + assert_eq!(*s.turn_counts.get(&ch_b).unwrap(), 3); + // heartbeat untouched + assert_eq!(s.heartbeat_session.as_deref(), Some("sess-hb")); + assert_eq!(s.heartbeat_turn_count, 7); + } + + #[test] + fn test_invalidate_heartbeat_clears_session_and_turn_count() { + let (mut s, ch_a, ch_b) = make_state(); + s.invalidate(&PromptSource::Heartbeat); + + assert!(s.heartbeat_session.is_none()); + assert_eq!(s.heartbeat_turn_count, 0); + // channels untouched + assert_eq!(s.sessions.len(), 2); + assert_eq!(*s.turn_counts.get(&ch_a).unwrap(), 5); + assert_eq!(*s.turn_counts.get(&ch_b).unwrap(), 3); + } + + #[test] + fn test_invalidate_all_clears_everything() { + let (mut s, _ch_a, _ch_b) = make_state(); + s.invalidate_all(); + + assert!(s.sessions.is_empty()); + assert!(s.turn_counts.is_empty()); + assert!(s.heartbeat_session.is_none()); + assert_eq!(s.heartbeat_turn_count, 0); + } + + #[test] + fn test_invalidate_nonexistent_channel_is_noop() { + let (mut s, ch_a, ch_b) = make_state(); + let ghost = Uuid::new_v4(); + s.invalidate(&PromptSource::Channel(ghost)); + + // Everything still intact. + assert_eq!(s.sessions.len(), 2); + assert_eq!(s.turn_counts.len(), 2); + assert_eq!(*s.turn_counts.get(&ch_a).unwrap(), 5); + assert_eq!(*s.turn_counts.get(&ch_b).unwrap(), 3); + } + + #[test] + fn test_invalidate_all_on_empty_state_is_noop() { + let mut s = SessionState::default(); + s.invalidate_all(); // should not panic + assert!(s.sessions.is_empty()); + assert!(s.turn_counts.is_empty()); + } + + #[test] + fn test_invalidate_channel_returns_true_when_session_existed() { + let (mut s, ch_a, ch_b) = make_state(); + assert!(s.invalidate_channel(&ch_a)); + assert!(!s.sessions.contains_key(&ch_a)); + assert!(!s.turn_counts.contains_key(&ch_a)); + // ch_b untouched + assert_eq!(s.sessions.get(&ch_b).unwrap(), "sess-b"); + assert_eq!(*s.turn_counts.get(&ch_b).unwrap(), 3); + // heartbeat untouched + assert_eq!(s.heartbeat_session.as_deref(), Some("sess-hb")); + assert_eq!(s.heartbeat_turn_count, 7); + } + + #[test] + fn test_invalidate_channel_returns_false_when_no_session() { + let (mut s, _ch_a, _ch_b) = make_state(); + let ghost = Uuid::new_v4(); + assert!(!s.invalidate_channel(&ghost)); + // Nothing changed. + assert_eq!(s.sessions.len(), 2); + assert_eq!(s.turn_counts.len(), 2); + } + + #[test] + fn test_removed_channels_cleaned_via_invalidate_channel() { + // Simulates handle_prompt_result: channels removed while agent + // was checked out should have both sessions and turn_counts stripped. + let (mut s, ch_a, ch_b) = make_state(); + let removed = vec![ch_a]; + for ch in &removed { + s.invalidate_channel(ch); + } + assert!(!s.sessions.contains_key(&ch_a)); + assert!(!s.turn_counts.contains_key(&ch_a)); + assert_eq!(s.sessions.get(&ch_b).unwrap(), "sess-b"); + assert_eq!(*s.turn_counts.get(&ch_b).unwrap(), 3); + } }