From b2c2b50e6d879ad58d1bb468a9f3eea0fa3de802 Mon Sep 17 00:00:00 2001 From: Matt Leon Date: Thu, 30 Oct 2025 10:56:56 +0100 Subject: [PATCH 1/2] rpcclient: support canceling in-flight http requests --- rpcclient/infrastructure.go | 42 +++++++++++----- rpcclient/infrastructure_test.go | 83 ++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 12 deletions(-) diff --git a/rpcclient/infrastructure.go b/rpcclient/infrastructure.go index e34fc1827e..1637cfa1a1 100644 --- a/rpcclient/infrastructure.go +++ b/rpcclient/infrastructure.go @@ -766,7 +766,7 @@ out: // handleSendPostMessage handles performing the passed HTTP request, reading the // result, unmarshalling it, and delivering the unmarshalled result to the // provided response channel. -func (c *Client) handleSendPostMessage(jReq *jsonRequest) { +func (c *Client) handleSendPostMessage(ctx context.Context, jReq *jsonRequest) { var ( lastErr error backoff time.Duration @@ -782,12 +782,17 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) { } tries := 10 +retryloop: for i := 0; i < tries; i++ { var httpReq *http.Request bodyReader := bytes.NewReader(jReq.marshalledJSON) - httpReq, err = http.NewRequest("POST", httpURL, bodyReader) + httpReq, err = http.NewRequestWithContext(ctx, "POST", httpURL, bodyReader) if err != nil { + // We must observe the contract that shutdown returns ErrClientShutdown. + if errors.Is(err, context.Canceled) && errors.Is(context.Cause(ctx), ErrClientShutdown) { + err = ErrClientShutdown + } jReq.responseChan <- &Response{result: nil, err: err} return } @@ -812,6 +817,11 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) { break } + // We must observe the contract that shutdown returns ErrClientShutdown. + if errors.Is(err, context.Canceled) && errors.Is(context.Cause(ctx), ErrClientShutdown) { + err = ErrClientShutdown + } + // Save the last error for the case where we backoff further, // retry and get an invalid response but no error. If this // happens the saved last error will be used to enrich the error @@ -830,8 +840,13 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) { select { case <-time.After(backoff): - case <-c.shutdown: - return + case <-ctx.Done(): + err = ctx.Err() + // maintain our contract: shutdown errors are ErrClientShutdown + if errors.Is(context.Cause(ctx), ErrClientShutdown) { + err = ErrClientShutdown + } + break retryloop } } if err != nil { @@ -891,30 +906,28 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) { // in HTTP POST mode. It uses a buffered channel to serialize output messages // while allowing the sender to continue running asynchronously. It must be run // as a goroutine. -func (c *Client) sendPostHandler() { +func (c *Client) sendPostHandler(ctx context.Context) { out: for { // Send any messages ready for send until the shutdown channel // is closed. select { case jReq := <-c.sendPostChan: - c.handleSendPostMessage(jReq) + c.handleSendPostMessage(ctx, jReq) - case <-c.shutdown: + case <-ctx.Done(): break out } } + err := context.Cause(ctx) // Drain any wait channels before exiting so nothing is left waiting // around to send. cleanup: for { select { case jReq := <-c.sendPostChan: - jReq.responseChan <- &Response{ - result: nil, - err: ErrClientShutdown, - } + jReq.responseChan <- &Response{result: nil, err: err} default: break cleanup @@ -1178,8 +1191,13 @@ func (c *Client) start() { // Start the I/O processing handlers depending on whether the client is // in HTTP POST mode or the default websocket mode. if c.config.HTTPPostMode { + ctx, cancel := context.WithCancelCause(context.Background()) c.wg.Add(1) - go c.sendPostHandler() + go c.sendPostHandler(ctx) + go func() { + <-c.shutdown + cancel(ErrClientShutdown) + }() } else { c.wg.Add(3) go func() { diff --git a/rpcclient/infrastructure_test.go b/rpcclient/infrastructure_test.go index 8416b7ad3c..9d88b6338a 100644 --- a/rpcclient/infrastructure_test.go +++ b/rpcclient/infrastructure_test.go @@ -1,8 +1,12 @@ package rpcclient import ( + "io" + "net" "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -108,3 +112,82 @@ func TestParseAddressString(t *testing.T) { }) } } + +// TestHTTPPostShutdownInterruptsPendingRequest ensures that a client operating +// in HTTP POST mode can interrupt an in-flight request during shutdown. +func TestHTTPPostShutdownInterruptsPendingRequest(t *testing.T) { + t.Parallel() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + requestAccepted := make(chan struct{}) + serverDone := make(chan struct{}) + + go func() { + defer close(serverDone) + + conn, err := listener.Accept() + if err != nil { + return + } + defer func() { + err := conn.Close() + assert.NoError(t, err) + }() + + close(requestAccepted) + + _, _ = io.Copy(io.Discard, conn) + }() + + t.Cleanup(func() { + err := listener.Close() + require.NoError(t, err) + <-serverDone + }) + + connCfg := &ConnConfig{ + Host: listener.Addr().String(), + User: "user", + Pass: "pass", + DisableTLS: true, + HTTPPostMode: true, + } + + client, err := New(connCfg, nil) + require.NoError(t, err) + t.Cleanup(client.Shutdown) + + future := client.GetBlockCountAsync() + + select { + case <-requestAccepted: + case <-time.After(2 * time.Second): + t.Fatalf("server did not accept client connection") + } + + select { + case <-future: + t.Fatalf("expected request to remain pending until shutdown") + case <-time.After(100 * time.Millisecond): + } + + client.Shutdown() + + waitDone := make(chan struct{}) + go func() { + client.WaitForShutdown() + close(waitDone) + }() + + select { + case <-waitDone: + case <-time.After(5 * time.Second): + t.Fatalf("client shutdown did not complete") + } + + result, err := future.Receive() + require.Zero(t, result) + require.ErrorContains(t, err, ErrClientShutdown.Error()) +} From f87d91434ee7912a128a4606a9401427a70ad96a Mon Sep 17 00:00:00 2001 From: Matt Leon Date: Wed, 21 Jan 2026 12:18:11 -0300 Subject: [PATCH 2/2] rpcclient: refine http post shutdown test Drop unreachable context-canceled mapping after request creation. Clarify the HTTP POST shutdown test flow with brief comments. --- rpcclient/infrastructure.go | 4 ---- rpcclient/infrastructure_test.go | 3 +++ 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/rpcclient/infrastructure.go b/rpcclient/infrastructure.go index 1637cfa1a1..b0a961be99 100644 --- a/rpcclient/infrastructure.go +++ b/rpcclient/infrastructure.go @@ -789,10 +789,6 @@ retryloop: bodyReader := bytes.NewReader(jReq.marshalledJSON) httpReq, err = http.NewRequestWithContext(ctx, "POST", httpURL, bodyReader) if err != nil { - // We must observe the contract that shutdown returns ErrClientShutdown. - if errors.Is(err, context.Canceled) && errors.Is(context.Cause(ctx), ErrClientShutdown) { - err = ErrClientShutdown - } jReq.responseChan <- &Response{result: nil, err: err} return } diff --git a/rpcclient/infrastructure_test.go b/rpcclient/infrastructure_test.go index 9d88b6338a..214348985f 100644 --- a/rpcclient/infrastructure_test.go +++ b/rpcclient/infrastructure_test.go @@ -161,12 +161,14 @@ func TestHTTPPostShutdownInterruptsPendingRequest(t *testing.T) { future := client.GetBlockCountAsync() + // Ensure the server sees the request before we initiate shutdown. select { case <-requestAccepted: case <-time.After(2 * time.Second): t.Fatalf("server did not accept client connection") } + // The request should remain pending until shutdown is requested. select { case <-future: t.Fatalf("expected request to remain pending until shutdown") @@ -181,6 +183,7 @@ func TestHTTPPostShutdownInterruptsPendingRequest(t *testing.T) { close(waitDone) }() + // Wait for shutdown to complete before asserting the final error. select { case <-waitDone: case <-time.After(5 * time.Second):