From 815f9942d3268cc474a0516f27f21fcfb043718c Mon Sep 17 00:00:00 2001 From: Sam Ruby Date: Fri, 13 Jun 2025 08:58:50 -0400 Subject: [PATCH 1/2] prevent SSE upgrade with HTTP Streaming servers It appears that the current code attempts to "upgrade" to SSE if progress notifications are used. This behavior is incompatible with the TypeScript MCP SDK. --- server/streamable_http.go | 23 ++++++- server/streamable_http_test.go | 119 +++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+), 3 deletions(-) diff --git a/server/streamable_http.go b/server/streamable_http.go index e9a011fb1..2e41b82e4 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -90,6 +90,16 @@ func WithLogger(logger util.Logger) StreamableHTTPOption { } } +// WithDisableSSEUpgrade disables automatic upgrade to SSE when notifications are sent. +// When enabled, responses will always be returned as direct JSON responses, +// making it compatible with HTTP streaming clients like the TypeScript MCP SDK. +// The default is false (SSE upgrade enabled for backward compatibility). +func WithDisableSSEUpgrade(disable bool) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.disableSSEUpgrade = disable + } +} + // StreamableHTTPServer implements a Streamable-http based MCP server. // It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams. // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http @@ -127,6 +137,7 @@ type StreamableHTTPServer struct { sessionIdManager SessionIdManager listenHeartbeatInterval time.Duration logger util.Logger + disableSSEUpgrade bool } // NewStreamableHTTPServer creates a new streamable-http server instance @@ -253,7 +264,7 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request } } - session := newStreamableHttpSession(sessionID, s.sessionTools) + session := newStreamableHttpSession(sessionID, s.sessionTools, s) // Set the client context before handling the message ctx := s.server.WithContext(r.Context(), session) @@ -363,7 +374,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) sessionID = uuid.New().String() } - session := newStreamableHttpSession(sessionID, s.sessionTools) + session := newStreamableHttpSession(sessionID, s.sessionTools, s) if err := s.server.RegisterSession(r.Context(), session); err != nil { http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest) return @@ -547,13 +558,15 @@ type streamableHttpSession struct { notificationChannel chan mcp.JSONRPCNotification // server -> client notifications tools *sessionToolsStore upgradeToSSE atomic.Bool + server *StreamableHTTPServer // reference to server for configuration access } -func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession { +func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, server *StreamableHTTPServer) *streamableHttpSession { return &streamableHttpSession{ sessionID: sessionID, notificationChannel: make(chan mcp.JSONRPCNotification, 100), tools: toolStore, + server: server, } } @@ -588,6 +601,10 @@ func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) { var _ SessionWithTools = (*streamableHttpSession)(nil) func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() { + // Check if SSE upgrade is disabled on the server + if s.server != nil && s.server.disableSSEUpgrade { + return // Don't upgrade to SSE + } s.upgradeToSSE.Store(true) } diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index aad48fc3a..3a3dfe37a 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -775,6 +775,125 @@ func TestStreamableHTTPServer_WithOptions(t *testing.T) { }) } +// TestStreamableHTTPServer_DisableSSEUpgrade tests that notifications don't upgrade to SSE +// when WithDisableSSEUpgrade(true) is set, ensuring compatibility with HTTP streaming clients +func TestStreamableHTTPServer_DisableSSEUpgrade(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0") + mcpServer.AddTool(mcp.Tool{ + Name: "test_tool", + Description: "Test tool that sends notifications", + InputSchema: mcp.ToolInputSchema{Type: "object"}, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Send a notification during tool execution + server := ServerFromContext(ctx) + err := server.SendNotificationToClient(ctx, "test/notification", map[string]any{ + "message": "test notification", + }) + if err != nil { + return nil, err + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{Type: "text", Text: "tool completed"}, + }, + }, nil + }) + + // Create server with SSE upgrade disabled + server := NewTestStreamableHTTPServer(mcpServer, WithDisableSSEUpgrade(true)) + defer server.Close() + + // Send initialize request + initResp, err := postJSON(server.URL, map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]any{ + "protocolVersion": "2025-03-26", + "capabilities": map[string]any{}, + "clientInfo": map[string]any{"name": "test", "version": "1.0.0"}, + }, + }) + if err != nil { + t.Fatalf("Failed to send initialize request: %v", err) + } + defer initResp.Body.Close() + + if initResp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200 for initialize, got %d", initResp.StatusCode) + } + + sessionID := initResp.Header.Get("Mcp-Session-Id") + if sessionID == "" { + t.Fatal("Expected session ID in response header") + } + + // Send tool call request that triggers notification + toolReq, _ := http.NewRequest("POST", server.URL, strings.NewReader(`{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "test_tool", + "arguments": {} + } + }`)) + toolReq.Header.Set("Content-Type", "application/json") + toolReq.Header.Set("Mcp-Session-Id", sessionID) + + resp, err := http.DefaultClient.Do(toolReq) + if err != nil { + t.Fatalf("Failed to send tool call request: %v", err) + } + defer resp.Body.Close() + + // Should receive JSON response (200 OK), not SSE (202 Accepted) + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Errorf("Expected status 200, got %d. Response: %s", resp.StatusCode, string(bodyBytes)) + } + + // Should be JSON, not SSE + contentType := resp.Header.Get("Content-Type") + if contentType != "application/json" { + t.Errorf("Expected content-type application/json, got %s", contentType) + } + + // Read and verify the response contains the tool result + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response: %v", err) + } + + var response map[string]any + if err := json.Unmarshal(responseBody, &response); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + // Check it's a proper JSON-RPC response + if response["id"].(float64) != 2 { + t.Errorf("Expected id 2, got %v", response["id"]) + } + + if response["jsonrpc"] != "2.0" { + t.Errorf("Expected jsonrpc 2.0, got %v", response["jsonrpc"]) + } + + // Verify the tool result is present + result, ok := response["result"] + if !ok { + t.Error("Expected result field in response") + } else { + resultMap := result.(map[string]any) + content := resultMap["content"].([]any) + firstContent := content[0].(map[string]any) + if firstContent["text"] != "tool completed" { + t.Errorf("Expected tool completed text, got %v", firstContent["text"]) + } + } +} + func postJSON(url string, bodyObject any) (*http.Response, error) { jsonBody, _ := json.Marshal(bodyObject) req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody)) From 38db0f65eb99803c12f7bd281415694ec5b2855a Mon Sep 17 00:00:00 2001 From: Sam Ruby Date: Fri, 13 Jun 2025 09:22:02 -0400 Subject: [PATCH 2/2] Address issue detected by CodeRabbit --- server/streamable_http.go | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/server/streamable_http.go b/server/streamable_http.go index 2e41b82e4..106881f23 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -297,19 +297,23 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request } }() - // if there's notifications, upgradedHeader to SSE response - if !upgradedHeader { - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Cache-Control", "no-cache") - w.WriteHeader(http.StatusAccepted) - upgradedHeader = true - } - err := writeSSEEvent(w, nt) - if err != nil { - s.logger.Errorf("Failed to write SSE event: %v", err) - return + // if there's notifications and SSE upgrade is not disabled, upgrade to SSE response + if !s.disableSSEUpgrade { + if !upgradedHeader { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusAccepted) + upgradedHeader = true + } + err := writeSSEEvent(w, nt) + if err != nil { + s.logger.Errorf("Failed to write SSE event: %v", err) + return + } } + // If SSE upgrade is disabled, notifications are dropped in POST mode + // (they can still be sent via separate GET connection if needed) }() case <-done: return @@ -336,7 +340,8 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request return } // If client-server communication already upgraded to SSE stream - if session.upgradeToSSE.Load() { + // Double-check that SSE upgrade is not disabled before performing the upgrade + if session.upgradeToSSE.Load() && !s.disableSSEUpgrade { if !upgradedHeader { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Connection", "keep-alive")