diff --git a/crates/sprout-acp/src/config.rs b/crates/sprout-acp/src/config.rs index e088dc1..f05aef8 100644 --- a/crates/sprout-acp/src/config.rs +++ b/crates/sprout-acp/src/config.rs @@ -185,6 +185,12 @@ pub struct CliArgs { value_parser = clap::value_parser!(u32).range(0..=100))] pub context_message_limit: u32, + /// Maximum turns per session before proactive rotation. 0 = disabled + /// (rotate only on MaxTokens / MaxTurnRequests). + #[arg(long, env = "SPROUT_ACP_MAX_TURNS_PER_SESSION", default_value_t = 0, + value_parser = clap::value_parser!(u32))] + pub max_turns_per_session: u32, + /// Disable automatic presence (online/offline) status. #[arg(long, env = "SPROUT_ACP_NO_PRESENCE")] pub no_presence: bool, @@ -234,6 +240,8 @@ pub struct Config { pub no_mention_filter: bool, pub config_path: PathBuf, pub context_message_limit: u32, + /// Maximum turns per session before proactive rotation. 0 = disabled. + pub max_turns_per_session: u32, pub presence_enabled: bool, pub typing_enabled: bool, /// Desired LLM model ID. Applied after every `session_new_full()`. @@ -373,6 +381,7 @@ impl Config { no_mention_filter: args.no_mention_filter, config_path: args.config, context_message_limit: args.context_message_limit, + max_turns_per_session: args.max_turns_per_session, presence_enabled: !args.no_presence, typing_enabled: !args.no_typing, model: args.model, @@ -382,7 +391,7 @@ impl Config { /// Human-readable summary (no secrets). pub fn summary(&self) -> String { format!( - "relay={} pubkey={} agent_cmd={} {} mcp_cmd={} timeout={}s agents={} heartbeat={}s subscribe={:?} dedup={:?} ignore_self={} context_limit={} presence={} typing={} model={}", + "relay={} pubkey={} agent_cmd={} {} mcp_cmd={} timeout={}s agents={} heartbeat={}s subscribe={:?} dedup={:?} ignore_self={} context_limit={} max_turns_per_session={} presence={} typing={} model={}", self.relay_url, self.keys.public_key().to_hex(), self.agent_command, @@ -395,6 +404,7 @@ impl Config { self.dedup_mode, self.ignore_self, self.context_message_limit, + self.max_turns_per_session, self.presence_enabled, self.typing_enabled, self.model.as_deref().unwrap_or("(agent default)"), @@ -692,6 +702,7 @@ mod tests { no_mention_filter: false, config_path: PathBuf::from("./sprout-acp.toml"), context_message_limit: 12, + max_turns_per_session: 0, presence_enabled: true, typing_enabled: true, model: None, diff --git a/crates/sprout-acp/src/main.rs b/crates/sprout-acp/src/main.rs index a4b5afa..228905d 100644 --- a/crates/sprout-acp/src/main.rs +++ b/crates/sprout-acp/src/main.rs @@ -18,7 +18,9 @@ use config::{Config, DedupMode, ModelsArgs, SubscribeMode}; use filter::SubscriptionRule; use futures_util::FutureExt; use nostr::ToBech32; -use pool::{AgentPool, OwnedAgent, PromptContext, PromptOutcome, PromptResult, PromptSource}; +use pool::{ + AgentPool, OwnedAgent, PromptContext, PromptOutcome, PromptResult, PromptSource, SessionState, +}; use queue::{EventQueue, QueuedEvent}; use relay::HarnessRelay; use sprout_core::kind::{ @@ -82,8 +84,7 @@ async fn main() -> Result<()> { agents.push(OwnedAgent { index: i, acp, - sessions: HashMap::new(), - heartbeat_session: None, + state: SessionState::default(), model_capabilities: None, desired_model: config.model.clone(), }); @@ -228,6 +229,7 @@ async fn main() -> Result<()> { rest_client: relay.rest_client(), channel_info: channel_info_map, context_message_limit: config.context_message_limit, + max_turns_per_session: config.max_turns_per_session, }); // ── Step 6: Heartbeat timer ─────────────────────────────────────────────── @@ -808,11 +810,8 @@ async fn handle_prompt_result( // Strip sessions for channels the agent was removed from while this // agent was checked out. This covers the gap where invalidate_channel_sessions // only touches idle agents. - if !removed_channels.is_empty() { - result - .agent - .sessions - .retain(|ch, _| !removed_channels.contains(ch)); + for ch in removed_channels { + result.agent.state.invalidate_channel(ch); } let outcome_label = match &result.outcome { @@ -899,8 +898,7 @@ async fn recover_panicked_agent( pool.agents_mut()[i] = Some(OwnedAgent { index: i, acp, - sessions: HashMap::new(), - heartbeat_session: None, + state: SessionState::default(), model_capabilities: None, desired_model: config.model.clone(), }); @@ -1011,8 +1009,7 @@ async fn respawn_agent_into(old_agent: OwnedAgent, config: &Config) -> Result, } +/// Per-channel session IDs and turn counters. +/// +/// Separated from `OwnedAgent` so the state machine is testable without +/// spawning a real agent subprocess. +#[derive(Default)] +pub struct SessionState { + /// channel_id → session_id + pub sessions: HashMap, + pub heartbeat_session: Option, + /// Per-channel turn counters for proactive session rotation. + /// Incremented on each successful prompt; reset when the session is rotated. + pub turn_counts: HashMap, + /// Turn counter for the heartbeat session. + pub heartbeat_turn_count: u32, +} + +impl SessionState { + /// Invalidate the session (and turn counter) for a specific prompt source. + pub fn invalidate(&mut self, source: &PromptSource) { + match source { + PromptSource::Channel(cid) => { + self.invalidate_channel(cid); + } + PromptSource::Heartbeat => { + self.heartbeat_session = None; + self.heartbeat_turn_count = 0; + } + } + } + + /// Invalidate a single channel's session and turn counter. + /// Returns `true` if the channel had an active session. + pub fn invalidate_channel(&mut self, channel_id: &Uuid) -> bool { + self.turn_counts.remove(channel_id); + self.sessions.remove(channel_id).is_some() + } + + /// Invalidate all sessions and turn counters (e.g. after agent exit). + pub fn invalidate_all(&mut self) { + self.sessions.clear(); + self.turn_counts.clear(); + self.heartbeat_session = None; + self.heartbeat_turn_count = 0; + } +} + /// An agent with its session state, owned by the pool or a running task. pub struct OwnedAgent { pub index: usize, pub acp: AcpClient, - /// channel_id → session_id - pub sessions: HashMap, - pub heartbeat_session: Option, + pub state: SessionState, /// Model catalog from first session/new. None until first session created. pub model_capabilities: Option, /// Desired model ID (from `Config.model`). Applied after every `session_new_full()`. @@ -97,6 +141,7 @@ pub struct PromptResult { } /// Whether the prompt came from a channel event or a heartbeat. +#[derive(Debug)] pub enum PromptSource { Channel(Uuid), Heartbeat, @@ -129,6 +174,8 @@ pub struct PromptContext { pub channel_info: std::collections::HashMap, /// Max messages to include in thread/DM context. 0 = disabled. pub context_message_limit: u32, + /// Max turns per session before proactive rotation. 0 = disabled. + pub max_turns_per_session: u32, } // ── AgentPool impl ──────────────────────────────────────────────────────────── @@ -161,7 +208,7 @@ impl AgentPool { if let Some(cid) = channel_id { let idx = self.agents.iter().position(|slot| { slot.as_ref() - .map(|a| a.sessions.contains_key(&cid)) + .map(|a| a.state.sessions.contains_key(&cid)) .unwrap_or(false) }); if let Some(i) = idx { @@ -194,7 +241,7 @@ impl AgentPool { pub fn has_session_for(&self, channel_id: Uuid) -> bool { self.agents.iter().any(|slot| { slot.as_ref() - .map(|a| a.sessions.contains_key(&channel_id)) + .map(|a| a.state.sessions.contains_key(&channel_id)) .unwrap_or(false) }) } @@ -247,7 +294,7 @@ impl AgentPool { let mut count = 0; for slot in &mut self.agents { if let Some(agent) = slot.as_mut() { - if agent.sessions.remove(&channel_id).is_some() { + if agent.state.invalidate_channel(&channel_id) { count += 1; } } @@ -401,7 +448,7 @@ pub async fn run_prompt_task( let (session_id, is_new_session) = match &source { PromptSource::Channel(cid) => { - if let Some(sid) = agent.sessions.get(cid) { + if let Some(sid) = agent.state.sessions.get(cid) { (sid.clone(), false) } else { // Create new session with model application. @@ -411,12 +458,11 @@ pub async fn run_prompt_task( target: "pool::session", "created session {sid} for channel {cid}" ); - agent.sessions.insert(*cid, sid.clone()); + agent.state.sessions.insert(*cid, sid.clone()); (sid, true) } Err(AcpError::AgentExited) => { - agent.sessions.clear(); - agent.heartbeat_session = None; + agent.state.invalidate_all(); let _ = result_tx.send(PromptResult { agent, source, @@ -438,7 +484,7 @@ pub async fn run_prompt_task( } } PromptSource::Heartbeat => { - if let Some(sid) = &agent.heartbeat_session { + if let Some(sid) = &agent.state.heartbeat_session { (sid.clone(), false) } else { match create_session_and_apply_model(&mut agent, &ctx).await { @@ -448,12 +494,11 @@ pub async fn run_prompt_task( "created heartbeat session {sid} for agent {}", agent.index ); - agent.heartbeat_session = Some(sid.clone()); + agent.state.heartbeat_session = Some(sid.clone()); (sid, true) } Err(AcpError::AgentExited) => { - agent.sessions.clear(); - agent.heartbeat_session = None; + agent.state.invalidate_all(); let _ = result_tx.send(PromptResult { agent, source, @@ -499,8 +544,7 @@ pub async fn run_prompt_task( ); } Ok(Err(AcpError::AgentExited)) => { - agent.sessions.clear(); - agent.heartbeat_session = None; + agent.state.invalidate_all(); let _ = result_tx.send(PromptResult { agent, source, @@ -514,7 +558,7 @@ pub async fn run_prompt_task( target: "pool::session", "initial_message failed for channel {cid}: {e} — invalidating session" ); - agent.sessions.remove(cid); + agent.state.invalidate(&source); let _ = result_tx.send(PromptResult { agent, source, @@ -530,11 +574,10 @@ pub async fn run_prompt_task( ); match agent.acp.cancel_with_cleanup(&session_id).await { Ok(_) => { - agent.sessions.remove(cid); + agent.state.invalidate(&source); } Err(AcpError::AgentExited) => { - agent.sessions.clear(); - agent.heartbeat_session = None; + agent.state.invalidate_all(); let _ = result_tx.send(PromptResult { agent, source, @@ -548,7 +591,7 @@ pub async fn run_prompt_task( target: "pool::session", "cancel_with_cleanup failed during initial_message timeout: {e}" ); - agent.sessions.remove(cid); + agent.state.invalidate(&source); } } let _ = result_tx.send(PromptResult { @@ -626,6 +669,41 @@ pub async fn run_prompt_task( match prompt_result { Ok(Ok(stop_reason)) => { log_stop_reason(&source, &stop_reason); + + // ── Session rotation on context exhaustion ──────────────── + let should_rotate = matches!( + stop_reason, + StopReason::MaxTokens | StopReason::MaxTurnRequests + ); + + // ── Proactive turn-based rotation ───────────────────────── + let should_rotate = should_rotate || { + let limit = ctx.max_turns_per_session; + if limit > 0 { + match &source { + PromptSource::Channel(cid) => { + let count = agent.state.turn_counts.entry(*cid).or_insert(0); + *count += 1; + *count >= limit + } + PromptSource::Heartbeat => { + agent.state.heartbeat_turn_count += 1; + agent.state.heartbeat_turn_count >= limit + } + } + } else { + false + } + }; + + if should_rotate { + tracing::info!( + target: "pool::session", + "rotating session for {source:?} after {stop_reason:?}", + ); + agent.state.invalidate(&source); + } + let _ = result_tx.send(PromptResult { agent, source, @@ -635,8 +713,7 @@ pub async fn run_prompt_task( } Ok(Err(AcpError::AgentExited)) => { tracing::error!(target: "pool::prompt", "agent {} exited during prompt", agent.index); - agent.sessions.clear(); - agent.heartbeat_session = None; + agent.state.invalidate_all(); let _ = result_tx.send(PromptResult { agent, source, @@ -647,14 +724,7 @@ pub async fn run_prompt_task( Ok(Err(e)) => { tracing::error!(target: "pool::prompt", "session_prompt error: {e}"); // Invalidate only the affected session. - match &source { - PromptSource::Channel(cid) => { - agent.sessions.remove(cid); - } - PromptSource::Heartbeat => { - agent.heartbeat_session = None; - } - } + agent.state.invalidate(&source); let _ = result_tx.send(PromptResult { agent, source, @@ -685,8 +755,7 @@ pub async fn run_prompt_task( "agent {} exited during cancel_with_cleanup", agent.index ); - agent.sessions.clear(); - agent.heartbeat_session = None; + agent.state.invalidate_all(); let _ = result_tx.send(PromptResult { agent, source, @@ -699,14 +768,7 @@ pub async fn run_prompt_task( target: "pool::prompt", "cancel_with_cleanup error: {e} — invalidating session" ); - match &source { - PromptSource::Channel(cid) => { - agent.sessions.remove(cid); - } - PromptSource::Heartbeat => { - agent.heartbeat_session = None; - } - } + agent.state.invalidate(&source); let _ = result_tx.send(PromptResult { agent, source, @@ -1005,10 +1067,10 @@ fn log_stop_reason(source: &PromptSource, stop_reason: &StopReason) { tracing::warn!(target: "pool::prompt", "turn cancelled for {label}"); } StopReason::MaxTokens => { - tracing::warn!(target: "pool::prompt", "turn hit max_tokens for {label}"); + tracing::warn!(target: "pool::prompt", "turn hit max_tokens for {label} — session will be rotated"); } StopReason::MaxTurnRequests => { - tracing::warn!(target: "pool::prompt", "turn hit max_turn_requests for {label}"); + tracing::warn!(target: "pool::prompt", "turn hit max_turn_requests for {label} — session will be rotated"); } StopReason::Refusal => { tracing::warn!(target: "pool::prompt", "turn refused for {label}"); @@ -1431,4 +1493,118 @@ mod tests { assert_eq!(pct_encode("+"), "%2B"); assert_eq!(pct_encode(" "), "%20"); } + + // ── SessionState tests ─────────────────────────────────────────────── + + fn make_state() -> (SessionState, Uuid, Uuid) { + let ch_a = Uuid::new_v4(); + let ch_b = Uuid::new_v4(); + let mut s = SessionState::default(); + s.sessions.insert(ch_a, "sess-a".into()); + s.sessions.insert(ch_b, "sess-b".into()); + s.turn_counts.insert(ch_a, 5); + s.turn_counts.insert(ch_b, 3); + s.heartbeat_session = Some("sess-hb".into()); + s.heartbeat_turn_count = 7; + (s, ch_a, ch_b) + } + + #[test] + fn test_invalidate_channel_clears_session_and_turn_count() { + let (mut s, ch_a, ch_b) = make_state(); + s.invalidate(&PromptSource::Channel(ch_a)); + + assert!(!s.sessions.contains_key(&ch_a)); + assert!(!s.turn_counts.contains_key(&ch_a)); + // ch_b untouched + assert_eq!(s.sessions.get(&ch_b).unwrap(), "sess-b"); + assert_eq!(*s.turn_counts.get(&ch_b).unwrap(), 3); + // heartbeat untouched + assert_eq!(s.heartbeat_session.as_deref(), Some("sess-hb")); + assert_eq!(s.heartbeat_turn_count, 7); + } + + #[test] + fn test_invalidate_heartbeat_clears_session_and_turn_count() { + let (mut s, ch_a, ch_b) = make_state(); + s.invalidate(&PromptSource::Heartbeat); + + assert!(s.heartbeat_session.is_none()); + assert_eq!(s.heartbeat_turn_count, 0); + // channels untouched + assert_eq!(s.sessions.len(), 2); + assert_eq!(*s.turn_counts.get(&ch_a).unwrap(), 5); + assert_eq!(*s.turn_counts.get(&ch_b).unwrap(), 3); + } + + #[test] + fn test_invalidate_all_clears_everything() { + let (mut s, _ch_a, _ch_b) = make_state(); + s.invalidate_all(); + + assert!(s.sessions.is_empty()); + assert!(s.turn_counts.is_empty()); + assert!(s.heartbeat_session.is_none()); + assert_eq!(s.heartbeat_turn_count, 0); + } + + #[test] + fn test_invalidate_nonexistent_channel_is_noop() { + let (mut s, ch_a, ch_b) = make_state(); + let ghost = Uuid::new_v4(); + s.invalidate(&PromptSource::Channel(ghost)); + + // Everything still intact. + assert_eq!(s.sessions.len(), 2); + assert_eq!(s.turn_counts.len(), 2); + assert_eq!(*s.turn_counts.get(&ch_a).unwrap(), 5); + assert_eq!(*s.turn_counts.get(&ch_b).unwrap(), 3); + } + + #[test] + fn test_invalidate_all_on_empty_state_is_noop() { + let mut s = SessionState::default(); + s.invalidate_all(); // should not panic + assert!(s.sessions.is_empty()); + assert!(s.turn_counts.is_empty()); + } + + #[test] + fn test_invalidate_channel_returns_true_when_session_existed() { + let (mut s, ch_a, ch_b) = make_state(); + assert!(s.invalidate_channel(&ch_a)); + assert!(!s.sessions.contains_key(&ch_a)); + assert!(!s.turn_counts.contains_key(&ch_a)); + // ch_b untouched + assert_eq!(s.sessions.get(&ch_b).unwrap(), "sess-b"); + assert_eq!(*s.turn_counts.get(&ch_b).unwrap(), 3); + // heartbeat untouched + assert_eq!(s.heartbeat_session.as_deref(), Some("sess-hb")); + assert_eq!(s.heartbeat_turn_count, 7); + } + + #[test] + fn test_invalidate_channel_returns_false_when_no_session() { + let (mut s, _ch_a, _ch_b) = make_state(); + let ghost = Uuid::new_v4(); + assert!(!s.invalidate_channel(&ghost)); + // Nothing changed. + assert_eq!(s.sessions.len(), 2); + assert_eq!(s.turn_counts.len(), 2); + } + + #[test] + fn test_removed_channels_cleaned_via_invalidate_channel() { + // Simulates handle_prompt_result: channels removed while agent + // was checked out should have both sessions and turn_counts stripped. + let (mut s, ch_a, ch_b) = make_state(); + let removed = vec![ch_a]; + for ch in &removed { + s.invalidate_channel(ch); + } + assert!(!s.sessions.contains_key(&ch_a)); + assert!(!s.turn_counts.contains_key(&ch_a)); + assert_eq!(s.sessions.get(&ch_b).unwrap(), "sess-b"); + assert_eq!(*s.turn_counts.get(&ch_b).unwrap(), 3); + } }