diff --git a/README.md b/README.md index 6ddc03e29..f047c3f47 100644 --- a/README.md +++ b/README.md @@ -537,7 +537,7 @@ For examples, see the [`examples/`](examples/) directory. ### Transports -MCP-Go supports stdio, SSE and streamable-HTTP transport layers. +MCP-Go supports stdio, SSE and streamable-HTTP transport layers. For SSE transport, you can use `SetConnectionLostHandler()` to detect and handle HTTP/2 idle timeout disconnections (NO_ERROR) for implementing reconnection logic. ### Session Management diff --git a/client/client.go b/client/client.go index 5e00f2e5c..cda7665ef 100644 --- a/client/client.go +++ b/client/client.go @@ -113,6 +113,17 @@ func (c *Client) OnNotification( c.notifications = append(c.notifications, handler) } +// OnConnectionLost registers a handler function to be called when the connection is lost. +// This is useful for handling HTTP2 idle timeout disconnections that should not be treated as errors. +func (c *Client) OnConnectionLost(handler func(error)) { + type connectionLostSetter interface { + SetConnectionLostHandler(func(error)) + } + if setter, ok := c.transport.(connectionLostSetter); ok { + setter.SetConnectionLostHandler(handler) + } +} + // sendRequest sends a JSON-RPC request to the server and waits for a response. // Returns the raw JSON response message or an error if the request fails. func (c *Client) sendRequest( diff --git a/client/transport/sse.go b/client/transport/sse.go index 97f78192f..92f1de416 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -34,10 +34,12 @@ type SSE struct { headers map[string]string headerFunc HTTPHeaderFunc - started atomic.Bool - closed atomic.Bool - cancelSSEStream context.CancelFunc - protocolVersion atomic.Value // string + started atomic.Bool + closed atomic.Bool + cancelSSEStream context.CancelFunc + protocolVersion atomic.Value // string + onConnectionLost func(error) + connectionLostMu sync.RWMutex // OAuth support oauthHandler *OAuthHandler @@ -204,6 +206,19 @@ func (c *SSE) readSSE(reader io.ReadCloser) { } break } + // Checking whether the connection was terminated due to NO_ERROR in HTTP2 based on RFC9113 + // Only handle NO_ERROR specially if onConnectionLost handler is set to maintain backward compatibility + if strings.Contains(err.Error(), "NO_ERROR") { + c.connectionLostMu.RLock() + handler := c.onConnectionLost + c.connectionLostMu.RUnlock() + + if handler != nil { + // This is not actually an error - HTTP2 idle timeout disconnection + handler(err) + return + } + } if !c.closed.Load() { fmt.Printf("SSE stream error: %v\n", err) } @@ -294,6 +309,12 @@ func (c *SSE) SetNotificationHandler(handler func(notification mcp.JSONRPCNotifi c.onNotification = handler } +func (c *SSE) SetConnectionLostHandler(handler func(error)) { + c.connectionLostMu.Lock() + defer c.connectionLostMu.Unlock() + c.onConnectionLost = handler +} + // SendRequest sends a JSON-RPC request to the server and waits for a response. // Returns the raw JSON response message or an error if the request fails. func (c *SSE) SendRequest( diff --git a/client/transport/sse_test.go b/client/transport/sse_test.go index f72c8e8c8..ca05180c4 100644 --- a/client/transport/sse_test.go +++ b/client/transport/sse_test.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "errors" + "io" + "strings" "sync" "testing" "time" @@ -15,6 +17,39 @@ import ( "github.com/mark3labs/mcp-go/mcp" ) +// mockReaderWithError is a mock io.ReadCloser that simulates reading some data +// and then returning a specific error +type mockReaderWithError struct { + data []byte + err error + position int + closed bool +} + +func (m *mockReaderWithError) Read(p []byte) (n int, err error) { + if m.closed { + return 0, io.EOF + } + + if m.position >= len(m.data) { + return 0, m.err + } + + n = copy(p, m.data[m.position:]) + m.position += n + + if m.position >= len(m.data) { + return n, m.err + } + + return n, nil +} + +func (m *mockReaderWithError) Close() error { + m.closed = true + return nil +} + // startMockSSEEchoServer starts a test HTTP server that implements // a minimal SSE-based echo server for testing purposes. // It returns the server URL and a function to close the server. @@ -508,6 +543,218 @@ func TestSSE(t *testing.T) { } }) + t.Run("NO_ERROR_WithoutConnectionLostHandler", func(t *testing.T) { + // Test that NO_ERROR without connection lost handler maintains backward compatibility + // When no connection lost handler is set, NO_ERROR should be treated as a regular error + + // Create a mock Reader that simulates NO_ERROR + mockReader := &mockReaderWithError{ + data: []byte("event: endpoint\ndata: /message\n\n"), + err: errors.New("connection closed: NO_ERROR"), + } + + // Create SSE transport + url, closeF := startMockSSEEchoServer() + defer closeF() + + trans, err := NewSSE(url) + if err != nil { + t.Fatal(err) + } + + // DO NOT set connection lost handler to test backward compatibility + + // Capture stderr to verify the error is printed (backward compatible behavior) + // Since we can't easily capture fmt.Printf output in tests, we'll just verify + // that the readSSE method returns without calling any handler + + // Directly test the readSSE method with our mock reader + go trans.readSSE(mockReader) + + // Wait for readSSE to complete + time.Sleep(100 * time.Millisecond) + + // The test passes if readSSE completes without panicking or hanging + // In backward compatibility mode, NO_ERROR should be treated as a regular error + t.Log("Backward compatibility test passed: NO_ERROR handled as regular error when no handler is set") + }) + + t.Run("NO_ERROR_ConnectionLost", func(t *testing.T) { + // Test that NO_ERROR in HTTP/2 connection loss is properly handled + // This test verifies that when a connection is lost in a way that produces + // an error message containing "NO_ERROR", the connection lost handler is called + + var connectionLostCalled bool + var connectionLostError error + var mu sync.Mutex + + // Create a mock Reader that simulates connection loss with NO_ERROR + mockReader := &mockReaderWithError{ + data: []byte("event: endpoint\ndata: /message\n\n"), + err: errors.New("http2: stream closed with error code NO_ERROR"), + } + + // Create SSE transport + url, closeF := startMockSSEEchoServer() + defer closeF() + + trans, err := NewSSE(url) + if err != nil { + t.Fatal(err) + } + + // Set connection lost handler + trans.SetConnectionLostHandler(func(err error) { + mu.Lock() + defer mu.Unlock() + connectionLostCalled = true + connectionLostError = err + }) + + // Directly test the readSSE method with our mock reader that simulates NO_ERROR + go trans.readSSE(mockReader) + + // Wait for connection lost handler to be called + timeout := time.After(1 * time.Second) + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-timeout: + t.Fatal("Connection lost handler was not called within timeout for NO_ERROR connection loss") + case <-ticker.C: + mu.Lock() + called := connectionLostCalled + err := connectionLostError + mu.Unlock() + + if called { + if err == nil { + t.Fatal("Expected connection lost error, got nil") + } + + // Verify that the error contains "NO_ERROR" string + if !strings.Contains(err.Error(), "NO_ERROR") { + t.Errorf("Expected error to contain 'NO_ERROR', got: %v", err) + } + + t.Logf("Connection lost handler called with NO_ERROR: %v", err) + return + } + } + } + }) + + t.Run("NO_ERROR_Handling", func(t *testing.T) { + // Test specific NO_ERROR string handling in readSSE method + // This tests the code path at line 209 where NO_ERROR is checked + + // Create a mock Reader that simulates an error containing "NO_ERROR" + mockReader := &mockReaderWithError{ + data: []byte("event: endpoint\ndata: /message\n\n"), + err: errors.New("connection closed: NO_ERROR"), + } + + // Create SSE transport + url, closeF := startMockSSEEchoServer() + defer closeF() + + trans, err := NewSSE(url) + if err != nil { + t.Fatal(err) + } + + var connectionLostCalled bool + var connectionLostError error + var mu sync.Mutex + + // Set connection lost handler to verify it's called for NO_ERROR + trans.SetConnectionLostHandler(func(err error) { + mu.Lock() + defer mu.Unlock() + connectionLostCalled = true + connectionLostError = err + }) + + // Directly test the readSSE method with our mock reader + go trans.readSSE(mockReader) + + // Wait for connection lost handler to be called + timeout := time.After(1 * time.Second) + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-timeout: + t.Fatal("Connection lost handler was not called within timeout for NO_ERROR") + case <-ticker.C: + mu.Lock() + called := connectionLostCalled + err := connectionLostError + mu.Unlock() + + if called { + if err == nil { + t.Fatal("Expected connection lost error with NO_ERROR, got nil") + } + + // Verify that the error contains "NO_ERROR" string + if !strings.Contains(err.Error(), "NO_ERROR") { + t.Errorf("Expected error to contain 'NO_ERROR', got: %v", err) + } + + t.Logf("Successfully handled NO_ERROR: %v", err) + return + } + } + } + }) + + t.Run("RegularError_DoesNotTriggerConnectionLost", func(t *testing.T) { + // Test that regular errors (not containing NO_ERROR) do not trigger connection lost handler + + // Create a mock Reader that simulates a regular error + mockReader := &mockReaderWithError{ + data: []byte("event: endpoint\ndata: /message\n\n"), + err: errors.New("regular connection error"), + } + + // Create SSE transport + url, closeF := startMockSSEEchoServer() + defer closeF() + + trans, err := NewSSE(url) + if err != nil { + t.Fatal(err) + } + + var connectionLostCalled bool + var mu sync.Mutex + + // Set connection lost handler - this should NOT be called for regular errors + trans.SetConnectionLostHandler(func(err error) { + mu.Lock() + defer mu.Unlock() + connectionLostCalled = true + }) + + // Directly test the readSSE method with our mock reader + go trans.readSSE(mockReader) + + // Wait and verify connection lost handler is NOT called + time.Sleep(200 * time.Millisecond) + + mu.Lock() + called := connectionLostCalled + mu.Unlock() + + if called { + t.Error("Connection lost handler should not be called for regular errors") + } + }) + } func TestSSEErrors(t *testing.T) {