diff --git a/client.go b/client.go index 0f3d485..0419791 100644 --- a/client.go +++ b/client.go @@ -8,6 +8,7 @@ package mcp import ( "context" + "encoding/json" "fmt" "net/http" "net/url" @@ -128,6 +129,10 @@ type Client struct { // Roots support. rootsProvider RootsProvider // Provider for roots information. rootsMu sync.RWMutex // Mutex for protecting the rootsProvider. + + // Sampling support + samplingEnabled bool + samplingHandler SamplingHandler } // ClientOption client option function @@ -200,6 +205,11 @@ type transportConfig struct { httpReqHandlerOptions []HTTPReqHandlerOption } +// SamplingHandler processes server to client sampling/createMessage requests +type SamplingHandler interface { + CreateMessage(ctx context.Context, params *CreateMessageParams) (*CreateMessageResult, error) +} + // newDefaultTransportConfig creates a default transport configuration. func newDefaultTransportConfig() *transportConfig { return &transportConfig{ @@ -307,6 +317,24 @@ func WithHTTPReqHandlerOption(options ...HTTPReqHandlerOption) ClientOption { } } +// WithSampling enables client-side sampling capability and registers a handler. +func WithSampling(handler SamplingHandler) ClientOption { + return func(c *Client) { + c.samplingEnabled = true + c.samplingHandler = handler + + if c.capabilities == nil { + c.capabilities = make(map[string]interface{}) + } + c.capabilities["sampling"] = map[string]interface{}{} + } +} + +// RegisterSamplingHandler allows changing the handler at runtime. +func (c *Client) RegisterSamplingHandler(handler SamplingHandler) { + c.samplingHandler = handler +} + // GetState returns the current client state. func (c *Client) GetState() State { return c.state @@ -667,3 +695,33 @@ func (c *Client) SendRootsListChangedNotification(ctx context.Context) error { func isZeroStruct(x interface{}) bool { return reflect.ValueOf(x).IsZero() } + +// dispatchReverseRequest is called by transports when the server sends a request to the client +// It handles sampling/createMessage and returns (resultRaw, handled, error) +func (c *Client) dispatchReverseRequest(ctx context.Context, req *JSONRPCRequest) (*json.RawMessage, bool, error) { + switch req.Method { + case MethodSamplingCreateMessage: + if !c.samplingEnabled || c.samplingHandler == nil { + return nil, true, fmt.Errorf("sampling not enabled on client") + } + var params CreateMessageParams + if req.Params != nil { + b, _ := json.Marshal(req.Params) + if err := json.Unmarshal(b, ¶ms); err != nil { + return nil, true, fmt.Errorf("invalid sampling params: %w", err) + } + } + res, err := c.samplingHandler.CreateMessage(ctx, ¶ms) + if err != nil { + return nil, true, err + } + buf, err := json.Marshal(res) + if err != nil { + return nil, true, fmt.Errorf("marshal CreateMessageResult failed: %w", err) + } + raw := json.RawMessage(buf) + return &raw, true, nil + default: + return nil, false, nil + } +} diff --git a/examples/sampling/README.md b/examples/sampling/README.md new file mode 100644 index 0000000..f564ca7 --- /dev/null +++ b/examples/sampling/README.md @@ -0,0 +1,34 @@ +# Sampling Example + +A demonstration of MCP server-to-client sampling functionality with bidirectional communication. + +## Features + +- **Server-to-Client Sampling**: Server requests AI model inference from connected clients +- **Bidirectional Communication**: Full duplex via HTTP sessions and SSE connections +- **Mock AI Handler**: Client-side sampling handler with intelligent prompt processing +- **Session Management**: Proper session routing for multi-client scenarios + +## Quick Start + +**Start the server:** +```bash +cd server +go run main.go +``` +Server will start on `localhost:3002/mcp` + +**Run the client:** +```bash +cd client +go run main.go +``` + +## What it demonstrates + +1. **Sampling-Enabled Server**: Creating an MCP server with sampling capability +2. **Tool Registration**: Tools that trigger server-to-client sampling requests +3. **Client Sampling Handler**: Processing sampling requests and returning AI responses +4. **Bidirectional Flow**: Complete request cycle from client tool call to AI response +5. **Content Parsing**: Handling JSON-serialized message content across protocol boundaries +6. **Session Context**: Using session information to route requests correctly \ No newline at end of file diff --git a/examples/sampling/client/main.go b/examples/sampling/client/main.go new file mode 100644 index 0000000..537a628 --- /dev/null +++ b/examples/sampling/client/main.go @@ -0,0 +1,242 @@ +// Tencent is pleased to support the open source community by making trpc-mcp-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-mcp-go is licensed under the Apache License Version 2.0. + +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "strings" + + mcp "trpc.group/trpc-go/trpc-mcp-go" +) + +// MockSamplingHandler implements a simple sampling handler for demonstration +type MockSamplingHandler struct{} + +func (h *MockSamplingHandler) CreateMessage(ctx context.Context, params *mcp.CreateMessageParams) (*mcp.CreateMessageResult, error) { + log.Printf("[Client] Received sampling request with %d messages", len(params.Messages)) + + if len(params.Messages) == 0 { + return &mcp.CreateMessageResult{ + Model: "mock-model-v1", + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: []mcp.TextContent{{ + Type: mcp.ContentTypeText, + Text: "No messages provided", + }}, + }, + }, nil + } + + // Get the user's prompt from the first message + userMessage := params.Messages[0] + var userText string + + // Enhanced content parsing + switch content := userMessage.Content.(type) { + case []mcp.TextContent: + if len(content) > 0 { + userText = content[0].Text + } + case mcp.TextContent: + userText = content.Text + case []interface{}: + // Handle JSON-deserialized content + if len(content) > 0 { + if item, ok := content[0].(map[string]interface{}); ok { + if text, exists := item["text"]; exists { + if textStr, ok := text.(string); ok { + userText = textStr + } + } + } + } + if userText == "" { + userText = "Could not parse interface{} content" + } + case []map[string]interface{}: + // Handle another JSON structure + if len(content) > 0 { + if text, exists := content[0]["text"]; exists { + if textStr, ok := text.(string); ok { + userText = textStr + } + } + } + if userText == "" { + userText = "Could not parse map content" + } + default: + log.Printf("[Client] Unknown content type: %T, value: %+v", content, content) + userText = fmt.Sprintf("Unknown content type: %T", content) + } + + log.Printf("[Client] Processing prompt: %q", userText) + + // Simple mock responses based on prompt content + var response string + switch { + case contains(userText, "capital", "France"): + response = "The capital of France is Paris." + case contains(userText, "15.5", "3.2", "calculate"): + response = "15.5 × 3.2 = 49.6. This calculation multiplies 15.5 by 3.2 to get 49.6." + default: + response = fmt.Sprintf("I received your message: %s. This is a mock response from the client-side sampling handler.", userText) + } + + return &mcp.CreateMessageResult{ + Model: "mock-model-v1", + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: []mcp.TextContent{{ + Type: mcp.ContentTypeText, + Text: response, + }}, + }, + }, nil +} + +// Helper function to check if text contains any of the given substrings (case insensitive) +func contains(text string, substrings ...string) bool { + lowerText := strings.ToLower(text) + for _, substr := range substrings { + if strings.Contains(lowerText, strings.ToLower(substr)) { + return true + } + } + return false +} + +func main() { + log.Println("Starting Sampling Demo Client...") + + ctx := context.Background() + + // Initialize client + client, err := initializeClient(ctx) + if err != nil { + log.Fatalf("Failed to initialize client: %v", err) + } + defer client.Close() + + // Demonstrate sampling functionality + if err := demonstrateSamplingTools(ctx, client); err != nil { + log.Fatalf("Demo failed: %v", err) + } + + log.Println("✅ Sampling demo completed successfully!") +} + +func initializeClient(ctx context.Context) (*mcp.Client, error) { + log.Println("===== Initialize Client =====") + + serverURL := "http://localhost:3002/mcp" + + // Create client with sampling handler + mcpClient, err := mcp.NewClient( + serverURL, + mcp.Implementation{ + Name: "Sampling-Demo-Client", + Version: "1.0.0", + }, + mcp.WithSampling(&MockSamplingHandler{}), // Add sampling handler + ) + if err != nil { + return nil, fmt.Errorf("failed to create client: %w", err) + } + + initResp, err := mcpClient.Initialize(ctx, &mcp.InitializeRequest{}) + if err != nil { + mcpClient.Close() + return nil, fmt.Errorf("initialization failed: %w", err) + } + + log.Printf("Server: %s %s", initResp.ServerInfo.Name, initResp.ServerInfo.Version) + + // Only print session ID for HTTP clients + if sessionID := mcpClient.GetSessionID(); sessionID != "" { + log.Printf("Session ID: %s", sessionID) + } else { + log.Println("Client Type: stdio (no session)") + } + + return mcpClient, nil +} + +func demonstrateSamplingTools(ctx context.Context, client *mcp.Client) error { + log.Println("===== List Available Tools =====") + + listToolsResp, err := client.ListTools(ctx, &mcp.ListToolsRequest{}) + if err != nil { + return fmt.Errorf("failed to list tools: %w", err) + } + + log.Printf("Found %d tools:", len(listToolsResp.Tools)) + for _, tool := range listToolsResp.Tools { + log.Printf(" • %s: %s", tool.Name, tool.Description) + + // Show the generated schemas + if tool.InputSchema != nil { + inputJSON, _ := json.MarshalIndent(tool.InputSchema, " ", " ") + log.Printf(" Input Schema: %s", string(inputJSON)) + } + if tool.OutputSchema != nil { + outputJSON, _ := json.MarshalIndent(tool.OutputSchema, " ", " ") + log.Printf(" Output Schema: %s", string(outputJSON)) + } + log.Println("") + } + + // Demo: Trigger Sampling Tool + log.Println("===== Demo: Trigger Sampling Tool =====") + log.Println("Calling trigger_sampling tool to demonstrate server→client sampling...") + + samplingResult, err := client.CallTool(ctx, &mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "trigger_sampling", + Arguments: map[string]any{ + "prompt": "What is the capital of France? Please answer in one sentence.", + }, + }, + }) + if err != nil { + return fmt.Errorf("trigger_sampling tool failed: %w", err) + } + + log.Printf("✅ Sampling result received:") + if samplingResult.StructuredContent != nil { + structuredJSON, _ := json.MarshalIndent(samplingResult.StructuredContent, " ", " ") + log.Printf(" Structured Content: %s", string(structuredJSON)) + } + + // Demo another sampling call with a different prompt + log.Println("\n===== Demo 2: Another Sampling Call =====") + log.Println("Calling trigger_sampling with a math question...") + + mathResult, err := client.CallTool(ctx, &mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "trigger_sampling", + Arguments: map[string]any{ + "prompt": "Calculate 15.5 * 3.2 and explain the calculation.", + }, + }, + }) + if err != nil { + return fmt.Errorf("second trigger_sampling call failed: %w", err) + } + + log.Printf("✅ Second sampling result received:") + if mathResult.StructuredContent != nil { + structuredJSON, _ := json.MarshalIndent(mathResult.StructuredContent, " ", " ") + log.Printf(" Structured Content: %s", string(structuredJSON)) + } + + return nil +} diff --git a/examples/sampling/server/main.go b/examples/sampling/server/main.go new file mode 100644 index 0000000..162c0c3 --- /dev/null +++ b/examples/sampling/server/main.go @@ -0,0 +1,135 @@ +// Tencent is pleased to support the open source community by making trpc-mcp-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-mcp-go is licensed under the Apache License Version 2.0. + +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + "time" + + mcp "trpc.group/trpc-go/trpc-mcp-go" +) + +type TriggerSamplingInput struct { + Prompt string `json:"prompt" jsonschema:"required,description=User prompt for sampling"` +} +type TriggerSamplingOutput struct { + Model string `json:"model"` + Message string `json:"message"` +} + +func main() { + log.Println("Starting Sampling Demo Server...") + + server := mcp.NewServer( + "Sampling-Demo-Server", + "1.0.0", + mcp.WithServerAddress(":3002"), + mcp.WithServerPath("/mcp"), + mcp.WithSamplingEnabled(true), + ) + + // Register a tool + // When the client calls it, the server triggers sampling/createMessage under the session ctx + tool := mcp.NewTool( + "trigger_sampling", + mcp.WithDescription("Trigger server→client sampling using current session"), + mcp.WithInputStruct[TriggerSamplingInput](), + mcp.WithOutputStruct[TriggerSamplingOutput](), + ) + handler := mcp.NewTypedToolHandler(func(ctx context.Context, req *mcp.CallToolRequest, in TriggerSamplingInput) (TriggerSamplingOutput, error) { + log.Printf("[Server] tool called, will request sampling with prompt: %q", in.Prompt) + + params := &mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + {Role: mcp.RoleUser, Content: []mcp.TextContent{{Type: mcp.ContentTypeText, Text: in.Prompt}}}, + }, + // These build parameters are optional. + MaxTokens: 128, + Temperature: 0.7, + } + + // Use the "ctx of the current request" to initiate sampling (ctx contains the session) + // and you can route it to the correct client. + cres, err := server.RequestSampling(ctx, params) + if err != nil { + return TriggerSamplingOutput{}, fmt.Errorf("RequestSampling failed: %w", err) + } + + // Extract text from the response with better handling + text := extractTextFromContent(cres.SamplingMessage.Content) + + log.Printf("[Server] sampling done. model=%s, text=%q", cres.Model, text) + return TriggerSamplingOutput{ + Model: cres.Model, + Message: text, + }, nil + }) + server.RegisterTool(tool, handler) + + // Graceful exit + stop := make(chan os.Signal, 1) + signal.Notify(stop, os.Interrupt, syscall.SIGTERM) + go func() { + <-stop + log.Println("Shutting down server...") + os.Exit(0) + }() + + log.Println("Server listening on http://localhost:3002/mcp") + if err := server.Start(); err != nil { + log.Fatalf("Server failed: %v", err) + } + + _ = time.Second +} + +// extractTextFromContent extracts text from various content types +func extractTextFromContent(content interface{}) string { + switch c := content.(type) { + case []mcp.TextContent: + if len(c) > 0 { + return c[0].Text + } + return "empty text content array" + case mcp.TextContent: + return c.Text + case string: + return c + case []interface{}: + // Handle JSON-deserialized content + if len(c) > 0 { + if item, ok := c[0].(map[string]interface{}); ok { + if text, exists := item["text"]; exists { + if textStr, ok := text.(string); ok { + return textStr + } + } + } + } + log.Printf("[Server] Empty or invalid []interface{} content: %+v", c) + return "empty interface array" + case []map[string]interface{}: + // Handle another possible JSON structure + if len(c) > 0 { + if text, exists := c[0]["text"]; exists { + if textStr, ok := text.(string); ok { + return textStr + } + } + } + log.Printf("[Server] Empty or invalid []map content: %+v", c) + return "empty map array" + default: + log.Printf("[Server] Unknown content type: %T, value: %+v", content, content) + return fmt.Sprintf("unknown content type: %T", content) + } +} diff --git a/handler.go b/handler.go index 947abcb..c869d16 100644 --- a/handler.go +++ b/handler.go @@ -42,6 +42,9 @@ type mcpHandler struct { // Server reference for notification handling. server serverNotificationDispatcher + + // Whether server enables Sampling capability + serverSamplingEnabled bool } // serverNotificationDispatcher defines the interface for dispatching notifications to handlers. @@ -92,6 +95,11 @@ func newMCPHandler(options ...func(*mcpHandler)) *mcpHandler { h.lifecycleManager.withResourceManager(h.resourceManager) h.lifecycleManager.withPromptManager(h.promptManager) + // Also pass sampling flag if already set by options + if h.serverSamplingEnabled { + h.lifecycleManager.withServerSamplingEnabled(true) + } + return h } @@ -123,6 +131,16 @@ func withPromptManager(manager *promptManager) func(*mcpHandler) { } } +// withServerSamplingEnabled sets whether the server exposes Sampling capability +func withServerSamplingEnabled(enabled bool) func(*mcpHandler) { + return func(h *mcpHandler) { + h.serverSamplingEnabled = enabled + if h.lifecycleManager != nil { + h.lifecycleManager.withServerSamplingEnabled(enabled) + } + } +} + // Definition: request dispatch table type type requestHandlerFunc func(ctx context.Context, req *JSONRPCRequest, session Session) (JSONRPCMessage, error) diff --git a/manager_lifecycle.go b/manager_lifecycle.go index eb7e8c3..d2e5297 100644 --- a/manager_lifecycle.go +++ b/manager_lifecycle.go @@ -46,6 +46,9 @@ type lifecycleManager struct { // Mutex for concurrent access mu sync.RWMutex + + // Whether sampling is enabled + samplingEnabled bool } // newLifecycleManager creates a lifecycle manager @@ -112,6 +115,12 @@ func (m *lifecycleManager) withStatelessMode(isStateless bool) *lifecycleManager return m } +// withServerSamplingEnabled sets the sampling capability flag +func (m *lifecycleManager) withServerSamplingEnabled(enabled bool) *lifecycleManager { + m.samplingEnabled = enabled + return m +} + // updateCapabilities updates the server capability information func (m *lifecycleManager) updateCapabilities() { // Use map as an intermediate variable @@ -141,6 +150,10 @@ func (m *lifecycleManager) updateCapabilities() { capMap["experimental"] = exp } + if m.samplingEnabled { + capMap["sampling"] = map[string]interface{}{} + } + // Update capabilities m.capabilities = capMap } @@ -271,6 +284,11 @@ func convertToServerCapabilities(capMap map[string]interface{}) ServerCapabiliti capabilities.Experimental = expMap } + // Handle sampling capability + if _, ok := capMap["sampling"].(map[string]interface{}); ok { + capabilities.Sampling = &SamplingCapability{} + } + return capabilities } diff --git a/mcp_messages.go b/mcp_messages.go index a9d7175..5a076c1 100644 --- a/mcp_messages.go +++ b/mcp_messages.go @@ -57,6 +57,10 @@ type ServerCapabilities struct { // autocompletion suggestions."} Completions *CompletionsCapability `json:"completions,omitempty"` + // Sampling indicates whether the server supports requesting client-side sampling + // Corresponds to schema: "sampling": {"description": "Present if the server may request client-side sampling."} + Sampling *SamplingCapability `json:"sampling,omitempty"` + // Experimental indicates non-standard experimental capabilities that the server supports // Corresponds to schema: "experimental": {"description": "Experimental, non-standard capabilities // that the server supports."} @@ -198,6 +202,9 @@ const ( // Utilities MethodLoggingSetLevel = "logging/setLevel" MethodPing = "ping" + + // Sampling + MethodSamplingCreateMessage = "sampling/createMessage" ) // Protocol version constants diff --git a/mcp_types.go b/mcp_types.go index 37f14ab..c985dad 100644 --- a/mcp_types.go +++ b/mcp_types.go @@ -343,3 +343,64 @@ func (p *DefaultRootsProvider) GetRoots() []Root { copy(result, p.roots) return result } + +// CreateMessage defines a server to client request to sample an LLM via the client +// The client selects/guards the model and may include human-in-the-loop review +type CreateMessage struct { + Request + CreateMessageParams `json:"params"` +} + +// CreateMessageParms carries the prompt/messages and optional preferences +type CreateMessageParams struct { + // Ordered chat history/messages (roles and multimodal content) + Messages []SamplingMessage `json:"messages"` + + // Optional model selection hints/priorities + ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"` + + // Optional system prompt to steer behavior + SystemPrompt string `json:"systemPrompt,omitempty"` + + // Whether to include MCP context (e.g., this server/all servers) + IncludeContext string `json:"includeContext,omitempty"` + + // Usual generation controls + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"maxTokens,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` + + // Free-form metadata (echoed back if useful to the client) + Metadata any `json:"metadata,omitempty"` +} + +// CreateMessageResult is the client->server response with the sampled message +type CreateMessageResult struct { + Result + SamplingMessage + Model string `json:"model"` + StopReason string `json:"stopReason,omitempty"` +} + +// SamplingMessage represents a chat message for/from the LLM. +type SamplingMessage struct { + Role Role `json:"role"` + // Can be single item or array; we accept interface{} for flexibility. + // If you prefer strict typing, you can switch to []Content and marshal accordingly. + Content any `json:"content"` +} + +// ModelPreferences guides the client when choosing a model. +type ModelPreferences struct { + // Ordered hints; the client SHOULD check in order, then consider priorities. + Hints []ModelHint `json:"hints,omitempty"` + // 0..1 weights for trade-offs (optional) + CostPriority float64 `json:"costPriority,omitempty"` + SpeedPriority float64 `json:"speedPriority,omitempty"` + IntelligencePriority float64 `json:"intelligencePriority,omitempty"` +} + +// ModelHint defines an extensible hint (e.g., substring of a model name). +type ModelHint struct { + Name string `json:"name,omitempty"` +} diff --git a/server.go b/server.go index 216deba..4bdf66f 100644 --- a/server.go +++ b/server.go @@ -85,6 +85,9 @@ type serverConfig struct { // Method name modifier for external customization. methodNameModifier MethodNameModifier + + // Whether to enable the Sampling function + enableSampling bool } // ServerNotificationHandler defines a function that handles notifications on the server side. @@ -190,6 +193,7 @@ func (s *Server) initComponents() { withResourceManager(resourceManager), withPromptManager(promptManager), withServer(s), // Set the server reference for notification handling. + withServerSamplingEnabled(s.config.enableSampling), ) // Collect HTTP handler options. @@ -377,6 +381,13 @@ func WithServerAddress(addr string) ServerOption { } } +// WithSamplingEnabled enables or disables the Sampling capability on the server. +func WithSamplingEnabled(enabled bool) ServerOption { + return func(s *Server) { + s.config.enableSampling = enabled + } +} + // Start starts the server func (s *Server) Start() error { if s.customServer != nil { @@ -721,3 +732,46 @@ func (s *Server) handleServerNotification(ctx context.Context, notification *JSO } return nil } + +// RequestSampling sends a sampling/createMessage request to the client and waits for the result +func (s *Server) RequestSampling(ctx context.Context, params *CreateMessageParams) (*CreateMessageResult, error) { + if s.config.isStateless { + return nil, ErrStatelessMode + } + + // Get session + session := ClientSessionFromContext(ctx) + if session == nil { + return nil, ErrSessionNotFound + } + sessionID := session.GetID() + if sessionID == "" { + return nil, ErrSessionNotFound + } + + // Assemble a JSON-RPC request + requestID := s.requestID.Add(1) + req := &JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: requestID, + Request: Request{ + Method: MethodSamplingCreateMessage, + }, + } + if params != nil { + req.Params = *params + } + + // Send and wait response + raw, err := s.SendRequest(ctx, sessionID, req) + if err != nil { + return nil, fmt.Errorf("failed to send %s: %w", MethodSamplingCreateMessage, err) + } + + // Parse result + var result CreateMessageResult + if err := json.Unmarshal(*raw, &result); err != nil { + return nil, fmt.Errorf("failed to parse CreateMessageResult: %w", err) + } + return &result, nil +} diff --git a/stdio_client.go b/stdio_client.go index d88f31a..118c1cc 100644 --- a/stdio_client.go +++ b/stdio_client.go @@ -46,6 +46,10 @@ type StdioClient struct { // Roots support. rootsProvider RootsProvider // Provider for roots information. rootsMu sync.RWMutex // Mutex for protecting the rootsProvider. + + // Sampling support + samplingEnabled bool + samplingHandler SamplingHandler } // StdioClientOption defines configuration options for StdioClient. @@ -115,6 +119,23 @@ func WithStdioCapabilities(capabilities map[string]interface{}) StdioClientOptio } } +// WithStdioSampling enables sampling capabilities and registers the processor +func WithStdioSampling(handler SamplingHandler) StdioClientOption { + return func(c *StdioClient) { + c.samplingEnabled = true + c.samplingHandler = handler + if c.capabilities == nil { + c.capabilities = make(map[string]interface{}) + } + c.capabilities["sampling"] = map[string]interface{}{} + } +} + +// RegisterStdioSamplingHandler replaces the sampling processor at runtime +func (c *StdioClient) RegisterStdioSamplingHandler(handler SamplingHandler) { + c.samplingHandler = handler +} + // Initialize initializes the client connection func (c *StdioClient) Initialize(ctx context.Context, req *InitializeRequest) (*InitializeResult, error) { if c.initialized.Load() { diff --git a/streamable_client.go b/streamable_client.go index 9aba314..56a2c0f 100644 --- a/streamable_client.go +++ b/streamable_client.go @@ -797,6 +797,26 @@ func (t *streamableHTTPClientTransport) handleIncomingRequest(request *JSONRPCRe switch request.Method { case MethodRootsList: t.handleRootsListRequest(request) + case MethodSamplingCreateMessage: + // Server to Client request, hand over the request to the upper client's dispatchReverseRequest + if t.client != nil { + resRaw, handled, err := t.client.dispatchReverseRequest(context.Background(), request) + if handled { + if err != nil { + t.sendErrorResponse(request, ErrCodeInternal, err.Error()) + return + } + // Construct JSON-RPC response and POST to server + resp := &JSONRPCResponse{ + JSONRPC: JSONRPCVersion, + ID: request.ID, + Result: resRaw, + } + t.sendResponseToServer(resp) + return + } + } + t.sendErrorResponse(request, ErrCodeMethodNotFound, "sampling not enabled") default: // Send method not found error. t.sendErrorResponse(request, ErrCodeMethodNotFound, fmt.Sprintf("Method not found: %s", request.Method)) diff --git a/transport_stdio.go b/transport_stdio.go index 5985db4..b6d09c4 100644 --- a/transport_stdio.go +++ b/transport_stdio.go @@ -480,6 +480,26 @@ func (t *stdioClientTransport) handleIncomingRequest(rawMessage json.RawMessage) switch request.Method { case MethodRootsList: t.handleRootsListRequest(&request) + case MethodSamplingCreateMessage: + if t.client != nil { + resRaw, handled, err := t.client.dispatchReverseRequest(context.Background(), &request) + if handled { + if err != nil { + t.sendErrorResponse(&request, ErrCodeInternal, err.Error()) + return + } + resp := &JSONRPCResponse{ + JSONRPC: JSONRPCVersion, + ID: request.ID, + Result: resRaw, + } + if err := t.sendResponse(context.Background(), resp); err != nil { + t.logger.Errorf("Client handleIncomingRequest: Failed to send sampling response: %v", err) + } + return + } + } + t.sendErrorResponse(&request, ErrCodeMethodNotFound, "sampling not enabled") default: t.logger.Warnf("Client handleIncomingRequest: Unknown method: %s", request.Method) // Send method not found error @@ -699,3 +719,33 @@ func (t *stdioClientTransport) isProcessRunning() bool { // On Unix systems, sending signal 0 checks if process exists. return t.process.Process.Signal(syscall.Signal(0)) == nil } + +// dispatchReverseRequest is called by transports when the server sends a request to the client +// It handles sampling/createMessage and returns (resultRaw, handled, error) +func (c *StdioClient) dispatchReverseRequest(ctx context.Context, req *JSONRPCRequest) (*json.RawMessage, bool, error) { + switch req.Method { + case MethodSamplingCreateMessage: + if !c.samplingEnabled || c.samplingHandler == nil { + return nil, true, fmt.Errorf("sampling not enabled on client") + } + var params CreateMessageParams + if req.Params != nil { + b, _ := json.Marshal(req.Params) + if err := json.Unmarshal(b, ¶ms); err != nil { + return nil, true, fmt.Errorf("invalid sampling params: %w", err) + } + } + res, err := c.samplingHandler.CreateMessage(ctx, ¶ms) + if err != nil { + return nil, true, err + } + buf, err := json.Marshal(res) + if err != nil { + return nil, true, fmt.Errorf("marshal CreateMessageResult failed: %w", err) + } + raw := json.RawMessage(buf) + return &raw, true, nil + default: + return nil, false, nil + } +}