diff --git a/.gitignore b/.gitignore index 1d4114b..ffe9094 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ vendor/ .env .env.local .env.*.local +.devenv # Config files (use examples) config.json diff --git a/cmd/server/main.go b/cmd/server/main.go index e05fced..d2b177a 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -17,6 +17,7 @@ import ( "github.com/vultisig/agent-backend/internal/api" "github.com/vultisig/agent-backend/internal/cache/redis" "github.com/vultisig/agent-backend/internal/config" + "github.com/vultisig/agent-backend/internal/mcp" "github.com/vultisig/agent-backend/internal/service/agent" mcpclient "github.com/vultisig/agent-backend/internal/service/mcp" "github.com/vultisig/agent-backend/internal/service/plugin" @@ -72,6 +73,37 @@ func main() { msgRepo := postgres.NewMessageRepository(db.Pool()) memRepo := postgres.NewMemoryRepository(db.Pool()) + // Initialize MCP JSON-RPC client for tool discovery and vault operations (optional) + var mcpProvider agent.MCPToolProvider + if cfg.MCP.ServerURL != "" { + cacheTTL := time.Duration(cfg.MCP.ToolCacheTTLSec) * time.Second + mcpClient := mcp.NewClient(cfg.MCP.ServerURL, cacheTTL, logger) + + mcpCtx, mcpCancel := context.WithTimeout(ctx, 10*time.Second) + defer mcpCancel() + + if err := mcpClient.Initialize(mcpCtx); err != nil { + logger.WithError(err).Warn("failed to initialize mcp client, continuing without mcp tools") + } else { + tools, err := mcpClient.ListTools(mcpCtx) + if err != nil { + logger.WithError(err).Warn("failed to list mcp tools, continuing without mcp tools") + } else { + logger.WithField("tool_count", len(tools)).Info("mcp tools loaded") + mcpProvider = mcpClient + } + + // Pre-warm skill cache (non-fatal) + skills, err := mcpClient.ListSkills(mcpCtx) + if err != nil { + logger.WithError(err).Warn("failed to list mcp skills, continuing without skills") + } else { + logger.WithField("skill_count", len(skills)).Info("mcp skills loaded") + } + } + } + + // Initialize MCP swap tx builder (optional) var swapTxBuilder agent.SwapTxBuilder if cfg.MCP.URL != "" { mcpCl := mcpclient.NewClient(cfg.MCP.URL) @@ -80,7 +112,7 @@ func main() { } // Initialize agent service - agentService := agent.NewAgentService(aiClient, msgRepo, convRepo, memRepo, redisClient, verifierClient, pluginService, swapTxBuilder, logger, cfg.AI.SummaryModel, cfg.Context) + agentService := agent.NewAgentService(aiClient, msgRepo, convRepo, memRepo, redisClient, verifierClient, pluginService, mcpProvider, swapTxBuilder, logger, cfg.AI.SummaryModel, cfg.Context) // Initialize API server server := api.NewServer( diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..ebd3112 --- /dev/null +++ b/flake.lock @@ -0,0 +1,306 @@ +{ + "nodes": { + "cachix": { + "inputs": { + "devenv": [ + "devenv" + ], + "flake-compat": [ + "devenv", + "flake-compat" + ], + "git-hooks": [ + "devenv", + "git-hooks" + ], + "nixpkgs": [ + "devenv", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1760971495, + "narHash": "sha256-IwnNtbNVrlZIHh7h4Wz6VP0Furxg9Hh0ycighvL5cZc=", + "owner": "cachix", + "repo": "cachix", + "rev": "c5bfd933d1033672f51a863c47303fc0e093c2d2", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "latest", + "repo": "cachix", + "type": "github" + } + }, + "devenv": { + "inputs": { + "cachix": "cachix", + "flake-compat": "flake-compat", + "flake-parts": "flake-parts", + "git-hooks": "git-hooks", + "nix": "nix", + "nixd": "nixd", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1771419682, + "narHash": "sha256-NAemVgEJeZjGl3+438M4rUL8ms9QdDFMYthU12F70FQ=", + "owner": "cachix", + "repo": "devenv", + "rev": "f77fc4de35c184d9ef9a32d5d7e9033351bcdfdc", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "devenv", + "type": "github" + } + }, + "flake-compat": { + "flake": false, + "locked": { + "lastModified": 1761588595, + "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-parts": { + "inputs": { + "nixpkgs-lib": [ + "devenv", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1760948891, + "narHash": "sha256-TmWcdiUUaWk8J4lpjzu4gCGxWY6/Ok7mOK4fIFfBuU4=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "864599284fc7c0ba6357ed89ed5e2cd5040f0c04", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "flake-root": { + "locked": { + "lastModified": 1723604017, + "narHash": "sha256-rBtQ8gg+Dn4Sx/s+pvjdq3CB2wQNzx9XGFq/JVGCB6k=", + "owner": "srid", + "repo": "flake-root", + "rev": "b759a56851e10cb13f6b8e5698af7b59c44be26e", + "type": "github" + }, + "original": { + "owner": "srid", + "repo": "flake-root", + "type": "github" + } + }, + "git-hooks": { + "inputs": { + "flake-compat": [ + "devenv", + "flake-compat" + ], + "gitignore": "gitignore", + "nixpkgs": [ + "devenv", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1760663237, + "narHash": "sha256-BflA6U4AM1bzuRMR8QqzPXqh8sWVCNDzOdsxXEguJIc=", + "owner": "cachix", + "repo": "git-hooks.nix", + "rev": "ca5b894d3e3e151ffc1db040b6ce4dcc75d31c37", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "git-hooks.nix", + "type": "github" + } + }, + "gitignore": { + "inputs": { + "nixpkgs": [ + "devenv", + "git-hooks", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1709087332, + "narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=", + "owner": "hercules-ci", + "repo": "gitignore.nix", + "rev": "637db329424fd7e46cf4185293b9cc8c88c95394", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, + "nix": { + "inputs": { + "flake-compat": [ + "devenv", + "flake-compat" + ], + "flake-parts": [ + "devenv", + "flake-parts" + ], + "git-hooks-nix": [ + "devenv", + "git-hooks" + ], + "nixpkgs": [ + "devenv", + "nixpkgs" + ], + "nixpkgs-23-11": [ + "devenv" + ], + "nixpkgs-regression": [ + "devenv" + ] + }, + "locked": { + "lastModified": 1770395975, + "narHash": "sha256-zg0AEZn8d4rqIIsw5XrkVL5p1y6fBj2L57awfUg+gNA=", + "owner": "cachix", + "repo": "nix", + "rev": "ccb6019ce2bd11f5de5fe4617c0079d8cb1ed057", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "devenv-2.32", + "repo": "nix", + "type": "github" + } + }, + "nixd": { + "inputs": { + "flake-parts": [ + "devenv", + "flake-parts" + ], + "flake-root": "flake-root", + "nixpkgs": [ + "devenv", + "nixpkgs" + ], + "treefmt-nix": "treefmt-nix" + }, + "locked": { + "lastModified": 1763964548, + "narHash": "sha256-JTRoaEWvPsVIMFJWeS4G2isPo15wqXY/otsiHPN0zww=", + "owner": "nix-community", + "repo": "nixd", + "rev": "d4bf15e56540422e2acc7bc26b20b0a0934e3f5e", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "nixd", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1761313199, + "narHash": "sha256-wCIACXbNtXAlwvQUo1Ed++loFALPjYUA3dpcUJiXO44=", + "owner": "cachix", + "repo": "devenv-nixpkgs", + "rev": "d1c30452ebecfc55185ae6d1c983c09da0c274ff", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "rolling", + "repo": "devenv-nixpkgs", + "type": "github" + } + }, + "nixpkgs_2": { + "locked": { + "lastModified": 1771177547, + "narHash": "sha256-trTtk3WTOHz7hSw89xIIvahkgoFJYQ0G43IlqprFoMA=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "ac055f38c798b0d87695240c7b761b82fc7e5bc2", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixpkgs-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "devenv": "devenv", + "nixpkgs": "nixpkgs_2", + "systems": "systems" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "treefmt-nix": { + "inputs": { + "nixpkgs": [ + "devenv", + "nixd", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1734704479, + "narHash": "sha256-MMi74+WckoyEWBRcg/oaGRvXC9BVVxDZNRMpL+72wBI=", + "owner": "numtide", + "repo": "treefmt-nix", + "rev": "65712f5af67234dad91a5a4baee986a8b62dbf8f", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "treefmt-nix", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..502e258 --- /dev/null +++ b/flake.nix @@ -0,0 +1,42 @@ +{ + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; + systems.url = "github:nix-systems/default"; + devenv.url = "github:cachix/devenv"; + }; + + outputs = { self, nixpkgs, devenv, systems, ... } @ inputs: + let + forEachSystem = nixpkgs.lib.genAttrs (import systems); + in + { + devShells = forEachSystem + (system: + let + pkgs = nixpkgs.legacyPackages.${system}; + in + { + default = devenv.lib.mkShell { + inherit inputs pkgs; + modules = [ + { + languages.go = { + enable = true; + }; + + packages = with pkgs; [ + postgresql + go-ethereum + sqlc + redis + ]; + + enterShell = '' + echo "agent-backend shell started!" + ''; + } + ]; + }; + }); + }; +} diff --git a/internal/ai/client.go b/internal/ai/client.go index 2471318..486e84b 100644 --- a/internal/ai/client.go +++ b/internal/ai/client.go @@ -18,7 +18,7 @@ const ( baseRetryDelay = 1 * time.Second ) -// Client is an OpenRouter-compatible AI API client. +// Client is an OpenRouter-compatible AI API client using OpenAI Chat Completions format. type Client struct { apiKey string model string @@ -30,32 +30,42 @@ type Client struct { // Message represents a simple conversation message with string content. type Message struct { - Role string `json:"role"` // "user" or "assistant" + Role string `json:"role"` // "user", "assistant", or "system" Content string `json:"content"` } -// AssistantMessage represents an assistant response with tool_use blocks. +// AssistantMessage represents an assistant response with optional tool calls. // Used to replay assistant tool calls in the conversation history. type AssistantMessage struct { - Role string `json:"role"` // "assistant" - Content []ContentBlock `json:"content"` + Role string `json:"role"` // "assistant" + Content string `json:"content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` } -// ToolResultMessage represents a user message containing tool results. -type ToolResultMessage struct { - Role string `json:"role"` // "user" - Content []ToolResultBlock `json:"content"` +// ToolMessage represents a single tool result in OpenAI format. +type ToolMessage struct { + Role string `json:"role"` // "tool" + ToolCallID string `json:"tool_call_id"` + Content string `json:"content"` } -// ToolResultBlock is a single tool result in a ToolResultMessage. -type ToolResultBlock struct { - Type string `json:"type"` // "tool_result" - ToolUseID string `json:"tool_use_id"` - Content string `json:"content"` - IsError bool `json:"is_error,omitempty"` +// ToolCall represents a tool invocation from the assistant. +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` // "function" + Function FunctionCall `json:"function"` +} + +// FunctionCall contains the function name and arguments. +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` } // Tool represents a tool that the model can use. +// InputSchema is kept as the internal field name; the client wraps it in the +// OpenAI function format ({"type":"function","function":{...,"parameters":...}}) +// when building the wire request. type Tool struct { Name string `json:"name"` Description string `json:"description"` @@ -64,46 +74,56 @@ type Tool struct { // ToolChoice specifies how the model should use tools. type ToolChoice struct { - Type string `json:"type"` // "auto", "any", or "tool" - Name string `json:"name,omitempty"` // Required when type is "tool" + Type string // "auto", "none", or "function" (specific tool) + Name string // required when Type is "function" } -// Request is the request body for the messages API. -// Messages accepts []any to support Message, AssistantMessage, and ToolResultMessage types. -type Request struct { - Model string `json:"model"` - MaxTokens int `json:"max_tokens"` - System string `json:"system,omitempty"` - Messages []any `json:"messages"` - Tools []Tool `json:"tools,omitempty"` - ToolChoice *ToolChoice `json:"tool_choice,omitempty"` +// MarshalJSON implements custom marshaling for ToolChoice. +// "auto" and "none" marshal as the bare string; a specific tool marshals as +// {"type":"function","function":{"name":"..."}}. +func (tc ToolChoice) MarshalJSON() ([]byte, error) { + switch tc.Type { + case "auto", "none": + return json.Marshal(tc.Type) + default: + return json.Marshal(struct { + Type string `json:"type"` + Function struct { + Name string `json:"name"` + } `json:"function"` + }{ + Type: "function", + Function: struct { + Name string `json:"name"` + }{Name: tc.Name}, + }) + } } -// Response is the response from the messages API. -type Response struct { - ID string `json:"id"` - Type string `json:"type"` - Role string `json:"role"` - Content []ContentBlock `json:"content"` - StopReason string `json:"stop_reason"` - Usage Usage `json:"usage"` +// Request is the request payload for the AI API. +// The client translates this into the OpenAI Chat Completions wire format. +type Request struct { + Model string + MaxTokens int + System string // prepended as a system message + Messages []any // Message, AssistantMessage, ToolMessage + Tools []Tool + ToolChoice *ToolChoice } -// ContentBlock represents a content block in the response. -type ContentBlock struct { - Type string `json:"type"` // "text" or "tool_use" - Text string `json:"text,omitempty"` - ID string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Input json.RawMessage `json:"input,omitempty"` - - partialInput string +// Response is the parsed response from the AI API. +type Response struct { + Content string + ToolCalls []ToolCall + FinishReason string + Usage Usage } // Usage contains token usage information. type Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` } // APIError represents an error from the AI API. @@ -117,6 +137,47 @@ func (e *APIError) Error() string { return fmt.Sprintf("ai: %s: %s", e.Type, e.Message) } +// --- Wire format types (internal) --- + +type openAIFunction struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters any `json:"parameters,omitempty"` +} + +type openAITool struct { + Type string `json:"type"` // "function" + Function openAIFunction `json:"function"` +} + +type openAIRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + Messages []any `json:"messages"` + Tools []openAITool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type openAIResponse struct { + ID string `json:"id"` + Choices []openAIChoice `json:"choices"` + Usage Usage `json:"usage"` +} + +type openAIChoice struct { + Message openAIMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type openAIMessage struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} + +// --- Client --- + // NewClient creates a new AI client. func NewClient(apiKey, model, baseURL, appName, appURL string) *Client { return &Client{ @@ -131,8 +192,8 @@ func NewClient(apiKey, model, baseURL, appName, appURL string) *Client { } } -// SendMessage sends a message to the model and returns the response. -func (c *Client) SendMessage(ctx context.Context, req *Request) (*Response, error) { +// buildWireRequest converts the public Request into the OpenAI wire format. +func (c *Client) buildWireRequest(req *Request, stream bool) ([]byte, error) { if req.Model == "" { req.Model = c.model } @@ -140,7 +201,44 @@ func (c *Client) SendMessage(ctx context.Context, req *Request) (*Response, erro req.MaxTokens = defaultMaxTokens } - body, err := json.Marshal(req) + messages := make([]any, 0, len(req.Messages)+1) + if req.System != "" { + messages = append(messages, Message{Role: "system", Content: req.System}) + } + messages = append(messages, req.Messages...) + + var tools []openAITool + for _, t := range req.Tools { + tools = append(tools, openAITool{ + Type: "function", + Function: openAIFunction{ + Name: t.Name, + Description: t.Description, + Parameters: t.InputSchema, + }, + }) + } + + var toolChoice any + if req.ToolChoice != nil { + toolChoice = req.ToolChoice + } + + wireReq := openAIRequest{ + Model: req.Model, + MaxTokens: req.MaxTokens, + Messages: messages, + Tools: tools, + ToolChoice: toolChoice, + Stream: stream, + } + + return json.Marshal(wireReq) +} + +// SendMessage sends a message to the model and returns the response. +func (c *Client) SendMessage(ctx context.Context, req *Request) (*Response, error) { + body, err := c.buildWireRequest(req, false) if err != nil { return nil, fmt.Errorf("marshal request: %w", err) } @@ -171,7 +269,7 @@ func (c *Client) SendMessage(ctx context.Context, req *Request) (*Response, erro } func (c *Client) doRequest(ctx context.Context, body []byte) (*Response, error) { - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/messages", bytes.NewReader(body)) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/chat/completions", bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("create request: %w", err) } @@ -218,12 +316,20 @@ func (c *Client) doRequest(ctx context.Context, body []byte) (*Response, error) return nil, &parsed.Error } - var result Response - if err := json.Unmarshal(respBody, &result); err != nil { + var oaiResp openAIResponse + if err := json.Unmarshal(respBody, &oaiResp); err != nil { return nil, fmt.Errorf("unmarshal response: %w", err) } - return &result, nil + result := &Response{Usage: oaiResp.Usage} + if len(oaiResp.Choices) > 0 { + choice := oaiResp.Choices[0] + result.Content = choice.Message.Content + result.ToolCalls = choice.Message.ToolCalls + result.FinishReason = choice.FinishReason + } + + return result, nil } func isRetryable(statusCode int) bool { @@ -234,47 +340,39 @@ func retryDelay(_ int, attempt int) time.Duration { return baseRetryDelay * time.Duration(1< 0 { - result.Usage.OutputTokens = md.Usage.OutputTokens - } + if tc.ID != "" { + existing.ID = tc.ID } - - case StreamEventError: - var errPayload struct { - Error APIError `json:"error"` + if tc.Function.Name != "" { + existing.Function.Name += tc.Function.Name } - if err := json.Unmarshal(data, &errPayload); err == nil { - return nil, &errPayload.Error + if tc.Function.Arguments != "" { + existing.Function.Arguments += tc.Function.Arguments } - return nil, &APIError{Type: "stream_error", Message: string(data)} + + cbDelta.ToolCalls = append(cbDelta.ToolCalls, ToolCallDelta{ + Index: tc.Index, + ID: tc.ID, + Function: FunctionCallDelta{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + }) + } + + if choice.FinishReason != nil { + result.FinishReason = *choice.FinishReason + } + + if chunk.Usage != nil { + result.Usage = *chunk.Usage + } + + if cbDelta.Content != "" || len(cbDelta.ToolCalls) > 0 { + callback(cbDelta) } } @@ -428,12 +521,12 @@ func (c *Client) readSSEStream(body io.Reader, callback StreamCallback) (*Respon return nil, fmt.Errorf("read stream: %w", err) } - for i := 0; i < len(contentBlocks); i++ { - if cb, ok := contentBlocks[i]; ok { - result.Content = append(result.Content, *cb) + // Collect accumulated tool calls in index order + for i := 0; i < len(toolCallsMap); i++ { + if tc, ok := toolCallsMap[i]; ok { + result.ToolCalls = append(result.ToolCalls, *tc) } } return &result, nil } - diff --git a/internal/api/conversation.go b/internal/api/conversation.go index 7e28ff8..710a7bc 100644 --- a/internal/api/conversation.go +++ b/internal/api/conversation.go @@ -46,11 +46,6 @@ func (s *Server) CreateConversation(c echo.Context) error { return c.JSON(http.StatusBadRequest, ErrorResponse{Error: "invalid request body"}) } - authPublicKey := GetPublicKey(c) - if req.PublicKey != authPublicKey { - return c.JSON(http.StatusForbidden, ErrorResponse{Error: "public key mismatch"}) - } - conv, err := s.convRepo.Create(c.Request().Context(), req.PublicKey) if err != nil { s.logger.WithError(err).Error("failed to create conversation") @@ -67,11 +62,6 @@ func (s *Server) ListConversations(c echo.Context) error { return c.JSON(http.StatusBadRequest, ErrorResponse{Error: "invalid request body"}) } - authPublicKey := GetPublicKey(c) - if req.PublicKey != authPublicKey { - return c.JSON(http.StatusForbidden, ErrorResponse{Error: "public key mismatch"}) - } - // Default pagination if req.Take <= 0 { req.Take = 20 @@ -109,11 +99,6 @@ func (s *Server) GetConversation(c echo.Context) error { return c.JSON(http.StatusBadRequest, ErrorResponse{Error: "invalid request body"}) } - authPublicKey := GetPublicKey(c) - if req.PublicKey != authPublicKey { - return c.JSON(http.StatusForbidden, ErrorResponse{Error: "public key mismatch"}) - } - conv, err := s.convRepo.GetWithMessages(c.Request().Context(), id, req.PublicKey) if err != nil { if errors.Is(err, postgres.ErrNotFound) { @@ -143,11 +128,6 @@ func (s *Server) DeleteConversation(c echo.Context) error { return c.JSON(http.StatusBadRequest, ErrorResponse{Error: "invalid request body"}) } - authPublicKey := GetPublicKey(c) - if req.PublicKey != authPublicKey { - return c.JSON(http.StatusForbidden, ErrorResponse{Error: "public key mismatch"}) - } - err = s.convRepo.Archive(c.Request().Context(), id, req.PublicKey) if err != nil { if errors.Is(err, postgres.ErrNotFound) { diff --git a/internal/config/config.go b/internal/config/config.go index f8c65bf..4ec7143 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -63,8 +63,11 @@ type VerifierConfig struct { URL string `envconfig:"VERIFIER_URL" required:"true"` } +// MCPConfig holds MCP (Model Context Protocol) server configuration. type MCPConfig struct { - URL string `envconfig:"MCP_URL" default:""` + ServerURL string `envconfig:"MCP_SERVER_URL"` + ToolCacheTTLSec int `envconfig:"MCP_TOOL_CACHE_TTL_SECONDS" default:"300"` + URL string `envconfig:"MCP_URL" default:""` } // TODO: Add MetricsConfig for Prometheus metrics when metrics are implemented. diff --git a/internal/mcp/client.go b/internal/mcp/client.go new file mode 100644 index 0000000..63adf6a --- /dev/null +++ b/internal/mcp/client.go @@ -0,0 +1,633 @@ +package mcp + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/sirupsen/logrus" + + "github.com/vultisig/agent-backend/internal/ai" +) + +// JSON-RPC 2.0 types + +type jsonRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params any `json:"params,omitempty"` +} + +type jsonRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *jsonRPCError `json:"error,omitempty"` +} + +type jsonRPCError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +func (e *jsonRPCError) Error() string { + return fmt.Sprintf("mcp rpc error %d: %s", e.Code, e.Message) +} + +// MCP-specific types + +// ToolError is returned when an MCP tool sets IsError: true. +// It carries the tool's text content so callers can still parse structured data from it. +type ToolError struct { + ToolName string + Text string +} + +func (e *ToolError) Error() string { + return fmt.Sprintf("mcp tool %s error: %s", e.ToolName, e.Text) +} + +// MCPTool represents a tool definition from the MCP server. +type MCPTool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema any `json:"inputSchema"` +} + +type callToolParams struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments,omitempty"` +} + +type callToolResult struct { + Content []callToolContent `json:"content"` + IsError bool `json:"isError,omitempty"` +} + +type callToolContent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` +} + +// toolCache holds cached MCP tools with a TTL. +type toolCache struct { + mu sync.RWMutex + tools []MCPTool + fetchedAt time.Time + ttl time.Duration +} + +func (tc *toolCache) get() ([]MCPTool, bool) { + tc.mu.RLock() + defer tc.mu.RUnlock() + if tc.tools == nil { + return nil, false + } + fresh := time.Since(tc.fetchedAt) < tc.ttl + return tc.tools, fresh +} + +func (tc *toolCache) set(tools []MCPTool) { + tc.mu.Lock() + defer tc.mu.Unlock() + tc.tools = tools + tc.fetchedAt = time.Now() +} + +// MCP resource types (resources/list, resources/read) + +type resourceEntry struct { + URI string `json:"uri"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + MimeType string `json:"mimeType,omitempty"` +} + +type readResourceParams struct { + URI string `json:"uri"` +} + +type readResourceResult struct { + Contents []resourceContent `json:"contents"` +} + +type resourceContent struct { + URI string `json:"uri"` + MimeType string `json:"mimeType,omitempty"` + Text string `json:"text,omitempty"` +} + +// skillEntry is an MCP skill discovered via resources/list. +type skillEntry struct { + Slug string + Name string + Description string + URI string +} + +// skillCache holds cached skill metadata with a TTL. +type skillCache struct { + mu sync.RWMutex + skills []skillEntry + fetchedAt time.Time + ttl time.Duration +} + +func (sc *skillCache) get() ([]skillEntry, bool) { + sc.mu.RLock() + defer sc.mu.RUnlock() + if sc.skills == nil { + return nil, false + } + fresh := time.Since(sc.fetchedAt) < sc.ttl + return sc.skills, fresh +} + +func (sc *skillCache) set(skills []skillEntry) { + sc.mu.Lock() + defer sc.mu.Unlock() + sc.skills = skills + sc.fetchedAt = time.Now() +} + +// Client is an MCP JSON-RPC 2.0 client using Streamable HTTP transport. +type Client struct { + serverURL string + httpClient *http.Client + sessionID string + requestID atomic.Int64 + cache toolCache + skills skillCache + skillContent sync.Map // slug → string (cached skill markdown) + logger *logrus.Logger +} + +// NewClient creates a new MCP client. +func NewClient(serverURL string, cacheTTL time.Duration, logger *logrus.Logger) *Client { + return &Client{ + serverURL: strings.TrimRight(serverURL, "/"), + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + cache: toolCache{ttl: cacheTTL}, + skills: skillCache{ttl: cacheTTL}, + logger: logger, + } +} + +// call performs a JSON-RPC 2.0 call over HTTP. +func (c *Client) call(ctx context.Context, method string, params any) (json.RawMessage, error) { + id := c.requestID.Add(1) + reqBody := jsonRPCRequest{ + JSONRPC: "2.0", + ID: id, + Method: method, + Params: params, + } + + body, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + url := c.serverURL + "/mcp" + + c.logger.WithFields(logrus.Fields{ + "mcp_method": method, + "mcp_id": id, + "mcp_url": url, + "mcp_session": c.sessionID, + }).Debug("mcp request sending") + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "application/json") + if c.sessionID != "" { + httpReq.Header.Set("Mcp-Session-Id", c.sessionID) + } + + start := time.Now() + resp, err := c.httpClient.Do(httpReq) + elapsed := time.Since(start) + if err != nil { + c.logger.WithError(err).WithFields(logrus.Fields{ + "mcp_method": method, + "mcp_id": id, + "mcp_elapsed": elapsed.String(), + }).Error("mcp request failed") + return nil, fmt.Errorf("send request: %w", err) + } + defer resp.Body.Close() + + // Track session ID from response + if sid := resp.Header.Get("Mcp-Session-Id"); sid != "" { + if c.sessionID != sid { + c.logger.WithFields(logrus.Fields{ + "mcp_session_old": c.sessionID, + "mcp_session_new": sid, + }).Debug("mcp session id updated") + } + c.sessionID = sid + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + c.logger.WithFields(logrus.Fields{ + "mcp_method": method, + "mcp_id": id, + "mcp_status": resp.StatusCode, + "mcp_elapsed": elapsed.String(), + "mcp_response_len": len(respBody), + }).Debug("mcp response received") + + if resp.StatusCode != http.StatusOK { + c.logger.WithFields(logrus.Fields{ + "mcp_method": method, + "mcp_status": resp.StatusCode, + "mcp_body": string(respBody), + }).Error("mcp server returned non-200 status") + return nil, fmt.Errorf("mcp server returned status %d: %s", resp.StatusCode, string(respBody)) + } + + var rpcResp jsonRPCResponse + if err := json.Unmarshal(respBody, &rpcResp); err != nil { + c.logger.WithError(err).WithField("mcp_body", string(respBody)).Error("mcp response unmarshal failed") + return nil, fmt.Errorf("unmarshal response: %w", err) + } + + if rpcResp.Error != nil { + c.logger.WithFields(logrus.Fields{ + "mcp_method": method, + "mcp_error_code": rpcResp.Error.Code, + "mcp_error_msg": rpcResp.Error.Message, + }).Error("mcp rpc error") + return nil, rpcResp.Error + } + + return rpcResp.Result, nil +} + +// Initialize performs the MCP initialize handshake. +func (c *Client) Initialize(ctx context.Context) error { + c.logger.WithField("mcp_url", c.serverURL).Info("mcp initializing") + + params := map[string]any{ + "protocolVersion": "2025-03-26", + "capabilities": map[string]any{}, + "clientInfo": map[string]string{ + "name": "vultisig-agent-backend", + "version": "1.0.0", + }, + } + + result, err := c.call(ctx, "initialize", params) + if err != nil { + return fmt.Errorf("initialize: %w", err) + } + + c.logger.WithField("mcp_init_result", string(result)).Info("mcp initialized successfully") + + // Send initialized notification — best-effort, some servers don't handle it + if _, err := c.call(ctx, "notifications/initialized", nil); err != nil { + c.logger.WithError(err).Debug("mcp notifications/initialized not supported by server (harmless)") + } + + return nil +} + +// ListTools fetches the tool list from the MCP server and updates the cache. +func (c *Client) ListTools(ctx context.Context) ([]MCPTool, error) { + c.logger.Debug("mcp listing tools") + + result, err := c.call(ctx, "tools/list", nil) + if err != nil { + // Return stale cache on error + if stale, _ := c.cache.get(); stale != nil { + c.logger.WithError(err).WithField("stale_count", len(stale)).Warn("mcp list tools failed, using stale cache") + return stale, nil + } + return nil, fmt.Errorf("list tools: %w", err) + } + + var listResult struct { + Tools []MCPTool `json:"tools"` + } + if err := json.Unmarshal(result, &listResult); err != nil { + return nil, fmt.Errorf("unmarshal tools: %w", err) + } + + names := make([]string, len(listResult.Tools)) + for i, t := range listResult.Tools { + names[i] = t.Name + } + c.logger.WithFields(logrus.Fields{ + "mcp_tool_count": len(listResult.Tools), + "mcp_tool_names": names, + }).Info("mcp tools discovered") + + c.cache.set(listResult.Tools) + return listResult.Tools, nil +} + +// CallTool invokes a tool on the MCP server. +func (c *Client) CallTool(ctx context.Context, name string, arguments json.RawMessage) (string, error) { + c.logger.WithFields(logrus.Fields{ + "mcp_tool": name, + "mcp_arguments": string(arguments), + }).Info("mcp calling tool") + + var args map[string]any + if len(arguments) > 0 { + if err := json.Unmarshal(arguments, &args); err != nil { + return "", fmt.Errorf("unmarshal arguments: %w", err) + } + } + + params := callToolParams{ + Name: name, + Arguments: args, + } + + result, err := c.call(ctx, "tools/call", params) + if err != nil { + c.logger.WithError(err).WithField("mcp_tool", name).Error("mcp tool call failed") + return "", fmt.Errorf("call tool %s: %w", name, err) + } + + var callResult callToolResult + if err := json.Unmarshal(result, &callResult); err != nil { + return "", fmt.Errorf("unmarshal tool result: %w", err) + } + + // Collect text content from result + var texts []string + for _, c := range callResult.Content { + if c.Type == "text" && c.Text != "" { + texts = append(texts, c.Text) + } + } + + text := strings.Join(texts, "\n") + + if callResult.IsError { + c.logger.WithFields(logrus.Fields{ + "mcp_tool": name, + "mcp_error": text, + }).Error("mcp tool returned error") + // Return the text with a ToolError so callers can still access the content. + return text, &ToolError{ToolName: name, Text: text} + } + + c.logger.WithFields(logrus.Fields{ + "mcp_tool": name, + "mcp_result_len": len(text), + }).Info("mcp tool call succeeded") + + return text, nil +} + +// GetTools returns cached MCP tools converted to AI tool format. +// If the cache is stale, it attempts a background refresh. +func (c *Client) GetTools(ctx context.Context) []ai.Tool { + tools, fresh := c.cache.get() + + c.logger.WithFields(logrus.Fields{ + "mcp_cache_count": len(tools), + "mcp_cache_fresh": fresh, + }).Debug("mcp GetTools called") + + if !fresh && tools != nil { + c.logger.Debug("mcp cache stale, starting background refresh") + go func() { + refreshCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _, _ = c.ListTools(refreshCtx) + }() + } + + if tools == nil { + c.logger.Warn("mcp cache empty, no tools available") + return nil + } + + result := make([]ai.Tool, len(tools)) + for i, t := range tools { + result[i] = ai.Tool{ + Name: t.Name, + Description: t.Description, + InputSchema: t.InputSchema, + } + } + return result +} + +// ToolNames returns the names of all cached MCP tools. +func (c *Client) ToolNames() []string { + tools, _ := c.cache.get() + names := make([]string, len(tools)) + for i, t := range tools { + names[i] = t.Name + } + return names +} + +// ToolDescriptions returns a formatted string describing all MCP tools for the system prompt. +func (c *Client) ToolDescriptions() string { + tools, _ := c.cache.get() + if len(tools) == 0 { + c.logger.Warn("mcp ToolDescriptions called but cache is empty") + return "" + } + + var b strings.Builder + b.WriteString("\n\n## Vultisig Tools\n\n") + b.WriteString("You have the following tools provided by the Vultisig platform. ") + b.WriteString("These are core capabilities you MUST include when listing your available tools or describing what you can do:\n\n") + for _, t := range tools { + b.WriteString("- **") + b.WriteString(t.Name) + b.WriteString("**") + if t.Description != "" { + b.WriteString(": ") + b.WriteString(t.Description) + } + b.WriteString("\n") + } + + desc := b.String() + c.logger.WithField("mcp_desc_len", len(desc)).Debug("mcp ToolDescriptions generated") + return desc +} + +// --------------------------------------------------------------------------- +// MCP Resources — skill discovery and loading +// --------------------------------------------------------------------------- + +// ListSkills fetches available skills from the MCP server via resources/list. +// Skills are resources with URIs matching "skills/*.md". +func (c *Client) ListSkills(ctx context.Context) ([]skillEntry, error) { + c.logger.Debug("mcp listing skills via resources/list") + + result, err := c.call(ctx, "resources/list", nil) + if err != nil { + if stale, _ := c.skills.get(); stale != nil { + c.logger.WithError(err).Warn("mcp resources/list failed, using stale skill cache") + return stale, nil + } + return nil, fmt.Errorf("list resources: %w", err) + } + + var listResult struct { + Resources []resourceEntry `json:"resources"` + } + if err := json.Unmarshal(result, &listResult); err != nil { + return nil, fmt.Errorf("unmarshal resources: %w", err) + } + + var skills []skillEntry + for _, r := range listResult.Resources { + slug := extractSkillSlug(r.URI) + if slug == "" { + continue + } + skills = append(skills, skillEntry{ + Slug: slug, + Name: r.Name, + Description: r.Description, + URI: r.URI, + }) + } + + slugs := make([]string, len(skills)) + for i, s := range skills { + slugs[i] = s.Slug + } + c.logger.WithFields(logrus.Fields{ + "skill_count": len(skills), + "skill_slugs": slugs, + }).Info("mcp skills discovered") + + c.skills.set(skills) + return skills, nil +} + +// extractSkillSlug extracts a slug from a skill resource URI. +// Handles various URI formats: +// - "skill://vultisig/evm-contract-call.md" → "evm-contract-call" +// - "skills/evm-contract-call.md" → "evm-contract-call" +// +// Returns "" if the URI doesn't end in .md. +func extractSkillSlug(uri string) string { + if !strings.HasSuffix(uri, ".md") { + return "" + } + base := strings.TrimSuffix(uri, ".md") + if idx := strings.LastIndex(base, "/"); idx >= 0 { + return base[idx+1:] + } + return base +} + +// ReadSkill fetches the content of a specific skill by slug. +func (c *Client) ReadSkill(ctx context.Context, slug string) (string, error) { + // Check in-memory content cache first + if cached, ok := c.skillContent.Load(slug); ok { + return cached.(string), nil + } + + // Look up the full URI from the skill cache + uri := c.skillURI(slug) + if uri == "" { + return "", fmt.Errorf("skill %q not found in skill list", slug) + } + + c.logger.WithFields(logrus.Fields{ + "skill": slug, + "uri": uri, + }).Debug("mcp reading skill via resources/read") + + result, err := c.call(ctx, "resources/read", readResourceParams{URI: uri}) + if err != nil { + return "", fmt.Errorf("read skill %s: %w", slug, err) + } + + var readResult readResourceResult + if err := json.Unmarshal(result, &readResult); err != nil { + return "", fmt.Errorf("unmarshal skill content: %w", err) + } + + if len(readResult.Contents) == 0 { + return "", fmt.Errorf("skill %s: empty content", slug) + } + + text := readResult.Contents[0].Text + c.skillContent.Store(slug, text) + + c.logger.WithFields(logrus.Fields{ + "skill": slug, + "content_len": len(text), + }).Info("mcp skill loaded") + + return text, nil +} + +// skillURI looks up the full resource URI for a skill slug from the cache. +func (c *Client) skillURI(slug string) string { + skills, _ := c.skills.get() + for _, s := range skills { + if s.Slug == slug { + return s.URI + } + } + return "" +} + +// SkillSummary returns a formatted list of available skills for injection into the system prompt. +// Returns "" if no skills are available. Triggers a background refresh if cache is stale. +func (c *Client) SkillSummary(ctx context.Context) string { + skills, fresh := c.skills.get() + + if !fresh && skills != nil { + go func() { + refreshCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _, _ = c.ListSkills(refreshCtx) + }() + } + + if len(skills) == 0 { + return "" + } + + var b strings.Builder + b.WriteString("\n\n## Available Skills\n\n") + b.WriteString("You have access to specialized skill guides that provide detailed instructions for specific workflows. ") + b.WriteString("Use the `get_skill` tool to load a skill's full instructions when it is relevant to the user's request.\n\n") + b.WriteString("**IMPORTANT**: Only load skills that are directly relevant to what the user is asking. Do not load all skills.\n\n") + for _, s := range skills { + b.WriteString("- **") + b.WriteString(s.Slug) + b.WriteString("**") + if s.Description != "" { + b.WriteString(": ") + b.WriteString(s.Description) + } + b.WriteString("\n") + } + return b.String() +} diff --git a/internal/service/agent/agent.go b/internal/service/agent/agent.go index 82b7533..18ffcd2 100644 --- a/internal/service/agent/agent.go +++ b/internal/service/agent/agent.go @@ -29,6 +29,16 @@ type PluginSkillsProvider interface { GetSkills(ctx context.Context) []PluginSkill } +// MCPToolProvider provides tools and skills discovered from an MCP server. +type MCPToolProvider interface { + GetTools(ctx context.Context) []ai.Tool + ToolNames() []string + CallTool(ctx context.Context, name string, arguments json.RawMessage) (string, error) + ToolDescriptions() string + SkillSummary(ctx context.Context) string + ReadSkill(ctx context.Context, slug string) (string, error) +} + type SwapTxBuilder interface { BuildSwapTx(ctx context.Context, req SwapTxBuildRequest) (*SwapTxBuildResponse, error) } @@ -65,6 +75,7 @@ type AgentService struct { redis *redis.Client verifier *verifier.Client pluginProvider PluginSkillsProvider + mcpProvider MCPToolProvider swapTxBuilder SwapTxBuilder logger *logrus.Logger summaryModel string @@ -87,6 +98,7 @@ func NewAgentService( redisClient *redis.Client, verifierClient *verifier.Client, pluginProvider PluginSkillsProvider, + mcpProvider MCPToolProvider, swapTxBuilder SwapTxBuilder, logger *logrus.Logger, summaryModel string, @@ -100,6 +112,7 @@ func NewAgentService( redis: redisClient, verifier: verifierClient, pluginProvider: pluginProvider, + mcpProvider: mcpProvider, swapTxBuilder: swapTxBuilder, logger: logger, summaryModel: summaryModel, @@ -113,7 +126,7 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub ctx, cancel := context.WithTimeout(ctx, claudeRequestTimeout) defer cancel() - _, err := s.convRepo.GetByID(ctx, convID, publicKey) + conv, err := s.convRepo.GetByID(ctx, convID, publicKey) if err != nil { if errors.Is(err, postgres.ErrNotFound) { return nil, fmt.Errorf("conversation not found") @@ -121,6 +134,16 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub return nil, fmt.Errorf("get conversation: %w", err) } + // Prime MCP session with vault info if set + if conv.VaultInfo != nil && s.mcpProvider != nil { + vaultInput, _ := json.Marshal(conv.VaultInfo) + if _, err := s.mcpProvider.CallTool(ctx, "set_vault_info", vaultInput); err != nil { + s.logger.WithError(err).Warn("failed to prime mcp session with vault info") + } else { + s.logger.Debug("mcp session primed with vault info") + } + } + window, err := s.getConversationWindow(ctx, convID, publicKey) if err != nil { return nil, fmt.Errorf("get conversation window: %w", err) @@ -141,6 +164,20 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub fullCtx := s.resolveContext(ctx, convID, req.Context) basePrompt := BuildFullPrompt(fullCtx, s.getPluginSkills(ctx)) + if conv.VaultInfo != nil { + basePrompt += BuildVaultInfoSection(conv.VaultInfo) + } + if s.mcpProvider != nil { + mcpDesc := s.mcpProvider.ToolDescriptions() + if mcpDesc != "" { + basePrompt += mcpDesc + } + skillSummary := s.mcpProvider.SkillSummary(ctx) + if skillSummary != "" { + s.logger.WithField("skill_summary_len", len(skillSummary)).Debug("appending skill summary to system prompt") + basePrompt += skillSummary + } + } systemPrompt := BuildSystemPromptWithSummary( basePrompt+s.loadMemorySection(ctx, req.PublicKey)+MemoryManagementInstructions, window.summary, @@ -154,9 +191,31 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub tools := agentTools() tools = append(tools, s.memoryTools()...) + if s.mcpProvider != nil { + mcpTools := s.mcpProvider.GetTools(ctx) + if len(mcpTools) > 0 { + mcpNames := make([]string, len(mcpTools)) + for i, t := range mcpTools { + mcpNames[i] = t.Name + } + s.logger.WithFields(logrus.Fields{ + "mcp_tool_count": len(mcpTools), + "mcp_tool_names": mcpNames, + }).Debug("appending mcp tools to ai request") + } else { + s.logger.Warn("mcp provider active but no tools returned") + } + tools = append(tools, mcpTools...) + + // Add get_skill tool if skills are available + if s.mcpProvider.SkillSummary(ctx) != "" { + tools = append(tools, GetSkillTool) + } + } var toolResp *ToolResponse var textContent string + var tokens *TokenSearchResult for i := 0; i < maxLoopIterations; i++ { aiReq := &ai.Request{ @@ -176,62 +235,64 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub s.persistMemoryUpdate(ctx, req.PublicKey, s.extractMemoryUpdate(resp)) - var assistantText string - var toolCalls []ai.ContentBlock - for _, block := range resp.Content { - switch block.Type { - case "text": - assistantText += block.Text - case "tool_use": - toolCalls = append(toolCalls, block) - } - } + assistantText := resp.Content + toolCalls := resp.ToolCalls - if resp.StopReason == "end_turn" || len(toolCalls) == 0 { + if resp.FinishReason == "stop" || len(toolCalls) == 0 { textContent = assistantText break } - var toolResults []ai.ToolResultBlock + var toolMessages []ai.ToolMessage for _, tc := range toolCalls { - if tc.Name == "respond_to_user" { + if tc.Function.Name == "respond_to_user" { var tr ToolResponse - if err := json.Unmarshal(tc.Input, &tr); err == nil { + if err := json.Unmarshal(json.RawMessage(tc.Function.Arguments), &tr); err == nil { toolResp = &tr } - toolResults = append(toolResults, ai.ToolResultBlock{ - Type: "tool_result", - ToolUseID: tc.ID, - Content: `{"ok": true}`, + toolMessages = append(toolMessages, ai.ToolMessage{ + Role: "tool", + ToolCallID: tc.ID, + Content: `{"ok": true}`, }) continue } - result, err := s.executeTool(ctx, tc.Name, tc.Input, req) + result, err := s.executeTool(ctx, convID, tc.Function.Name, json.RawMessage(tc.Function.Arguments), req) if err != nil { result = jsonError(err.Error()) } s.logger.WithFields(logrus.Fields{ - "tool": tc.Name, + "tool": tc.Function.Name, "tool_id": tc.ID, }).Debug("tool executed") - toolResults = append(toolResults, ai.ToolResultBlock{ - Type: "tool_result", - ToolUseID: tc.ID, - Content: result, + toolMessages = append(toolMessages, ai.ToolMessage{ + Role: "tool", + ToolCallID: tc.ID, + Content: result, }) + + // Track find_token results for structured passthrough + if tc.Function.Name == "find_token" { + if parsed := extractTokens(result); parsed != nil { + tokens = parsed + s.logger.WithField("token_count", len(parsed.Tokens)).Info("tokens extracted from find_token result") + } else { + s.logger.WithField("result_preview", truncateResult(result, 200)).Warn("find_token result could not be parsed as token data") + } + } } messages = append(messages, ai.AssistantMessage{ - Role: "assistant", - Content: resp.Content, - }) - messages = append(messages, ai.ToolResultMessage{ - Role: "user", - Content: toolResults, + Role: "assistant", + Content: assistantText, + ToolCalls: toolCalls, }) + for _, tm := range toolMessages { + messages = append(messages, tm) + } if toolResp != nil { break @@ -239,10 +300,20 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub } if toolResp != nil { - return s.buildLoopResponse(ctx, convID, req, toolResp, window) + resp, err := s.buildLoopResponse(ctx, convID, req, toolResp, window) + if err != nil { + return nil, err + } + resp.Tokens = tokens + return resp, nil } if textContent != "" { - return s.buildTextResponse(ctx, convID, textContent) + resp, err := s.buildTextResponse(ctx, convID, textContent) + if err != nil { + return nil, err + } + resp.Tokens = tokens + return resp, nil } return nil, errors.New("no response content from Claude") @@ -258,12 +329,20 @@ func (s *AgentService) ProcessMessageStream(ctx context.Context, convID uuid.UUI ctx, cancel := context.WithTimeout(ctx, claudeRequestTimeout) defer cancel() - _, err := s.convRepo.GetByID(ctx, convID, publicKey) + conv, err := s.convRepo.GetByID(ctx, convID, publicKey) if err != nil { sendErr("conversation not found") return } + // Prime MCP session with vault info if set + if conv.VaultInfo != nil && s.mcpProvider != nil { + vaultInput, _ := json.Marshal(conv.VaultInfo) + if _, err := s.mcpProvider.CallTool(ctx, "set_vault_info", vaultInput); err != nil { + s.logger.WithError(err).Warn("failed to prime mcp session with vault info") + } + } + window, err := s.getConversationWindow(ctx, convID, publicKey) if err != nil { sendErr("failed to load conversation") @@ -287,6 +366,20 @@ func (s *AgentService) ProcessMessageStream(ctx context.Context, convID uuid.UUI fullCtx := s.resolveContext(ctx, convID, req.Context) basePrompt := BuildFullPrompt(fullCtx, s.getPluginSkills(ctx)) + if conv.VaultInfo != nil { + basePrompt += BuildVaultInfoSection(conv.VaultInfo) + } + if s.mcpProvider != nil { + mcpDesc := s.mcpProvider.ToolDescriptions() + if mcpDesc != "" { + basePrompt += mcpDesc + } + skillSummary := s.mcpProvider.SkillSummary(ctx) + if skillSummary != "" { + s.logger.WithField("skill_summary_len", len(skillSummary)).Debug("appending skill summary to system prompt (stream)") + basePrompt += skillSummary + } + } systemPrompt := BuildSystemPromptWithSummary( basePrompt+s.loadMemorySection(ctx, req.PublicKey)+MemoryManagementInstructions, window.summary, @@ -300,6 +393,15 @@ func (s *AgentService) ProcessMessageStream(ctx context.Context, convID uuid.UUI tools := agentTools() tools = append(tools, s.memoryTools()...) + if s.mcpProvider != nil { + mcpTools := s.mcpProvider.GetTools(ctx) + if len(mcpTools) > 0 { + tools = append(tools, mcpTools...) + } + if s.mcpProvider.SkillSummary(ctx) != "" { + tools = append(tools, GetSkillTool) + } + } var toolResp *ToolResponse var textContent string @@ -316,29 +418,16 @@ func (s *AgentService) ProcessMessageStream(ctx context.Context, convID uuid.UUI } extractor := ai.NewResponseFieldExtractor() - callback := func(ev ai.StreamEvent) { - if ev.Type != ai.StreamEventContentBlockDelta { - return + callback := func(delta ai.StreamDelta) { + if delta.Content != "" { + eventCh <- SSEEvent{Event: "text_delta", Data: TextDeltaPayload{Delta: delta.Content}} } - var delta struct { - Delta struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - PartialJSON string `json:"partial_json,omitempty"` - } `json:"delta"` - } - if err := json.Unmarshal(ev.Data, &delta); err != nil { - return - } - switch delta.Delta.Type { - case "text_delta": - if delta.Delta.Text != "" { - eventCh <- SSEEvent{Event: "text_delta", Data: TextDeltaPayload{Delta: delta.Delta.Text}} - } - case "input_json_delta": - text := extractor.Feed(delta.Delta.PartialJSON) - if text != "" { - eventCh <- SSEEvent{Event: "text_delta", Data: TextDeltaPayload{Delta: text}} + for _, tc := range delta.ToolCalls { + if tc.Function.Arguments != "" { + text := extractor.Feed(tc.Function.Arguments) + if text != "" { + eventCh <- SSEEvent{Event: "text_delta", Data: TextDeltaPayload{Delta: text}} + } } } } @@ -351,56 +440,48 @@ func (s *AgentService) ProcessMessageStream(ctx context.Context, convID uuid.UUI s.persistMemoryUpdate(ctx, req.PublicKey, s.extractMemoryUpdate(resp)) - var assistantText string - var toolCalls []ai.ContentBlock - for _, block := range resp.Content { - switch block.Type { - case "text": - assistantText += block.Text - case "tool_use": - toolCalls = append(toolCalls, block) - } - } + assistantText := resp.Content + toolCalls := resp.ToolCalls - if resp.StopReason == "end_turn" || len(toolCalls) == 0 { + if resp.FinishReason == "stop" || len(toolCalls) == 0 { textContent = assistantText break } - var toolResults []ai.ToolResultBlock + var toolMessages []ai.ToolMessage for _, tc := range toolCalls { - if tc.Name == "respond_to_user" { + if tc.Function.Name == "respond_to_user" { var tr ToolResponse - if err := json.Unmarshal(tc.Input, &tr); err == nil { + if err := json.Unmarshal(json.RawMessage(tc.Function.Arguments), &tr); err == nil { toolResp = &tr } - toolResults = append(toolResults, ai.ToolResultBlock{ - Type: "tool_result", - ToolUseID: tc.ID, - Content: `{"ok": true}`, + toolMessages = append(toolMessages, ai.ToolMessage{ + Role: "tool", + ToolCallID: tc.ID, + Content: `{"ok": true}`, }) continue } - result, err := s.executeTool(ctx, tc.Name, tc.Input, req) + result, err := s.executeTool(ctx, convID, tc.Function.Name, json.RawMessage(tc.Function.Arguments), req) if err != nil { result = jsonError(err.Error()) } - toolResults = append(toolResults, ai.ToolResultBlock{ - Type: "tool_result", - ToolUseID: tc.ID, - Content: result, + toolMessages = append(toolMessages, ai.ToolMessage{ + Role: "tool", + ToolCallID: tc.ID, + Content: result, }) } messages = append(messages, ai.AssistantMessage{ - Role: "assistant", - Content: resp.Content, - }) - messages = append(messages, ai.ToolResultMessage{ - Role: "user", - Content: toolResults, + Role: "assistant", + Content: assistantText, + ToolCalls: toolCalls, }) + for _, tm := range toolMessages { + messages = append(messages, tm) + } if toolResp != nil { break @@ -696,6 +777,44 @@ func (s *AgentService) emitTextResponse(ctx context.Context, convID uuid.UUID, t eventCh <- SSEEvent{Event: "message", Data: MessagePayload{Message: *assistantMsg}} } +// extractTokens tries to parse a TokenSearchResult from an MCP tool result. +// MCP text content may not be pure JSON (e.g., multiple text blocks joined with \n, +// or descriptive text surrounding JSON), so we try multiple strategies. +func extractTokens(result string) *TokenSearchResult { + // Strategy 1: direct unmarshal (pure JSON) + var direct TokenSearchResult + if err := json.Unmarshal([]byte(result), &direct); err == nil && len(direct.Tokens) > 0 { + return &direct + } + + // Strategy 2: the result may contain non-JSON text around the JSON object. + // Scan for the first '{' and try to decode from there. json.Decoder + // stops after the first complete JSON value, ignoring trailing text. + for i := strings.IndexByte(result, '{'); i >= 0 && i < len(result); { + var candidate TokenSearchResult + dec := json.NewDecoder(strings.NewReader(result[i:])) + if err := dec.Decode(&candidate); err == nil && len(candidate.Tokens) > 0 { + return &candidate + } + // Try the next '{' occurrence + next := strings.IndexByte(result[i+1:], '{') + if next < 0 { + break + } + i = i + 1 + next + } + + return nil +} + +// truncateResult returns the first n bytes of a string for log previews. +func truncateResult(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +} + func (s *AgentService) autoContinueAfterInstall(ctx context.Context, convID uuid.UUID, req *SendMessageRequest, window *conversationWindow, resp *SendMessageResponse) { pendingKey := fmt.Sprintf("pending_build:%s", convID) suggID, err := s.redis.Get(ctx, pendingKey) @@ -1065,14 +1184,7 @@ func (s *AgentService) summarizeOldMessages(ctx context.Context, convID uuid.UUI return fmt.Errorf("call ai: %w", err) } - var summaryText string - for _, block := range resp.Content { - if block.Type == "text" { - summaryText = block.Text - break - } - } - + summaryText := resp.Content if summaryText == "" { return fmt.Errorf("empty response from ai") } diff --git a/internal/service/agent/executor.go b/internal/service/agent/executor.go index 15a094c..8bf2785 100644 --- a/internal/service/agent/executor.go +++ b/internal/service/agent/executor.go @@ -3,16 +3,20 @@ package agent import ( "context" "encoding/json" + "errors" "time" "github.com/google/uuid" + "github.com/sirupsen/logrus" + + "github.com/vultisig/agent-backend/internal/mcp" ) const suggestionTTL = 1 * time.Hour // executeTool dispatches a tool call to the appropriate handler. // Returns a JSON string result for Claude. Errors are returned as JSON {"error": "..."} so the LLM can communicate them naturally. -func (s *AgentService) executeTool(ctx context.Context, name string, input json.RawMessage, req *SendMessageRequest) (string, error) { +func (s *AgentService) executeTool(ctx context.Context, convID uuid.UUID, name string, input json.RawMessage, req *SendMessageRequest) (string, error) { switch name { case "check_plugin_installed": return s.execCheckPluginInstalled(ctx, input, req) @@ -26,11 +30,112 @@ func (s *AgentService) executeTool(ctx context.Context, name string, input json. return s.execCreateSuggestion(ctx, input) case "update_memory": return s.execUpdateMemory(ctx, input, req) + case "get_skill": + return s.execGetSkill(ctx, input) + case "set_vault": + return s.execSetVault(ctx, convID, input, req) default: + // Check MCP tools before returning unknown + if s.mcpProvider != nil { + for _, mcpName := range s.mcpProvider.ToolNames() { + if mcpName == name { + result, err := s.mcpProvider.CallTool(ctx, name, input) + if err != nil { + // ToolError carries the text content — return it so + // trackToolResult can still extract structured data + // and Claude can narrate the error to the user. + var toolErr *mcp.ToolError + if errors.As(err, &toolErr) && toolErr.Text != "" { + s.logger.WithField("tool", name).Warn("mcp tool returned isError with content") + return toolErr.Text, nil + } + s.logger.WithError(err).WithField("tool", name).Warn("mcp tool call failed") + return jsonError("mcp tool error: " + err.Error()), nil + } + return result, nil + } + } + } return jsonError("unknown tool: " + name), nil } } +// execSetVault stores vault keys for this conversation and primes the MCP session. +func (s *AgentService) execSetVault(ctx context.Context, convID uuid.UUID, input json.RawMessage, req *SendMessageRequest) (string, error) { + var params struct { + ECDSAPublicKey string `json:"ecdsa_public_key"` + EDDSAPublicKey string `json:"eddsa_public_key"` + ChainCode string `json:"chain_code"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return jsonError("invalid input: " + err.Error()), nil + } + + if params.ECDSAPublicKey == "" || params.EDDSAPublicKey == "" || params.ChainCode == "" { + return jsonError("all three fields are required: ecdsa_public_key, eddsa_public_key, chain_code"), nil + } + + // Store in DB + if err := s.convRepo.UpdateVaultInfo(ctx, convID, req.PublicKey, params.ECDSAPublicKey, params.EDDSAPublicKey, params.ChainCode); err != nil { + s.logger.WithError(err).Error("failed to store vault info") + return jsonError("failed to store vault info: " + err.Error()), nil + } + + s.logger.WithFields(logrus.Fields{ + "conversation_id": convID, + "ecdsa_prefix": truncateKey(params.ECDSAPublicKey), + "eddsa_prefix": truncateKey(params.EDDSAPublicKey), + "chain_code_prefix": truncateKey(params.ChainCode), + }).Info("vault info set for conversation") + + // Prime MCP session with vault info + if s.mcpProvider != nil { + mcpInput, _ := json.Marshal(params) + if _, err := s.mcpProvider.CallTool(ctx, "set_vault_info", mcpInput); err != nil { + s.logger.WithError(err).Warn("failed to prime mcp session with vault info") + // Non-fatal: vault is stored locally, MCP will be primed on next request + } + } + + result, _ := json.Marshal(map[string]any{ + "ok": true, + "message": "Vault set for this conversation.", + }) + return string(result), nil +} + +// execGetSkill loads a skill's full instructions from the MCP server. +func (s *AgentService) execGetSkill(ctx context.Context, input json.RawMessage) (string, error) { + var params struct { + SkillName string `json:"skill_name"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return jsonError("invalid input: " + err.Error()), nil + } + if params.SkillName == "" { + return jsonError("skill_name is required"), nil + } + if s.mcpProvider == nil { + return jsonError("skills not available"), nil + } + + content, err := s.mcpProvider.ReadSkill(ctx, params.SkillName) + if err != nil { + s.logger.WithError(err).WithField("skill", params.SkillName).Warn("failed to load skill") + return jsonError("failed to load skill: " + err.Error()), nil + } + + return content, nil +} + +// truncateKey returns the first 12 chars of a key for logging. +func truncateKey(key string) string { + if len(key) <= 12 { + return key + } + return key[:12] + "..." +} + // execCheckPluginInstalled checks if a plugin is installed for the user's vault. func (s *AgentService) execCheckPluginInstalled(ctx context.Context, input json.RawMessage, req *SendMessageRequest) (string, error) { var params struct { diff --git a/internal/service/agent/memory.go b/internal/service/agent/memory.go index 53adda0..f2a1ef0 100644 --- a/internal/service/agent/memory.go +++ b/internal/service/agent/memory.go @@ -62,10 +62,10 @@ func (s *AgentService) persistMemoryUpdate(ctx context.Context, publicKey string } func (s *AgentService) extractMemoryUpdate(resp *ai.Response) *updateMemoryInput { - for _, block := range resp.Content { - if block.Type == "tool_use" && block.Name == "update_memory" { + for _, tc := range resp.ToolCalls { + if tc.Function.Name == "update_memory" { var mu updateMemoryInput - if err := json.Unmarshal(block.Input, &mu); err == nil && mu.Content != "" { + if err := json.Unmarshal(json.RawMessage(tc.Function.Arguments), &mu); err == nil && mu.Content != "" { return &mu } } diff --git a/internal/service/agent/prompt.go b/internal/service/agent/prompt.go index b5c694b..d15ca02 100644 --- a/internal/service/agent/prompt.go +++ b/internal/service/agent/prompt.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/vultisig/agent-backend/internal/ai" + "github.com/vultisig/agent-backend/internal/types" ) // ActionsTable is the shared actions reference used by both the system prompt and starters. @@ -512,6 +513,19 @@ func BuildFullPrompt(msgCtx *MessageContext, plugins []PluginSkill) string { return sb.String() } +// BuildVaultInfoSection returns a system prompt section describing the active vault. +func BuildVaultInfoSection(v *types.VaultInfo) string { + if v == nil { + return "" + } + return "\n\n## Active Vault\n\n" + + "This conversation has a vault bound to it. The MCP server has been primed with these keys " + + "so tools like get_eth_balance and get_token_balance will derive addresses automatically.\n\n" + + "- ECDSA public key: `" + v.ECDSAPublicKey + "`\n" + + "- EdDSA public key: `" + v.EDDSAPublicKey + "`\n" + + "- Chaincode: `" + v.ChaincodeHex + "`\n" +} + // UpdateMemoryTool is the tool definition for updating the user's memory document. var UpdateMemoryTool = ai.Tool{ Name: "update_memory", diff --git a/internal/service/agent/starters.go b/internal/service/agent/starters.go index bbd252a..afb5479 100644 --- a/internal/service/agent/starters.go +++ b/internal/service/agent/starters.go @@ -81,14 +81,7 @@ func (s *AgentService) GenerateStarters(ctx context.Context, req *GetStartersReq return empty } - var text string - for _, block := range resp.Content { - if block.Type == "text" { - text = block.Text - break - } - } - + text := resp.Content if text == "" { return empty } diff --git a/internal/service/agent/tools.go b/internal/service/agent/tools.go index ef864a3..3db6bce 100644 --- a/internal/service/agent/tools.go +++ b/internal/service/agent/tools.go @@ -102,6 +102,53 @@ var CheckBillingStatusTool = ai.Tool{ }, } +// SetVaultTool sets the active vault for this conversation. +var SetVaultTool = ai.Tool{ + Name: "set_vault", + Description: "Set the active vault for this conversation. " + + "Call this when the user provides their vault's public keys and chaincode. " + + "This binds the vault to the conversation so tools like get_eth_balance " + + "and get_token_balance can derive addresses automatically. " + + "The user may switch vaults during a conversation by calling this again.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "ecdsa_public_key": map[string]any{ + "type": "string", + "description": "The vault's ECDSA (secp256k1) public key in hex.", + }, + "eddsa_public_key": map[string]any{ + "type": "string", + "description": "The vault's EdDSA (ed25519) public key in hex.", + }, + "chain_code": map[string]any{ + "type": "string", + "description": "The vault's chaincode in hex, used for key derivation.", + }, + }, + "required": []string{"ecdsa_public_key", "eddsa_public_key", "chain_code"}, + }, +} + +// GetSkillTool loads a specific skill's full instructions on demand. +// Added to the tool list dynamically only when skills are available from MCP. +var GetSkillTool = ai.Tool{ + Name: "get_skill", + Description: "Load the full instructions for a specific skill. " + + "Use this when you identify a skill from the Available Skills list that is relevant to the user's request. " + + "Only load skills that are directly needed — do not speculatively load skills.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "skill_name": map[string]any{ + "type": "string", + "description": "The slug name of the skill to load (as listed in Available Skills).", + }, + }, + "required": []string{"skill_name"}, + }, +} + // agentTools returns all granular tools for the decision loop. func agentTools() []ai.Tool { return []ai.Tool{ @@ -111,5 +158,6 @@ func agentTools() []ai.Tool { GetRecipeSchemaTool, SuggestPolicyTool, CreateSuggestionTool, + SetVaultTool, } } diff --git a/internal/service/agent/types.go b/internal/service/agent/types.go index 0b004ca..b7ed301 100644 --- a/internal/service/agent/types.go +++ b/internal/service/agent/types.go @@ -72,13 +72,27 @@ type ActionResult struct { // SendMessageResponse is the response for sending a message. type SendMessageResponse struct { - Message types.Message `json:"message"` - Title *string `json:"title,omitempty"` - Suggestions []Suggestion `json:"suggestions,omitempty"` - Actions []Action `json:"actions,omitempty"` - PolicyReady *PolicyReady `json:"policy_ready,omitempty"` - InstallRequired *InstallRequired `json:"install_required,omitempty"` - TxReady *TxReady `json:"tx_ready,omitempty"` + Message types.Message `json:"message"` + Title *string `json:"title,omitempty"` + Suggestions []Suggestion `json:"suggestions,omitempty"` + Actions []Action `json:"actions,omitempty"` + PolicyReady *PolicyReady `json:"policy_ready,omitempty"` + InstallRequired *InstallRequired `json:"install_required,omitempty"` + TxReady *TxReady `json:"tx_ready,omitempty"` + Transactions []Transaction `json:"transactions,omitempty"` + Tokens *TokenSearchResult `json:"tokens,omitempty"` +} + +// Transaction represents an unsigned transaction returned by an MCP tool +// that the wallet must sign and broadcast. +type Transaction struct { + Sequence int `json:"sequence"` + Chain string `json:"chain"` + ChainID string `json:"chain_id"` + Action string `json:"action"` + SigningMode string `json:"signing_mode"` + UnsignedTxHex string `json:"unsigned_tx_hex"` + TxDetails map[string]string `json:"tx_details"` } type TxReady struct { @@ -196,3 +210,25 @@ type ToolAction struct { Params map[string]any `json:"params,omitempty"` AutoExecute bool `json:"auto_execute"` } + +// TokenSearchResult contains tokens returned by the find_token MCP tool. +type TokenSearchResult struct { + Tokens []Token `json:"tokens"` +} + +// Token represents a cryptocurrency token with its on-chain deployments. +type Token struct { + ID string `json:"id"` + Name string `json:"name"` + Symbol string `json:"symbol"` + MarketCapRank int `json:"market_cap_rank"` + Logo string `json:"logo"` + Deployments []TokenDeployment `json:"deployments"` +} + +// TokenDeployment represents a token's deployment on a specific chain. +type TokenDeployment struct { + Chain string `json:"chain"` + ContractAddress string `json:"contract_address"` + Decimals int `json:"decimals"` +} diff --git a/internal/storage/postgres/conversation.go b/internal/storage/postgres/conversation.go index 367baf1..b462709 100644 --- a/internal/storage/postgres/conversation.go +++ b/internal/storage/postgres/conversation.go @@ -135,6 +135,46 @@ func (r *ConversationRepository) UpdateSummaryWithCursor(ctx context.Context, id return nil } +// UpdateVaultInfo updates the vault keys for a conversation. +func (r *ConversationRepository) UpdateVaultInfo(ctx context.Context, id uuid.UUID, publicKey string, ecdsa, eddsa, chaincode string) error { + rowsAffected, err := r.q.UpdateVaultInfo(ctx, &queries.UpdateVaultInfoParams{ + EcdsaPublicKey: stringPtrToPgtext(&ecdsa), + EddsaPublicKey: stringPtrToPgtext(&eddsa), + ChaincodeHex: stringPtrToPgtext(&chaincode), + ID: uuidToPgtype(id), + PublicKey: publicKey, + }) + if err != nil { + return fmt.Errorf("update vault info: %w", err) + } + if rowsAffected == 0 { + return ErrNotFound + } + return nil +} + +// GetVaultInfo returns the vault keys for a conversation. +func (r *ConversationRepository) GetVaultInfo(ctx context.Context, id uuid.UUID, publicKey string) (*types.VaultInfo, error) { + row, err := r.q.GetVaultInfo(ctx, &queries.GetVaultInfoParams{ + ID: uuidToPgtype(id), + PublicKey: publicKey, + }) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("get vault info: %w", err) + } + if !row.EcdsaPublicKey.Valid || !row.EddsaPublicKey.Valid || !row.ChaincodeHex.Valid { + return nil, nil + } + return &types.VaultInfo{ + ECDSAPublicKey: row.EcdsaPublicKey.String, + EDDSAPublicKey: row.EddsaPublicKey.String, + ChaincodeHex: row.ChaincodeHex.String, + }, nil +} + // GetSummaryWithCursor returns the summary and summary_up_to cursor of a conversation. func (r *ConversationRepository) GetSummaryWithCursor(ctx context.Context, id uuid.UUID, publicKey string) (*string, *time.Time, error) { row, err := r.q.GetConversationSummaryWithCursor(ctx, &queries.GetConversationSummaryWithCursorParams{ diff --git a/internal/storage/postgres/convert.go b/internal/storage/postgres/convert.go index 282126c..ee02afa 100644 --- a/internal/storage/postgres/convert.go +++ b/internal/storage/postgres/convert.go @@ -72,7 +72,7 @@ func conversationFromDB(c *queries.AgentConversation) *types.Conversation { if c == nil { return nil } - return &types.Conversation{ + conv := &types.Conversation{ ID: pgtypeToUUID(c.ID), PublicKey: c.PublicKey, Title: pgtextToStringPtr(c.Title), @@ -82,6 +82,14 @@ func conversationFromDB(c *queries.AgentConversation) *types.Conversation { UpdatedAt: pgtimestamptzToTime(c.UpdatedAt), ArchivedAt: pgtimestamptzToTimePtr(c.ArchivedAt), } + if c.EcdsaPublicKey.Valid && c.EddsaPublicKey.Valid && c.ChaincodeHex.Valid { + conv.VaultInfo = &types.VaultInfo{ + ECDSAPublicKey: c.EcdsaPublicKey.String, + EDDSAPublicKey: c.EddsaPublicKey.String, + ChaincodeHex: c.ChaincodeHex.String, + } + } + return conv } func conversationsFromDB(cs []*queries.AgentConversation) []types.Conversation { diff --git a/internal/storage/postgres/migrations/20260221000001_add_vault_info.sql b/internal/storage/postgres/migrations/20260221000001_add_vault_info.sql new file mode 100644 index 0000000..d84e2ef --- /dev/null +++ b/internal/storage/postgres/migrations/20260221000001_add_vault_info.sql @@ -0,0 +1,11 @@ +-- +goose Up +ALTER TABLE agent_conversations + ADD COLUMN ecdsa_public_key TEXT, + ADD COLUMN eddsa_public_key TEXT, + ADD COLUMN chaincode_hex TEXT; + +-- +goose Down +ALTER TABLE agent_conversations + DROP COLUMN IF EXISTS ecdsa_public_key, + DROP COLUMN IF EXISTS eddsa_public_key, + DROP COLUMN IF EXISTS chaincode_hex; diff --git a/internal/storage/postgres/queries/conversations.sql.go b/internal/storage/postgres/queries/conversations.sql.go index 819c0a1..423a5af 100644 --- a/internal/storage/postgres/queries/conversations.sql.go +++ b/internal/storage/postgres/queries/conversations.sql.go @@ -46,7 +46,7 @@ const createConversation = `-- name: CreateConversation :one INSERT INTO agent_conversations (public_key) VALUES ($1) -RETURNING id, public_key, title, summary, summary_up_to, created_at, updated_at, archived_at +RETURNING id, public_key, title, summary, summary_up_to, ecdsa_public_key, eddsa_public_key, chaincode_hex, created_at, updated_at, archived_at ` // Conversations table queries @@ -59,6 +59,9 @@ func (q *Queries) CreateConversation(ctx context.Context, publicKey string) (*Ag &i.Title, &i.Summary, &i.SummaryUpTo, + &i.EcdsaPublicKey, + &i.EddsaPublicKey, + &i.ChaincodeHex, &i.CreatedAt, &i.UpdatedAt, &i.ArchivedAt, @@ -67,7 +70,7 @@ func (q *Queries) CreateConversation(ctx context.Context, publicKey string) (*Ag } const getConversationByID = `-- name: GetConversationByID :one -SELECT id, public_key, title, summary, summary_up_to, created_at, updated_at, archived_at FROM agent_conversations +SELECT id, public_key, title, summary, summary_up_to, ecdsa_public_key, eddsa_public_key, chaincode_hex, created_at, updated_at, archived_at FROM agent_conversations WHERE id = $1 AND public_key = $2 AND archived_at IS NULL ` @@ -85,6 +88,9 @@ func (q *Queries) GetConversationByID(ctx context.Context, arg *GetConversationB &i.Title, &i.Summary, &i.SummaryUpTo, + &i.EcdsaPublicKey, + &i.EddsaPublicKey, + &i.ChaincodeHex, &i.CreatedAt, &i.UpdatedAt, &i.ArchivedAt, @@ -114,8 +120,31 @@ func (q *Queries) GetConversationSummaryWithCursor(ctx context.Context, arg *Get return &i, err } +const getVaultInfo = `-- name: GetVaultInfo :one +SELECT ecdsa_public_key, eddsa_public_key, chaincode_hex FROM agent_conversations +WHERE id = $1 AND public_key = $2 AND archived_at IS NULL +` + +type GetVaultInfoParams struct { + ID pgtype.UUID `json:"id"` + PublicKey string `json:"public_key"` +} + +type GetVaultInfoRow struct { + EcdsaPublicKey pgtype.Text `json:"ecdsa_public_key"` + EddsaPublicKey pgtype.Text `json:"eddsa_public_key"` + ChaincodeHex pgtype.Text `json:"chaincode_hex"` +} + +func (q *Queries) GetVaultInfo(ctx context.Context, arg *GetVaultInfoParams) (*GetVaultInfoRow, error) { + row := q.db.QueryRow(ctx, getVaultInfo, arg.ID, arg.PublicKey) + var i GetVaultInfoRow + err := row.Scan(&i.EcdsaPublicKey, &i.EddsaPublicKey, &i.ChaincodeHex) + return &i, err +} + const listConversations = `-- name: ListConversations :many -SELECT id, public_key, title, summary, summary_up_to, created_at, updated_at, archived_at FROM agent_conversations +SELECT id, public_key, title, summary, summary_up_to, ecdsa_public_key, eddsa_public_key, chaincode_hex, created_at, updated_at, archived_at FROM agent_conversations WHERE public_key = $1 AND archived_at IS NULL ORDER BY updated_at DESC LIMIT $2 OFFSET $3 @@ -142,6 +171,9 @@ func (q *Queries) ListConversations(ctx context.Context, arg *ListConversationsP &i.Title, &i.Summary, &i.SummaryUpTo, + &i.EcdsaPublicKey, + &i.EddsaPublicKey, + &i.ChaincodeHex, &i.CreatedAt, &i.UpdatedAt, &i.ArchivedAt, @@ -201,3 +233,31 @@ func (q *Queries) UpdateConversationTitle(ctx context.Context, arg *UpdateConver } return result.RowsAffected(), nil } + +const updateVaultInfo = `-- name: UpdateVaultInfo :execrows +UPDATE agent_conversations +SET ecdsa_public_key = $1, eddsa_public_key = $2, chaincode_hex = $3, updated_at = NOW() +WHERE id = $4 AND public_key = $5 AND archived_at IS NULL +` + +type UpdateVaultInfoParams struct { + EcdsaPublicKey pgtype.Text `json:"ecdsa_public_key"` + EddsaPublicKey pgtype.Text `json:"eddsa_public_key"` + ChaincodeHex pgtype.Text `json:"chaincode_hex"` + ID pgtype.UUID `json:"id"` + PublicKey string `json:"public_key"` +} + +func (q *Queries) UpdateVaultInfo(ctx context.Context, arg *UpdateVaultInfoParams) (int64, error) { + result, err := q.db.Exec(ctx, updateVaultInfo, + arg.EcdsaPublicKey, + arg.EddsaPublicKey, + arg.ChaincodeHex, + arg.ID, + arg.PublicKey, + ) + if err != nil { + return 0, err + } + return result.RowsAffected(), nil +} diff --git a/internal/storage/postgres/queries/models.go b/internal/storage/postgres/queries/models.go index abfb866..00e4ab1 100644 --- a/internal/storage/postgres/queries/models.go +++ b/internal/storage/postgres/queries/models.go @@ -55,14 +55,17 @@ func (ns NullAgentMessageRole) Value() (driver.Value, error) { } type AgentConversation struct { - ID pgtype.UUID `json:"id"` - PublicKey string `json:"public_key"` - Title pgtype.Text `json:"title"` - Summary pgtype.Text `json:"summary"` - SummaryUpTo pgtype.Timestamptz `json:"summary_up_to"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` - ArchivedAt pgtype.Timestamptz `json:"archived_at"` + ID pgtype.UUID `json:"id"` + PublicKey string `json:"public_key"` + Title pgtype.Text `json:"title"` + Summary pgtype.Text `json:"summary"` + SummaryUpTo pgtype.Timestamptz `json:"summary_up_to"` + EcdsaPublicKey pgtype.Text `json:"ecdsa_public_key"` + EddsaPublicKey pgtype.Text `json:"eddsa_public_key"` + ChaincodeHex pgtype.Text `json:"chaincode_hex"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` + ArchivedAt pgtype.Timestamptz `json:"archived_at"` } type AgentMessage struct { diff --git a/internal/storage/postgres/schema/schema.sql b/internal/storage/postgres/schema/schema.sql index ece039b..e0fadbb 100644 --- a/internal/storage/postgres/schema/schema.sql +++ b/internal/storage/postgres/schema/schema.sql @@ -9,6 +9,9 @@ CREATE TABLE agent_conversations ( title TEXT, summary TEXT, summary_up_to TIMESTAMPTZ, + ecdsa_public_key TEXT, + eddsa_public_key TEXT, + chaincode_hex TEXT, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), archived_at TIMESTAMPTZ diff --git a/internal/storage/postgres/sqlc/conversations.sql b/internal/storage/postgres/sqlc/conversations.sql index d6c3c9c..32b9473 100644 --- a/internal/storage/postgres/sqlc/conversations.sql +++ b/internal/storage/postgres/sqlc/conversations.sql @@ -37,3 +37,12 @@ WHERE id = $3 AND public_key = $4; -- name: GetConversationSummaryWithCursor :one SELECT summary, summary_up_to FROM agent_conversations WHERE id = $1 AND public_key = $2; + +-- name: UpdateVaultInfo :execrows +UPDATE agent_conversations +SET ecdsa_public_key = $1, eddsa_public_key = $2, chaincode_hex = $3, updated_at = NOW() +WHERE id = $4 AND public_key = $5 AND archived_at IS NULL; + +-- name: GetVaultInfo :one +SELECT ecdsa_public_key, eddsa_public_key, chaincode_hex FROM agent_conversations +WHERE id = $1 AND public_key = $2 AND archived_at IS NULL; diff --git a/internal/types/conversation.go b/internal/types/conversation.go index 08306dd..cfc62b7 100644 --- a/internal/types/conversation.go +++ b/internal/types/conversation.go @@ -23,11 +23,19 @@ type Conversation struct { Title *string `json:"title"` Summary *string `json:"summary,omitempty"` SummaryUpTo *time.Time `json:"summary_up_to,omitempty"` + VaultInfo *VaultInfo `json:"vault_info,omitempty"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` ArchivedAt *time.Time `json:"archived_at,omitempty"` } +// VaultInfo holds the cryptographic keys for a vault bound to a conversation. +type VaultInfo struct { + ECDSAPublicKey string `json:"ecdsa_public_key"` + EDDSAPublicKey string `json:"eddsa_public_key"` + ChaincodeHex string `json:"chaincode_hex"` +} + // Message represents a single message in a conversation. type Message struct { ID uuid.UUID `json:"id"`