diff --git a/Cargo.lock b/Cargo.lock index 891f880..ec19606 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2951,6 +2951,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" dependencies = [ "aws-lc-rs", + "log", "once_cell", "ring", "rustls-pki-types", @@ -3677,6 +3678,7 @@ dependencies = [ "nostr", "rand 0.8.5", "reqwest", + "rustls", "serde", "serde_json", "sha2", diff --git a/crates/sprout-mcp/src/server.rs b/crates/sprout-mcp/src/server.rs index 7f27b49..f70078d 100644 --- a/crates/sprout-mcp/src/server.rs +++ b/crates/sprout-mcp/src/server.rs @@ -60,7 +60,7 @@ fn normalize_mention_pubkeys(mention_pubkeys: &[String], sender_pubkey: &str) -> normalized } -fn build_top_level_message_tags( +fn build_message_tags( channel_id: &str, sender_pubkey: &str, mention_pubkeys: Option<&[String]>, @@ -85,6 +85,32 @@ fn build_top_level_message_tags( Ok(tags) } +/// Extract the thread root event ID from a serialized Nostr tag array. +/// +/// Parses `"e"` tags with NIP-10 markers: +/// - If a `"root"` marker exists, returns that event ID. +/// - If only a `"reply"` marker exists, returns the reply target (it IS the root +/// for a direct reply — needed so nested replies can supply the correct root). +/// - If no thread markers exist, returns `None` (top-level message). +fn find_root_from_tags(tags: &serde_json::Value) -> Option { + let arr = tags.as_array()?; + let mut root = None; + let mut reply = None; + for tag in arr { + let Some(parts) = tag.as_array() else { + continue; + }; + if parts.len() >= 4 && parts[0].as_str() == Some("e") { + match parts[3].as_str() { + Some("root") => root = parts[1].as_str().map(|s| s.to_string()), + Some("reply") => reply = parts[1].as_str().map(|s| s.to_string()), + _ => {} + } + } + } + root.or(reply) +} + /// Maximum allowed content size for a single message (64 KiB). const MAX_CONTENT_BYTES: usize = 65_536; @@ -98,7 +124,7 @@ pub struct SendMessageParams { /// Nostr event kind. Defaults to KIND_STREAM_MESSAGE (NIP-29 group chat message). #[serde(default = "default_kind")] pub kind: Option, - /// Optional parent event ID. If provided, sends a reply via REST instead of WebSocket. + /// Optional parent event ID for threading. If provided, NIP-10 reply tags are added. #[serde(default)] pub parent_event_id: Option, /// If true and parent_event_id is set, surface the reply in the main channel timeline. @@ -679,6 +705,42 @@ impl SproutMcpServer { } } + /// Build NIP-10 reply tags for a threaded reply. + /// + /// Fetches the parent event via `GET /api/events/{id}` to determine thread + /// ancestry. This requires `MessagesRead` scope — acceptable because the MCP + /// server's read tools (get_messages, list_channels, search, etc.) already + /// require it, so any usable MCP token will have it. + /// + /// - Direct reply (parent is top-level): `["e", parent, "", "reply"]` + /// - Nested reply (parent is a reply): `["e", root, "", "root"]` + `["e", parent, "", "reply"]` + async fn build_reply_tags(&self, parent_event_id: &str) -> Result, String> { + let resp = self + .client + .get(&format!("/api/events/{}", parent_event_id)) + .await + .map_err(|e| format!("failed to fetch parent event: {e}"))?; + + let event_json: serde_json::Value = serde_json::from_str(&resp) + .map_err(|e| format!("failed to parse parent event: {e}"))?; + + let reply_tag = Tag::parse(&["e", parent_event_id, "", "reply"]) + .map_err(|e| format!("failed to build reply tag: {e}"))?; + + match find_root_from_tags(&event_json["tags"]) { + Some(root) if root != parent_event_id => { + // Nested reply — parent is itself a reply with a different root. + let root_tag = Tag::parse(&["e", &root, "", "root"]) + .map_err(|e| format!("failed to build root tag: {e}"))?; + Ok(vec![root_tag, reply_tag]) + } + _ => { + // Direct reply — parent is the root (or has no thread ancestry). + Ok(vec![reply_tag]) + } + } + } + /// Send a message to a Sprout channel. #[tool( name = "send_message", @@ -700,7 +762,6 @@ Default kind is 9 (stream message)." ); } - // Validate reply fields when present. if let Some(ref parent_id) = p.parent_event_id { if parent_id.len() != 64 || !parent_id.chars().all(|c| c.is_ascii_hexdigit()) { return format!( @@ -720,61 +781,49 @@ Default kind is 9 (stream message)." } } - // Use a user-signed WebSocket event for top-level messages so downstream - // clients see the agent pubkey directly rather than the relay pubkey. - // Threaded replies still go through REST because that path handles the - // reply ancestry tags and DB bookkeeping for us. - if p.parent_event_id.is_none() { - let kind = Kind::from( - p.kind - .unwrap_or(sprout_core::kind::KIND_STREAM_MESSAGE as u16), - ); - let sender_pubkey = self.client.keys().public_key().to_hex(); - let tags = match build_top_level_message_tags( - &p.channel_id, - &sender_pubkey, - p.mention_pubkeys.as_deref(), - ) { + let kind = Kind::from( + p.kind + .unwrap_or(sprout_core::kind::KIND_STREAM_MESSAGE as u16), + ); + let sender_pubkey = self.client.keys().public_key().to_hex(); + + // Base tags: channel + sender + mentions (same for all messages). + let mut tags = + match build_message_tags(&p.channel_id, &sender_pubkey, p.mention_pubkeys.as_deref()) { Ok(tags) => tags, Err(error) => return format!("Error: {error}"), }; - let event = - match EventBuilder::new(kind, p.content, tags).sign_with_keys(self.client.keys()) { - Ok(event) => event, - Err(e) => return format!("Error: failed to sign message event: {e}"), - }; - - return match self.client.send_event(event).await { - Ok(ok) => serde_json::json!({ - "event_id": ok.event_id, - "accepted": ok.accepted, - "message": ok.message, - }) - .to_string(), - Err(e) => format!("Error: {e}"), - }; - } - - let mut body = serde_json::json!({ - "content": p.content, - "broadcast_to_channel": p.broadcast_to_channel.unwrap_or(false), - }); + // Thread reply tags (NIP-10). if let Some(ref parent_id) = p.parent_event_id { - body["parent_event_id"] = serde_json::Value::String(parent_id.clone()); - } - if let Some(kind) = p.kind { - body["kind"] = serde_json::json!(kind); - } - if let Some(ref mentions) = p.mention_pubkeys { - body["mention_pubkeys"] = serde_json::json!(mentions); + match self.build_reply_tags(parent_id).await { + Ok(reply_tags) => tags.extend(reply_tags), + Err(e) => return format!("Error: {e}"), + } + + // Surface reply in channel timeline when requested. + if p.broadcast_to_channel.unwrap_or(false) { + match Tag::parse(&["broadcast", "1"]) { + Ok(tag) => tags.push(tag), + Err(e) => return format!("Error: failed to build broadcast tag: {e}"), + } + } } - match self - .client - .post(&format!("/api/channels/{}/messages", p.channel_id), &body) - .await - { - Ok(b) => b, + + // Sign with bot's own key and send via WebSocket. + let event = + match EventBuilder::new(kind, p.content, tags).sign_with_keys(self.client.keys()) { + Ok(event) => event, + Err(e) => return format!("Error: failed to sign message event: {e}"), + }; + + match self.client.send_event(event).await { + Ok(ok) => serde_json::json!({ + "event_id": ok.event_id, + "accepted": ok.accepted, + "message": ok.message, + }) + .to_string(), Err(e) => format!("Error: {e}"), } } @@ -1991,11 +2040,10 @@ mod tests { } #[test] - fn build_top_level_message_tags_lowercases_sender() { + fn build_message_tags_lowercases_sender() { let sender = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"; - let tags = - build_top_level_message_tags("550e8400-e29b-41d4-a716-446655440000", sender, None) - .expect("tags should build"); + let tags = build_message_tags("550e8400-e29b-41d4-a716-446655440000", sender, None) + .expect("tags should build"); let tag_strings = tags .iter() .map(|tag| tag.clone().to_vec()) @@ -2017,11 +2065,11 @@ mod tests { } #[test] - fn build_top_level_message_tags_includes_mention_p_tags() { + fn build_message_tags_includes_mention_p_tags() { let sender = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; let mention = "BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB".to_string(); - let tags = build_top_level_message_tags( + let tags = build_message_tags( "550e8400-e29b-41d4-a716-446655440000", sender, Some(&[mention]), @@ -2049,11 +2097,10 @@ mod tests { } #[test] - fn build_top_level_message_tags_no_mentions_when_none() { + fn build_message_tags_no_mentions_when_none() { let sender = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; - let tags = - build_top_level_message_tags("550e8400-e29b-41d4-a716-446655440000", sender, None) - .expect("tags should build"); + let tags = build_message_tags("550e8400-e29b-41d4-a716-446655440000", sender, None) + .expect("tags should build"); let tag_strings = tags .iter() .map(|tag| tag.clone().to_vec()) @@ -2072,11 +2119,10 @@ mod tests { } #[test] - fn build_top_level_message_tags_no_mentions_when_empty_slice() { + fn build_message_tags_no_mentions_when_empty_slice() { let sender = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; - let tags = - build_top_level_message_tags("550e8400-e29b-41d4-a716-446655440000", sender, Some(&[])) - .expect("tags should build"); + let tags = build_message_tags("550e8400-e29b-41d4-a716-446655440000", sender, Some(&[])) + .expect("tags should build"); let tag_strings = tags .iter() .map(|tag| tag.clone().to_vec()) @@ -2137,6 +2183,146 @@ mod tests { assert_eq!(parsed.event_id, params.event_id); assert!(matches!(parsed.direction, VoteDirection::Up)); } + + // ── find_root_from_tags ─────────────────────────────────────────────────── + + #[test] + fn find_root_from_tags_no_thread_tags() { + let tags = serde_json::json!([["h", "channel-uuid"], ["p", "sender_hex"]]); + assert_eq!(find_root_from_tags(&tags), None); + } + + #[test] + fn find_root_from_tags_direct_reply() { + // Parent has only a "reply" marker — it's a direct reply to root. + // The reply target IS the root, so we return it (needed for nested replies). + let tags = serde_json::json!([ + ["h", "channel-uuid"], + ["p", "sender_hex"], + ["e", "parent123", "", "reply"] + ]); + assert_eq!(find_root_from_tags(&tags), Some("parent123".to_string())); + } + + #[test] + fn find_root_from_tags_nested_reply() { + // Parent has both "root" and "reply" markers. + let tags = serde_json::json!([ + ["h", "channel-uuid"], + ["e", "root_abc", "", "root"], + ["e", "parent_def", "", "reply"], + ["p", "sender_hex"] + ]); + assert_eq!(find_root_from_tags(&tags), Some("root_abc".to_string())); + } + + #[test] + fn find_root_from_tags_root_only() { + // Parent has only a "root" marker (unusual but valid). + let tags = serde_json::json!([["e", "root_abc", "", "root"]]); + assert_eq!(find_root_from_tags(&tags), Some("root_abc".to_string())); + } + + #[test] + fn find_root_from_tags_root_takes_priority_over_reply() { + // When both root and reply markers exist, root wins. + let tags = serde_json::json!([ + ["e", "the_root", "", "root"], + ["e", "the_parent", "", "reply"] + ]); + assert_eq!(find_root_from_tags(&tags), Some("the_root".to_string())); + } + + #[test] + fn find_root_from_tags_reply_only_returns_reply_target() { + // Parent is a direct reply (only "reply" marker) — the reply target + // IS the root. Needed so nested replies can supply the correct root tag. + let tags = serde_json::json!([["h", "channel-uuid"], ["e", "original_root", "", "reply"]]); + assert_eq!( + find_root_from_tags(&tags), + Some("original_root".to_string()) + ); + } + + #[test] + fn find_root_from_tags_ignores_non_e_tags() { + let tags = serde_json::json!([ + ["p", "some_pubkey"], + ["h", "channel-uuid"], + ["broadcast", "1"] + ]); + assert_eq!(find_root_from_tags(&tags), None); + } + + #[test] + fn find_root_from_tags_ignores_short_e_tags() { + // e-tag with only 3 elements (no marker) — ignored. + let tags = serde_json::json!([["e", "abc123", ""]]); + assert_eq!(find_root_from_tags(&tags), None); + } + + #[test] + fn find_root_from_tags_null_input() { + let tags = serde_json::json!(null); + assert_eq!(find_root_from_tags(&tags), None); + } + + #[test] + fn find_root_from_tags_skips_malformed_tags() { + // Non-array tags (number, string) should be skipped, not abort the scan. + let tags = serde_json::json!([42, "not-an-array", ["e", "root_abc", "", "root"]]); + assert_eq!(find_root_from_tags(&tags), Some("root_abc".to_string())); + } + + // ── build_message_tags (new tests) ──────────────────────────────────────── + + #[test] + fn build_message_tags_no_mentions() { + let tags = build_message_tags("chan-uuid", "aabb00", None).unwrap(); + assert_eq!(tags.len(), 2); + assert_eq!(tags[0].as_slice(), &["h", "chan-uuid"]); + assert_eq!(tags[1].as_slice(), &["p", "aabb00"]); + } + + #[test] + fn build_message_tags_with_mentions() { + let mentions = vec!["cc11".to_string(), "dd22".to_string()]; + let tags = build_message_tags("chan-uuid", "aabb00", Some(&mentions)).unwrap(); + assert_eq!(tags.len(), 4); + assert_eq!(tags[2].as_slice(), &["p", "cc11"]); + assert_eq!(tags[3].as_slice(), &["p", "dd22"]); + } + + #[test] + fn build_message_tags_self_mention_excluded() { + // Sender in mention list should be filtered out. + let mentions = vec!["AABB00".to_string(), "cc11".to_string()]; + let tags = build_message_tags("chan-uuid", "aabb00", Some(&mentions)).unwrap(); + assert_eq!(tags.len(), 3); // h + p(sender) + p(cc11), NOT p(AABB00) + assert_eq!(tags[2].as_slice(), &["p", "cc11"]); + } + + #[test] + fn build_message_tags_deduplicates_mentions() { + let mentions = vec!["cc11".to_string(), "CC11".to_string()]; + let tags = build_message_tags("chan-uuid", "aabb00", Some(&mentions)).unwrap(); + assert_eq!(tags.len(), 3); // h + p(sender) + p(cc11) — deduped + } + + // ── normalize_mention_pubkeys (new tests) ───────────────────────────────── + + #[test] + fn normalize_mention_pubkeys_empty() { + let result = normalize_mention_pubkeys(&[], "sender"); + assert!(result.is_empty()); + } + + #[test] + fn normalize_mention_pubkeys_case_insensitive_sender_filter() { + let mentions = vec!["SENDER".to_string(), "other".to_string()]; + let result = normalize_mention_pubkeys(&mentions, "sender"); + assert_eq!(result, vec!["other"]); + } } #[cfg(test)] diff --git a/crates/sprout-test-client/Cargo.toml b/crates/sprout-test-client/Cargo.toml index 3de947f..3c0ce55 100644 --- a/crates/sprout-test-client/Cargo.toml +++ b/crates/sprout-test-client/Cargo.toml @@ -22,6 +22,7 @@ tracing-subscriber = { workspace = true } thiserror = { workspace = true } uuid = { workspace = true } url = { workspace = true } +rustls = "0.23" [dev-dependencies] tracing-subscriber = { workspace = true } diff --git a/crates/sprout-test-client/src/bin/mention.rs b/crates/sprout-test-client/src/bin/mention.rs index f57bced..26f2c47 100644 --- a/crates/sprout-test-client/src/bin/mention.rs +++ b/crates/sprout-test-client/src/bin/mention.rs @@ -6,6 +6,8 @@ use sprout_test_client::SproutTestClient; #[tokio::main] async fn main() -> anyhow::Result<()> { + // rustls needs a CryptoProvider even for plain ws:// connections. + let _ = rustls::crypto::ring::default_provider().install_default(); let args: Vec = std::env::args().collect(); if args.len() < 4 { eprintln!("Usage: mention "); diff --git a/crates/sprout-test-client/tests/e2e_mcp.rs b/crates/sprout-test-client/tests/e2e_mcp.rs index fb6e641..86f96ea 100644 --- a/crates/sprout-test-client/tests/e2e_mcp.rs +++ b/crates/sprout-test-client/tests/e2e_mcp.rs @@ -102,6 +102,11 @@ fn spawn_mcp_server(keys: &Keys) -> Child { ]) .env("SPROUT_RELAY_URL", relay_ws_url()) .env("SPROUT_PRIVATE_KEY", &nsec) + // Tests exercise all 43 tools — enable every toolset. + .env("SPROUT_TOOLSETS", "all") + // Prevent a stale SPROUT_API_TOKEN from the host .env leaking into + // the subprocess and causing NIP-42 auth failures against a fresh DB. + .env_remove("SPROUT_API_TOKEN") // Suppress verbose startup logs so they don't pollute stderr output. .env("RUST_LOG", "error") .stdin(Stdio::piped()) @@ -295,8 +300,8 @@ async fn test_mcp_initialize_and_list_tools() { assert_eq!( tools.len(), - 43, - "expected exactly 43 tools, got {}. Tools: {:?}", + 42, + "expected exactly 42 tools, got {}. Tools: {:?}", tools.len(), tools .iter()