diff --git a/rpcclient/infrastructure.go b/rpcclient/infrastructure.go index e34fc1827e..b0a961be99 100644 --- a/rpcclient/infrastructure.go +++ b/rpcclient/infrastructure.go @@ -766,7 +766,7 @@ out: // handleSendPostMessage handles performing the passed HTTP request, reading the // result, unmarshalling it, and delivering the unmarshalled result to the // provided response channel. -func (c *Client) handleSendPostMessage(jReq *jsonRequest) { +func (c *Client) handleSendPostMessage(ctx context.Context, jReq *jsonRequest) { var ( lastErr error backoff time.Duration @@ -782,11 +782,12 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) { } tries := 10 +retryloop: for i := 0; i < tries; i++ { var httpReq *http.Request bodyReader := bytes.NewReader(jReq.marshalledJSON) - httpReq, err = http.NewRequest("POST", httpURL, bodyReader) + httpReq, err = http.NewRequestWithContext(ctx, "POST", httpURL, bodyReader) if err != nil { jReq.responseChan <- &Response{result: nil, err: err} return @@ -812,6 +813,11 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) { break } + // We must observe the contract that shutdown returns ErrClientShutdown. + if errors.Is(err, context.Canceled) && errors.Is(context.Cause(ctx), ErrClientShutdown) { + err = ErrClientShutdown + } + // Save the last error for the case where we backoff further, // retry and get an invalid response but no error. If this // happens the saved last error will be used to enrich the error @@ -830,8 +836,13 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) { select { case <-time.After(backoff): - case <-c.shutdown: - return + case <-ctx.Done(): + err = ctx.Err() + // maintain our contract: shutdown errors are ErrClientShutdown + if errors.Is(context.Cause(ctx), ErrClientShutdown) { + err = ErrClientShutdown + } + break retryloop } } if err != nil { @@ -891,30 +902,28 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) { // in HTTP POST mode. It uses a buffered channel to serialize output messages // while allowing the sender to continue running asynchronously. It must be run // as a goroutine. -func (c *Client) sendPostHandler() { +func (c *Client) sendPostHandler(ctx context.Context) { out: for { // Send any messages ready for send until the shutdown channel // is closed. select { case jReq := <-c.sendPostChan: - c.handleSendPostMessage(jReq) + c.handleSendPostMessage(ctx, jReq) - case <-c.shutdown: + case <-ctx.Done(): break out } } + err := context.Cause(ctx) // Drain any wait channels before exiting so nothing is left waiting // around to send. cleanup: for { select { case jReq := <-c.sendPostChan: - jReq.responseChan <- &Response{ - result: nil, - err: ErrClientShutdown, - } + jReq.responseChan <- &Response{result: nil, err: err} default: break cleanup @@ -1178,8 +1187,13 @@ func (c *Client) start() { // Start the I/O processing handlers depending on whether the client is // in HTTP POST mode or the default websocket mode. if c.config.HTTPPostMode { + ctx, cancel := context.WithCancelCause(context.Background()) c.wg.Add(1) - go c.sendPostHandler() + go c.sendPostHandler(ctx) + go func() { + <-c.shutdown + cancel(ErrClientShutdown) + }() } else { c.wg.Add(3) go func() { diff --git a/rpcclient/infrastructure_test.go b/rpcclient/infrastructure_test.go index 8416b7ad3c..214348985f 100644 --- a/rpcclient/infrastructure_test.go +++ b/rpcclient/infrastructure_test.go @@ -1,8 +1,12 @@ package rpcclient import ( + "io" + "net" "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -108,3 +112,85 @@ func TestParseAddressString(t *testing.T) { }) } } + +// TestHTTPPostShutdownInterruptsPendingRequest ensures that a client operating +// in HTTP POST mode can interrupt an in-flight request during shutdown. +func TestHTTPPostShutdownInterruptsPendingRequest(t *testing.T) { + t.Parallel() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + requestAccepted := make(chan struct{}) + serverDone := make(chan struct{}) + + go func() { + defer close(serverDone) + + conn, err := listener.Accept() + if err != nil { + return + } + defer func() { + err := conn.Close() + assert.NoError(t, err) + }() + + close(requestAccepted) + + _, _ = io.Copy(io.Discard, conn) + }() + + t.Cleanup(func() { + err := listener.Close() + require.NoError(t, err) + <-serverDone + }) + + connCfg := &ConnConfig{ + Host: listener.Addr().String(), + User: "user", + Pass: "pass", + DisableTLS: true, + HTTPPostMode: true, + } + + client, err := New(connCfg, nil) + require.NoError(t, err) + t.Cleanup(client.Shutdown) + + future := client.GetBlockCountAsync() + + // Ensure the server sees the request before we initiate shutdown. + select { + case <-requestAccepted: + case <-time.After(2 * time.Second): + t.Fatalf("server did not accept client connection") + } + + // The request should remain pending until shutdown is requested. + select { + case <-future: + t.Fatalf("expected request to remain pending until shutdown") + case <-time.After(100 * time.Millisecond): + } + + client.Shutdown() + + waitDone := make(chan struct{}) + go func() { + client.WaitForShutdown() + close(waitDone) + }() + + // Wait for shutdown to complete before asserting the final error. + select { + case <-waitDone: + case <-time.After(5 * time.Second): + t.Fatalf("client shutdown did not complete") + } + + result, err := future.Receive() + require.Zero(t, result) + require.ErrorContains(t, err, ErrClientShutdown.Error()) +}