Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions rpcclient/infrastructure.go
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ out:
// handleSendPostMessage handles performing the passed HTTP request, reading the
// result, unmarshalling it, and delivering the unmarshalled result to the
// provided response channel.
func (c *Client) handleSendPostMessage(jReq *jsonRequest) {
func (c *Client) handleSendPostMessage(ctx context.Context, jReq *jsonRequest) {
var (
lastErr error
backoff time.Duration
Expand All @@ -782,11 +782,12 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) {
}

tries := 10
retryloop:
for i := 0; i < tries; i++ {
var httpReq *http.Request

bodyReader := bytes.NewReader(jReq.marshalledJSON)
httpReq, err = http.NewRequest("POST", httpURL, bodyReader)
httpReq, err = http.NewRequestWithContext(ctx, "POST", httpURL, bodyReader)
if err != nil {
jReq.responseChan <- &Response{result: nil, err: err}
return
Expand All @@ -812,6 +813,11 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) {
break
}

// We must observe the contract that shutdown returns ErrClientShutdown.
if errors.Is(err, context.Canceled) && errors.Is(context.Cause(ctx), ErrClientShutdown) {
err = ErrClientShutdown
}

// Save the last error for the case where we backoff further,
// retry and get an invalid response but no error. If this
// happens the saved last error will be used to enrich the error
Expand All @@ -830,8 +836,13 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) {
select {
case <-time.After(backoff):

case <-c.shutdown:
return
case <-ctx.Done():
err = ctx.Err()
// maintain our contract: shutdown errors are ErrClientShutdown
if errors.Is(context.Cause(ctx), ErrClientShutdown) {
err = ErrClientShutdown
}
break retryloop
}
}
if err != nil {
Expand Down Expand Up @@ -891,30 +902,28 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) {
// in HTTP POST mode. It uses a buffered channel to serialize output messages
// while allowing the sender to continue running asynchronously. It must be run
// as a goroutine.
func (c *Client) sendPostHandler() {
func (c *Client) sendPostHandler(ctx context.Context) {
out:
for {
// Send any messages ready for send until the shutdown channel
// is closed.
select {
case jReq := <-c.sendPostChan:
c.handleSendPostMessage(jReq)
c.handleSendPostMessage(ctx, jReq)

case <-c.shutdown:
case <-ctx.Done():
break out
}
}

err := context.Cause(ctx)
// Drain any wait channels before exiting so nothing is left waiting
// around to send.
cleanup:
for {
select {
case jReq := <-c.sendPostChan:
jReq.responseChan <- &Response{
result: nil,
err: ErrClientShutdown,
}
jReq.responseChan <- &Response{result: nil, err: err}

default:
break cleanup
Expand Down Expand Up @@ -1178,8 +1187,13 @@ func (c *Client) start() {
// Start the I/O processing handlers depending on whether the client is
// in HTTP POST mode or the default websocket mode.
if c.config.HTTPPostMode {
ctx, cancel := context.WithCancelCause(context.Background())
c.wg.Add(1)
go c.sendPostHandler()
go c.sendPostHandler(ctx)
go func() {
<-c.shutdown
cancel(ErrClientShutdown)
}()
} else {
c.wg.Add(3)
go func() {
Expand Down
86 changes: 86 additions & 0 deletions rpcclient/infrastructure_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package rpcclient

import (
"io"
"net"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -108,3 +112,85 @@ func TestParseAddressString(t *testing.T) {
})
}
}

// TestHTTPPostShutdownInterruptsPendingRequest ensures that a client operating
// in HTTP POST mode can interrupt an in-flight request during shutdown.
func TestHTTPPostShutdownInterruptsPendingRequest(t *testing.T) {
t.Parallel()

listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

requestAccepted := make(chan struct{})
serverDone := make(chan struct{})

go func() {
defer close(serverDone)

conn, err := listener.Accept()
if err != nil {
return
}
defer func() {
err := conn.Close()
assert.NoError(t, err)
}()

close(requestAccepted)

_, _ = io.Copy(io.Discard, conn)
}()

t.Cleanup(func() {
err := listener.Close()
require.NoError(t, err)
<-serverDone
})

connCfg := &ConnConfig{
Host: listener.Addr().String(),
User: "user",
Pass: "pass",
DisableTLS: true,
HTTPPostMode: true,
}

client, err := New(connCfg, nil)
require.NoError(t, err)
t.Cleanup(client.Shutdown)

future := client.GetBlockCountAsync()

// Ensure the server sees the request before we initiate shutdown.
select {
case <-requestAccepted:
case <-time.After(2 * time.Second):
t.Fatalf("server did not accept client connection")
}

// The request should remain pending until shutdown is requested.
select {
case <-future:
t.Fatalf("expected request to remain pending until shutdown")
case <-time.After(100 * time.Millisecond):
}

client.Shutdown()

waitDone := make(chan struct{})
go func() {
client.WaitForShutdown()
close(waitDone)
}()

// Wait for shutdown to complete before asserting the final error.
select {
case <-waitDone:
case <-time.After(5 * time.Second):
t.Fatalf("client shutdown did not complete")
}

result, err := future.Receive()
require.Zero(t, result)
require.ErrorContains(t, err, ErrClientShutdown.Error())
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose to add more test coverage of the edge cases:

  • Cancellation while waiting in retry backoff path (ctx.Done() select branch).
  • Cancellation on final retry attempt (i == tries-1).
  • Cancellation during response body read after a successful Do.

Loading