From 5c94effd7ecbb1bb0befffef3332345185f3cd09 Mon Sep 17 00:00:00 2001 From: Raghav Sood Date: Sat, 21 Feb 2026 02:20:52 +0800 Subject: [PATCH 1/7] add MCP server integration and per-conversation vault info Connects the agent backend to a Vultisig MCP server (optional, via MCP_SERVER_URL) so Claude can discover and call external tools like get_eth_balance and get_token_balance. Adds a set_vault built-in tool that binds vault keys (ECDSA, EdDSA, chaincode) to a conversation. On each request the MCP session is automatically primed with the active vault so address derivation happens server-side without the user repeating keys. Co-Authored-By: Claude Opus 4.6 --- cmd/server/main.go | 26 +- internal/config/config.go | 5 +- internal/mcp/client.go | 404 ++++++++++++++++++ internal/service/agent/agent.go | 77 +++- internal/service/agent/executor.go | 69 ++- internal/service/agent/prompt.go | 14 + internal/service/agent/tools.go | 29 ++ internal/storage/postgres/conversation.go | 40 ++ internal/storage/postgres/convert.go | 10 +- .../20260221000001_add_vault_info.sql | 11 + .../postgres/queries/conversations.sql.go | 66 ++- internal/storage/postgres/queries/models.go | 19 +- internal/storage/postgres/schema/schema.sql | 3 + .../storage/postgres/sqlc/conversations.sql | 9 + internal/types/conversation.go | 8 + 15 files changed, 771 insertions(+), 19 deletions(-) create mode 100644 internal/mcp/client.go create mode 100644 internal/storage/postgres/migrations/20260221000001_add_vault_info.sql diff --git a/cmd/server/main.go b/cmd/server/main.go index e05fced..82be637 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -17,6 +17,7 @@ import ( "github.com/vultisig/agent-backend/internal/api" "github.com/vultisig/agent-backend/internal/cache/redis" "github.com/vultisig/agent-backend/internal/config" + "github.com/vultisig/agent-backend/internal/mcp" "github.com/vultisig/agent-backend/internal/service/agent" mcpclient "github.com/vultisig/agent-backend/internal/service/mcp" "github.com/vultisig/agent-backend/internal/service/plugin" @@ -72,6 +73,29 @@ func main() { msgRepo := postgres.NewMessageRepository(db.Pool()) memRepo := postgres.NewMemoryRepository(db.Pool()) + // Initialize MCP JSON-RPC client for tool discovery and vault operations (optional) + var mcpProvider agent.MCPToolProvider + if cfg.MCP.ServerURL != "" { + cacheTTL := time.Duration(cfg.MCP.ToolCacheTTLSec) * time.Second + mcpClient := mcp.NewClient(cfg.MCP.ServerURL, cacheTTL, logger) + + mcpCtx, mcpCancel := context.WithTimeout(ctx, 10*time.Second) + defer mcpCancel() + + if err := mcpClient.Initialize(mcpCtx); err != nil { + logger.WithError(err).Warn("failed to initialize mcp client, continuing without mcp tools") + } else { + tools, err := mcpClient.ListTools(mcpCtx) + if err != nil { + logger.WithError(err).Warn("failed to list mcp tools, continuing without mcp tools") + } else { + logger.WithField("tool_count", len(tools)).Info("mcp tools loaded") + mcpProvider = mcpClient + } + } + } + + // Initialize MCP swap tx builder (optional) var swapTxBuilder agent.SwapTxBuilder if cfg.MCP.URL != "" { mcpCl := mcpclient.NewClient(cfg.MCP.URL) @@ -80,7 +104,7 @@ func main() { } // Initialize agent service - agentService := agent.NewAgentService(aiClient, msgRepo, convRepo, memRepo, redisClient, verifierClient, pluginService, swapTxBuilder, logger, cfg.AI.SummaryModel, cfg.Context) + agentService := agent.NewAgentService(aiClient, msgRepo, convRepo, memRepo, redisClient, verifierClient, pluginService, mcpProvider, swapTxBuilder, logger, cfg.AI.SummaryModel, cfg.Context) // Initialize API server server := api.NewServer( diff --git a/internal/config/config.go b/internal/config/config.go index f8c65bf..4ec7143 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -63,8 +63,11 @@ type VerifierConfig struct { URL string `envconfig:"VERIFIER_URL" required:"true"` } +// MCPConfig holds MCP (Model Context Protocol) server configuration. type MCPConfig struct { - URL string `envconfig:"MCP_URL" default:""` + ServerURL string `envconfig:"MCP_SERVER_URL"` + ToolCacheTTLSec int `envconfig:"MCP_TOOL_CACHE_TTL_SECONDS" default:"300"` + URL string `envconfig:"MCP_URL" default:""` } // TODO: Add MetricsConfig for Prometheus metrics when metrics are implemented. diff --git a/internal/mcp/client.go b/internal/mcp/client.go new file mode 100644 index 0000000..8e54991 --- /dev/null +++ b/internal/mcp/client.go @@ -0,0 +1,404 @@ +package mcp + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/sirupsen/logrus" + + "github.com/vultisig/agent-backend/internal/ai" +) + +// JSON-RPC 2.0 types + +type jsonRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params any `json:"params,omitempty"` +} + +type jsonRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *jsonRPCError `json:"error,omitempty"` +} + +type jsonRPCError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +func (e *jsonRPCError) Error() string { + return fmt.Sprintf("mcp rpc error %d: %s", e.Code, e.Message) +} + +// MCP-specific types + +// MCPTool represents a tool definition from the MCP server. +type MCPTool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema any `json:"inputSchema"` +} + +type callToolParams struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments,omitempty"` +} + +type callToolResult struct { + Content []callToolContent `json:"content"` + IsError bool `json:"isError,omitempty"` +} + +type callToolContent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` +} + +// toolCache holds cached MCP tools with a TTL. +type toolCache struct { + mu sync.RWMutex + tools []MCPTool + fetchedAt time.Time + ttl time.Duration +} + +func (tc *toolCache) get() ([]MCPTool, bool) { + tc.mu.RLock() + defer tc.mu.RUnlock() + if tc.tools == nil { + return nil, false + } + fresh := time.Since(tc.fetchedAt) < tc.ttl + return tc.tools, fresh +} + +func (tc *toolCache) set(tools []MCPTool) { + tc.mu.Lock() + defer tc.mu.Unlock() + tc.tools = tools + tc.fetchedAt = time.Now() +} + +// Client is an MCP JSON-RPC 2.0 client using Streamable HTTP transport. +type Client struct { + serverURL string + httpClient *http.Client + sessionID string + requestID atomic.Int64 + cache toolCache + logger *logrus.Logger +} + +// NewClient creates a new MCP client. +func NewClient(serverURL string, cacheTTL time.Duration, logger *logrus.Logger) *Client { + return &Client{ + serverURL: strings.TrimRight(serverURL, "/"), + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + cache: toolCache{ttl: cacheTTL}, + logger: logger, + } +} + +// call performs a JSON-RPC 2.0 call over HTTP. +func (c *Client) call(ctx context.Context, method string, params any) (json.RawMessage, error) { + id := c.requestID.Add(1) + reqBody := jsonRPCRequest{ + JSONRPC: "2.0", + ID: id, + Method: method, + Params: params, + } + + body, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + url := c.serverURL + "/mcp" + + c.logger.WithFields(logrus.Fields{ + "mcp_method": method, + "mcp_id": id, + "mcp_url": url, + "mcp_session": c.sessionID, + }).Debug("mcp request sending") + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "application/json") + if c.sessionID != "" { + httpReq.Header.Set("Mcp-Session-Id", c.sessionID) + } + + start := time.Now() + resp, err := c.httpClient.Do(httpReq) + elapsed := time.Since(start) + if err != nil { + c.logger.WithError(err).WithFields(logrus.Fields{ + "mcp_method": method, + "mcp_id": id, + "mcp_elapsed": elapsed.String(), + }).Error("mcp request failed") + return nil, fmt.Errorf("send request: %w", err) + } + defer resp.Body.Close() + + // Track session ID from response + if sid := resp.Header.Get("Mcp-Session-Id"); sid != "" { + if c.sessionID != sid { + c.logger.WithFields(logrus.Fields{ + "mcp_session_old": c.sessionID, + "mcp_session_new": sid, + }).Debug("mcp session id updated") + } + c.sessionID = sid + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + c.logger.WithFields(logrus.Fields{ + "mcp_method": method, + "mcp_id": id, + "mcp_status": resp.StatusCode, + "mcp_elapsed": elapsed.String(), + "mcp_response_len": len(respBody), + }).Debug("mcp response received") + + if resp.StatusCode != http.StatusOK { + c.logger.WithFields(logrus.Fields{ + "mcp_method": method, + "mcp_status": resp.StatusCode, + "mcp_body": string(respBody), + }).Error("mcp server returned non-200 status") + return nil, fmt.Errorf("mcp server returned status %d: %s", resp.StatusCode, string(respBody)) + } + + var rpcResp jsonRPCResponse + if err := json.Unmarshal(respBody, &rpcResp); err != nil { + c.logger.WithError(err).WithField("mcp_body", string(respBody)).Error("mcp response unmarshal failed") + return nil, fmt.Errorf("unmarshal response: %w", err) + } + + if rpcResp.Error != nil { + c.logger.WithFields(logrus.Fields{ + "mcp_method": method, + "mcp_error_code": rpcResp.Error.Code, + "mcp_error_msg": rpcResp.Error.Message, + }).Error("mcp rpc error") + return nil, rpcResp.Error + } + + return rpcResp.Result, nil +} + +// Initialize performs the MCP initialize handshake. +func (c *Client) Initialize(ctx context.Context) error { + c.logger.WithField("mcp_url", c.serverURL).Info("mcp initializing") + + params := map[string]any{ + "protocolVersion": "2025-03-26", + "capabilities": map[string]any{}, + "clientInfo": map[string]string{ + "name": "vultisig-agent-backend", + "version": "1.0.0", + }, + } + + result, err := c.call(ctx, "initialize", params) + if err != nil { + return fmt.Errorf("initialize: %w", err) + } + + c.logger.WithField("mcp_init_result", string(result)).Info("mcp initialized successfully") + + // Send initialized notification — best-effort, some servers don't handle it + if _, err := c.call(ctx, "notifications/initialized", nil); err != nil { + c.logger.WithError(err).Debug("mcp notifications/initialized not supported by server (harmless)") + } + + return nil +} + +// ListTools fetches the tool list from the MCP server and updates the cache. +func (c *Client) ListTools(ctx context.Context) ([]MCPTool, error) { + c.logger.Debug("mcp listing tools") + + result, err := c.call(ctx, "tools/list", nil) + if err != nil { + // Return stale cache on error + if stale, _ := c.cache.get(); stale != nil { + c.logger.WithError(err).WithField("stale_count", len(stale)).Warn("mcp list tools failed, using stale cache") + return stale, nil + } + return nil, fmt.Errorf("list tools: %w", err) + } + + var listResult struct { + Tools []MCPTool `json:"tools"` + } + if err := json.Unmarshal(result, &listResult); err != nil { + return nil, fmt.Errorf("unmarshal tools: %w", err) + } + + names := make([]string, len(listResult.Tools)) + for i, t := range listResult.Tools { + names[i] = t.Name + } + c.logger.WithFields(logrus.Fields{ + "mcp_tool_count": len(listResult.Tools), + "mcp_tool_names": names, + }).Info("mcp tools discovered") + + c.cache.set(listResult.Tools) + return listResult.Tools, nil +} + +// CallTool invokes a tool on the MCP server. +func (c *Client) CallTool(ctx context.Context, name string, arguments json.RawMessage) (string, error) { + c.logger.WithFields(logrus.Fields{ + "mcp_tool": name, + "mcp_arguments": string(arguments), + }).Info("mcp calling tool") + + var args map[string]any + if len(arguments) > 0 { + if err := json.Unmarshal(arguments, &args); err != nil { + return "", fmt.Errorf("unmarshal arguments: %w", err) + } + } + + params := callToolParams{ + Name: name, + Arguments: args, + } + + result, err := c.call(ctx, "tools/call", params) + if err != nil { + c.logger.WithError(err).WithField("mcp_tool", name).Error("mcp tool call failed") + return "", fmt.Errorf("call tool %s: %w", name, err) + } + + var callResult callToolResult + if err := json.Unmarshal(result, &callResult); err != nil { + return "", fmt.Errorf("unmarshal tool result: %w", err) + } + + // Collect text content from result + var texts []string + for _, c := range callResult.Content { + if c.Type == "text" && c.Text != "" { + texts = append(texts, c.Text) + } + } + + text := strings.Join(texts, "\n") + + if callResult.IsError { + c.logger.WithFields(logrus.Fields{ + "mcp_tool": name, + "mcp_error": text, + }).Error("mcp tool returned error") + return "", fmt.Errorf("mcp tool error: %s", text) + } + + c.logger.WithFields(logrus.Fields{ + "mcp_tool": name, + "mcp_result_len": len(text), + }).Info("mcp tool call succeeded") + + return text, nil +} + +// GetAnthropicTools returns cached MCP tools converted to Anthropic tool format. +// If the cache is stale, it attempts a background refresh. +func (c *Client) GetAnthropicTools(ctx context.Context) []ai.Tool { + tools, fresh := c.cache.get() + + c.logger.WithFields(logrus.Fields{ + "mcp_cache_count": len(tools), + "mcp_cache_fresh": fresh, + }).Debug("mcp GetAnthropicTools called") + + if !fresh && tools != nil { + c.logger.Debug("mcp cache stale, starting background refresh") + go func() { + refreshCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _, _ = c.ListTools(refreshCtx) + }() + } + + if tools == nil { + c.logger.Warn("mcp cache empty, no tools available") + return nil + } + + result := make([]ai.Tool, len(tools)) + for i, t := range tools { + result[i] = ai.Tool{ + Name: t.Name, + Description: t.Description, + InputSchema: t.InputSchema, + } + } + return result +} + +// ToolNames returns the names of all cached MCP tools. +func (c *Client) ToolNames() []string { + tools, _ := c.cache.get() + names := make([]string, len(tools)) + for i, t := range tools { + names[i] = t.Name + } + return names +} + +// ToolDescriptions returns a formatted string describing all MCP tools for the system prompt. +func (c *Client) ToolDescriptions() string { + tools, _ := c.cache.get() + if len(tools) == 0 { + c.logger.Warn("mcp ToolDescriptions called but cache is empty") + return "" + } + + var b strings.Builder + b.WriteString("\n\n## Vultisig Tools\n\n") + b.WriteString("You have the following tools provided by the Vultisig platform. ") + b.WriteString("These are core capabilities you MUST include when listing your available tools or describing what you can do:\n\n") + for _, t := range tools { + b.WriteString("- **") + b.WriteString(t.Name) + b.WriteString("**") + if t.Description != "" { + b.WriteString(": ") + b.WriteString(t.Description) + } + b.WriteString("\n") + } + + desc := b.String() + c.logger.WithField("mcp_desc_len", len(desc)).Debug("mcp ToolDescriptions generated") + return desc +} diff --git a/internal/service/agent/agent.go b/internal/service/agent/agent.go index 82b7533..0654b36 100644 --- a/internal/service/agent/agent.go +++ b/internal/service/agent/agent.go @@ -29,6 +29,14 @@ type PluginSkillsProvider interface { GetSkills(ctx context.Context) []PluginSkill } +// MCPToolProvider provides tools discovered from an MCP server. +type MCPToolProvider interface { + GetAnthropicTools(ctx context.Context) []ai.Tool + ToolNames() []string + CallTool(ctx context.Context, name string, arguments json.RawMessage) (string, error) + ToolDescriptions() string +} + type SwapTxBuilder interface { BuildSwapTx(ctx context.Context, req SwapTxBuildRequest) (*SwapTxBuildResponse, error) } @@ -65,6 +73,7 @@ type AgentService struct { redis *redis.Client verifier *verifier.Client pluginProvider PluginSkillsProvider + mcpProvider MCPToolProvider swapTxBuilder SwapTxBuilder logger *logrus.Logger summaryModel string @@ -87,6 +96,7 @@ func NewAgentService( redisClient *redis.Client, verifierClient *verifier.Client, pluginProvider PluginSkillsProvider, + mcpProvider MCPToolProvider, swapTxBuilder SwapTxBuilder, logger *logrus.Logger, summaryModel string, @@ -100,6 +110,7 @@ func NewAgentService( redis: redisClient, verifier: verifierClient, pluginProvider: pluginProvider, + mcpProvider: mcpProvider, swapTxBuilder: swapTxBuilder, logger: logger, summaryModel: summaryModel, @@ -113,7 +124,7 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub ctx, cancel := context.WithTimeout(ctx, claudeRequestTimeout) defer cancel() - _, err := s.convRepo.GetByID(ctx, convID, publicKey) + conv, err := s.convRepo.GetByID(ctx, convID, publicKey) if err != nil { if errors.Is(err, postgres.ErrNotFound) { return nil, fmt.Errorf("conversation not found") @@ -121,6 +132,16 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub return nil, fmt.Errorf("get conversation: %w", err) } + // Prime MCP session with vault info if set + if conv.VaultInfo != nil && s.mcpProvider != nil { + vaultInput, _ := json.Marshal(conv.VaultInfo) + if _, err := s.mcpProvider.CallTool(ctx, "set_vault_info", vaultInput); err != nil { + s.logger.WithError(err).Warn("failed to prime mcp session with vault info") + } else { + s.logger.Debug("mcp session primed with vault info") + } + } + window, err := s.getConversationWindow(ctx, convID, publicKey) if err != nil { return nil, fmt.Errorf("get conversation window: %w", err) @@ -141,6 +162,15 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub fullCtx := s.resolveContext(ctx, convID, req.Context) basePrompt := BuildFullPrompt(fullCtx, s.getPluginSkills(ctx)) + if conv.VaultInfo != nil { + basePrompt += BuildVaultInfoSection(conv.VaultInfo) + } + if s.mcpProvider != nil { + mcpDesc := s.mcpProvider.ToolDescriptions() + if mcpDesc != "" { + basePrompt += mcpDesc + } + } systemPrompt := BuildSystemPromptWithSummary( basePrompt+s.loadMemorySection(ctx, req.PublicKey)+MemoryManagementInstructions, window.summary, @@ -154,6 +184,22 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub tools := agentTools() tools = append(tools, s.memoryTools()...) + if s.mcpProvider != nil { + mcpTools := s.mcpProvider.GetAnthropicTools(ctx) + if len(mcpTools) > 0 { + mcpNames := make([]string, len(mcpTools)) + for i, t := range mcpTools { + mcpNames[i] = t.Name + } + s.logger.WithFields(logrus.Fields{ + "mcp_tool_count": len(mcpTools), + "mcp_tool_names": mcpNames, + }).Debug("appending mcp tools to ai request") + } else { + s.logger.Warn("mcp provider active but no tools returned") + } + tools = append(tools, mcpTools...) + } var toolResp *ToolResponse var textContent string @@ -207,7 +253,7 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub continue } - result, err := s.executeTool(ctx, tc.Name, tc.Input, req) + result, err := s.executeTool(ctx, convID, tc.Name, tc.Input, req) if err != nil { result = jsonError(err.Error()) } @@ -258,12 +304,20 @@ func (s *AgentService) ProcessMessageStream(ctx context.Context, convID uuid.UUI ctx, cancel := context.WithTimeout(ctx, claudeRequestTimeout) defer cancel() - _, err := s.convRepo.GetByID(ctx, convID, publicKey) + conv, err := s.convRepo.GetByID(ctx, convID, publicKey) if err != nil { sendErr("conversation not found") return } + // Prime MCP session with vault info if set + if conv.VaultInfo != nil && s.mcpProvider != nil { + vaultInput, _ := json.Marshal(conv.VaultInfo) + if _, err := s.mcpProvider.CallTool(ctx, "set_vault_info", vaultInput); err != nil { + s.logger.WithError(err).Warn("failed to prime mcp session with vault info") + } + } + window, err := s.getConversationWindow(ctx, convID, publicKey) if err != nil { sendErr("failed to load conversation") @@ -287,6 +341,15 @@ func (s *AgentService) ProcessMessageStream(ctx context.Context, convID uuid.UUI fullCtx := s.resolveContext(ctx, convID, req.Context) basePrompt := BuildFullPrompt(fullCtx, s.getPluginSkills(ctx)) + if conv.VaultInfo != nil { + basePrompt += BuildVaultInfoSection(conv.VaultInfo) + } + if s.mcpProvider != nil { + mcpDesc := s.mcpProvider.ToolDescriptions() + if mcpDesc != "" { + basePrompt += mcpDesc + } + } systemPrompt := BuildSystemPromptWithSummary( basePrompt+s.loadMemorySection(ctx, req.PublicKey)+MemoryManagementInstructions, window.summary, @@ -300,6 +363,12 @@ func (s *AgentService) ProcessMessageStream(ctx context.Context, convID uuid.UUI tools := agentTools() tools = append(tools, s.memoryTools()...) + if s.mcpProvider != nil { + mcpTools := s.mcpProvider.GetAnthropicTools(ctx) + if len(mcpTools) > 0 { + tools = append(tools, mcpTools...) + } + } var toolResp *ToolResponse var textContent string @@ -382,7 +451,7 @@ func (s *AgentService) ProcessMessageStream(ctx context.Context, convID uuid.UUI continue } - result, err := s.executeTool(ctx, tc.Name, tc.Input, req) + result, err := s.executeTool(ctx, convID, tc.Name, tc.Input, req) if err != nil { result = jsonError(err.Error()) } diff --git a/internal/service/agent/executor.go b/internal/service/agent/executor.go index 15a094c..6706eaf 100644 --- a/internal/service/agent/executor.go +++ b/internal/service/agent/executor.go @@ -6,13 +6,14 @@ import ( "time" "github.com/google/uuid" + "github.com/sirupsen/logrus" ) const suggestionTTL = 1 * time.Hour // executeTool dispatches a tool call to the appropriate handler. // Returns a JSON string result for Claude. Errors are returned as JSON {"error": "..."} so the LLM can communicate them naturally. -func (s *AgentService) executeTool(ctx context.Context, name string, input json.RawMessage, req *SendMessageRequest) (string, error) { +func (s *AgentService) executeTool(ctx context.Context, convID uuid.UUID, name string, input json.RawMessage, req *SendMessageRequest) (string, error) { switch name { case "check_plugin_installed": return s.execCheckPluginInstalled(ctx, input, req) @@ -26,11 +27,77 @@ func (s *AgentService) executeTool(ctx context.Context, name string, input json. return s.execCreateSuggestion(ctx, input) case "update_memory": return s.execUpdateMemory(ctx, input, req) + case "set_vault": + return s.execSetVault(ctx, convID, input, req) default: + // Check MCP tools before returning unknown + if s.mcpProvider != nil { + for _, mcpName := range s.mcpProvider.ToolNames() { + if mcpName == name { + result, err := s.mcpProvider.CallTool(ctx, name, input) + if err != nil { + s.logger.WithError(err).WithField("tool", name).Warn("mcp tool call failed") + return jsonError("mcp tool error: " + err.Error()), nil + } + return result, nil + } + } + } return jsonError("unknown tool: " + name), nil } } +// execSetVault stores vault keys for this conversation and primes the MCP session. +func (s *AgentService) execSetVault(ctx context.Context, convID uuid.UUID, input json.RawMessage, req *SendMessageRequest) (string, error) { + var params struct { + ECDSAPublicKey string `json:"ecdsa_public_key"` + EDDSAPublicKey string `json:"eddsa_public_key"` + ChaincodeHex string `json:"chaincode_hex"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return jsonError("invalid input: " + err.Error()), nil + } + + if params.ECDSAPublicKey == "" || params.EDDSAPublicKey == "" || params.ChaincodeHex == "" { + return jsonError("all three fields are required: ecdsa_public_key, eddsa_public_key, chaincode_hex"), nil + } + + // Store in DB + if err := s.convRepo.UpdateVaultInfo(ctx, convID, req.PublicKey, params.ECDSAPublicKey, params.EDDSAPublicKey, params.ChaincodeHex); err != nil { + s.logger.WithError(err).Error("failed to store vault info") + return jsonError("failed to store vault info: " + err.Error()), nil + } + + s.logger.WithFields(logrus.Fields{ + "conversation_id": convID, + "ecdsa_prefix": truncateKey(params.ECDSAPublicKey), + "eddsa_prefix": truncateKey(params.EDDSAPublicKey), + }).Info("vault info set for conversation") + + // Prime MCP session with vault info + if s.mcpProvider != nil { + mcpInput, _ := json.Marshal(params) + if _, err := s.mcpProvider.CallTool(ctx, "set_vault_info", mcpInput); err != nil { + s.logger.WithError(err).Warn("failed to prime mcp session with vault info") + // Non-fatal: vault is stored locally, MCP will be primed on next request + } + } + + result, _ := json.Marshal(map[string]any{ + "ok": true, + "message": "Vault set for this conversation.", + }) + return string(result), nil +} + +// truncateKey returns the first 12 chars of a key for logging. +func truncateKey(key string) string { + if len(key) <= 12 { + return key + } + return key[:12] + "..." +} + // execCheckPluginInstalled checks if a plugin is installed for the user's vault. func (s *AgentService) execCheckPluginInstalled(ctx context.Context, input json.RawMessage, req *SendMessageRequest) (string, error) { var params struct { diff --git a/internal/service/agent/prompt.go b/internal/service/agent/prompt.go index b5c694b..d15ca02 100644 --- a/internal/service/agent/prompt.go +++ b/internal/service/agent/prompt.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/vultisig/agent-backend/internal/ai" + "github.com/vultisig/agent-backend/internal/types" ) // ActionsTable is the shared actions reference used by both the system prompt and starters. @@ -512,6 +513,19 @@ func BuildFullPrompt(msgCtx *MessageContext, plugins []PluginSkill) string { return sb.String() } +// BuildVaultInfoSection returns a system prompt section describing the active vault. +func BuildVaultInfoSection(v *types.VaultInfo) string { + if v == nil { + return "" + } + return "\n\n## Active Vault\n\n" + + "This conversation has a vault bound to it. The MCP server has been primed with these keys " + + "so tools like get_eth_balance and get_token_balance will derive addresses automatically.\n\n" + + "- ECDSA public key: `" + v.ECDSAPublicKey + "`\n" + + "- EdDSA public key: `" + v.EDDSAPublicKey + "`\n" + + "- Chaincode: `" + v.ChaincodeHex + "`\n" +} + // UpdateMemoryTool is the tool definition for updating the user's memory document. var UpdateMemoryTool = ai.Tool{ Name: "update_memory", diff --git a/internal/service/agent/tools.go b/internal/service/agent/tools.go index ef864a3..9487630 100644 --- a/internal/service/agent/tools.go +++ b/internal/service/agent/tools.go @@ -102,6 +102,34 @@ var CheckBillingStatusTool = ai.Tool{ }, } +// SetVaultTool sets the active vault for this conversation. +var SetVaultTool = ai.Tool{ + Name: "set_vault", + Description: "Set the active vault for this conversation. " + + "Call this when the user provides their vault's public keys and chaincode. " + + "This binds the vault to the conversation so tools like get_eth_balance " + + "and get_token_balance can derive addresses automatically. " + + "The user may switch vaults during a conversation by calling this again.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "ecdsa_public_key": map[string]any{ + "type": "string", + "description": "The vault's ECDSA (secp256k1) public key in hex.", + }, + "eddsa_public_key": map[string]any{ + "type": "string", + "description": "The vault's EdDSA (ed25519) public key in hex.", + }, + "chaincode_hex": map[string]any{ + "type": "string", + "description": "The vault's chaincode in hex, used for key derivation.", + }, + }, + "required": []string{"ecdsa_public_key", "eddsa_public_key", "chaincode_hex"}, + }, +} + // agentTools returns all granular tools for the decision loop. func agentTools() []ai.Tool { return []ai.Tool{ @@ -111,5 +139,6 @@ func agentTools() []ai.Tool { GetRecipeSchemaTool, SuggestPolicyTool, CreateSuggestionTool, + SetVaultTool, } } diff --git a/internal/storage/postgres/conversation.go b/internal/storage/postgres/conversation.go index 367baf1..b462709 100644 --- a/internal/storage/postgres/conversation.go +++ b/internal/storage/postgres/conversation.go @@ -135,6 +135,46 @@ func (r *ConversationRepository) UpdateSummaryWithCursor(ctx context.Context, id return nil } +// UpdateVaultInfo updates the vault keys for a conversation. +func (r *ConversationRepository) UpdateVaultInfo(ctx context.Context, id uuid.UUID, publicKey string, ecdsa, eddsa, chaincode string) error { + rowsAffected, err := r.q.UpdateVaultInfo(ctx, &queries.UpdateVaultInfoParams{ + EcdsaPublicKey: stringPtrToPgtext(&ecdsa), + EddsaPublicKey: stringPtrToPgtext(&eddsa), + ChaincodeHex: stringPtrToPgtext(&chaincode), + ID: uuidToPgtype(id), + PublicKey: publicKey, + }) + if err != nil { + return fmt.Errorf("update vault info: %w", err) + } + if rowsAffected == 0 { + return ErrNotFound + } + return nil +} + +// GetVaultInfo returns the vault keys for a conversation. +func (r *ConversationRepository) GetVaultInfo(ctx context.Context, id uuid.UUID, publicKey string) (*types.VaultInfo, error) { + row, err := r.q.GetVaultInfo(ctx, &queries.GetVaultInfoParams{ + ID: uuidToPgtype(id), + PublicKey: publicKey, + }) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("get vault info: %w", err) + } + if !row.EcdsaPublicKey.Valid || !row.EddsaPublicKey.Valid || !row.ChaincodeHex.Valid { + return nil, nil + } + return &types.VaultInfo{ + ECDSAPublicKey: row.EcdsaPublicKey.String, + EDDSAPublicKey: row.EddsaPublicKey.String, + ChaincodeHex: row.ChaincodeHex.String, + }, nil +} + // GetSummaryWithCursor returns the summary and summary_up_to cursor of a conversation. func (r *ConversationRepository) GetSummaryWithCursor(ctx context.Context, id uuid.UUID, publicKey string) (*string, *time.Time, error) { row, err := r.q.GetConversationSummaryWithCursor(ctx, &queries.GetConversationSummaryWithCursorParams{ diff --git a/internal/storage/postgres/convert.go b/internal/storage/postgres/convert.go index 282126c..ee02afa 100644 --- a/internal/storage/postgres/convert.go +++ b/internal/storage/postgres/convert.go @@ -72,7 +72,7 @@ func conversationFromDB(c *queries.AgentConversation) *types.Conversation { if c == nil { return nil } - return &types.Conversation{ + conv := &types.Conversation{ ID: pgtypeToUUID(c.ID), PublicKey: c.PublicKey, Title: pgtextToStringPtr(c.Title), @@ -82,6 +82,14 @@ func conversationFromDB(c *queries.AgentConversation) *types.Conversation { UpdatedAt: pgtimestamptzToTime(c.UpdatedAt), ArchivedAt: pgtimestamptzToTimePtr(c.ArchivedAt), } + if c.EcdsaPublicKey.Valid && c.EddsaPublicKey.Valid && c.ChaincodeHex.Valid { + conv.VaultInfo = &types.VaultInfo{ + ECDSAPublicKey: c.EcdsaPublicKey.String, + EDDSAPublicKey: c.EddsaPublicKey.String, + ChaincodeHex: c.ChaincodeHex.String, + } + } + return conv } func conversationsFromDB(cs []*queries.AgentConversation) []types.Conversation { diff --git a/internal/storage/postgres/migrations/20260221000001_add_vault_info.sql b/internal/storage/postgres/migrations/20260221000001_add_vault_info.sql new file mode 100644 index 0000000..d84e2ef --- /dev/null +++ b/internal/storage/postgres/migrations/20260221000001_add_vault_info.sql @@ -0,0 +1,11 @@ +-- +goose Up +ALTER TABLE agent_conversations + ADD COLUMN ecdsa_public_key TEXT, + ADD COLUMN eddsa_public_key TEXT, + ADD COLUMN chaincode_hex TEXT; + +-- +goose Down +ALTER TABLE agent_conversations + DROP COLUMN IF EXISTS ecdsa_public_key, + DROP COLUMN IF EXISTS eddsa_public_key, + DROP COLUMN IF EXISTS chaincode_hex; diff --git a/internal/storage/postgres/queries/conversations.sql.go b/internal/storage/postgres/queries/conversations.sql.go index 819c0a1..423a5af 100644 --- a/internal/storage/postgres/queries/conversations.sql.go +++ b/internal/storage/postgres/queries/conversations.sql.go @@ -46,7 +46,7 @@ const createConversation = `-- name: CreateConversation :one INSERT INTO agent_conversations (public_key) VALUES ($1) -RETURNING id, public_key, title, summary, summary_up_to, created_at, updated_at, archived_at +RETURNING id, public_key, title, summary, summary_up_to, ecdsa_public_key, eddsa_public_key, chaincode_hex, created_at, updated_at, archived_at ` // Conversations table queries @@ -59,6 +59,9 @@ func (q *Queries) CreateConversation(ctx context.Context, publicKey string) (*Ag &i.Title, &i.Summary, &i.SummaryUpTo, + &i.EcdsaPublicKey, + &i.EddsaPublicKey, + &i.ChaincodeHex, &i.CreatedAt, &i.UpdatedAt, &i.ArchivedAt, @@ -67,7 +70,7 @@ func (q *Queries) CreateConversation(ctx context.Context, publicKey string) (*Ag } const getConversationByID = `-- name: GetConversationByID :one -SELECT id, public_key, title, summary, summary_up_to, created_at, updated_at, archived_at FROM agent_conversations +SELECT id, public_key, title, summary, summary_up_to, ecdsa_public_key, eddsa_public_key, chaincode_hex, created_at, updated_at, archived_at FROM agent_conversations WHERE id = $1 AND public_key = $2 AND archived_at IS NULL ` @@ -85,6 +88,9 @@ func (q *Queries) GetConversationByID(ctx context.Context, arg *GetConversationB &i.Title, &i.Summary, &i.SummaryUpTo, + &i.EcdsaPublicKey, + &i.EddsaPublicKey, + &i.ChaincodeHex, &i.CreatedAt, &i.UpdatedAt, &i.ArchivedAt, @@ -114,8 +120,31 @@ func (q *Queries) GetConversationSummaryWithCursor(ctx context.Context, arg *Get return &i, err } +const getVaultInfo = `-- name: GetVaultInfo :one +SELECT ecdsa_public_key, eddsa_public_key, chaincode_hex FROM agent_conversations +WHERE id = $1 AND public_key = $2 AND archived_at IS NULL +` + +type GetVaultInfoParams struct { + ID pgtype.UUID `json:"id"` + PublicKey string `json:"public_key"` +} + +type GetVaultInfoRow struct { + EcdsaPublicKey pgtype.Text `json:"ecdsa_public_key"` + EddsaPublicKey pgtype.Text `json:"eddsa_public_key"` + ChaincodeHex pgtype.Text `json:"chaincode_hex"` +} + +func (q *Queries) GetVaultInfo(ctx context.Context, arg *GetVaultInfoParams) (*GetVaultInfoRow, error) { + row := q.db.QueryRow(ctx, getVaultInfo, arg.ID, arg.PublicKey) + var i GetVaultInfoRow + err := row.Scan(&i.EcdsaPublicKey, &i.EddsaPublicKey, &i.ChaincodeHex) + return &i, err +} + const listConversations = `-- name: ListConversations :many -SELECT id, public_key, title, summary, summary_up_to, created_at, updated_at, archived_at FROM agent_conversations +SELECT id, public_key, title, summary, summary_up_to, ecdsa_public_key, eddsa_public_key, chaincode_hex, created_at, updated_at, archived_at FROM agent_conversations WHERE public_key = $1 AND archived_at IS NULL ORDER BY updated_at DESC LIMIT $2 OFFSET $3 @@ -142,6 +171,9 @@ func (q *Queries) ListConversations(ctx context.Context, arg *ListConversationsP &i.Title, &i.Summary, &i.SummaryUpTo, + &i.EcdsaPublicKey, + &i.EddsaPublicKey, + &i.ChaincodeHex, &i.CreatedAt, &i.UpdatedAt, &i.ArchivedAt, @@ -201,3 +233,31 @@ func (q *Queries) UpdateConversationTitle(ctx context.Context, arg *UpdateConver } return result.RowsAffected(), nil } + +const updateVaultInfo = `-- name: UpdateVaultInfo :execrows +UPDATE agent_conversations +SET ecdsa_public_key = $1, eddsa_public_key = $2, chaincode_hex = $3, updated_at = NOW() +WHERE id = $4 AND public_key = $5 AND archived_at IS NULL +` + +type UpdateVaultInfoParams struct { + EcdsaPublicKey pgtype.Text `json:"ecdsa_public_key"` + EddsaPublicKey pgtype.Text `json:"eddsa_public_key"` + ChaincodeHex pgtype.Text `json:"chaincode_hex"` + ID pgtype.UUID `json:"id"` + PublicKey string `json:"public_key"` +} + +func (q *Queries) UpdateVaultInfo(ctx context.Context, arg *UpdateVaultInfoParams) (int64, error) { + result, err := q.db.Exec(ctx, updateVaultInfo, + arg.EcdsaPublicKey, + arg.EddsaPublicKey, + arg.ChaincodeHex, + arg.ID, + arg.PublicKey, + ) + if err != nil { + return 0, err + } + return result.RowsAffected(), nil +} diff --git a/internal/storage/postgres/queries/models.go b/internal/storage/postgres/queries/models.go index abfb866..00e4ab1 100644 --- a/internal/storage/postgres/queries/models.go +++ b/internal/storage/postgres/queries/models.go @@ -55,14 +55,17 @@ func (ns NullAgentMessageRole) Value() (driver.Value, error) { } type AgentConversation struct { - ID pgtype.UUID `json:"id"` - PublicKey string `json:"public_key"` - Title pgtype.Text `json:"title"` - Summary pgtype.Text `json:"summary"` - SummaryUpTo pgtype.Timestamptz `json:"summary_up_to"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` - ArchivedAt pgtype.Timestamptz `json:"archived_at"` + ID pgtype.UUID `json:"id"` + PublicKey string `json:"public_key"` + Title pgtype.Text `json:"title"` + Summary pgtype.Text `json:"summary"` + SummaryUpTo pgtype.Timestamptz `json:"summary_up_to"` + EcdsaPublicKey pgtype.Text `json:"ecdsa_public_key"` + EddsaPublicKey pgtype.Text `json:"eddsa_public_key"` + ChaincodeHex pgtype.Text `json:"chaincode_hex"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` + ArchivedAt pgtype.Timestamptz `json:"archived_at"` } type AgentMessage struct { diff --git a/internal/storage/postgres/schema/schema.sql b/internal/storage/postgres/schema/schema.sql index ece039b..e0fadbb 100644 --- a/internal/storage/postgres/schema/schema.sql +++ b/internal/storage/postgres/schema/schema.sql @@ -9,6 +9,9 @@ CREATE TABLE agent_conversations ( title TEXT, summary TEXT, summary_up_to TIMESTAMPTZ, + ecdsa_public_key TEXT, + eddsa_public_key TEXT, + chaincode_hex TEXT, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), archived_at TIMESTAMPTZ diff --git a/internal/storage/postgres/sqlc/conversations.sql b/internal/storage/postgres/sqlc/conversations.sql index d6c3c9c..32b9473 100644 --- a/internal/storage/postgres/sqlc/conversations.sql +++ b/internal/storage/postgres/sqlc/conversations.sql @@ -37,3 +37,12 @@ WHERE id = $3 AND public_key = $4; -- name: GetConversationSummaryWithCursor :one SELECT summary, summary_up_to FROM agent_conversations WHERE id = $1 AND public_key = $2; + +-- name: UpdateVaultInfo :execrows +UPDATE agent_conversations +SET ecdsa_public_key = $1, eddsa_public_key = $2, chaincode_hex = $3, updated_at = NOW() +WHERE id = $4 AND public_key = $5 AND archived_at IS NULL; + +-- name: GetVaultInfo :one +SELECT ecdsa_public_key, eddsa_public_key, chaincode_hex FROM agent_conversations +WHERE id = $1 AND public_key = $2 AND archived_at IS NULL; diff --git a/internal/types/conversation.go b/internal/types/conversation.go index 08306dd..cfc62b7 100644 --- a/internal/types/conversation.go +++ b/internal/types/conversation.go @@ -23,11 +23,19 @@ type Conversation struct { Title *string `json:"title"` Summary *string `json:"summary,omitempty"` SummaryUpTo *time.Time `json:"summary_up_to,omitempty"` + VaultInfo *VaultInfo `json:"vault_info,omitempty"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` ArchivedAt *time.Time `json:"archived_at,omitempty"` } +// VaultInfo holds the cryptographic keys for a vault bound to a conversation. +type VaultInfo struct { + ECDSAPublicKey string `json:"ecdsa_public_key"` + EDDSAPublicKey string `json:"eddsa_public_key"` + ChaincodeHex string `json:"chaincode_hex"` +} + // Message represents a single message in a conversation. type Message struct { ID uuid.UUID `json:"id"` From 3d2e589ce868fc5ee864a88c951bb3aef553e0fb Mon Sep 17 00:00:00 2001 From: Raghav Sood Date: Mon, 23 Feb 2026 04:12:08 +0800 Subject: [PATCH 2/7] add nix devenv, transaction extraction, and cleanup - Add nix flake dev environment with Go, PostgreSQL, Redis, sqlc - Add Transaction type and extract transactions from MCP tool results - Remove redundant public key auth checks (already handled by middleware) - Rename chaincode_hex to chain_code for consistency Co-Authored-By: Claude Opus 4.6 --- .gitignore | 1 + flake.lock | 306 +++++++++++++++++++++++++++++ flake.nix | 42 ++++ internal/api/conversation.go | 20 -- internal/service/agent/executor.go | 9 +- internal/service/agent/tools.go | 4 +- internal/service/agent/types.go | 13 ++ 7 files changed, 369 insertions(+), 26 deletions(-) create mode 100644 flake.lock create mode 100644 flake.nix diff --git a/.gitignore b/.gitignore index 1d4114b..ffe9094 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ vendor/ .env .env.local .env.*.local +.devenv # Config files (use examples) config.json diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..ebd3112 --- /dev/null +++ b/flake.lock @@ -0,0 +1,306 @@ +{ + "nodes": { + "cachix": { + "inputs": { + "devenv": [ + "devenv" + ], + "flake-compat": [ + "devenv", + "flake-compat" + ], + "git-hooks": [ + "devenv", + "git-hooks" + ], + "nixpkgs": [ + "devenv", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1760971495, + "narHash": "sha256-IwnNtbNVrlZIHh7h4Wz6VP0Furxg9Hh0ycighvL5cZc=", + "owner": "cachix", + "repo": "cachix", + "rev": "c5bfd933d1033672f51a863c47303fc0e093c2d2", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "latest", + "repo": "cachix", + "type": "github" + } + }, + "devenv": { + "inputs": { + "cachix": "cachix", + "flake-compat": "flake-compat", + "flake-parts": "flake-parts", + "git-hooks": "git-hooks", + "nix": "nix", + "nixd": "nixd", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1771419682, + "narHash": "sha256-NAemVgEJeZjGl3+438M4rUL8ms9QdDFMYthU12F70FQ=", + "owner": "cachix", + "repo": "devenv", + "rev": "f77fc4de35c184d9ef9a32d5d7e9033351bcdfdc", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "devenv", + "type": "github" + } + }, + "flake-compat": { + "flake": false, + "locked": { + "lastModified": 1761588595, + "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-parts": { + "inputs": { + "nixpkgs-lib": [ + "devenv", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1760948891, + "narHash": "sha256-TmWcdiUUaWk8J4lpjzu4gCGxWY6/Ok7mOK4fIFfBuU4=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "864599284fc7c0ba6357ed89ed5e2cd5040f0c04", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "flake-root": { + "locked": { + "lastModified": 1723604017, + "narHash": "sha256-rBtQ8gg+Dn4Sx/s+pvjdq3CB2wQNzx9XGFq/JVGCB6k=", + "owner": "srid", + "repo": "flake-root", + "rev": "b759a56851e10cb13f6b8e5698af7b59c44be26e", + "type": "github" + }, + "original": { + "owner": "srid", + "repo": "flake-root", + "type": "github" + } + }, + "git-hooks": { + "inputs": { + "flake-compat": [ + "devenv", + "flake-compat" + ], + "gitignore": "gitignore", + "nixpkgs": [ + "devenv", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1760663237, + "narHash": "sha256-BflA6U4AM1bzuRMR8QqzPXqh8sWVCNDzOdsxXEguJIc=", + "owner": "cachix", + "repo": "git-hooks.nix", + "rev": "ca5b894d3e3e151ffc1db040b6ce4dcc75d31c37", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "git-hooks.nix", + "type": "github" + } + }, + "gitignore": { + "inputs": { + "nixpkgs": [ + "devenv", + "git-hooks", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1709087332, + "narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=", + "owner": "hercules-ci", + "repo": "gitignore.nix", + "rev": "637db329424fd7e46cf4185293b9cc8c88c95394", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, + "nix": { + "inputs": { + "flake-compat": [ + "devenv", + "flake-compat" + ], + "flake-parts": [ + "devenv", + "flake-parts" + ], + "git-hooks-nix": [ + "devenv", + "git-hooks" + ], + "nixpkgs": [ + "devenv", + "nixpkgs" + ], + "nixpkgs-23-11": [ + "devenv" + ], + "nixpkgs-regression": [ + "devenv" + ] + }, + "locked": { + "lastModified": 1770395975, + "narHash": "sha256-zg0AEZn8d4rqIIsw5XrkVL5p1y6fBj2L57awfUg+gNA=", + "owner": "cachix", + "repo": "nix", + "rev": "ccb6019ce2bd11f5de5fe4617c0079d8cb1ed057", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "devenv-2.32", + "repo": "nix", + "type": "github" + } + }, + "nixd": { + "inputs": { + "flake-parts": [ + "devenv", + "flake-parts" + ], + "flake-root": "flake-root", + "nixpkgs": [ + "devenv", + "nixpkgs" + ], + "treefmt-nix": "treefmt-nix" + }, + "locked": { + "lastModified": 1763964548, + "narHash": "sha256-JTRoaEWvPsVIMFJWeS4G2isPo15wqXY/otsiHPN0zww=", + "owner": "nix-community", + "repo": "nixd", + "rev": "d4bf15e56540422e2acc7bc26b20b0a0934e3f5e", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "nixd", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1761313199, + "narHash": "sha256-wCIACXbNtXAlwvQUo1Ed++loFALPjYUA3dpcUJiXO44=", + "owner": "cachix", + "repo": "devenv-nixpkgs", + "rev": "d1c30452ebecfc55185ae6d1c983c09da0c274ff", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "rolling", + "repo": "devenv-nixpkgs", + "type": "github" + } + }, + "nixpkgs_2": { + "locked": { + "lastModified": 1771177547, + "narHash": "sha256-trTtk3WTOHz7hSw89xIIvahkgoFJYQ0G43IlqprFoMA=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "ac055f38c798b0d87695240c7b761b82fc7e5bc2", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixpkgs-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "devenv": "devenv", + "nixpkgs": "nixpkgs_2", + "systems": "systems" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "treefmt-nix": { + "inputs": { + "nixpkgs": [ + "devenv", + "nixd", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1734704479, + "narHash": "sha256-MMi74+WckoyEWBRcg/oaGRvXC9BVVxDZNRMpL+72wBI=", + "owner": "numtide", + "repo": "treefmt-nix", + "rev": "65712f5af67234dad91a5a4baee986a8b62dbf8f", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "treefmt-nix", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..502e258 --- /dev/null +++ b/flake.nix @@ -0,0 +1,42 @@ +{ + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; + systems.url = "github:nix-systems/default"; + devenv.url = "github:cachix/devenv"; + }; + + outputs = { self, nixpkgs, devenv, systems, ... } @ inputs: + let + forEachSystem = nixpkgs.lib.genAttrs (import systems); + in + { + devShells = forEachSystem + (system: + let + pkgs = nixpkgs.legacyPackages.${system}; + in + { + default = devenv.lib.mkShell { + inherit inputs pkgs; + modules = [ + { + languages.go = { + enable = true; + }; + + packages = with pkgs; [ + postgresql + go-ethereum + sqlc + redis + ]; + + enterShell = '' + echo "agent-backend shell started!" + ''; + } + ]; + }; + }); + }; +} diff --git a/internal/api/conversation.go b/internal/api/conversation.go index 7e28ff8..710a7bc 100644 --- a/internal/api/conversation.go +++ b/internal/api/conversation.go @@ -46,11 +46,6 @@ func (s *Server) CreateConversation(c echo.Context) error { return c.JSON(http.StatusBadRequest, ErrorResponse{Error: "invalid request body"}) } - authPublicKey := GetPublicKey(c) - if req.PublicKey != authPublicKey { - return c.JSON(http.StatusForbidden, ErrorResponse{Error: "public key mismatch"}) - } - conv, err := s.convRepo.Create(c.Request().Context(), req.PublicKey) if err != nil { s.logger.WithError(err).Error("failed to create conversation") @@ -67,11 +62,6 @@ func (s *Server) ListConversations(c echo.Context) error { return c.JSON(http.StatusBadRequest, ErrorResponse{Error: "invalid request body"}) } - authPublicKey := GetPublicKey(c) - if req.PublicKey != authPublicKey { - return c.JSON(http.StatusForbidden, ErrorResponse{Error: "public key mismatch"}) - } - // Default pagination if req.Take <= 0 { req.Take = 20 @@ -109,11 +99,6 @@ func (s *Server) GetConversation(c echo.Context) error { return c.JSON(http.StatusBadRequest, ErrorResponse{Error: "invalid request body"}) } - authPublicKey := GetPublicKey(c) - if req.PublicKey != authPublicKey { - return c.JSON(http.StatusForbidden, ErrorResponse{Error: "public key mismatch"}) - } - conv, err := s.convRepo.GetWithMessages(c.Request().Context(), id, req.PublicKey) if err != nil { if errors.Is(err, postgres.ErrNotFound) { @@ -143,11 +128,6 @@ func (s *Server) DeleteConversation(c echo.Context) error { return c.JSON(http.StatusBadRequest, ErrorResponse{Error: "invalid request body"}) } - authPublicKey := GetPublicKey(c) - if req.PublicKey != authPublicKey { - return c.JSON(http.StatusForbidden, ErrorResponse{Error: "public key mismatch"}) - } - err = s.convRepo.Archive(c.Request().Context(), id, req.PublicKey) if err != nil { if errors.Is(err, postgres.ErrNotFound) { diff --git a/internal/service/agent/executor.go b/internal/service/agent/executor.go index 6706eaf..bb077bb 100644 --- a/internal/service/agent/executor.go +++ b/internal/service/agent/executor.go @@ -52,18 +52,18 @@ func (s *AgentService) execSetVault(ctx context.Context, convID uuid.UUID, input var params struct { ECDSAPublicKey string `json:"ecdsa_public_key"` EDDSAPublicKey string `json:"eddsa_public_key"` - ChaincodeHex string `json:"chaincode_hex"` + ChainCode string `json:"chain_code"` } if err := json.Unmarshal(input, ¶ms); err != nil { return jsonError("invalid input: " + err.Error()), nil } - if params.ECDSAPublicKey == "" || params.EDDSAPublicKey == "" || params.ChaincodeHex == "" { - return jsonError("all three fields are required: ecdsa_public_key, eddsa_public_key, chaincode_hex"), nil + if params.ECDSAPublicKey == "" || params.EDDSAPublicKey == "" || params.ChainCode == "" { + return jsonError("all three fields are required: ecdsa_public_key, eddsa_public_key, chain_code"), nil } // Store in DB - if err := s.convRepo.UpdateVaultInfo(ctx, convID, req.PublicKey, params.ECDSAPublicKey, params.EDDSAPublicKey, params.ChaincodeHex); err != nil { + if err := s.convRepo.UpdateVaultInfo(ctx, convID, req.PublicKey, params.ECDSAPublicKey, params.EDDSAPublicKey, params.ChainCode); err != nil { s.logger.WithError(err).Error("failed to store vault info") return jsonError("failed to store vault info: " + err.Error()), nil } @@ -72,6 +72,7 @@ func (s *AgentService) execSetVault(ctx context.Context, convID uuid.UUID, input "conversation_id": convID, "ecdsa_prefix": truncateKey(params.ECDSAPublicKey), "eddsa_prefix": truncateKey(params.EDDSAPublicKey), + "chain_code_prefix": truncateKey(params.ChainCode), }).Info("vault info set for conversation") // Prime MCP session with vault info diff --git a/internal/service/agent/tools.go b/internal/service/agent/tools.go index 9487630..f5c132f 100644 --- a/internal/service/agent/tools.go +++ b/internal/service/agent/tools.go @@ -121,12 +121,12 @@ var SetVaultTool = ai.Tool{ "type": "string", "description": "The vault's EdDSA (ed25519) public key in hex.", }, - "chaincode_hex": map[string]any{ + "chain_code": map[string]any{ "type": "string", "description": "The vault's chaincode in hex, used for key derivation.", }, }, - "required": []string{"ecdsa_public_key", "eddsa_public_key", "chaincode_hex"}, + "required": []string{"ecdsa_public_key", "eddsa_public_key", "chain_code"}, }, } diff --git a/internal/service/agent/types.go b/internal/service/agent/types.go index 0b004ca..12bbf90 100644 --- a/internal/service/agent/types.go +++ b/internal/service/agent/types.go @@ -79,6 +79,19 @@ type SendMessageResponse struct { PolicyReady *PolicyReady `json:"policy_ready,omitempty"` InstallRequired *InstallRequired `json:"install_required,omitempty"` TxReady *TxReady `json:"tx_ready,omitempty"` + Transactions []Transaction `json:"transactions,omitempty"` +} + +// Transaction represents an unsigned transaction returned by an MCP tool +// that the wallet must sign and broadcast. +type Transaction struct { + Sequence int `json:"sequence"` + Chain string `json:"chain"` + ChainID string `json:"chain_id"` + Action string `json:"action"` + SigningMode string `json:"signing_mode"` + UnsignedTxHex string `json:"unsigned_tx_hex"` + TxDetails map[string]string `json:"tx_details"` } type TxReady struct { From 1b7e4f2d363c163709116410feaff6534e7407f7 Mon Sep 17 00:00:00 2001 From: Raghav Sood Date: Mon, 23 Feb 2026 05:17:58 +0800 Subject: [PATCH 3/7] add find_token MCP tool support with structured token data passthrough Extract token search results from the find_token MCP tool and pass them as structured data in the API response so frontend apps can prompt users to add tokens/chains to their vault. Co-Authored-By: Claude Opus 4.6 --- internal/service/agent/agent.go | 23 +++++++++++++++++-- internal/service/agent/types.go | 39 ++++++++++++++++++++++++++------- 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/internal/service/agent/agent.go b/internal/service/agent/agent.go index 0654b36..6c08b51 100644 --- a/internal/service/agent/agent.go +++ b/internal/service/agent/agent.go @@ -203,6 +203,7 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub var toolResp *ToolResponse var textContent string + var tokens *TokenSearchResult for i := 0; i < maxLoopIterations; i++ { aiReq := &ai.Request{ @@ -268,6 +269,14 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub ToolUseID: tc.ID, Content: result, }) + + // Track find_token results for structured passthrough + if tc.Name == "find_token" { + var tokenResult TokenSearchResult + if err := json.Unmarshal([]byte(result), &tokenResult); err == nil && len(tokenResult.Tokens) > 0 { + tokens = &tokenResult + } + } } messages = append(messages, ai.AssistantMessage{ @@ -285,10 +294,20 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub } if toolResp != nil { - return s.buildLoopResponse(ctx, convID, req, toolResp, window) + resp, err := s.buildLoopResponse(ctx, convID, req, toolResp, window) + if err != nil { + return nil, err + } + resp.Tokens = tokens + return resp, nil } if textContent != "" { - return s.buildTextResponse(ctx, convID, textContent) + resp, err := s.buildTextResponse(ctx, convID, textContent) + if err != nil { + return nil, err + } + resp.Tokens = tokens + return resp, nil } return nil, errors.New("no response content from Claude") diff --git a/internal/service/agent/types.go b/internal/service/agent/types.go index 12bbf90..b7ed301 100644 --- a/internal/service/agent/types.go +++ b/internal/service/agent/types.go @@ -72,14 +72,15 @@ type ActionResult struct { // SendMessageResponse is the response for sending a message. type SendMessageResponse struct { - Message types.Message `json:"message"` - Title *string `json:"title,omitempty"` - Suggestions []Suggestion `json:"suggestions,omitempty"` - Actions []Action `json:"actions,omitempty"` - PolicyReady *PolicyReady `json:"policy_ready,omitempty"` - InstallRequired *InstallRequired `json:"install_required,omitempty"` - TxReady *TxReady `json:"tx_ready,omitempty"` - Transactions []Transaction `json:"transactions,omitempty"` + Message types.Message `json:"message"` + Title *string `json:"title,omitempty"` + Suggestions []Suggestion `json:"suggestions,omitempty"` + Actions []Action `json:"actions,omitempty"` + PolicyReady *PolicyReady `json:"policy_ready,omitempty"` + InstallRequired *InstallRequired `json:"install_required,omitempty"` + TxReady *TxReady `json:"tx_ready,omitempty"` + Transactions []Transaction `json:"transactions,omitempty"` + Tokens *TokenSearchResult `json:"tokens,omitempty"` } // Transaction represents an unsigned transaction returned by an MCP tool @@ -209,3 +210,25 @@ type ToolAction struct { Params map[string]any `json:"params,omitempty"` AutoExecute bool `json:"auto_execute"` } + +// TokenSearchResult contains tokens returned by the find_token MCP tool. +type TokenSearchResult struct { + Tokens []Token `json:"tokens"` +} + +// Token represents a cryptocurrency token with its on-chain deployments. +type Token struct { + ID string `json:"id"` + Name string `json:"name"` + Symbol string `json:"symbol"` + MarketCapRank int `json:"market_cap_rank"` + Logo string `json:"logo"` + Deployments []TokenDeployment `json:"deployments"` +} + +// TokenDeployment represents a token's deployment on a specific chain. +type TokenDeployment struct { + Chain string `json:"chain"` + ContractAddress string `json:"contract_address"` + Decimals int `json:"decimals"` +} From 446c20d67bc232ac1655c6a000ead3d8d74f60b6 Mon Sep 17 00:00:00 2001 From: Raghav Sood Date: Mon, 23 Feb 2026 06:29:35 +0800 Subject: [PATCH 4/7] fix find_token structured data extraction from MCP results MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two issues prevented the tokens field from appearing in API responses: 1. CallTool discarded text on IsError — when the MCP tool set IsError: true, CallTool returned a Go error and the text content was lost. Now returns a ToolError that carries the text, so executeTool can still pass it to trackToolResult for structured extraction. 2. trackToolResult assumed pure JSON — MCP tools may return multiple text content blocks (joined with \n) or mix descriptive text with JSON. The direct json.Unmarshal failed silently. Now uses extractTokens() which tries direct unmarshal first, then scans for JSON objects in the text using json.Decoder (which handles trailing content). Also adds diagnostic logging when parsing fails so we can see the actual MCP result text in logs. Co-Authored-By: Claude Opus 4.6 --- internal/mcp/client.go | 18 ++++++++++-- internal/service/agent/agent.go | 46 ++++++++++++++++++++++++++++-- internal/service/agent/executor.go | 11 +++++++ 3 files changed, 69 insertions(+), 6 deletions(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 8e54991..81274cd 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -44,6 +44,17 @@ func (e *jsonRPCError) Error() string { // MCP-specific types +// ToolError is returned when an MCP tool sets IsError: true. +// It carries the tool's text content so callers can still parse structured data from it. +type ToolError struct { + ToolName string + Text string +} + +func (e *ToolError) Error() string { + return fmt.Sprintf("mcp tool %s error: %s", e.ToolName, e.Text) +} + // MCPTool represents a tool definition from the MCP server. type MCPTool struct { Name string `json:"name"` @@ -316,10 +327,11 @@ func (c *Client) CallTool(ctx context.Context, name string, arguments json.RawMe if callResult.IsError { c.logger.WithFields(logrus.Fields{ - "mcp_tool": name, - "mcp_error": text, + "mcp_tool": name, + "mcp_error": text, }).Error("mcp tool returned error") - return "", fmt.Errorf("mcp tool error: %s", text) + // Return the text with a ToolError so callers can still access the content. + return text, &ToolError{ToolName: name, Text: text} } c.logger.WithFields(logrus.Fields{ diff --git a/internal/service/agent/agent.go b/internal/service/agent/agent.go index 6c08b51..09a6bda 100644 --- a/internal/service/agent/agent.go +++ b/internal/service/agent/agent.go @@ -272,9 +272,11 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub // Track find_token results for structured passthrough if tc.Name == "find_token" { - var tokenResult TokenSearchResult - if err := json.Unmarshal([]byte(result), &tokenResult); err == nil && len(tokenResult.Tokens) > 0 { - tokens = &tokenResult + if parsed := extractTokens(result); parsed != nil { + tokens = parsed + s.logger.WithField("token_count", len(parsed.Tokens)).Info("tokens extracted from find_token result") + } else { + s.logger.WithField("result_preview", truncateResult(result, 200)).Warn("find_token result could not be parsed as token data") } } } @@ -784,6 +786,44 @@ func (s *AgentService) emitTextResponse(ctx context.Context, convID uuid.UUID, t eventCh <- SSEEvent{Event: "message", Data: MessagePayload{Message: *assistantMsg}} } +// extractTokens tries to parse a TokenSearchResult from an MCP tool result. +// MCP text content may not be pure JSON (e.g., multiple text blocks joined with \n, +// or descriptive text surrounding JSON), so we try multiple strategies. +func extractTokens(result string) *TokenSearchResult { + // Strategy 1: direct unmarshal (pure JSON) + var direct TokenSearchResult + if err := json.Unmarshal([]byte(result), &direct); err == nil && len(direct.Tokens) > 0 { + return &direct + } + + // Strategy 2: the result may contain non-JSON text around the JSON object. + // Scan for the first '{' and try to decode from there. json.Decoder + // stops after the first complete JSON value, ignoring trailing text. + for i := strings.IndexByte(result, '{'); i >= 0 && i < len(result); { + var candidate TokenSearchResult + dec := json.NewDecoder(strings.NewReader(result[i:])) + if err := dec.Decode(&candidate); err == nil && len(candidate.Tokens) > 0 { + return &candidate + } + // Try the next '{' occurrence + next := strings.IndexByte(result[i+1:], '{') + if next < 0 { + break + } + i = i + 1 + next + } + + return nil +} + +// truncateResult returns the first n bytes of a string for log previews. +func truncateResult(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +} + func (s *AgentService) autoContinueAfterInstall(ctx context.Context, convID uuid.UUID, req *SendMessageRequest, window *conversationWindow, resp *SendMessageResponse) { pendingKey := fmt.Sprintf("pending_build:%s", convID) suggID, err := s.redis.Get(ctx, pendingKey) diff --git a/internal/service/agent/executor.go b/internal/service/agent/executor.go index bb077bb..0686f89 100644 --- a/internal/service/agent/executor.go +++ b/internal/service/agent/executor.go @@ -3,10 +3,13 @@ package agent import ( "context" "encoding/json" + "errors" "time" "github.com/google/uuid" "github.com/sirupsen/logrus" + + "github.com/vultisig/agent-backend/internal/mcp" ) const suggestionTTL = 1 * time.Hour @@ -36,6 +39,14 @@ func (s *AgentService) executeTool(ctx context.Context, convID uuid.UUID, name s if mcpName == name { result, err := s.mcpProvider.CallTool(ctx, name, input) if err != nil { + // ToolError carries the text content — return it so + // trackToolResult can still extract structured data + // and Claude can narrate the error to the user. + var toolErr *mcp.ToolError + if errors.As(err, &toolErr) && toolErr.Text != "" { + s.logger.WithField("tool", name).Warn("mcp tool returned isError with content") + return toolErr.Text, nil + } s.logger.WithError(err).WithField("tool", name).Warn("mcp tool call failed") return jsonError("mcp tool error: " + err.Error()), nil } From 03e2e16de3f02a6640045916b72017394673ec08 Mon Sep 17 00:00:00 2001 From: Raghav Sood Date: Tue, 24 Feb 2026 03:51:49 +0800 Subject: [PATCH 5/7] add MCP skill discovery and on-demand loading Integrate MCP resources protocol (resources/list, resources/read) to discover and load skill guides from the MCP server. Skills are markdown documents at skills/{slug}.md that provide detailed workflow instructions. The skill list is injected into the system prompt so the LLM knows what's available, but skill content is only loaded on-demand via the new get_skill tool when relevant to the user's request. This keeps the context window lean as the skill library grows. - MCP client: add ListSkills, ReadSkill, SkillSummary with TTL caching - Agent: extend MCPToolProvider interface with skill methods - Agent: inject skill summary into system prompt after tool descriptions - Tools: add get_skill native tool (only registered when skills exist) - Executor: add get_skill handler delegating to MCP ReadSkill - Main: pre-warm skill cache at startup alongside tools Co-Authored-By: Claude Opus 4.6 --- cmd/server/main.go | 8 ++ internal/mcp/client.go | 205 ++++++++++++++++++++++++++++- internal/service/agent/agent.go | 22 +++- internal/service/agent/executor.go | 26 ++++ internal/service/agent/tools.go | 19 +++ 5 files changed, 273 insertions(+), 7 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index 82be637..d2b177a 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -92,6 +92,14 @@ func main() { logger.WithField("tool_count", len(tools)).Info("mcp tools loaded") mcpProvider = mcpClient } + + // Pre-warm skill cache (non-fatal) + skills, err := mcpClient.ListSkills(mcpCtx) + if err != nil { + logger.WithError(err).Warn("failed to list mcp skills, continuing without skills") + } else { + logger.WithField("skill_count", len(skills)).Info("mcp skills loaded") + } } } diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 81274cd..93d4e7a 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -102,14 +102,72 @@ func (tc *toolCache) set(tools []MCPTool) { tc.fetchedAt = time.Now() } +// MCP resource types (resources/list, resources/read) + +type resourceEntry struct { + URI string `json:"uri"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + MimeType string `json:"mimeType,omitempty"` +} + +type readResourceParams struct { + URI string `json:"uri"` +} + +type readResourceResult struct { + Contents []resourceContent `json:"contents"` +} + +type resourceContent struct { + URI string `json:"uri"` + MimeType string `json:"mimeType,omitempty"` + Text string `json:"text,omitempty"` +} + +// skillEntry is an MCP skill discovered via resources/list. +type skillEntry struct { + Slug string + Name string + Description string + URI string +} + +// skillCache holds cached skill metadata with a TTL. +type skillCache struct { + mu sync.RWMutex + skills []skillEntry + fetchedAt time.Time + ttl time.Duration +} + +func (sc *skillCache) get() ([]skillEntry, bool) { + sc.mu.RLock() + defer sc.mu.RUnlock() + if sc.skills == nil { + return nil, false + } + fresh := time.Since(sc.fetchedAt) < sc.ttl + return sc.skills, fresh +} + +func (sc *skillCache) set(skills []skillEntry) { + sc.mu.Lock() + defer sc.mu.Unlock() + sc.skills = skills + sc.fetchedAt = time.Now() +} + // Client is an MCP JSON-RPC 2.0 client using Streamable HTTP transport. type Client struct { - serverURL string - httpClient *http.Client - sessionID string - requestID atomic.Int64 - cache toolCache - logger *logrus.Logger + serverURL string + httpClient *http.Client + sessionID string + requestID atomic.Int64 + cache toolCache + skills skillCache + skillContent sync.Map // slug → string (cached skill markdown) + logger *logrus.Logger } // NewClient creates a new MCP client. @@ -120,6 +178,7 @@ func NewClient(serverURL string, cacheTTL time.Duration, logger *logrus.Logger) Timeout: 30 * time.Second, }, cache: toolCache{ttl: cacheTTL}, + skills: skillCache{ttl: cacheTTL}, logger: logger, } } @@ -414,3 +473,137 @@ func (c *Client) ToolDescriptions() string { c.logger.WithField("mcp_desc_len", len(desc)).Debug("mcp ToolDescriptions generated") return desc } + +// --------------------------------------------------------------------------- +// MCP Resources — skill discovery and loading +// --------------------------------------------------------------------------- + +// ListSkills fetches available skills from the MCP server via resources/list. +// Skills are resources with URIs matching "skills/*.md". +func (c *Client) ListSkills(ctx context.Context) ([]skillEntry, error) { + c.logger.Debug("mcp listing skills via resources/list") + + result, err := c.call(ctx, "resources/list", nil) + if err != nil { + if stale, _ := c.skills.get(); stale != nil { + c.logger.WithError(err).Warn("mcp resources/list failed, using stale skill cache") + return stale, nil + } + return nil, fmt.Errorf("list resources: %w", err) + } + + var listResult struct { + Resources []resourceEntry `json:"resources"` + } + if err := json.Unmarshal(result, &listResult); err != nil { + return nil, fmt.Errorf("unmarshal resources: %w", err) + } + + var skills []skillEntry + for _, r := range listResult.Resources { + slug := extractSkillSlug(r.URI) + if slug == "" { + continue + } + skills = append(skills, skillEntry{ + Slug: slug, + Name: r.Name, + Description: r.Description, + URI: r.URI, + }) + } + + slugs := make([]string, len(skills)) + for i, s := range skills { + slugs[i] = s.Slug + } + c.logger.WithFields(logrus.Fields{ + "skill_count": len(skills), + "skill_slugs": slugs, + }).Info("mcp skills discovered") + + c.skills.set(skills) + return skills, nil +} + +// extractSkillSlug converts a resource URI like "skills/dca-setup.md" to "dca-setup". +// Returns "" if the URI doesn't match the skill pattern. +func extractSkillSlug(uri string) string { + if !strings.HasPrefix(uri, "skills/") || !strings.HasSuffix(uri, ".md") { + return "" + } + return strings.TrimSuffix(strings.TrimPrefix(uri, "skills/"), ".md") +} + +// ReadSkill fetches the content of a specific skill by slug. +func (c *Client) ReadSkill(ctx context.Context, slug string) (string, error) { + // Check in-memory content cache first + if cached, ok := c.skillContent.Load(slug); ok { + return cached.(string), nil + } + + uri := "skills/" + slug + ".md" + c.logger.WithFields(logrus.Fields{ + "skill": slug, + "uri": uri, + }).Debug("mcp reading skill via resources/read") + + result, err := c.call(ctx, "resources/read", readResourceParams{URI: uri}) + if err != nil { + return "", fmt.Errorf("read skill %s: %w", slug, err) + } + + var readResult readResourceResult + if err := json.Unmarshal(result, &readResult); err != nil { + return "", fmt.Errorf("unmarshal skill content: %w", err) + } + + if len(readResult.Contents) == 0 { + return "", fmt.Errorf("skill %s: empty content", slug) + } + + text := readResult.Contents[0].Text + c.skillContent.Store(slug, text) + + c.logger.WithFields(logrus.Fields{ + "skill": slug, + "content_len": len(text), + }).Info("mcp skill loaded") + + return text, nil +} + +// SkillSummary returns a formatted list of available skills for injection into the system prompt. +// Returns "" if no skills are available. Triggers a background refresh if cache is stale. +func (c *Client) SkillSummary(ctx context.Context) string { + skills, fresh := c.skills.get() + + if !fresh && skills != nil { + go func() { + refreshCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _, _ = c.ListSkills(refreshCtx) + }() + } + + if len(skills) == 0 { + return "" + } + + var b strings.Builder + b.WriteString("\n\n## Available Skills\n\n") + b.WriteString("You have access to specialized skill guides that provide detailed instructions for specific workflows. ") + b.WriteString("Use the `get_skill` tool to load a skill's full instructions when it is relevant to the user's request.\n\n") + b.WriteString("**IMPORTANT**: Only load skills that are directly relevant to what the user is asking. Do not load all skills.\n\n") + for _, s := range skills { + b.WriteString("- **") + b.WriteString(s.Slug) + b.WriteString("**") + if s.Description != "" { + b.WriteString(": ") + b.WriteString(s.Description) + } + b.WriteString("\n") + } + return b.String() +} diff --git a/internal/service/agent/agent.go b/internal/service/agent/agent.go index 09a6bda..6f89382 100644 --- a/internal/service/agent/agent.go +++ b/internal/service/agent/agent.go @@ -29,12 +29,14 @@ type PluginSkillsProvider interface { GetSkills(ctx context.Context) []PluginSkill } -// MCPToolProvider provides tools discovered from an MCP server. +// MCPToolProvider provides tools and skills discovered from an MCP server. type MCPToolProvider interface { GetAnthropicTools(ctx context.Context) []ai.Tool ToolNames() []string CallTool(ctx context.Context, name string, arguments json.RawMessage) (string, error) ToolDescriptions() string + SkillSummary(ctx context.Context) string + ReadSkill(ctx context.Context, slug string) (string, error) } type SwapTxBuilder interface { @@ -170,6 +172,11 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub if mcpDesc != "" { basePrompt += mcpDesc } + skillSummary := s.mcpProvider.SkillSummary(ctx) + if skillSummary != "" { + s.logger.WithField("skill_summary_len", len(skillSummary)).Debug("appending skill summary to system prompt") + basePrompt += skillSummary + } } systemPrompt := BuildSystemPromptWithSummary( basePrompt+s.loadMemorySection(ctx, req.PublicKey)+MemoryManagementInstructions, @@ -199,6 +206,11 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub s.logger.Warn("mcp provider active but no tools returned") } tools = append(tools, mcpTools...) + + // Add get_skill tool if skills are available + if s.mcpProvider.SkillSummary(ctx) != "" { + tools = append(tools, GetSkillTool) + } } var toolResp *ToolResponse @@ -370,6 +382,11 @@ func (s *AgentService) ProcessMessageStream(ctx context.Context, convID uuid.UUI if mcpDesc != "" { basePrompt += mcpDesc } + skillSummary := s.mcpProvider.SkillSummary(ctx) + if skillSummary != "" { + s.logger.WithField("skill_summary_len", len(skillSummary)).Debug("appending skill summary to system prompt (stream)") + basePrompt += skillSummary + } } systemPrompt := BuildSystemPromptWithSummary( basePrompt+s.loadMemorySection(ctx, req.PublicKey)+MemoryManagementInstructions, @@ -389,6 +406,9 @@ func (s *AgentService) ProcessMessageStream(ctx context.Context, convID uuid.UUI if len(mcpTools) > 0 { tools = append(tools, mcpTools...) } + if s.mcpProvider.SkillSummary(ctx) != "" { + tools = append(tools, GetSkillTool) + } } var toolResp *ToolResponse diff --git a/internal/service/agent/executor.go b/internal/service/agent/executor.go index 0686f89..8bf2785 100644 --- a/internal/service/agent/executor.go +++ b/internal/service/agent/executor.go @@ -30,6 +30,8 @@ func (s *AgentService) executeTool(ctx context.Context, convID uuid.UUID, name s return s.execCreateSuggestion(ctx, input) case "update_memory": return s.execUpdateMemory(ctx, input, req) + case "get_skill": + return s.execGetSkill(ctx, input) case "set_vault": return s.execSetVault(ctx, convID, input, req) default: @@ -102,6 +104,30 @@ func (s *AgentService) execSetVault(ctx context.Context, convID uuid.UUID, input return string(result), nil } +// execGetSkill loads a skill's full instructions from the MCP server. +func (s *AgentService) execGetSkill(ctx context.Context, input json.RawMessage) (string, error) { + var params struct { + SkillName string `json:"skill_name"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return jsonError("invalid input: " + err.Error()), nil + } + if params.SkillName == "" { + return jsonError("skill_name is required"), nil + } + if s.mcpProvider == nil { + return jsonError("skills not available"), nil + } + + content, err := s.mcpProvider.ReadSkill(ctx, params.SkillName) + if err != nil { + s.logger.WithError(err).WithField("skill", params.SkillName).Warn("failed to load skill") + return jsonError("failed to load skill: " + err.Error()), nil + } + + return content, nil +} + // truncateKey returns the first 12 chars of a key for logging. func truncateKey(key string) string { if len(key) <= 12 { diff --git a/internal/service/agent/tools.go b/internal/service/agent/tools.go index f5c132f..3db6bce 100644 --- a/internal/service/agent/tools.go +++ b/internal/service/agent/tools.go @@ -130,6 +130,25 @@ var SetVaultTool = ai.Tool{ }, } +// GetSkillTool loads a specific skill's full instructions on demand. +// Added to the tool list dynamically only when skills are available from MCP. +var GetSkillTool = ai.Tool{ + Name: "get_skill", + Description: "Load the full instructions for a specific skill. " + + "Use this when you identify a skill from the Available Skills list that is relevant to the user's request. " + + "Only load skills that are directly needed — do not speculatively load skills.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "skill_name": map[string]any{ + "type": "string", + "description": "The slug name of the skill to load (as listed in Available Skills).", + }, + }, + "required": []string{"skill_name"}, + }, +} + // agentTools returns all granular tools for the decision loop. func agentTools() []ai.Tool { return []ai.Tool{ From ece74112129a9a52e959aafa8705235950ab5833 Mon Sep 17 00:00:00 2001 From: Raghav Sood Date: Tue, 24 Feb 2026 04:06:30 +0800 Subject: [PATCH 6/7] fix skill URI parsing for skill:// scheme The MCP server returns skill resources with URIs like "skill://vultisig/evm-contract-call.md" but extractSkillSlug was filtering for the "skills/" prefix, discarding all entries. - extractSkillSlug now extracts the last path segment before .md, handling any URI scheme (skill://, skills/, etc.) - ReadSkill now looks up the full URI from the skill cache instead of constructing it, so it works with any URI format Co-Authored-By: Claude Opus 4.6 --- internal/mcp/client.go | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 93d4e7a..9f3873d 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -526,13 +526,21 @@ func (c *Client) ListSkills(ctx context.Context) ([]skillEntry, error) { return skills, nil } -// extractSkillSlug converts a resource URI like "skills/dca-setup.md" to "dca-setup". -// Returns "" if the URI doesn't match the skill pattern. +// extractSkillSlug extracts a slug from a skill resource URI. +// Handles various URI formats: +// - "skill://vultisig/evm-contract-call.md" → "evm-contract-call" +// - "skills/evm-contract-call.md" → "evm-contract-call" +// +// Returns "" if the URI doesn't end in .md. func extractSkillSlug(uri string) string { - if !strings.HasPrefix(uri, "skills/") || !strings.HasSuffix(uri, ".md") { + if !strings.HasSuffix(uri, ".md") { return "" } - return strings.TrimSuffix(strings.TrimPrefix(uri, "skills/"), ".md") + base := strings.TrimSuffix(uri, ".md") + if idx := strings.LastIndex(base, "/"); idx >= 0 { + return base[idx+1:] + } + return base } // ReadSkill fetches the content of a specific skill by slug. @@ -542,7 +550,12 @@ func (c *Client) ReadSkill(ctx context.Context, slug string) (string, error) { return cached.(string), nil } - uri := "skills/" + slug + ".md" + // Look up the full URI from the skill cache + uri := c.skillURI(slug) + if uri == "" { + return "", fmt.Errorf("skill %q not found in skill list", slug) + } + c.logger.WithFields(logrus.Fields{ "skill": slug, "uri": uri, @@ -573,6 +586,17 @@ func (c *Client) ReadSkill(ctx context.Context, slug string) (string, error) { return text, nil } +// skillURI looks up the full resource URI for a skill slug from the cache. +func (c *Client) skillURI(slug string) string { + skills, _ := c.skills.get() + for _, s := range skills { + if s.Slug == slug { + return s.URI + } + } + return "" +} + // SkillSummary returns a formatted list of available skills for injection into the system prompt. // Returns "" if no skills are available. Triggers a background refresh if cache is stale. func (c *Client) SkillSummary(ctx context.Context) string { From ed20297877acbe4667276c2b59082ef9f207ee5f Mon Sep 17 00:00:00 2001 From: Raghav Sood Date: Tue, 24 Feb 2026 14:26:28 +0800 Subject: [PATCH 7/7] convert AI client from Anthropic format to OpenAI Chat Completions format MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit OpenRouter's universal format is OpenAI Chat Completions — it translates that to whatever the underlying model needs. Non-Anthropic models were rejecting our Anthropic Messages API format with "Invalid Anthropic Messages API request". - Rewrite client.go types and wire format (ToolCall, AssistantMessage, ToolMessage, ToolChoice with custom MarshalJSON) - URL: /messages → /chat/completions - System prompt: moved from request field to system message - Tool definitions: InputSchema wrapped as function parameters on wire - Response: parsed from choices[0].message with convenience fields - Streaming: OpenAI SSE format (data: chunks, data: [DONE]) - Update agent.go ProcessMessage/ProcessMessageStream for new types - Rename GetAnthropicTools → GetTools in MCP client and interface Co-Authored-By: Claude Opus 4.6 --- internal/ai/client.go | 427 ++++++++++++++++++----------- internal/mcp/client.go | 6 +- internal/service/agent/agent.go | 150 ++++------ internal/service/agent/memory.go | 6 +- internal/service/agent/starters.go | 9 +- 5 files changed, 324 insertions(+), 274 deletions(-) diff --git a/internal/ai/client.go b/internal/ai/client.go index 2471318..486e84b 100644 --- a/internal/ai/client.go +++ b/internal/ai/client.go @@ -18,7 +18,7 @@ const ( baseRetryDelay = 1 * time.Second ) -// Client is an OpenRouter-compatible AI API client. +// Client is an OpenRouter-compatible AI API client using OpenAI Chat Completions format. type Client struct { apiKey string model string @@ -30,32 +30,42 @@ type Client struct { // Message represents a simple conversation message with string content. type Message struct { - Role string `json:"role"` // "user" or "assistant" + Role string `json:"role"` // "user", "assistant", or "system" Content string `json:"content"` } -// AssistantMessage represents an assistant response with tool_use blocks. +// AssistantMessage represents an assistant response with optional tool calls. // Used to replay assistant tool calls in the conversation history. type AssistantMessage struct { - Role string `json:"role"` // "assistant" - Content []ContentBlock `json:"content"` + Role string `json:"role"` // "assistant" + Content string `json:"content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` } -// ToolResultMessage represents a user message containing tool results. -type ToolResultMessage struct { - Role string `json:"role"` // "user" - Content []ToolResultBlock `json:"content"` +// ToolMessage represents a single tool result in OpenAI format. +type ToolMessage struct { + Role string `json:"role"` // "tool" + ToolCallID string `json:"tool_call_id"` + Content string `json:"content"` } -// ToolResultBlock is a single tool result in a ToolResultMessage. -type ToolResultBlock struct { - Type string `json:"type"` // "tool_result" - ToolUseID string `json:"tool_use_id"` - Content string `json:"content"` - IsError bool `json:"is_error,omitempty"` +// ToolCall represents a tool invocation from the assistant. +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` // "function" + Function FunctionCall `json:"function"` +} + +// FunctionCall contains the function name and arguments. +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` } // Tool represents a tool that the model can use. +// InputSchema is kept as the internal field name; the client wraps it in the +// OpenAI function format ({"type":"function","function":{...,"parameters":...}}) +// when building the wire request. type Tool struct { Name string `json:"name"` Description string `json:"description"` @@ -64,46 +74,56 @@ type Tool struct { // ToolChoice specifies how the model should use tools. type ToolChoice struct { - Type string `json:"type"` // "auto", "any", or "tool" - Name string `json:"name,omitempty"` // Required when type is "tool" + Type string // "auto", "none", or "function" (specific tool) + Name string // required when Type is "function" } -// Request is the request body for the messages API. -// Messages accepts []any to support Message, AssistantMessage, and ToolResultMessage types. -type Request struct { - Model string `json:"model"` - MaxTokens int `json:"max_tokens"` - System string `json:"system,omitempty"` - Messages []any `json:"messages"` - Tools []Tool `json:"tools,omitempty"` - ToolChoice *ToolChoice `json:"tool_choice,omitempty"` +// MarshalJSON implements custom marshaling for ToolChoice. +// "auto" and "none" marshal as the bare string; a specific tool marshals as +// {"type":"function","function":{"name":"..."}}. +func (tc ToolChoice) MarshalJSON() ([]byte, error) { + switch tc.Type { + case "auto", "none": + return json.Marshal(tc.Type) + default: + return json.Marshal(struct { + Type string `json:"type"` + Function struct { + Name string `json:"name"` + } `json:"function"` + }{ + Type: "function", + Function: struct { + Name string `json:"name"` + }{Name: tc.Name}, + }) + } } -// Response is the response from the messages API. -type Response struct { - ID string `json:"id"` - Type string `json:"type"` - Role string `json:"role"` - Content []ContentBlock `json:"content"` - StopReason string `json:"stop_reason"` - Usage Usage `json:"usage"` +// Request is the request payload for the AI API. +// The client translates this into the OpenAI Chat Completions wire format. +type Request struct { + Model string + MaxTokens int + System string // prepended as a system message + Messages []any // Message, AssistantMessage, ToolMessage + Tools []Tool + ToolChoice *ToolChoice } -// ContentBlock represents a content block in the response. -type ContentBlock struct { - Type string `json:"type"` // "text" or "tool_use" - Text string `json:"text,omitempty"` - ID string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Input json.RawMessage `json:"input,omitempty"` - - partialInput string +// Response is the parsed response from the AI API. +type Response struct { + Content string + ToolCalls []ToolCall + FinishReason string + Usage Usage } // Usage contains token usage information. type Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` } // APIError represents an error from the AI API. @@ -117,6 +137,47 @@ func (e *APIError) Error() string { return fmt.Sprintf("ai: %s: %s", e.Type, e.Message) } +// --- Wire format types (internal) --- + +type openAIFunction struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters any `json:"parameters,omitempty"` +} + +type openAITool struct { + Type string `json:"type"` // "function" + Function openAIFunction `json:"function"` +} + +type openAIRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + Messages []any `json:"messages"` + Tools []openAITool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type openAIResponse struct { + ID string `json:"id"` + Choices []openAIChoice `json:"choices"` + Usage Usage `json:"usage"` +} + +type openAIChoice struct { + Message openAIMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type openAIMessage struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} + +// --- Client --- + // NewClient creates a new AI client. func NewClient(apiKey, model, baseURL, appName, appURL string) *Client { return &Client{ @@ -131,8 +192,8 @@ func NewClient(apiKey, model, baseURL, appName, appURL string) *Client { } } -// SendMessage sends a message to the model and returns the response. -func (c *Client) SendMessage(ctx context.Context, req *Request) (*Response, error) { +// buildWireRequest converts the public Request into the OpenAI wire format. +func (c *Client) buildWireRequest(req *Request, stream bool) ([]byte, error) { if req.Model == "" { req.Model = c.model } @@ -140,7 +201,44 @@ func (c *Client) SendMessage(ctx context.Context, req *Request) (*Response, erro req.MaxTokens = defaultMaxTokens } - body, err := json.Marshal(req) + messages := make([]any, 0, len(req.Messages)+1) + if req.System != "" { + messages = append(messages, Message{Role: "system", Content: req.System}) + } + messages = append(messages, req.Messages...) + + var tools []openAITool + for _, t := range req.Tools { + tools = append(tools, openAITool{ + Type: "function", + Function: openAIFunction{ + Name: t.Name, + Description: t.Description, + Parameters: t.InputSchema, + }, + }) + } + + var toolChoice any + if req.ToolChoice != nil { + toolChoice = req.ToolChoice + } + + wireReq := openAIRequest{ + Model: req.Model, + MaxTokens: req.MaxTokens, + Messages: messages, + Tools: tools, + ToolChoice: toolChoice, + Stream: stream, + } + + return json.Marshal(wireReq) +} + +// SendMessage sends a message to the model and returns the response. +func (c *Client) SendMessage(ctx context.Context, req *Request) (*Response, error) { + body, err := c.buildWireRequest(req, false) if err != nil { return nil, fmt.Errorf("marshal request: %w", err) } @@ -171,7 +269,7 @@ func (c *Client) SendMessage(ctx context.Context, req *Request) (*Response, erro } func (c *Client) doRequest(ctx context.Context, body []byte) (*Response, error) { - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/messages", bytes.NewReader(body)) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/chat/completions", bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("create request: %w", err) } @@ -218,12 +316,20 @@ func (c *Client) doRequest(ctx context.Context, body []byte) (*Response, error) return nil, &parsed.Error } - var result Response - if err := json.Unmarshal(respBody, &result); err != nil { + var oaiResp openAIResponse + if err := json.Unmarshal(respBody, &oaiResp); err != nil { return nil, fmt.Errorf("unmarshal response: %w", err) } - return &result, nil + result := &Response{Usage: oaiResp.Usage} + if len(oaiResp.Choices) > 0 { + choice := oaiResp.Choices[0] + result.Content = choice.Message.Content + result.ToolCalls = choice.Message.ToolCalls + result.FinishReason = choice.FinishReason + } + + return result, nil } func isRetryable(statusCode int) bool { @@ -234,47 +340,39 @@ func retryDelay(_ int, attempt int) time.Duration { return baseRetryDelay * time.Duration(1< 0 { - result.Usage.OutputTokens = md.Usage.OutputTokens - } + if tc.ID != "" { + existing.ID = tc.ID } - - case StreamEventError: - var errPayload struct { - Error APIError `json:"error"` + if tc.Function.Name != "" { + existing.Function.Name += tc.Function.Name } - if err := json.Unmarshal(data, &errPayload); err == nil { - return nil, &errPayload.Error + if tc.Function.Arguments != "" { + existing.Function.Arguments += tc.Function.Arguments } - return nil, &APIError{Type: "stream_error", Message: string(data)} + + cbDelta.ToolCalls = append(cbDelta.ToolCalls, ToolCallDelta{ + Index: tc.Index, + ID: tc.ID, + Function: FunctionCallDelta{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + }) + } + + if choice.FinishReason != nil { + result.FinishReason = *choice.FinishReason + } + + if chunk.Usage != nil { + result.Usage = *chunk.Usage + } + + if cbDelta.Content != "" || len(cbDelta.ToolCalls) > 0 { + callback(cbDelta) } } @@ -428,12 +521,12 @@ func (c *Client) readSSEStream(body io.Reader, callback StreamCallback) (*Respon return nil, fmt.Errorf("read stream: %w", err) } - for i := 0; i < len(contentBlocks); i++ { - if cb, ok := contentBlocks[i]; ok { - result.Content = append(result.Content, *cb) + // Collect accumulated tool calls in index order + for i := 0; i < len(toolCallsMap); i++ { + if tc, ok := toolCallsMap[i]; ok { + result.ToolCalls = append(result.ToolCalls, *tc) } } return &result, nil } - diff --git a/internal/mcp/client.go b/internal/mcp/client.go index 9f3873d..63adf6a 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -401,15 +401,15 @@ func (c *Client) CallTool(ctx context.Context, name string, arguments json.RawMe return text, nil } -// GetAnthropicTools returns cached MCP tools converted to Anthropic tool format. +// GetTools returns cached MCP tools converted to AI tool format. // If the cache is stale, it attempts a background refresh. -func (c *Client) GetAnthropicTools(ctx context.Context) []ai.Tool { +func (c *Client) GetTools(ctx context.Context) []ai.Tool { tools, fresh := c.cache.get() c.logger.WithFields(logrus.Fields{ "mcp_cache_count": len(tools), "mcp_cache_fresh": fresh, - }).Debug("mcp GetAnthropicTools called") + }).Debug("mcp GetTools called") if !fresh && tools != nil { c.logger.Debug("mcp cache stale, starting background refresh") diff --git a/internal/service/agent/agent.go b/internal/service/agent/agent.go index 6f89382..18ffcd2 100644 --- a/internal/service/agent/agent.go +++ b/internal/service/agent/agent.go @@ -31,7 +31,7 @@ type PluginSkillsProvider interface { // MCPToolProvider provides tools and skills discovered from an MCP server. type MCPToolProvider interface { - GetAnthropicTools(ctx context.Context) []ai.Tool + GetTools(ctx context.Context) []ai.Tool ToolNames() []string CallTool(ctx context.Context, name string, arguments json.RawMessage) (string, error) ToolDescriptions() string @@ -192,7 +192,7 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub tools := agentTools() tools = append(tools, s.memoryTools()...) if s.mcpProvider != nil { - mcpTools := s.mcpProvider.GetAnthropicTools(ctx) + mcpTools := s.mcpProvider.GetTools(ctx) if len(mcpTools) > 0 { mcpNames := make([]string, len(mcpTools)) for i, t := range mcpTools { @@ -235,55 +235,47 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub s.persistMemoryUpdate(ctx, req.PublicKey, s.extractMemoryUpdate(resp)) - var assistantText string - var toolCalls []ai.ContentBlock - for _, block := range resp.Content { - switch block.Type { - case "text": - assistantText += block.Text - case "tool_use": - toolCalls = append(toolCalls, block) - } - } + assistantText := resp.Content + toolCalls := resp.ToolCalls - if resp.StopReason == "end_turn" || len(toolCalls) == 0 { + if resp.FinishReason == "stop" || len(toolCalls) == 0 { textContent = assistantText break } - var toolResults []ai.ToolResultBlock + var toolMessages []ai.ToolMessage for _, tc := range toolCalls { - if tc.Name == "respond_to_user" { + if tc.Function.Name == "respond_to_user" { var tr ToolResponse - if err := json.Unmarshal(tc.Input, &tr); err == nil { + if err := json.Unmarshal(json.RawMessage(tc.Function.Arguments), &tr); err == nil { toolResp = &tr } - toolResults = append(toolResults, ai.ToolResultBlock{ - Type: "tool_result", - ToolUseID: tc.ID, - Content: `{"ok": true}`, + toolMessages = append(toolMessages, ai.ToolMessage{ + Role: "tool", + ToolCallID: tc.ID, + Content: `{"ok": true}`, }) continue } - result, err := s.executeTool(ctx, convID, tc.Name, tc.Input, req) + result, err := s.executeTool(ctx, convID, tc.Function.Name, json.RawMessage(tc.Function.Arguments), req) if err != nil { result = jsonError(err.Error()) } s.logger.WithFields(logrus.Fields{ - "tool": tc.Name, + "tool": tc.Function.Name, "tool_id": tc.ID, }).Debug("tool executed") - toolResults = append(toolResults, ai.ToolResultBlock{ - Type: "tool_result", - ToolUseID: tc.ID, - Content: result, + toolMessages = append(toolMessages, ai.ToolMessage{ + Role: "tool", + ToolCallID: tc.ID, + Content: result, }) // Track find_token results for structured passthrough - if tc.Name == "find_token" { + if tc.Function.Name == "find_token" { if parsed := extractTokens(result); parsed != nil { tokens = parsed s.logger.WithField("token_count", len(parsed.Tokens)).Info("tokens extracted from find_token result") @@ -294,13 +286,13 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub } messages = append(messages, ai.AssistantMessage{ - Role: "assistant", - Content: resp.Content, - }) - messages = append(messages, ai.ToolResultMessage{ - Role: "user", - Content: toolResults, + Role: "assistant", + Content: assistantText, + ToolCalls: toolCalls, }) + for _, tm := range toolMessages { + messages = append(messages, tm) + } if toolResp != nil { break @@ -402,7 +394,7 @@ func (s *AgentService) ProcessMessageStream(ctx context.Context, convID uuid.UUI tools := agentTools() tools = append(tools, s.memoryTools()...) if s.mcpProvider != nil { - mcpTools := s.mcpProvider.GetAnthropicTools(ctx) + mcpTools := s.mcpProvider.GetTools(ctx) if len(mcpTools) > 0 { tools = append(tools, mcpTools...) } @@ -426,29 +418,16 @@ func (s *AgentService) ProcessMessageStream(ctx context.Context, convID uuid.UUI } extractor := ai.NewResponseFieldExtractor() - callback := func(ev ai.StreamEvent) { - if ev.Type != ai.StreamEventContentBlockDelta { - return + callback := func(delta ai.StreamDelta) { + if delta.Content != "" { + eventCh <- SSEEvent{Event: "text_delta", Data: TextDeltaPayload{Delta: delta.Content}} } - var delta struct { - Delta struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - PartialJSON string `json:"partial_json,omitempty"` - } `json:"delta"` - } - if err := json.Unmarshal(ev.Data, &delta); err != nil { - return - } - switch delta.Delta.Type { - case "text_delta": - if delta.Delta.Text != "" { - eventCh <- SSEEvent{Event: "text_delta", Data: TextDeltaPayload{Delta: delta.Delta.Text}} - } - case "input_json_delta": - text := extractor.Feed(delta.Delta.PartialJSON) - if text != "" { - eventCh <- SSEEvent{Event: "text_delta", Data: TextDeltaPayload{Delta: text}} + for _, tc := range delta.ToolCalls { + if tc.Function.Arguments != "" { + text := extractor.Feed(tc.Function.Arguments) + if text != "" { + eventCh <- SSEEvent{Event: "text_delta", Data: TextDeltaPayload{Delta: text}} + } } } } @@ -461,56 +440,48 @@ func (s *AgentService) ProcessMessageStream(ctx context.Context, convID uuid.UUI s.persistMemoryUpdate(ctx, req.PublicKey, s.extractMemoryUpdate(resp)) - var assistantText string - var toolCalls []ai.ContentBlock - for _, block := range resp.Content { - switch block.Type { - case "text": - assistantText += block.Text - case "tool_use": - toolCalls = append(toolCalls, block) - } - } + assistantText := resp.Content + toolCalls := resp.ToolCalls - if resp.StopReason == "end_turn" || len(toolCalls) == 0 { + if resp.FinishReason == "stop" || len(toolCalls) == 0 { textContent = assistantText break } - var toolResults []ai.ToolResultBlock + var toolMessages []ai.ToolMessage for _, tc := range toolCalls { - if tc.Name == "respond_to_user" { + if tc.Function.Name == "respond_to_user" { var tr ToolResponse - if err := json.Unmarshal(tc.Input, &tr); err == nil { + if err := json.Unmarshal(json.RawMessage(tc.Function.Arguments), &tr); err == nil { toolResp = &tr } - toolResults = append(toolResults, ai.ToolResultBlock{ - Type: "tool_result", - ToolUseID: tc.ID, - Content: `{"ok": true}`, + toolMessages = append(toolMessages, ai.ToolMessage{ + Role: "tool", + ToolCallID: tc.ID, + Content: `{"ok": true}`, }) continue } - result, err := s.executeTool(ctx, convID, tc.Name, tc.Input, req) + result, err := s.executeTool(ctx, convID, tc.Function.Name, json.RawMessage(tc.Function.Arguments), req) if err != nil { result = jsonError(err.Error()) } - toolResults = append(toolResults, ai.ToolResultBlock{ - Type: "tool_result", - ToolUseID: tc.ID, - Content: result, + toolMessages = append(toolMessages, ai.ToolMessage{ + Role: "tool", + ToolCallID: tc.ID, + Content: result, }) } messages = append(messages, ai.AssistantMessage{ - Role: "assistant", - Content: resp.Content, - }) - messages = append(messages, ai.ToolResultMessage{ - Role: "user", - Content: toolResults, + Role: "assistant", + Content: assistantText, + ToolCalls: toolCalls, }) + for _, tm := range toolMessages { + messages = append(messages, tm) + } if toolResp != nil { break @@ -1213,14 +1184,7 @@ func (s *AgentService) summarizeOldMessages(ctx context.Context, convID uuid.UUI return fmt.Errorf("call ai: %w", err) } - var summaryText string - for _, block := range resp.Content { - if block.Type == "text" { - summaryText = block.Text - break - } - } - + summaryText := resp.Content if summaryText == "" { return fmt.Errorf("empty response from ai") } diff --git a/internal/service/agent/memory.go b/internal/service/agent/memory.go index 53adda0..f2a1ef0 100644 --- a/internal/service/agent/memory.go +++ b/internal/service/agent/memory.go @@ -62,10 +62,10 @@ func (s *AgentService) persistMemoryUpdate(ctx context.Context, publicKey string } func (s *AgentService) extractMemoryUpdate(resp *ai.Response) *updateMemoryInput { - for _, block := range resp.Content { - if block.Type == "tool_use" && block.Name == "update_memory" { + for _, tc := range resp.ToolCalls { + if tc.Function.Name == "update_memory" { var mu updateMemoryInput - if err := json.Unmarshal(block.Input, &mu); err == nil && mu.Content != "" { + if err := json.Unmarshal(json.RawMessage(tc.Function.Arguments), &mu); err == nil && mu.Content != "" { return &mu } } diff --git a/internal/service/agent/starters.go b/internal/service/agent/starters.go index bbd252a..afb5479 100644 --- a/internal/service/agent/starters.go +++ b/internal/service/agent/starters.go @@ -81,14 +81,7 @@ func (s *AgentService) GenerateStarters(ctx context.Context, req *GetStartersReq return empty } - var text string - for _, block := range resp.Content { - if block.Type == "text" { - text = block.Text - break - } - } - + text := resp.Content if text == "" { return empty }