From fbade9487f280f94b60a1675712f506ce9074be4 Mon Sep 17 00:00:00 2001 From: "yuxuan.wang1" Date: Tue, 2 Dec 2025 20:07:47 +0800 Subject: [PATCH] feat: support ttheader streaming timeout and unify streaming timeout control --- client/callopt/streamcall/call_options.go | 23 +- .../callopt/streamcall/call_options_test.go | 32 +- client/client.go | 12 + client/client_test.go | 74 +- client/option_stream.go | 19 +- client/stream.go | 143 +++- client/stream_test.go | 390 ++++++++- go.mod | 2 + go.sum | 5 +- internal/client/option.go | 2 + .../utils/contextwatcher/contextwatcher.go | 66 ++ .../contextwatcher/contextwatcher_test.go | 467 +++++++++++ pkg/kerrors/kerrors_test.go | 1 + pkg/kerrors/streaming_errors.go | 4 + pkg/remote/trans/nphttp2/client_conn.go | 13 + pkg/remote/trans/nphttp2/status/status.go | 35 +- .../trans/nphttp2/status/status_test.go | 60 +- pkg/remote/trans/nphttp2/stream.go | 4 + pkg/remote/trans/ttstream/client_handler.go | 20 +- .../trans/ttstream/client_handler_test.go | 49 ++ pkg/remote/trans/ttstream/exception.go | 6 +- .../trans/ttstream/exception_builder.go | 58 ++ pkg/remote/trans/ttstream/exception_test.go | 62 ++ pkg/remote/trans/ttstream/stream.go | 21 +- pkg/remote/trans/ttstream/stream_client.go | 32 +- pkg/remote/trans/ttstream/stream_server.go | 30 +- pkg/remote/trans/ttstream/stream_test.go | 22 - pkg/remote/trans/ttstream/transport_server.go | 28 +- .../trans/ttstream/transport_server_test.go | 65 ++ .../trans/ttstream/transport_timeout_test.go | 738 ++++++++++++++++++ pkg/rpcinfo/interface.go | 2 + pkg/rpcinfo/mocks_test.go | 8 + pkg/rpcinfo/mutable.go | 2 + pkg/rpcinfo/rpcconfig.go | 18 + pkg/rpcinfo/rpcconfig_test.go | 18 + pkg/streaming/streamx.go | 3 + pkg/streaming/types/doc.go | 43 + pkg/streaming/types/timeout.go | 35 + pkg/streaming/types/timeout_test.go | 60 ++ 39 files changed, 2540 insertions(+), 132 deletions(-) create mode 100644 internal/utils/contextwatcher/contextwatcher.go create mode 100644 internal/utils/contextwatcher/contextwatcher_test.go create mode 100644 pkg/remote/trans/ttstream/client_handler_test.go create mode 100644 pkg/remote/trans/ttstream/exception_builder.go create mode 100644 pkg/remote/trans/ttstream/transport_server_test.go create mode 100644 pkg/remote/trans/ttstream/transport_timeout_test.go create mode 100644 pkg/streaming/types/doc.go create mode 100644 pkg/streaming/types/timeout.go create mode 100644 pkg/streaming/types/timeout_test.go diff --git a/client/callopt/streamcall/call_options.go b/client/callopt/streamcall/call_options.go index 4380248f7e..402712248f 100644 --- a/client/callopt/streamcall/call_options.go +++ b/client/callopt/streamcall/call_options.go @@ -48,7 +48,6 @@ func WithTag(key, val string) Option { } // WithRecvTimeout add recv timeout for stream.Recv function. -// NOTICE: ONLY effective for ttheader streaming protocol for now. func WithRecvTimeout(d time.Duration) Option { return Option{f: func(o *callopt.CallOptions, di *strings.Builder) { di.WriteString("WithRecvTimeout(") @@ -58,3 +57,25 @@ func WithRecvTimeout(d time.Duration) Option { o.StreamOptions.RecvTimeout = d }} } + +// WithSendTimeout add send timeout for stream.Send function. +func WithSendTimeout(d time.Duration) Option { + return Option{f: func(o *callopt.CallOptions, di *strings.Builder) { + di.WriteString("WithSendTimeout(") + di.WriteString(d.String()) + di.WriteString(")") + + o.StreamOptions.SendTimeout = d + }} +} + +// WithStreamTimeout add timeout for whole stream. +func WithStreamTimeout(d time.Duration) Option { + return Option{f: func(o *callopt.CallOptions, di *strings.Builder) { + di.WriteString("WithStreamTimeout(") + di.WriteString(d.String()) + di.WriteString(")") + + o.StreamOptions.StreamTimeout = d + }} +} diff --git a/client/callopt/streamcall/call_options_test.go b/client/callopt/streamcall/call_options_test.go index f7eefa8351..2ff37fa937 100644 --- a/client/callopt/streamcall/call_options_test.go +++ b/client/callopt/streamcall/call_options_test.go @@ -25,11 +25,29 @@ import ( "github.com/cloudwego/kitex/internal/test" ) -func TestWithRecvTimeout(t *testing.T) { - var sb strings.Builder - callOpts := callopt.CallOptions{} - testTimeout := 1 * time.Second - WithRecvTimeout(testTimeout).f(&callOpts, &sb) - test.Assert(t, callOpts.StreamOptions.RecvTimeout == testTimeout) - test.Assert(t, sb.String() == "WithRecvTimeout(1s)") +func Test_streamCallTimeoutCallOptions(t *testing.T) { + t.Run("WithRecvTimeout", func(t *testing.T) { + var sb strings.Builder + callOpts := callopt.CallOptions{} + testTimeout := 1 * time.Second + WithRecvTimeout(testTimeout).f(&callOpts, &sb) + test.Assert(t, callOpts.StreamOptions.RecvTimeout == testTimeout) + test.Assert(t, sb.String() == "WithRecvTimeout(1s)") + }) + t.Run("WithSendTimeout", func(t *testing.T) { + var sb strings.Builder + callOpts := callopt.CallOptions{} + testTimeout := 1 * time.Second + WithSendTimeout(testTimeout).f(&callOpts, &sb) + test.Assert(t, callOpts.StreamOptions.SendTimeout == testTimeout) + test.Assert(t, sb.String() == "WithSendTimeout(1s)") + }) + t.Run("WithStreamTimeout", func(t *testing.T) { + var sb strings.Builder + callOpts := callopt.CallOptions{} + testTimeout := 1 * time.Second + WithStreamTimeout(testTimeout).f(&callOpts, &sb) + test.Assert(t, callOpts.StreamOptions.StreamTimeout == testTimeout) + test.Assert(t, sb.String() == "WithStreamTimeout(1s)") + }) } diff --git a/client/client.go b/client/client.go index a57c54fb5d..f613247b77 100644 --- a/client/client.go +++ b/client/client.go @@ -827,6 +827,12 @@ func initRPCInfo(ctx context.Context, method string, opt *client.Options, svcInf if sopt.RecvTimeout > 0 { cfg.SetStreamRecvTimeout(sopt.RecvTimeout) } + if sopt.SendTimeout > 0 { + cfg.SetStreamSendTimeout(sopt.SendTimeout) + } + if sopt.StreamTimeout > 0 { + cfg.SetStreamTimeout(sopt.StreamTimeout) + } ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri) @@ -838,6 +844,12 @@ func initRPCInfo(ctx context.Context, method string, opt *client.Options, svcInf if callOpts.StreamOptions.RecvTimeout != 0 { cfg.SetStreamRecvTimeout(callOpts.StreamOptions.RecvTimeout) } + if callOpts.StreamOptions.SendTimeout != 0 { + cfg.SetStreamSendTimeout(callOpts.StreamOptions.SendTimeout) + } + if callOpts.StreamOptions.StreamTimeout != 0 { + cfg.SetStreamTimeout(callOpts.StreamOptions.StreamTimeout) + } } return ctx, ri, callOpts diff --git a/client/client_test.go b/client/client_test.go index 7a5f32baac..6c7a7b20f1 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1266,22 +1266,66 @@ func Test_initRPCInfoWithStreamClientCallOption(t *testing.T) { mtd := mocks.MockMethod svcInfo := mocks.ServiceInfo() callOptTimeout := 1 * time.Second + cliTimeout := 2 * time.Second testService := "testService" - // config call option - cliIntf, err := NewClient(svcInfo, WithTransportProtocol(transport.TTHeaderStreaming), WithDestService(testService)) - test.Assert(t, err == nil, err) - cli := cliIntf.(*kcFinalizerClient) - ctx := NewCtxWithCallOptions(context.Background(), streamcall.GetCallOptions([]streamcall.Option{streamcall.WithRecvTimeout(callOptTimeout)})) - _, ri, _ := cli.initRPCInfo(ctx, mtd, 0, nil, true) - test.Assert(t, ri.Config().StreamRecvTimeout() == callOptTimeout) + testcases := []struct { + desc string + cliOpt StreamOption + callOpt streamcall.Option + verifyFunc func(t *testing.T, ri rpcinfo.RPCInfo, isPureCli bool) + }{ + { + desc: "stream recv timeout", + cliOpt: WithStreamRecvTimeout(cliTimeout), + callOpt: streamcall.WithRecvTimeout(callOptTimeout), + verifyFunc: func(t *testing.T, ri rpcinfo.RPCInfo, isPureCli bool) { + if isPureCli { + test.Assert(t, ri.Config().StreamRecvTimeout() == cliTimeout, ri) + } else { + test.Assert(t, ri.Config().StreamRecvTimeout() == callOptTimeout, ri) + } + }, + }, + { + desc: "stream send timeout", + cliOpt: WithStreamSendTimeout(cliTimeout), + callOpt: streamcall.WithSendTimeout(callOptTimeout), + verifyFunc: func(t *testing.T, ri rpcinfo.RPCInfo, isPureCli bool) { + if isPureCli { + test.Assert(t, ri.Config().StreamSendTimeout() == cliTimeout, ri) + } else { + test.Assert(t, ri.Config().StreamSendTimeout() == callOptTimeout, ri) + } + }, + }, + { + desc: "stream timeout", + cliOpt: WithStreamTimeout(cliTimeout), + callOpt: streamcall.WithStreamTimeout(callOptTimeout), + verifyFunc: func(t *testing.T, ri rpcinfo.RPCInfo, isPureCli bool) { + if isPureCli { + test.Assert(t, ri.Config().StreamTimeout() == cliTimeout, ri) + } else { + test.Assert(t, ri.Config().StreamTimeout() == callOptTimeout, ri) + } + }, + }, + } - // call option has higher priority - cliTimeout := 2 * time.Second - cliIntf, err = NewClient(svcInfo, WithTransportProtocol(transport.TTHeaderStreaming), WithStreamOptions(WithStreamRecvTimeout(cliTimeout)), WithDestService(testService)) - test.Assert(t, err == nil, err) - cli = cliIntf.(*kcFinalizerClient) - ctx = NewCtxWithCallOptions(context.Background(), streamcall.GetCallOptions([]streamcall.Option{streamcall.WithRecvTimeout(callOptTimeout)})) - _, ri, _ = cli.initRPCInfo(ctx, mtd, 0, nil, true) - test.Assert(t, ri.Config().StreamRecvTimeout() == callOptTimeout) + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + // config client option + cliIntf, err := NewClient(svcInfo, WithTransportProtocol(transport.TTHeaderStreaming), WithDestService(testService), WithStreamOptions(tc.cliOpt)) + test.Assert(t, err == nil, err) + cli := cliIntf.(*kcFinalizerClient) + _, ri, _ := cli.initRPCInfo(ctx, mtd, 0, nil, true) + tc.verifyFunc(t, ri, true) + + // call option has higher priority + ctx = NewCtxWithCallOptions(context.Background(), streamcall.GetCallOptions([]streamcall.Option{tc.callOpt})) + _, ri, _ = cli.initRPCInfo(ctx, mtd, 0, nil, true) + tc.verifyFunc(t, ri, false) + }) + } } diff --git a/client/option_stream.go b/client/option_stream.go index 0afd1a99fa..9f715b75ca 100644 --- a/client/option_stream.go +++ b/client/option_stream.go @@ -40,7 +40,6 @@ func WithStreamOptions(opts ...StreamOption) Option { } // WithStreamRecvTimeout add recv timeout for stream.Recv function. -// NOTICE: ONLY effective for ttheader streaming protocol for now. func WithStreamRecvTimeout(d time.Duration) StreamOption { return StreamOption{F: func(o *client.StreamOptions, di *utils.Slice) { di.Push(fmt.Sprintf("WithStreamRecvTimeout(%dms)", d.Milliseconds())) @@ -49,6 +48,24 @@ func WithStreamRecvTimeout(d time.Duration) StreamOption { }} } +// WithStreamSendTimeout add send timeout for stream.Send function. +func WithStreamSendTimeout(d time.Duration) StreamOption { + return StreamOption{F: func(o *client.StreamOptions, di *utils.Slice) { + di.Push(fmt.Sprintf("WithStreamSendTimeout(%dms)", d.Milliseconds())) + + o.SendTimeout = d + }} +} + +// WithStreamTimeout add timeout for whole stream. +func WithStreamTimeout(d time.Duration) StreamOption { + return StreamOption{F: func(o *client.StreamOptions, di *utils.Slice) { + di.Push(fmt.Sprintf("WithStreamTimeout(%dms)", d.Milliseconds())) + + o.StreamTimeout = d + }} +} + // WithStreamMiddleware add middleware for stream. func WithStreamMiddleware(mw cep.StreamMiddleware) StreamOption { return StreamOption{F: func(o *StreamOptions, di *utils.Slice) { diff --git a/client/stream.go b/client/stream.go index d1741fc9e7..49aab18349 100644 --- a/client/stream.go +++ b/client/stream.go @@ -21,6 +21,9 @@ import ( "fmt" "io" "sync/atomic" + "time" + + "github.com/bytedance/gopkg/util/gopool" internal_stream "github.com/cloudwego/kitex/internal/stream" "github.com/cloudwego/kitex/pkg/endpoint" @@ -28,11 +31,16 @@ import ( "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/remotecli" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" + "github.com/cloudwego/kitex/pkg/remote/trans/ttstream" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/stats" "github.com/cloudwego/kitex/pkg/streaming" + streaming_types "github.com/cloudwego/kitex/pkg/streaming/types" + "github.com/cloudwego/kitex/transport" ) // Streaming client streaming interface for code generate @@ -178,12 +186,21 @@ func (kc *kClient) invokeStreamingEndpoint() (endpoint.Endpoint, error) { return func(ctx context.Context, req, resp interface{}) (err error) { // req and resp as &streaming.Stream ri := rpcinfo.GetRPCInfo(ctx) + var cancel context.CancelFunc + // apply stream timeout + if tm := ri.Config().StreamTimeout(); tm > 0 { + ctx, cancel = context.WithTimeout(ctx, tm) + } st, scm, err := remotecli.NewStream(ctx, ri, handler, kc.opt.RemoteOpt) if err != nil { + if cancel != nil { + cancel() + } return } - clientStream := newStream(ctx, st, scm, kc, ri, ri.Invocation().MethodInfo().StreamingMode(), + clientStream := newStream(ctx, cancel, + st, scm, kc, ri, ri.Invocation().MethodInfo().StreamingMode(), sendEP, recvEP, kc.opt.StreamOptions.EventHandler, grpcSendEP, grpcRecvEP) rresp := resp.(*streaming.Result) rresp.ClientStream = clientStream @@ -201,11 +218,15 @@ type stream struct { ri rpcinfo.RPCInfo eventHandler internal_stream.StreamEventHandler - recv cep.StreamRecvEndpoint - send cep.StreamSendEndpoint + recv cep.StreamRecvEndpoint + recvTm time.Duration + send cep.StreamSendEndpoint + sendTm time.Duration streamingMode serviceinfo.StreamingMode finished uint32 + isGRPC bool + cancelFunc context.CancelFunc } var ( @@ -214,9 +235,12 @@ var ( _ streaming.WithDoFinish = (*grpcStream)(nil) ) -func newStream(ctx context.Context, s streaming.ClientStream, scm *remotecli.StreamConnManager, kc *kClient, ri rpcinfo.RPCInfo, mode serviceinfo.StreamingMode, +func newStream(ctx context.Context, cancel context.CancelFunc, s streaming.ClientStream, scm *remotecli.StreamConnManager, kc *kClient, ri rpcinfo.RPCInfo, mode serviceinfo.StreamingMode, sendEP cep.StreamSendEndpoint, recvEP cep.StreamRecvEndpoint, eventHandler internal_stream.StreamEventHandler, grpcSendEP endpoint.SendEndpoint, grpcRecvEP endpoint.RecvEndpoint, ) *stream { + recvTm := ri.Config().StreamRecvTimeout() + sendTm := ri.Config().StreamSendTimeout() + isGRPC := ri.Config().TransportProtocol()&transport.GRPC != 0 st := &stream{ ClientStream: s, ctx: ctx, @@ -225,12 +249,16 @@ func newStream(ctx context.Context, s streaming.ClientStream, scm *remotecli.Str ri: ri, streamingMode: mode, recv: recvEP, + recvTm: recvTm, send: sendEP, + sendTm: sendTm, eventHandler: eventHandler, + isGRPC: isGRPC, + cancelFunc: cancel, } if grpcStreamGetter, ok := s.(streaming.GRPCStreamGetter); ok { if grpcStream := grpcStreamGetter.GetGRPCStream(); grpcStream != nil { - st.grpcStream = newGRPCStream(grpcStream, grpcSendEP, grpcRecvEP) + st.grpcStream = newGRPCStream(grpcStream, grpcSendEP, sendTm, grpcRecvEP, recvTm) st.grpcStream.st = st } } @@ -261,7 +289,7 @@ func (s *stream) RecvMsg(ctx context.Context, m interface{}) (err error) { ctx = rpcinfo.NewCtxWithRPCInfo(ctx, s.ri) } } - err = s.recv(ctx, s.ClientStream, m) + err = s.recvWithTimeout(ctx, m) if err == nil { // BizStatusErr is returned by the server handle, meaning the stream is ended; // And it should be returned to the calling business code for error handling @@ -276,6 +304,22 @@ func (s *stream) RecvMsg(ctx context.Context, m interface{}) (err error) { return } +func (s *stream) recvWithTimeout(ctx context.Context, m interface{}) error { + return callWithTimeout(s.recvTm, + func() error { + return s.recv(ctx, s.ClientStream, m) + }, + func(tm time.Duration) error { + if s.isGRPC { + return status.NewTimeoutStatus(codes.DeadlineExceeded, fmt.Sprintf(recvTimeoutErrTpl, tm), streaming_types.StreamRecvTimeout).Err() + } + return ttstream.NewTimeoutException(streaming_types.StreamRecvTimeout, remote.Client, tm) + }, + func(err error) { + s.Cancel(err) + }) +} + // SendMsg sends a message to the server. // If an error is returned, stream.DoFinish() will be called to record the end of stream func (s *stream) SendMsg(ctx context.Context, m interface{}) (err error) { @@ -286,7 +330,7 @@ func (s *stream) SendMsg(ctx context.Context, m interface{}) (err error) { ctx = rpcinfo.NewCtxWithRPCInfo(ctx, s.ri) } } - err = s.send(ctx, s.ClientStream, m) + err = s.sendWithTimeout(ctx, m) if s.eventHandler != nil { s.eventHandler(s.ctx, stats.StreamSend, err) } @@ -296,6 +340,22 @@ func (s *stream) SendMsg(ctx context.Context, m interface{}) (err error) { return } +func (s *stream) sendWithTimeout(ctx context.Context, m interface{}) error { + return callWithTimeout(s.sendTm, + func() error { + return s.send(ctx, s.ClientStream, m) + }, + func(tm time.Duration) error { + if s.isGRPC { + return status.NewTimeoutStatus(codes.DeadlineExceeded, fmt.Sprintf(sendTimeoutErrTpl, tm), streaming_types.StreamSendTimeout).Err() + } + return ttstream.NewTimeoutException(streaming_types.StreamSendTimeout, remote.Client, tm) + }, + func(err error) { + s.Cancel(err) + }) +} + // DoFinish implements the streaming.WithDoFinish interface, and it records the end of stream // It will release the connection. func (s *stream) DoFinish(err error) { @@ -303,6 +363,10 @@ func (s *stream) DoFinish(err error) { // already called return } + // release stream timeout cancel + if s.cancelFunc != nil { + s.cancelFunc() + } if !isRPCError(err) { // only rpc errors are reported err = nil @@ -320,11 +384,18 @@ func (s *stream) GetGRPCStream() streaming.Stream { return s.grpcStream } -func newGRPCStream(st streaming.Stream, sendEP endpoint.SendEndpoint, recvEP endpoint.RecvEndpoint) *grpcStream { +const ( + recvTimeoutErrTpl = "stream Recv timeout, timeout config=%v" + sendTimeoutErrTpl = "stream Send timeout, timeout config=%v" +) + +func newGRPCStream(st streaming.Stream, sendEP endpoint.SendEndpoint, sendTm time.Duration, recvEP endpoint.RecvEndpoint, recvTm time.Duration) *grpcStream { return &grpcStream{ Stream: st, sendEndpoint: sendEP, + sendTm: sendTm, recvEndpoint: recvEP, + recvTm: recvTm, } } @@ -334,7 +405,9 @@ type grpcStream struct { st *stream sendEndpoint endpoint.SendEndpoint + sendTm time.Duration recvEndpoint endpoint.RecvEndpoint + recvTm time.Duration } // Header returns the header metadata sent by the server if any. @@ -347,7 +420,7 @@ func (s *grpcStream) Header() (md metadata.MD, err error) { } func (s *grpcStream) RecvMsg(m interface{}) (err error) { - err = s.recvEndpoint(s.Stream, m) + err = s.recvWithTimeout(m) if err == nil { // BizStatusErr is returned by the server handle, meaning the stream is ended; // And it should be returned to the calling business code for error handling @@ -362,8 +435,22 @@ func (s *grpcStream) RecvMsg(m interface{}) (err error) { return } +func (s *grpcStream) recvWithTimeout(m interface{}) error { + return callWithTimeout(s.recvTm, + func() error { + return s.recvEndpoint(s.Stream, m) + }, + func(tm time.Duration) error { + return status.NewTimeoutStatus(codes.DeadlineExceeded, fmt.Sprintf(recvTimeoutErrTpl, tm), streaming_types.StreamRecvTimeout).Err() + }, + func(err error) { + s.st.Cancel(err) + }, + ) +} + func (s *grpcStream) SendMsg(m interface{}) (err error) { - err = s.sendEndpoint(s.Stream, m) + err = s.sendWithTimeout(m) if s.st.eventHandler != nil { s.st.eventHandler(s.st.ctx, stats.StreamSend, err) } @@ -373,6 +460,20 @@ func (s *grpcStream) SendMsg(m interface{}) (err error) { return } +func (s *grpcStream) sendWithTimeout(m interface{}) error { + return callWithTimeout(s.sendTm, + func() error { + return s.sendEndpoint(s.Stream, m) + }, + func(tm time.Duration) error { + return status.NewTimeoutStatus(codes.DeadlineExceeded, fmt.Sprintf(sendTimeoutErrTpl, tm), streaming_types.StreamSendTimeout).Err() + }, + func(err error) { + s.st.Cancel(err) + }, + ) +} + func (s *grpcStream) DoFinish(err error) { s.st.DoFinish(err) } @@ -389,6 +490,28 @@ func isRPCError(err error) bool { return !isBizStatusError } +func callWithTimeout(tm time.Duration, call func() error, buildTmErr func(time.Duration) error, cancel func(error)) error { + if tm <= 0 { + return call() + } + + timer := time.NewTimer(tm) + defer timer.Stop() + finishChan := make(chan error, 1) + gopool.Go(func() { + callErr := call() + finishChan <- callErr + }) + select { + case <-timer.C: + err := buildTmErr(tm) + cancel(err) + return err + case callErr := <-finishChan: + return callErr + } +} + var ( recvEndpoint cep.StreamRecvEndpoint = func(ctx context.Context, stream streaming.ClientStream, m interface{}) error { return stream.RecvMsg(ctx, m) diff --git a/client/stream_test.go b/client/stream_test.go index f90f470b6b..9991307f22 100644 --- a/client/stream_test.go +++ b/client/stream_test.go @@ -21,7 +21,9 @@ import ( "errors" "fmt" "io" + "strings" "testing" + "time" "github.com/golang/mock/gomock" @@ -34,10 +36,14 @@ import ( "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/remotecli" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" + "github.com/cloudwego/kitex/pkg/remote/trans/ttstream" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/stats" "github.com/cloudwego/kitex/pkg/streaming" + streaming_types "github.com/cloudwego/kitex/pkg/streaming/types" "github.com/cloudwego/kitex/pkg/utils" "github.com/cloudwego/kitex/transport" ) @@ -131,7 +137,7 @@ func TestStreaming(t *testing.T) { connpool.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(conn, nil) cliInfo.ConnPool = connpool s, cr, _ := remotecli.NewStream(ctx, mockRPCInfo, new(mocks.MockCliTransHandler), cliInfo) - stream := newStream(ctx, + stream := newStream(ctx, nil, s, cr, kc, mockRPCInfo, serviceinfo.StreamingBidirectional, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { return stream.SendMsg(ctx, message) @@ -208,11 +214,13 @@ func TestClosedClient(t *testing.T) { type mockStream struct { streaming.ClientStream - ctx context.Context - close func() error - header func() (streaming.Header, error) - recv func(ctx context.Context, msg interface{}) error - send func(ctx context.Context, msg interface{}) error + ctx context.Context + close func() error + header func() (streaming.Header, error) + recv func(ctx context.Context, msg interface{}) error + send func(ctx context.Context, msg interface{}) error + cancel func(err error) + gRPCStream *mockGRPCStream } func (s *mockStream) Context() context.Context { @@ -235,10 +243,22 @@ func (s *mockStream) CloseSend(ctx context.Context) error { return s.close() } +func (s *mockStream) Cancel(err error) { + s.cancel(err) +} + +func (s *mockStream) GetGRPCStream() streaming.Stream { + return s.gRPCStream +} + +type mockGRPCStream struct { + streaming.Stream +} + func Test_newStream(t *testing.T) { sendErr := errors.New("send error") recvErr := errors.New("recv error") - ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, rpcinfo.NewRPCStats()) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) st := &mockStream{ ctx: rpcinfo.NewCtxWithRPCInfo(context.Background(), ri), } @@ -252,7 +272,7 @@ func Test_newStream(t *testing.T) { cr := mock_remote.NewMockConnReleaser(ctrl) cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(1) scr := remotecli.NewStreamConnManager(cr) - s := newStream(ctx, + s := newStream(ctx, nil, st, scr, kc, @@ -306,7 +326,8 @@ func Test_stream_Header(t *testing.T) { cr := mock_remote.NewMockConnReleaser(ctrl) cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(0) scr := remotecli.NewStreamConnManager(cr) - s := newStream(ctx, st, scr, &kClient{}, nil, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, rpcinfo.NewRPCConfig(), nil) + s := newStream(ctx, nil, st, scr, &kClient{}, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil) md, err := s.Header() test.Assert(t, err == nil) @@ -316,7 +337,7 @@ func Test_stream_Header(t *testing.T) { t.Run("error", func(t *testing.T) { headerErr := errors.New("header error") - ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, rpcinfo.NewRPCStats()) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) st := &mockStream{ header: func() (streaming.Header, error) { return nil, headerErr @@ -339,7 +360,7 @@ func Test_stream_Header(t *testing.T) { cr := mock_remote.NewMockConnReleaser(ctrl) cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(1) scr := remotecli.NewStreamConnManager(cr) - s := newStream(ctx, st, scr, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil) + s := newStream(ctx, nil, st, scr, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil) md, err := s.Header() test.Assert(t, err == headerErr) @@ -356,8 +377,8 @@ func Test_stream_RecvMsg(t *testing.T) { cr := mock_remote.NewMockConnReleaser(ctrl) cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(0) scm := remotecli.NewStreamConnManager(cr) - mockRPCInfo := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), nil, nil) - s := newStream(ctx, &mockStream{}, scm, &kClient{}, mockRPCInfo, serviceinfo.StreamingBidirectional, nil, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { + mockRPCInfo := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), rpcinfo.NewRPCConfig(), nil) + s := newStream(ctx, nil, &mockStream{}, scm, &kClient{}, mockRPCInfo, serviceinfo.StreamingBidirectional, nil, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { return nil }, nil, nil, func(stream streaming.Stream, message interface{}) (err error) { @@ -371,7 +392,7 @@ func Test_stream_RecvMsg(t *testing.T) { }) t.Run("no-error-client-streaming", func(t *testing.T) { - ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), nil, rpcinfo.NewRPCStats()) + ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) st := &mockStream{ ctx: rpcinfo.NewCtxWithRPCInfo(context.Background(), ri), } @@ -392,7 +413,7 @@ func Test_stream_RecvMsg(t *testing.T) { // client streaming should release connection after RecvMsg cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(1) scm := remotecli.NewStreamConnManager(cr) - s := newStream(ctx, st, scm, kc, ri, serviceinfo.StreamingClient, nil, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { + s := newStream(ctx, nil, st, scm, kc, ri, serviceinfo.StreamingClient, nil, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { return nil }, nil, nil, func(stream streaming.Stream, message interface{}) (err error) { @@ -407,7 +428,7 @@ func Test_stream_RecvMsg(t *testing.T) { t.Run("error", func(t *testing.T) { recvErr := errors.New("recv error") - ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, rpcinfo.NewRPCStats()) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) st := &mockStream{ ctx: rpcinfo.NewCtxWithRPCInfo(context.Background(), ri), } @@ -428,7 +449,7 @@ func Test_stream_RecvMsg(t *testing.T) { cr := mock_remote.NewMockConnReleaser(ctrl) cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(1) scm := remotecli.NewStreamConnManager(cr) - s := newStream(ctx, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { + s := newStream(ctx, nil, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { return recvErr }, nil, nil, func(stream streaming.Stream, message interface{}) (err error) { @@ -440,16 +461,168 @@ func Test_stream_RecvMsg(t *testing.T) { test.Assert(t, err == recvErr) test.Assert(t, finishCalled) }) + + t.Run("gRPC recv in time", func(t *testing.T) { + cr := mock_remote.NewMockConnReleaser(ctrl) + cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(0) + scm := remotecli.NewStreamConnManager(cr) + cfg := rpcinfo.NewRPCConfig() + tm := 200 * time.Millisecond + rpcinfo.AsMutableRPCConfig(cfg).SetStreamRecvTimeout(tm) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.GRPC) + mockRPCInfo := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), cfg, nil) + st := &mockStream{gRPCStream: &mockGRPCStream{}} + s := newStream(ctx, nil, st, scm, &kClient{}, mockRPCInfo, serviceinfo.StreamingBidirectional, nil, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { + // mock recv time + time.Sleep(100 * time.Millisecond) + return nil + }, nil, nil, + func(stream streaming.Stream, message interface{}) (err error) { + // mock recv time + time.Sleep(100 * time.Millisecond) + return nil + }, + ) + + err := s.RecvMsg(context.Background(), nil) + test.Assert(t, err == nil) + oldS := s.GetGRPCStream() + err = oldS.RecvMsg(nil) + test.Assert(t, err == nil) + }) + + t.Run("gRPC recv timeout", func(t *testing.T) { + finishCalled := false + tracer := &mockTracer{ + finish: func(ctx context.Context) { + finishCalled = true + }, + } + ctl := &rpcinfo.TraceController{} + ctl.Append(tracer) + kc := &kClient{ + opt: &client.Options{ + TracerCtl: ctl, + }, + } + + cr := mock_remote.NewMockConnReleaser(ctrl) + cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(1) + scm := remotecli.NewStreamConnManager(cr) + cfg := rpcinfo.NewRPCConfig() + tm := 200 * time.Millisecond + rpcinfo.AsMutableRPCConfig(cfg).SetStreamRecvTimeout(tm) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.GRPC) + mockRPCInfo := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), cfg, rpcinfo.NewRPCStats()) + + checkGRPCStatus := func(err error) { + gRPCSt, ok := status.FromError(err) + test.Assert(t, ok) + test.Assert(t, gRPCSt != nil) + test.Assert(t, gRPCSt.Code() == codes.DeadlineExceeded, gRPCSt.Code()) + test.Assert(t, strings.Contains(gRPCSt.Message(), "stream Recv timeout, timeout config="), gRPCSt.Message()) + test.Assert(t, gRPCSt.TimeoutType() == streaming_types.StreamRecvTimeout, gRPCSt.TimeoutType()) + } + st := &mockStream{ + cancel: func(err error) { + checkGRPCStatus(err) + }, + gRPCStream: &mockGRPCStream{}, + } + s := newStream(ctx, nil, st, scm, kc, mockRPCInfo, serviceinfo.StreamingBidirectional, nil, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { + // mock recv timeout + time.Sleep(400 * time.Millisecond) + return nil + }, nil, nil, + func(stream streaming.Stream, message interface{}) (err error) { + // mock recv timeout + time.Sleep(400 * time.Millisecond) + return nil + }, + ) + + err := s.RecvMsg(context.Background(), nil) + test.Assert(t, err != nil) + checkGRPCStatus(err) + oldS := s.GetGRPCStream() + err = oldS.RecvMsg(nil) + test.Assert(t, err != nil) + checkGRPCStatus(err) + test.Assert(t, finishCalled) + }) + + t.Run("ttstream recv in time", func(t *testing.T) { + cr := mock_remote.NewMockConnReleaser(ctrl) + cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(0) + scm := remotecli.NewStreamConnManager(cr) + cfg := rpcinfo.NewRPCConfig() + tm := 200 * time.Millisecond + rpcinfo.AsMutableRPCConfig(cfg).SetStreamRecvTimeout(tm) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.TTHeaderStreaming) + mockRPCInfo := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), cfg, nil) + s := newStream(ctx, nil, &mockStream{}, scm, &kClient{}, mockRPCInfo, serviceinfo.StreamingBidirectional, nil, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { + time.Sleep(100 * time.Millisecond) + return nil + }, nil, nil, nil) + err := s.RecvMsg(context.Background(), nil) + test.Assert(t, err == nil) + }) + + t.Run("ttstream recv timeout", func(t *testing.T) { + finishCalled := false + tracer := &mockTracer{ + finish: func(ctx context.Context) { + finishCalled = true + }, + } + ctl := &rpcinfo.TraceController{} + ctl.Append(tracer) + kc := &kClient{ + opt: &client.Options{ + TracerCtl: ctl, + }, + } + cr := mock_remote.NewMockConnReleaser(ctrl) + cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(1) + scm := remotecli.NewStreamConnManager(cr) + cfg := rpcinfo.NewRPCConfig() + tm := 200 * time.Millisecond + rpcinfo.AsMutableRPCConfig(cfg).SetStreamRecvTimeout(tm) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.TTHeaderStreaming) + mockRPCInfo := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), cfg, rpcinfo.NewRPCStats()) + checkTTStreamException := func(err error) { + ttEx, ok := err.(*ttstream.Exception) + test.Assert(t, ok, fmt.Sprintf("expected ttstream.Exception, got %T", err)) + test.Assert(t, ttEx != nil) + test.Assert(t, errors.Is(err, kerrors.ErrStreamingTimeout)) + test.Assert(t, strings.Contains(err.Error(), "stream Recv timeout, timeout config="), err.Error()) + } + st := &mockStream{ + cancel: func(err error) { + checkTTStreamException(err) + }, + } + s := newStream(ctx, nil, st, scm, kc, mockRPCInfo, serviceinfo.StreamingBidirectional, nil, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { + time.Sleep(400 * time.Millisecond) + return nil + }, nil, nil, nil) + err := s.RecvMsg(context.Background(), nil) + test.Assert(t, err != nil) + checkTTStreamException(err) + test.Assert(t, finishCalled) + }) } func Test_stream_SendMsg(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + t.Run("no-error", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() cr := mock_remote.NewMockConnReleaser(ctrl) cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(0) scm := remotecli.NewStreamConnManager(cr) - s := newStream(ctx, &mockStream{}, scm, &kClient{}, nil, serviceinfo.StreamingBidirectional, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { + ri := rpcinfo.NewRPCInfo(nil, nil, nil, rpcinfo.NewRPCConfig(), nil) + s := newStream(ctx, nil, &mockStream{}, scm, &kClient{}, ri, serviceinfo.StreamingBidirectional, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { return nil }, nil, nil, func(stream streaming.Stream, message interface{}) (err error) { @@ -464,13 +637,11 @@ func Test_stream_SendMsg(t *testing.T) { }) t.Run("error", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() cr := mock_remote.NewMockConnReleaser(ctrl) cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(1) scm := remotecli.NewStreamConnManager(cr) sendErr := errors.New("recv error") - ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, rpcinfo.NewRPCStats()) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) st := &mockStream{ ctx: rpcinfo.NewCtxWithRPCInfo(context.Background(), ri), } @@ -488,7 +659,7 @@ func Test_stream_SendMsg(t *testing.T) { }, } - s := newStream(ctx, st, scm, kc, ri, serviceinfo.StreamingBidirectional, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { + s := newStream(ctx, nil, st, scm, kc, ri, serviceinfo.StreamingBidirectional, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { return sendErr }, nil, nil, func(stream streaming.Stream, message interface{}) (err error) { @@ -501,6 +672,150 @@ func Test_stream_SendMsg(t *testing.T) { test.Assert(t, err == sendErr) test.Assert(t, finishCalled) }) + + t.Run("gRPC send in time", func(t *testing.T) { + cr := mock_remote.NewMockConnReleaser(ctrl) + cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(0) + scm := remotecli.NewStreamConnManager(cr) + cfg := rpcinfo.NewRPCConfig() + tm := 200 * time.Millisecond + rpcinfo.AsMutableRPCConfig(cfg).SetStreamSendTimeout(tm) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.GRPC) + mockRPCInfo := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), cfg, nil) + st := &mockStream{gRPCStream: &mockGRPCStream{}} + s := newStream(ctx, nil, st, scm, &kClient{}, mockRPCInfo, serviceinfo.StreamingBidirectional, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { + time.Sleep(100 * time.Millisecond) + return nil + }, nil, nil, + func(stream streaming.Stream, message interface{}) (err error) { + time.Sleep(100 * time.Millisecond) + return nil + }, + nil, + ) + err := s.SendMsg(context.Background(), nil) + test.Assert(t, err == nil) + oldS := s.GetGRPCStream() + err = oldS.SendMsg(nil) + test.Assert(t, err == nil) + }) + + t.Run("gRPC send timeout", func(t *testing.T) { + finishCalled := false + tracer := &mockTracer{ + finish: func(ctx context.Context) { + finishCalled = true + }, + } + ctl := &rpcinfo.TraceController{} + ctl.Append(tracer) + kc := &kClient{ + opt: &client.Options{ + TracerCtl: ctl, + }, + } + cr := mock_remote.NewMockConnReleaser(ctrl) + cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(1) + scm := remotecli.NewStreamConnManager(cr) + cfg := rpcinfo.NewRPCConfig() + tm := 200 * time.Millisecond + rpcinfo.AsMutableRPCConfig(cfg).SetStreamSendTimeout(tm) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.GRPC) + mockRPCInfo := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), cfg, rpcinfo.NewRPCStats()) + checkGRPCStatus := func(err error) { + gRPCSt, ok := status.FromError(err) + test.Assert(t, ok) + test.Assert(t, gRPCSt != nil) + test.Assert(t, gRPCSt.Code() == codes.DeadlineExceeded, gRPCSt.Code()) + test.Assert(t, strings.Contains(gRPCSt.Message(), "stream Send timeout, timeout config="), gRPCSt.Message()) + test.Assert(t, gRPCSt.TimeoutType() == streaming_types.StreamSendTimeout, gRPCSt.TimeoutType()) + } + st := &mockStream{ + cancel: func(err error) { + checkGRPCStatus(err) + }, + gRPCStream: &mockGRPCStream{}, + } + s := newStream(ctx, nil, st, scm, kc, mockRPCInfo, serviceinfo.StreamingBidirectional, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { + time.Sleep(400 * time.Millisecond) + return nil + }, nil, nil, + func(stream streaming.Stream, message interface{}) (err error) { + time.Sleep(400 * time.Millisecond) + return nil + }, + nil, + ) + err := s.SendMsg(context.Background(), nil) + test.Assert(t, err != nil) + checkGRPCStatus(err) + oldS := s.GetGRPCStream() + err = oldS.SendMsg(nil) + test.Assert(t, err != nil) + checkGRPCStatus(err) + test.Assert(t, finishCalled) + }) + + t.Run("ttstream send in time", func(t *testing.T) { + cr := mock_remote.NewMockConnReleaser(ctrl) + cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(0) + scm := remotecli.NewStreamConnManager(cr) + cfg := rpcinfo.NewRPCConfig() + tm := 200 * time.Millisecond + rpcinfo.AsMutableRPCConfig(cfg).SetStreamSendTimeout(tm) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.TTHeaderStreaming) + mockRPCInfo := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), cfg, nil) + s := newStream(ctx, nil, &mockStream{}, scm, &kClient{}, mockRPCInfo, serviceinfo.StreamingBidirectional, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { + time.Sleep(100 * time.Millisecond) + return nil + }, nil, nil, nil, nil) + err := s.SendMsg(context.Background(), nil) + test.Assert(t, err == nil) + }) + + t.Run("ttstream send timeout", func(t *testing.T) { + finishCalled := false + tracer := &mockTracer{ + finish: func(ctx context.Context) { + finishCalled = true + }, + } + ctl := &rpcinfo.TraceController{} + ctl.Append(tracer) + kc := &kClient{ + opt: &client.Options{ + TracerCtl: ctl, + }, + } + cr := mock_remote.NewMockConnReleaser(ctrl) + cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(1) + scm := remotecli.NewStreamConnManager(cr) + cfg := rpcinfo.NewRPCConfig() + tm := 200 * time.Millisecond + rpcinfo.AsMutableRPCConfig(cfg).SetStreamSendTimeout(tm) + rpcinfo.AsMutableRPCConfig(cfg).SetTransportProtocol(transport.TTHeaderStreaming) + mockRPCInfo := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), cfg, rpcinfo.NewRPCStats()) + checkTTStreamException := func(err error) { + ttEx, ok := err.(*ttstream.Exception) + test.Assert(t, ok, fmt.Sprintf("expected ttstream.Exception, got %T", err)) + test.Assert(t, ttEx != nil) + test.Assert(t, errors.Is(err, kerrors.ErrStreamingTimeout)) + test.Assert(t, strings.Contains(err.Error(), "stream Send timeout, timeout config="), err.Error()) + } + st := &mockStream{ + cancel: func(err error) { + checkTTStreamException(err) + }, + } + s := newStream(ctx, nil, st, scm, kc, mockRPCInfo, serviceinfo.StreamingBidirectional, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { + time.Sleep(400 * time.Millisecond) + return nil + }, nil, nil, nil, nil) + err := s.SendMsg(context.Background(), nil) + test.Assert(t, err != nil) + checkTTStreamException(err) + test.Assert(t, finishCalled) + }) } func Test_stream_Close(t *testing.T) { @@ -510,12 +825,13 @@ func Test_stream_Close(t *testing.T) { cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(0) scm := remotecli.NewStreamConnManager(cr) called := false - s := newStream(ctx, &mockStream{ + ri := rpcinfo.NewRPCInfo(nil, nil, nil, rpcinfo.NewRPCConfig(), nil) + s := newStream(ctx, nil, &mockStream{ close: func() error { called = true return nil }, - }, scm, &kClient{}, nil, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil) + }, scm, &kClient{}, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil) err := s.CloseSend(context.Background()) @@ -528,7 +844,7 @@ func Test_stream_DoFinish(t *testing.T) { defer ctrl.Finish() t.Run("no-error", func(t *testing.T) { - ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, rpcinfo.NewRPCStats()) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) st := &mockStream{ ctx: rpcinfo.NewCtxWithRPCInfo(context.Background(), ri), } @@ -543,7 +859,7 @@ func Test_stream_DoFinish(t *testing.T) { cr := mock_remote.NewMockConnReleaser(ctrl) cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(1) scm := remotecli.NewStreamConnManager(cr) - s := newStream(ctx, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil) + s := newStream(ctx, nil, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil) finishCalled := false err := errors.New("any err") @@ -558,7 +874,7 @@ func Test_stream_DoFinish(t *testing.T) { }) t.Run("EOF", func(t *testing.T) { - ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, rpcinfo.NewRPCStats()) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) st := &mockStream{ ctx: rpcinfo.NewCtxWithRPCInfo(context.Background(), ri), } @@ -573,7 +889,7 @@ func Test_stream_DoFinish(t *testing.T) { cr := mock_remote.NewMockConnReleaser(ctrl) cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(1) scm := remotecli.NewStreamConnManager(cr) - s := newStream(ctx, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil) + s := newStream(ctx, nil, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil) finishCalled := false err := errors.New("any err") @@ -588,7 +904,7 @@ func Test_stream_DoFinish(t *testing.T) { }) t.Run("biz-status-error", func(t *testing.T) { - ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, rpcinfo.NewRPCStats()) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) st := &mockStream{ ctx: rpcinfo.NewCtxWithRPCInfo(context.Background(), ri), } @@ -603,7 +919,7 @@ func Test_stream_DoFinish(t *testing.T) { cr := mock_remote.NewMockConnReleaser(ctrl) cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(1) scm := remotecli.NewStreamConnManager(cr) - s := newStream(ctx, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil) + s := newStream(ctx, nil, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil) finishCalled := false var err error @@ -618,7 +934,7 @@ func Test_stream_DoFinish(t *testing.T) { }) t.Run("error", func(t *testing.T) { - ri := rpcinfo.NewRPCInfo(nil, nil, nil, nil, rpcinfo.NewRPCStats()) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) st := &mockStream{ ctx: rpcinfo.NewCtxWithRPCInfo(context.Background(), ri), } @@ -633,7 +949,7 @@ func Test_stream_DoFinish(t *testing.T) { cr := mock_remote.NewMockConnReleaser(ctrl) cr.EXPECT().ReleaseConn(gomock.Any(), gomock.Any()).Times(1) scm := remotecli.NewStreamConnManager(cr) - s := newStream(st.ctx, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil) + s := newStream(st.ctx, nil, st, scm, kc, ri, serviceinfo.StreamingBidirectional, nil, nil, nil, nil, nil) finishCalled := false expectedErr := errors.New("error") @@ -665,7 +981,7 @@ func Test_isRPCError(t *testing.T) { } func TestContextFallback(t *testing.T) { - mockRPCInfo := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), nil, nil) + mockRPCInfo := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("mock_service", "mock_method"), rpcinfo.NewRPCConfig(), nil) mockSt := &mockStream{ recv: func(ctx context.Context, message interface{}) error { test.Assert(t, ctx == context.Background()) @@ -676,7 +992,7 @@ func TestContextFallback(t *testing.T) { return nil }, } - st := newStream(context.Background(), mockSt, nil, nil, mockRPCInfo, serviceinfo.StreamingBidirectional, sendEndpoint, recvEndpoint, nil, nil, nil) + st := newStream(context.Background(), nil, mockSt, nil, nil, mockRPCInfo, serviceinfo.StreamingBidirectional, sendEndpoint, recvEndpoint, nil, nil, nil) err := st.RecvMsg(context.Background(), nil) test.Assert(t, err == nil) err = st.SendMsg(context.Background(), nil) @@ -694,7 +1010,7 @@ func TestContextFallback(t *testing.T) { return nil }, } - st = newStream(context.Background(), mockSt, nil, nil, mockRPCInfo, serviceinfo.StreamingBidirectional, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { + st = newStream(context.Background(), nil, mockSt, nil, nil, mockRPCInfo, serviceinfo.StreamingBidirectional, func(ctx context.Context, stream streaming.ClientStream, message interface{}) (err error) { ri := rpcinfo.GetRPCInfo(ctx) test.Assert(t, ri == mockRPCInfo) return sendEndpoint(ctx, stream, message) diff --git a/go.mod b/go.mod index 4ac124de49..706e86c57a 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,8 @@ module github.com/cloudwego/kitex go 1.20 +replace github.com/cloudwego/gopkg => github.com/DMwangnima/gopkg v0.1.2-0.20251126081112-5f381e4b62fa + require ( github.com/bytedance/gopkg v0.1.3 github.com/bytedance/sonic v1.14.1 diff --git a/go.sum b/go.sum index 927af6f42e..d7f1f11d78 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/DMwangnima/gopkg v0.1.2-0.20251126081112-5f381e4b62fa h1:knwNVNo7UpqSiA+MY/ddg9F5okJxNTwbObFZwxXsvhA= +github.com/DMwangnima/gopkg v0.1.2-0.20251126081112-5f381e4b62fa/go.mod h1:FQuXsRWRsSqJLsMVd5SYzp8/Z1y5gXKnVvRrWUOsCMI= github.com/bytedance/gopkg v0.1.1/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= @@ -19,9 +21,6 @@ github.com/cloudwego/fastpb v0.0.5 h1:vYnBPsfbAtU5TVz5+f9UTlmSCixG9F9vRwaqE0mZPZ github.com/cloudwego/fastpb v0.0.5/go.mod h1:Bho7aAKBUtT9RPD2cNVkTdx4yQumfSv3If7wYnm1izk= github.com/cloudwego/frugal v0.3.0 h1:tgAP0nytiJuyoIM3V3TDOGzjrSNRAIlNG1HHOAzZ3Cs= github.com/cloudwego/frugal v0.3.0/go.mod h1:pMk46fFyAwUbW7q7lfdK7c6HsD6bWtu6/3Vhz63CgsY= -github.com/cloudwego/gopkg v0.1.4/go.mod h1:FQuXsRWRsSqJLsMVd5SYzp8/Z1y5gXKnVvRrWUOsCMI= -github.com/cloudwego/gopkg v0.1.6 h1:EMlOHg975CxKX1/BtIVYKGW8hxNptTkjjJ7bvfXu4L4= -github.com/cloudwego/gopkg v0.1.6/go.mod h1:FQuXsRWRsSqJLsMVd5SYzp8/Z1y5gXKnVvRrWUOsCMI= github.com/cloudwego/localsession v0.1.2 h1:RBmeLDO5sKr4ujd8iBp5LTMmuVKLdu88jjIneq/fEZ8= github.com/cloudwego/localsession v0.1.2/go.mod h1:J4uams2YT/2d4t7OI6A7NF7EcG8OlHJsOX2LdPbqoyc= github.com/cloudwego/netpoll v0.7.2 h1:4qDBGQ6CG2SvEXhZSDxMdtqt/NLDxjAVk0PC/biKiJo= diff --git a/internal/client/option.go b/internal/client/option.go index 1b8c766087..bddb40bcef 100644 --- a/internal/client/option.go +++ b/internal/client/option.go @@ -102,6 +102,8 @@ type StreamOption struct { type StreamOptions struct { EventHandler stream.StreamEventHandler RecvTimeout time.Duration + SendTimeout time.Duration + StreamTimeout time.Duration StreamMiddlewares []cep.StreamMiddleware StreamMiddlewareBuilders []cep.StreamMiddlewareBuilder StreamRecvMiddlewares []cep.StreamRecvMiddleware diff --git a/internal/utils/contextwatcher/contextwatcher.go b/internal/utils/contextwatcher/contextwatcher.go new file mode 100644 index 0000000000..0d3e55d278 --- /dev/null +++ b/internal/utils/contextwatcher/contextwatcher.go @@ -0,0 +1,66 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package contextwatcher + +import ( + "context" + "sync" + + "github.com/bytedance/gopkg/util/gopool" +) + +type ContextWatcher struct { + ctxs sync.Map +} + +var global *ContextWatcher + +func init() { + global = &ContextWatcher{} +} + +func RegisterContext(ctx context.Context, callback func(context.Context)) { + if ctx.Done() == nil { + return + } + + finCh := make(chan struct{}) + _, loaded := global.ctxs.LoadOrStore(ctx, finCh) + if loaded { + return + } + + gopool.Go(func() { + select { + case <-ctx.Done(): + callback(ctx) + // Clean up the map entry after callback execution + global.ctxs.Delete(ctx) + + case <-finCh: + return + } + }) +} + +func DeregisterContext(ctx context.Context) { + rawVal, loaded := global.ctxs.LoadAndDelete(ctx) + if loaded { + // make goroutine exited + close(rawVal.(chan struct{})) + } +} diff --git a/internal/utils/contextwatcher/contextwatcher_test.go b/internal/utils/contextwatcher/contextwatcher_test.go new file mode 100644 index 0000000000..19b0737175 --- /dev/null +++ b/internal/utils/contextwatcher/contextwatcher_test.go @@ -0,0 +1,467 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package contextwatcher + +import ( + "context" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" +) + +func init() { + // Initialize global variable to fix the nil pointer issue + global = &ContextWatcher{} +} + +// TestRegisterContext_BasicCallback tests that callback is invoked when context is cancelled +func TestRegisterContext_BasicCallback(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + called := atomic.Bool{} + var mu sync.Mutex + var callbackCtx context.Context + + RegisterContext(ctx, func(c context.Context) { + called.Store(true) + mu.Lock() + callbackCtx = c + mu.Unlock() + }) + + // Cancel the context + cancel() + + // Wait for callback to be invoked + time.Sleep(50 * time.Millisecond) + + if !called.Load() { + t.Error("callback was not called after context cancellation") + } + + mu.Lock() + receivedCtx := callbackCtx + mu.Unlock() + + if receivedCtx != ctx { + t.Error("callback received wrong context") + } +} + +// TestRegisterContext_WithTimeout tests callback with timeout context +func TestRegisterContext_WithTimeout(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + called := atomic.Bool{} + + RegisterContext(ctx, func(c context.Context) { + called.Store(true) + }) + + // Wait for timeout + time.Sleep(100 * time.Millisecond) + + if !called.Load() { + t.Error("callback was not called after context timeout") + } +} + +// TestRegisterContext_NoDeadline tests that no goroutine is created for contexts without deadline +func TestRegisterContext_NoDeadline(t *testing.T) { + ctx := context.Background() + + called := atomic.Bool{} + + beforeGoroutines := runtime.NumGoroutine() + + RegisterContext(ctx, func(c context.Context) { + called.Store(true) + }) + + time.Sleep(50 * time.Millisecond) + + afterGoroutines := runtime.NumGoroutine() + + if called.Load() { + t.Error("callback should not be called for context without deadline") + } + + // No new goroutine should be created + if afterGoroutines > beforeGoroutines { + t.Errorf("goroutine leak detected: before=%d, after=%d", beforeGoroutines, afterGoroutines) + } +} + +// TestRegisterContext_DuplicateRegistration tests that duplicate registrations are ignored +func TestRegisterContext_DuplicateRegistration(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + callCount := atomic.Int32{} + + callback := func(c context.Context) { + callCount.Add(1) + } + + // Register the same context twice + RegisterContext(ctx, callback) + RegisterContext(ctx, callback) + + cancel() + + time.Sleep(50 * time.Millisecond) + + // Should only be called once + if count := callCount.Load(); count != 1 { + t.Errorf("expected callback to be called once, got %d times", count) + } +} + +// TestDeregisterContext_PreventCallback tests that deregistration prevents callback +func TestDeregisterContext_PreventCallback(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + called := atomic.Bool{} + + RegisterContext(ctx, func(c context.Context) { + called.Store(true) + }) + + // Deregister before cancellation + DeregisterContext(ctx) + + // Give some time for deregister to take effect + time.Sleep(10 * time.Millisecond) + + cancel() + + // Wait to see if callback is called + time.Sleep(50 * time.Millisecond) + + if called.Load() { + t.Error("callback should not be called after deregistration") + } +} + +// TestDeregisterContext_NonExistent tests deregistering a non-existent context +func TestDeregisterContext_NonExistent(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Should not panic + DeregisterContext(ctx) +} + +// TestDeregisterContext_AfterCallback tests deregistering after callback is called +func TestDeregisterContext_AfterCallback(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + called := atomic.Bool{} + + RegisterContext(ctx, func(c context.Context) { + called.Store(true) + }) + + cancel() + + time.Sleep(50 * time.Millisecond) + + if !called.Load() { + t.Fatal("callback should have been called") + } + + // Deregister after callback - should not panic + DeregisterContext(ctx) +} + +// TestConcurrentRegistration tests concurrent registrations of different contexts +func TestConcurrentRegistration(t *testing.T) { + const numContexts = 100 + + var wg sync.WaitGroup + callCount := atomic.Int32{} + + for i := 0; i < numContexts; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + RegisterContext(ctx, func(c context.Context) { + callCount.Add(1) + }) + + cancel() + }() + } + + wg.Wait() + time.Sleep(100 * time.Millisecond) + + if count := callCount.Load(); count != numContexts { + t.Errorf("expected %d callbacks, got %d", numContexts, count) + } +} + +// TestConcurrentRegisterDeregister tests concurrent register and deregister operations +func TestConcurrentRegisterDeregister(t *testing.T) { + const numOps = 100 + + var wg sync.WaitGroup + + for i := 0; i < numOps; i++ { + wg.Add(2) + + ctx, cancel := context.WithCancel(context.Background()) + + // Register + go func() { + defer wg.Done() + RegisterContext(ctx, func(c context.Context) { + // callback + }) + }() + + // Deregister + go func() { + defer wg.Done() + time.Sleep(time.Millisecond) + DeregisterContext(ctx) + }() + + cancel() + } + + wg.Wait() + time.Sleep(100 * time.Millisecond) +} + +// TestConcurrentSameContext tests concurrent operations on the same context +func TestConcurrentSameContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + callCount := atomic.Int32{} + + var wg sync.WaitGroup + + // Multiple goroutines trying to register the same context + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + RegisterContext(ctx, func(c context.Context) { + callCount.Add(1) + }) + }() + } + + wg.Wait() + + cancel() + time.Sleep(50 * time.Millisecond) + + // Should only be called once despite multiple registration attempts + if count := callCount.Load(); count != 1 { + t.Errorf("expected callback to be called once, got %d times", count) + } +} + +// TestGoroutineCleanup tests that goroutines are properly cleaned up +func TestGoroutineCleanup(t *testing.T) { + // Allow some buffer for goroutine count fluctuation + runtime.GC() + time.Sleep(100 * time.Millisecond) + + beforeGoroutines := runtime.NumGoroutine() + + const numContexts = 50 + + // Register and cancel contexts + for i := 0; i < numContexts; i++ { + ctx, cancel := context.WithCancel(context.Background()) + RegisterContext(ctx, func(c context.Context) {}) + cancel() + } + + // Wait for all callbacks to be executed + time.Sleep(200 * time.Millisecond) + + runtime.GC() + time.Sleep(100 * time.Millisecond) + + afterGoroutines := runtime.NumGoroutine() + + // Allow some variance but should not have significant goroutine leak + // Note: This test may need adjustment based on the actual implementation + // If callback doesn't clean up the map entry, there might be some goroutines left + if afterGoroutines > beforeGoroutines+numContexts { + t.Logf("Warning: Potential goroutine leak - before: %d, after: %d", + beforeGoroutines, afterGoroutines) + } +} + +// TestGoroutineCleanupWithDeregister tests that deregister properly cleans up goroutines +func TestGoroutineCleanupWithDeregister(t *testing.T) { + runtime.GC() + time.Sleep(100 * time.Millisecond) + + beforeGoroutines := runtime.NumGoroutine() + + const numContexts = 50 + contexts := make([]context.Context, numContexts) + cancels := make([]context.CancelFunc, numContexts) + + // Register contexts + for i := 0; i < numContexts; i++ { + ctx, cancel := context.WithCancel(context.Background()) + contexts[i] = ctx + cancels[i] = cancel + RegisterContext(ctx, func(c context.Context) {}) + } + + // Deregister all contexts + for i := 0; i < numContexts; i++ { + DeregisterContext(contexts[i]) + } + + // Clean up + for i := 0; i < numContexts; i++ { + cancels[i]() + } + + time.Sleep(200 * time.Millisecond) + runtime.GC() + time.Sleep(100 * time.Millisecond) + + afterGoroutines := runtime.NumGoroutine() + + // With proper deregister, goroutines should be cleaned up + if afterGoroutines > beforeGoroutines+5 { + t.Errorf("Goroutine leak detected after deregister - before: %d, after: %d", + beforeGoroutines, afterGoroutines) + } +} + +// TestCallbackPanic tests that a panic in callback doesn't affect other operations +func TestCallbackPanic(t *testing.T) { + ctx1, cancel1 := context.WithCancel(context.Background()) + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + called2 := atomic.Bool{} + + // First context with panicking callback + RegisterContext(ctx1, func(c context.Context) { + panic("test panic") + }) + + // Second context with normal callback + RegisterContext(ctx2, func(c context.Context) { + called2.Store(true) + }) + + cancel1() + time.Sleep(50 * time.Millisecond) + + cancel2() + time.Sleep(50 * time.Millisecond) + + // Second callback should still be called despite first one panicking + if !called2.Load() { + t.Error("second callback was not called after first callback panicked") + } +} + +// TestMultipleDeregister tests that multiple deregistrations don't cause issues +func TestMultipleDeregister(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + RegisterContext(ctx, func(c context.Context) {}) + + // Multiple deregistrations should not panic + DeregisterContext(ctx) + DeregisterContext(ctx) + DeregisterContext(ctx) +} + +// TestContextAlreadyCancelled tests registering an already cancelled context +func TestContextAlreadyCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + // Context is already cancelled + time.Sleep(10 * time.Millisecond) + + called := atomic.Bool{} + + RegisterContext(ctx, func(c context.Context) { + called.Store(true) + }) + + time.Sleep(50 * time.Millisecond) + + // Callback should still be called for already-cancelled context + if !called.Load() { + t.Error("callback was not called for already-cancelled context") + } +} + +// BenchmarkRegisterContext benchmarks the RegisterContext operation +func BenchmarkRegisterContext(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + ctxB, cancelB := context.WithCancel(ctx) + RegisterContext(ctxB, func(c context.Context) {}) + cancelB() + } +} + +// BenchmarkRegisterDeregister benchmarks register followed by deregister +func BenchmarkRegisterDeregister(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + ctxB, cancelB := context.WithCancel(ctx) + RegisterContext(ctxB, func(c context.Context) {}) + DeregisterContext(ctxB) + cancelB() + } +} + +// BenchmarkConcurrentRegister benchmarks concurrent registrations +func BenchmarkConcurrentRegister(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + ctx, cancel := context.WithCancel(context.Background()) + RegisterContext(ctx, func(c context.Context) {}) + cancel() + } + }) +} diff --git a/pkg/kerrors/kerrors_test.go b/pkg/kerrors/kerrors_test.go index 588dfe2c28..526effe8fd 100644 --- a/pkg/kerrors/kerrors_test.go +++ b/pkg/kerrors/kerrors_test.go @@ -50,6 +50,7 @@ func TestIsKitexError(t *testing.T) { // streaming errors ErrStreamingProtocol, ErrStreamingCanceled, + ErrStreamingTimeout, } for _, e := range errs { test.Assert(t, IsKitexError(e)) diff --git a/pkg/kerrors/streaming_errors.go b/pkg/kerrors/streaming_errors.go index 5f997e7983..9e0e331b3c 100644 --- a/pkg/kerrors/streaming_errors.go +++ b/pkg/kerrors/streaming_errors.go @@ -22,3 +22,7 @@ var ErrStreamingProtocol = &basicError{"streaming protocol error"} // ErrStreamingCanceled is the parent type of all streaming canceled errors var ErrStreamingCanceled = &basicError{"streaming canceled"} + +// ErrStreamingTimeout is the parent type of all streaming timeout errors +// including Stream timeout, Recv timeout and Send timeout +var ErrStreamingTimeout = &basicError{"streaming timeout"} diff --git a/pkg/remote/trans/nphttp2/client_conn.go b/pkg/remote/trans/nphttp2/client_conn.go index e855d30bfa..8362d51855 100644 --- a/pkg/remote/trans/nphttp2/client_conn.go +++ b/pkg/remote/trans/nphttp2/client_conn.go @@ -31,6 +31,7 @@ import ( "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" ) @@ -174,6 +175,18 @@ func (c *clientConn) Header() (metadata.MD, error) { return c.s.Header() } func (c *clientConn) Trailer() metadata.MD { return c.s.Trailer() } func (c *clientConn) GetRecvCompress() string { return c.s.RecvCompress() } +func (c *clientConn) Cancel(err error) { + var finalErr error + if err == nil { + finalErr = status.Err(codes.Canceled, context.Canceled.Error()) + } else if _, ok := err.(*status.Error); ok { + finalErr = err + } else { + finalErr = status.Errorf(codes.Canceled, err.Error()) + } + c.tr.CloseStream(c.s, finalErr) +} + type hasGetRecvCompress interface { GetRecvCompress() string } diff --git a/pkg/remote/trans/nphttp2/status/status.go b/pkg/remote/trans/nphttp2/status/status.go index 130d4425cd..812635a2f2 100644 --- a/pkg/remote/trans/nphttp2/status/status.go +++ b/pkg/remote/trans/nphttp2/status/status.go @@ -39,6 +39,7 @@ import ( "google.golang.org/protobuf/types/known/anypb" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" + streaming_types "github.com/cloudwego/kitex/pkg/streaming/types" ) type Iface interface { @@ -48,7 +49,8 @@ type Iface interface { // Status represents an RPC status code, message, and details. It is immutable // and should be created with New, Newf, or FromProto. type Status struct { - s *spb.Status + s *spb.Status + timeoutType streaming_types.TimeoutType } // New returns a Status representing c and msg. @@ -61,6 +63,17 @@ func Newf(c codes.Code, format string, a ...interface{}) *Status { return New(c, fmt.Sprintf(format, a...)) } +// NewTimeoutStatus returns ad Status with specific TimeoutType. +func NewTimeoutStatus(c codes.Code, msg string, timeoutType streaming_types.TimeoutType) *Status { + return &Status{ + s: &spb.Status{ + Code: int32(c), + Message: msg, + }, + timeoutType: timeoutType, + } +} + // ErrorProto returns an error representing s. If s.Code is OK, returns nil. func ErrorProto(s *spb.Status) error { return FromProto(s).Err() @@ -97,6 +110,14 @@ func (s *Status) Message() string { return s.s.Message } +// TimeoutType returns the specific TimeoutType related to Status. +func (s *Status) TimeoutType() streaming_types.TimeoutType { + if s == nil || s.s == nil { + return 0 + } + return s.timeoutType +} + // AppendMessage append extra msg for Status func (s *Status) AppendMessage(extraMsg string) *Status { if s == nil || s.s == nil || extraMsg == "" { @@ -119,7 +140,10 @@ func (s *Status) Err() error { if s.Code() == codes.OK { return nil } - return &Error{e: s.Proto()} + return &Error{ + e: s.Proto(), + timeoutType: s.timeoutType, + } } // WithDetails returns a new status with the provided details messages appended to the status. @@ -161,7 +185,8 @@ func (s *Status) Details() []interface{} { // Error wraps a pointer of a status proto. It implements error and Status, // and a nil *Error should never be returned by this package. type Error struct { - e *spb.Status + e *spb.Status + timeoutType streaming_types.TimeoutType } func (e *Error) Error() string { @@ -170,7 +195,9 @@ func (e *Error) Error() string { // GRPCStatus returns the Status represented by se. func (e *Error) GRPCStatus() *Status { - return FromProto(e.e) + st := FromProto(e.e) + st.timeoutType = e.timeoutType + return st } // Is implements future error.Is functionality. diff --git a/pkg/remote/trans/nphttp2/status/status_test.go b/pkg/remote/trans/nphttp2/status/status_test.go index 08cd55d82f..5fe836665a 100644 --- a/pkg/remote/trans/nphttp2/status/status_test.go +++ b/pkg/remote/trans/nphttp2/status/status_test.go @@ -25,6 +25,7 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" + streaming_types "github.com/cloudwego/kitex/pkg/streaming/types" ) func TestStatus(t *testing.T) { @@ -70,7 +71,7 @@ func TestError(t *testing.T) { s.Code = 1 s.Message = "test err" - er := &Error{s} + er := &Error{e: s} test.Assert(t, len(er.Error()) > 0) status := er.GRPCStatus() @@ -101,7 +102,7 @@ func TestFromContextError(t *testing.T) { s := new(spb.Status) s.Code = 1 s.Message = "test err" - grpcErr := &Error{s} + grpcErr := &Error{e: s} // grpc err codeGrpcErr := Code(grpcErr) test.Assert(t, codeGrpcErr == codes.Canceled) @@ -114,3 +115,58 @@ func TestFromContextError(t *testing.T) { codeNil := Code(nil) test.Assert(t, codeNil == codes.OK) } + +func TestTimeoutStatus(t *testing.T) { + t.Run("stream timeout status", func(t *testing.T) { + statusMsg := "stream timeout exceeded" + status := NewTimeoutStatus(codes.DeadlineExceeded, statusMsg, streaming_types.StreamTimeout) + test.Assert(t, status != nil) + test.Assert(t, status.Code() == codes.DeadlineExceeded) + test.Assert(t, status.Message() == statusMsg) + test.Assert(t, status.TimeoutType() == streaming_types.StreamTimeout) + + err := status.Err() + test.Assert(t, err != nil) + statusErr, ok := err.(*Error) + test.Assert(t, ok) + test.Assert(t, statusErr.timeoutType == streaming_types.StreamTimeout) + + recoveredStatus := statusErr.GRPCStatus() + test.Assert(t, recoveredStatus.Code() == codes.DeadlineExceeded) + test.Assert(t, recoveredStatus.Message() == statusMsg) + }) + t.Run("stream recv timeout status", func(t *testing.T) { + statusMsg := "stream recv timeout exceeded" + status := NewTimeoutStatus(codes.DeadlineExceeded, statusMsg, streaming_types.StreamRecvTimeout) + test.Assert(t, status != nil) + test.Assert(t, status.Code() == codes.DeadlineExceeded) + test.Assert(t, status.Message() == statusMsg) + test.Assert(t, status.TimeoutType() == streaming_types.StreamRecvTimeout) + + err := status.Err() + test.Assert(t, err != nil) + statusErr, ok := err.(*Error) + test.Assert(t, ok) + test.Assert(t, statusErr.timeoutType == streaming_types.StreamRecvTimeout) + }) + t.Run("stream send timeout status", func(t *testing.T) { + statusMsg := "stream send timeout exceeded" + status := NewTimeoutStatus(codes.DeadlineExceeded, statusMsg, streaming_types.StreamSendTimeout) + test.Assert(t, status != nil) + test.Assert(t, status.Code() == codes.DeadlineExceeded) + test.Assert(t, status.Message() == statusMsg) + test.Assert(t, status.TimeoutType() == streaming_types.StreamSendTimeout) + + err := status.Err() + test.Assert(t, err != nil) + statusErr, ok := err.(*Error) + test.Assert(t, ok) + test.Assert(t, statusErr.timeoutType == streaming_types.StreamSendTimeout) + }) + t.Run("nil status timeout type", func(t *testing.T) { + var nilStatus *Status + test.Assert(t, nilStatus.TimeoutType() == 0) + test.Assert(t, nilStatus.Code() == codes.OK) + test.Assert(t, nilStatus.Message() == "") + }) +} diff --git a/pkg/remote/trans/nphttp2/stream.go b/pkg/remote/trans/nphttp2/stream.go index 42ff8258f0..cb6c6fbf29 100644 --- a/pkg/remote/trans/nphttp2/stream.go +++ b/pkg/remote/trans/nphttp2/stream.go @@ -258,6 +258,10 @@ func (s *clientStream) Context() context.Context { return s.ctx } +func (s *clientStream) Cancel(err error) { + s.conn.Cancel(err) +} + func streamingHeaderToHTTP2MD(header streaming.Header) metadata.MD { md := metadata.MD{} for k, v := range header { diff --git a/pkg/remote/trans/ttstream/client_handler.go b/pkg/remote/trans/ttstream/client_handler.go index ec73d3ccab..fd7eff1673 100644 --- a/pkg/remote/trans/ttstream/client_handler.go +++ b/pkg/remote/trans/ttstream/client_handler.go @@ -18,10 +18,13 @@ package ttstream import ( "context" + "strconv" + "time" "github.com/bytedance/gopkg/cloud/metainfo" "github.com/cloudwego/gopkg/protocol/ttheader" + "github.com/cloudwego/kitex/internal/utils/contextwatcher" "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -48,7 +51,6 @@ type clientTransHandler struct { // NewStream creates a client stream func (c clientTransHandler) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) (streaming.ClientStream, error) { - rconfig := ri.Config() invocation := ri.Invocation() method := invocation.MethodName() addr := ri.To().Address() @@ -68,6 +70,8 @@ func (c clientTransHandler) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) ( if intHeader == nil { intHeader = IntHeader{} } + tm := injectStreamTimeout(ctx, intHeader) + if strHeader == nil { strHeader = map[string]string{} } @@ -82,8 +86,10 @@ func (c clientTransHandler) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) ( // create new stream cs := newClientStream(ctx, trans, streamFrame{sid: genStreamID(), method: method}) // stream should be configured before WriteStream or there would be a race condition for metaFrameHandler - cs.setRecvTimeout(rconfig.StreamRecvTimeout()) cs.setMetaFrameHandler(c.metaHandler) + cs.setStreamTimeout(tm) + + contextwatcher.RegisterContext(ctx, cs.ctxDoneCallback) if err = trans.WriteStream(ctx, cs, intHeader, strHeader); err != nil { return nil, err @@ -91,3 +97,13 @@ func (c clientTransHandler) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) ( return cs, err } + +func injectStreamTimeout(ctx context.Context, hd IntHeader) time.Duration { + ddl, ok := ctx.Deadline() + if !ok { + return 0 + } + tm := time.Until(ddl) + hd[ttheader.StreamTimeout] = strconv.Itoa(int(tm.Milliseconds())) + return tm +} diff --git a/pkg/remote/trans/ttstream/client_handler_test.go b/pkg/remote/trans/ttstream/client_handler_test.go new file mode 100644 index 0000000000..e59b8e0d86 --- /dev/null +++ b/pkg/remote/trans/ttstream/client_handler_test.go @@ -0,0 +1,49 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/cloudwego/gopkg/protocol/ttheader" + + "github.com/cloudwego/kitex/internal/test" +) + +func Test_injectStreamTimeout(t *testing.T) { + // ctx without deadline + ctx := context.Background() + hd := make(IntHeader) + tm := injectStreamTimeout(ctx, hd) + test.Assert(t, tm == 0, tm) + test.Assert(t, hd[ttheader.StreamTimeout] == "", hd) + + // ctx with timeout + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + hd = make(IntHeader) + tm = injectStreamTimeout(ctx, hd) + test.Assert(t, tm <= 1*time.Second, tm) + tmStr, ok := hd[ttheader.StreamTimeout] + test.Assert(t, ok) + tmMs, err := strconv.Atoi(tmStr) + test.Assert(t, err == nil, err) + test.Assert(t, time.Duration(tmMs)*time.Millisecond < 1*time.Second, tmMs) +} diff --git a/pkg/remote/trans/ttstream/exception.go b/pkg/remote/trans/ttstream/exception.go index d942d019de..1f1f8a4212 100644 --- a/pkg/remote/trans/ttstream/exception.go +++ b/pkg/remote/trans/ttstream/exception.go @@ -40,6 +40,10 @@ var ( errInternalCancel = newException("internal canceled", kerrors.ErrStreamingCanceled, 12011) errBizHandlerReturnCancel = newException("canceled by business handler returning", kerrors.ErrStreamingCanceled, 12012) errConnectionClosedCancel = newException("canceled by connection closed", kerrors.ErrStreamingCanceled, 12013) + + errStreamTimeout = newException("stream timeout", kerrors.ErrStreamingTimeout, 12014) + errStreamRecvTimeout = newException("stream Recv timeout", kerrors.ErrStreamingTimeout, 12015) + errStreamSendTimeout = newException("stream Send timeout", kerrors.ErrStreamingTimeout, 12016) ) const ( @@ -164,7 +168,7 @@ func (e *Exception) TypeId() int32 { // appendCancelPath is a common util func to process cancelPath metadata in Rst Frame and Exception func appendCancelPath(oriCp, node string) string { if len(oriCp) > 0 { - return strings.Join([]string{oriCp, node}, ",") + return oriCp + "," + node } return node } diff --git a/pkg/remote/trans/ttstream/exception_builder.go b/pkg/remote/trans/ttstream/exception_builder.go new file mode 100644 index 0000000000..81f5718ede --- /dev/null +++ b/pkg/remote/trans/ttstream/exception_builder.go @@ -0,0 +1,58 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "fmt" + "time" + + "github.com/cloudwego/kitex/pkg/remote" + streaming_types "github.com/cloudwego/kitex/pkg/streaming/types" +) + +// NewTimeoutException builds ttstream timeout Exception based on +// errStreamTimeout, errStreamRecvTimeout or errStreamSendTimeout. +// If passing non-defined streaming_types.TimeoutType or remote.RPCRole, nil result would be returned. +func NewTimeoutException(tmType streaming_types.TimeoutType, role remote.RPCRole, tm time.Duration) *Exception { + var baseEx *Exception + switch tmType { + case streaming_types.StreamRecvTimeout: + baseEx = errStreamRecvTimeout + case streaming_types.StreamSendTimeout: + baseEx = errStreamSendTimeout + case streaming_types.StreamTimeout: + baseEx = errStreamTimeout + default: + return nil + } + + var side sideType + switch role { + case remote.Client: + side = clientSide + case remote.Server: + side = serverSide + default: + return nil + } + + return newTimeoutException(baseEx, side, tm) +} + +func newTimeoutException(baseEx *Exception, side sideType, tm time.Duration) *Exception { + return baseEx.newBuilder().withSide(side).withCause(fmt.Errorf("%s, timeout config=%v", baseEx.message, tm)) +} diff --git a/pkg/remote/trans/ttstream/exception_test.go b/pkg/remote/trans/ttstream/exception_test.go index f03a600a57..ca40fa7fb1 100644 --- a/pkg/remote/trans/ttstream/exception_test.go +++ b/pkg/remote/trans/ttstream/exception_test.go @@ -21,11 +21,14 @@ import ( "fmt" "strings" "testing" + "time" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/remote" + streaming_types "github.com/cloudwego/kitex/pkg/streaming/types" ) func TestErrors(t *testing.T) { @@ -69,6 +72,9 @@ func TestCommonParentKerror(t *testing.T) { for _, err := range errs { test.Assert(t, errors.Is(err, kerrors.ErrStreamingCanceled), err) } + + // timeout Exception + test.Assert(t, errors.Is(errStreamTimeout, kerrors.ErrStreamingTimeout)) } func TestGetTypeId(t *testing.T) { @@ -158,6 +164,62 @@ func TestCanceledException(t *testing.T) { }) } +func TestTimeoutException(t *testing.T) { + t.Run("stream timeout", func(t *testing.T) { + // [ttstream error, code=12014] [client-side stream] timeout config: 1s + clientEx := NewTimeoutException(streaming_types.StreamTimeout, remote.Client, time.Second) + t.Log(clientEx) + test.Assert(t, clientEx != nil) + test.Assert(t, errors.Is(clientEx, kerrors.ErrStreamingTimeout)) + test.Assert(t, errors.Is(clientEx, errStreamTimeout)) + test.Assert(t, clientEx.TypeId() == 12014, clientEx.TypeId()) + + // [ttstream error, code=12014] [server-side stream] timeout config: 2s + serverEx := NewTimeoutException(streaming_types.StreamTimeout, remote.Server, 2*time.Second) + t.Log(serverEx) + test.Assert(t, serverEx != nil) + test.Assert(t, errors.Is(serverEx, kerrors.ErrStreamingTimeout)) + test.Assert(t, errors.Is(serverEx, errStreamTimeout)) + test.Assert(t, serverEx.TypeId() == 12014, serverEx.TypeId()) + }) + + t.Run("stream recv timeout", func(t *testing.T) { + // [ttstream error, code=12015] [client-side stream] timeout config: 500ms + clientEx := NewTimeoutException(streaming_types.StreamRecvTimeout, remote.Client, 500*time.Millisecond) + t.Log(clientEx) + test.Assert(t, clientEx != nil) + test.Assert(t, errors.Is(clientEx, kerrors.ErrStreamingTimeout)) + test.Assert(t, errors.Is(clientEx, errStreamRecvTimeout)) + test.Assert(t, clientEx.TypeId() == 12015, clientEx.TypeId()) + + // [ttstream error, code=12015] [server-side stream] timeout config: 1s + serverEx := NewTimeoutException(streaming_types.StreamRecvTimeout, remote.Server, time.Second) + t.Log(serverEx) + test.Assert(t, serverEx != nil) + test.Assert(t, errors.Is(serverEx, kerrors.ErrStreamingTimeout)) + test.Assert(t, errors.Is(serverEx, errStreamRecvTimeout)) + test.Assert(t, serverEx.TypeId() == 12015, serverEx.TypeId()) + }) + + t.Run("stream send timeout", func(t *testing.T) { + // [ttstream error, code=12016] [client-side stream] timeout config: 800ms + clientEx := NewTimeoutException(streaming_types.StreamSendTimeout, remote.Client, 800*time.Millisecond) + t.Log(clientEx) + test.Assert(t, clientEx != nil) + test.Assert(t, errors.Is(clientEx, kerrors.ErrStreamingTimeout)) + test.Assert(t, errors.Is(clientEx, errStreamSendTimeout)) + test.Assert(t, clientEx.TypeId() == 12016, clientEx.TypeId()) + + // [ttstream error, code=12016] [server-side stream] timeout config: 1.5s + serverEx := NewTimeoutException(streaming_types.StreamSendTimeout, remote.Server, 1500*time.Millisecond) + t.Log(serverEx) + test.Assert(t, serverEx != nil) + test.Assert(t, errors.Is(serverEx, kerrors.ErrStreamingTimeout)) + test.Assert(t, errors.Is(serverEx, errStreamSendTimeout)) + test.Assert(t, serverEx.TypeId() == 12016, serverEx.TypeId()) + }) +} + func Test_utilFuncs(t *testing.T) { // test formatCancelPath t.Run("formatCancelPath", func(t *testing.T) { diff --git a/pkg/remote/trans/ttstream/stream.go b/pkg/remote/trans/ttstream/stream.go index a6424b25b3..00ac0cb0b3 100644 --- a/pkg/remote/trans/ttstream/stream.go +++ b/pkg/remote/trans/ttstream/stream.go @@ -19,12 +19,12 @@ package ttstream import ( "context" "fmt" - "time" "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/ttheader" + "github.com/cloudwego/kitex/internal/utils/contextwatcher" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/streaming" ktransport "github.com/cloudwego/kitex/transport" @@ -84,7 +84,6 @@ type stream struct { wheader streaming.Header // wheader == nil means it already be sent wtrailer streaming.Trailer // wtrailer == nil means it already be sent - recvTimeout time.Duration closeCallback []func(error) } @@ -125,17 +124,11 @@ func (s *stream) SendMsg(ctx context.Context, msg any) (err error) { } func (s *stream) RecvMsg(ctx context.Context, data any) error { - nctx := s.ctx - if s.recvTimeout > 0 { - var cancel context.CancelFunc - nctx, cancel = context.WithTimeout(nctx, s.recvTimeout) - defer cancel() - } - payload, err := s.reader.output(nctx) + payload, err := s.reader.output(s.ctx) if err != nil { return err } - err = DecodePayload(nctx, payload, data) + err = DecodePayload(s.ctx, payload, data) // payload will not be access after decode mcache.Free(payload) @@ -153,19 +146,13 @@ func (s *stream) RegisterCloseCallback(cb func(error)) { s.closeCallback = append(s.closeCallback, cb) } -func (s *stream) setRecvTimeout(timeout time.Duration) { - if timeout <= 0 { - return - } - s.recvTimeout = timeout -} - func (s *stream) runCloseCallback(exception error) { if len(s.closeCallback) > 0 { for _, cb := range s.closeCallback { cb(exception) } } + contextwatcher.DeregisterContext(s.ctx) _ = s.writer.CloseStream(s.sid) } diff --git a/pkg/remote/trans/ttstream/stream_client.go b/pkg/remote/trans/ttstream/stream_client.go index 555fbdcb18..bf3263b716 100644 --- a/pkg/remote/trans/ttstream/stream_client.go +++ b/pkg/remote/trans/ttstream/stream_client.go @@ -22,6 +22,7 @@ import ( "fmt" "sync" "sync/atomic" + "time" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/ttheader" @@ -62,6 +63,7 @@ type clientStream struct { // exception as Recv closeStreamException atomic.Value // type must be of *Exception storeExceptionOnce sync.Once + streamTimeout time.Duration // whole stream timeout // for Header()/Trailer() headerSig chan int32 @@ -125,6 +127,11 @@ func (s *clientStream) Context() context.Context { return s.ctx } +func (s *clientStream) Cancel(err error) { + finalEx, cancelPath := s.parseCancelErr(err) + s.close(finalEx, true, cancelPath) +} + // ctxDoneCallback convert ctx.Err() to ttstream related err and close client-side stream. // it is invoked in container.Pipe func (s *clientStream) ctxDoneCallback(ctx context.Context) { @@ -133,24 +140,26 @@ func (s *clientStream) ctxDoneCallback(ctx context.Context) { s.close(finalEx, true, cancelPath) } -// parseCtxErr parses information in ctx.Err and returning ttstream Exception and cascading cancelPath -func (s *clientStream) parseCtxErr(ctx context.Context) (finalEx *Exception, cancelPath string) { +func (s *clientStream) parseCancelErr(err error) (finalEx *Exception, cancelPath string) { svcName := s.rpcInfo.From().ServiceName() - cErr := ctx.Err() - switch cErr { + switch err { // biz code invokes cancel() - case context.Canceled: + case nil, context.Canceled: finalEx = errBizCancel.newBuilder().withSide(clientSide) // the initial node sending rst, the original cancelPath is empty cancelPath = appendCancelPath("", svcName) + // stream timeout + case context.DeadlineExceeded: + finalEx = newTimeoutException(errStreamTimeout, clientSide, s.streamTimeout) + cancelPath = appendCancelPath("", svcName) default: - if tEx, ok := cErr.(*Exception); ok { + if tEx, ok := err.(*Exception); ok { // for cascading cancel case, we need to change the side from server to client finalEx = tEx.newBuilder().withSide(clientSide) cancelPath = appendCancelPath(tEx.cancelPath, svcName) } else { // ctx provided by other sources(e.g. gRPC handler has been canceled, cErr is gRPC error) - finalEx = errInternalCancel.newBuilder().withSide(clientSide).withCause(cErr) + finalEx = errInternalCancel.newBuilder().withSide(clientSide).withCause(err) // as upstream cascading path may have existed(e.g. gRPC service chains), using non-ttstream path // as a unified placeholder enables quick identification of such scenarios cancelPath = appendCancelPath("non-ttstream path", svcName) @@ -159,6 +168,11 @@ func (s *clientStream) parseCtxErr(ctx context.Context) (finalEx *Exception, can return } +// parseCtxErr parses information in ctx.Err and returning ttstream Exception and cascading cancelPath +func (s *clientStream) parseCtxErr(ctx context.Context) (finalEx *Exception, cancelPath string) { + return s.parseCancelErr(ctx.Err()) +} + func (s *clientStream) close(exception error, sendRst bool, cancelPath string) { if exception != nil { // store exception before change clientStream state @@ -204,6 +218,10 @@ func (s *clientStream) setMetaFrameHandler(metaHandler MetaFrameHandler) { s.metaFrameHandler = metaHandler } +func (s *clientStream) setStreamTimeout(tm time.Duration) { + s.streamTimeout = tm +} + // === clientStream OnRead callback func (s *clientStream) onReadMetaFrame(fr *Frame) error { diff --git a/pkg/remote/trans/ttstream/stream_server.go b/pkg/remote/trans/ttstream/stream_server.go index be28ff49b5..92f6f28b53 100644 --- a/pkg/remote/trans/ttstream/stream_server.go +++ b/pkg/remote/trans/ttstream/stream_server.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "sync/atomic" + "time" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/gopkg/protocol/ttheader" @@ -33,8 +34,9 @@ var _ ServerStreamMeta = (*serverStream)(nil) func newServerStream(ctx context.Context, writer streamWriter, smeta streamFrame) *serverStream { s := newBasicStream(ctx, writer, smeta) - s.reader = newStreamReader() - return &serverStream{stream: s} + ss := &serverStream{stream: s} + s.reader = newStreamReaderWithCtxDoneCallback(ss.ctxDoneCallback) + return ss } // initial state: streamStateActive @@ -51,6 +53,7 @@ type serverStream struct { *stream state int32 cancelFunc cancelWithReason + timeout time.Duration } func (s *serverStream) SetHeader(hd streaming.Header) error { @@ -228,3 +231,26 @@ func (s *serverStream) closeTest(exception error, cancelPath string) error { s.runCloseCallback(exception) return nil } + +func (s *serverStream) ctxDoneCallback(ctx context.Context) { + finalEx := s.parseCtxErr(ctx) + s.close(finalEx) +} + +func (s *serverStream) parseCtxErr(ctx context.Context) (finalEx *Exception) { + cErr := ctx.Err() + switch cErr { + // stream timeout + case context.DeadlineExceeded: + finalEx = newTimeoutException(errStreamTimeout, serverSide, s.timeout) + // other close stream scenarios, there is no need to process + default: + if ex, ok := cErr.(*Exception); ok { + finalEx = ex + } else { + finalEx = errInternalCancel.newBuilder().withSide(serverSide).withCause(cErr) + } + } + + return finalEx +} diff --git a/pkg/remote/trans/ttstream/stream_test.go b/pkg/remote/trans/ttstream/stream_test.go index bbb0d2983f..79252ac6e4 100644 --- a/pkg/remote/trans/ttstream/stream_test.go +++ b/pkg/remote/trans/ttstream/stream_test.go @@ -20,9 +20,7 @@ package ttstream import ( "context" - "strings" "testing" - "time" "github.com/cloudwego/kitex/internal/test" ) @@ -57,26 +55,6 @@ func TestGenericStreaming(t *testing.T) { // test.Assert(t, res.B == req.B) } -// TestStreamRecvTimeout tests that RecvMsg correctly handles timeout scenarios -func TestStreamRecvTimeout(t *testing.T) { - _, ss, err := newTestStreamPipe(testServiceInfo, "Bidi") - test.Assert(t, err == nil, err) - - // Set a very short timeout for testing - ss.setRecvTimeout(time.Millisecond * 10) - - // Create a context that won't expire - ctx := context.Background() - - // Try to receive message - should timeout quickly - res := new(testResponse) - err = ss.RecvMsg(ctx, res) - test.Assert(t, err != nil, "RecvMsg should timeout") - test.Assert(t, strings.Contains(err.Error(), "timeout") || - strings.Contains(err.Error(), "deadline exceeded"), - "Error should be timeout related") -} - // TestStreamRecvWithCancellation tests that RecvMsg respects context cancellation func TestStreamRecvWithCancellation(t *testing.T) { cs, ss, err := newTestStreamPipe(testServiceInfo, "Bidi") diff --git a/pkg/remote/trans/ttstream/transport_server.go b/pkg/remote/trans/ttstream/transport_server.go index 4a67dd33bf..c9431b912b 100644 --- a/pkg/remote/trans/ttstream/transport_server.go +++ b/pkg/remote/trans/ttstream/transport_server.go @@ -21,12 +21,16 @@ import ( "errors" "io" "net" + "strconv" "sync" "sync/atomic" + "time" "github.com/cloudwego/gopkg/bufiox" + "github.com/cloudwego/gopkg/protocol/ttheader" "github.com/cloudwego/netpoll" + "github.com/cloudwego/kitex/internal/utils/contextwatcher" "github.com/cloudwego/kitex/pkg/gofunc" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/remote/trans/ttstream/container" @@ -152,11 +156,20 @@ func (t *serverTransport) readFrame(reader bufiox.Reader) error { var s *serverStream if fr.typ == headerFrameType { + var ctx context.Context + var cancel context.CancelFunc // server recv a header frame, we should create a new stream - ctx, cancel := context.WithCancel(context.Background()) + tm, ok := parseStreamTimeout(fr) + if ok { + ctx, cancel = context.WithTimeout(context.Background(), tm) + } else { + ctx, cancel = context.WithCancel(context.Background()) + } ctx, cFunc := newContextWithCancelReason(ctx, cancel) s = newServerStream(ctx, t, fr.streamFrame) s.cancelFunc = cFunc + s.timeout = tm + contextwatcher.RegisterContext(s.ctx, s.ctxDoneCallback) t.storeStream(s) err = t.spipe.Write(context.Background(), s) } else { @@ -261,3 +274,16 @@ READ: t.scache = t.scache[:n] goto READ } + +func parseStreamTimeout(fr *Frame) (time.Duration, bool) { + tmStr, ok := fr.meta[ttheader.StreamTimeout] + if !ok { + return 0, false + } + tmMs, err := strconv.Atoi(tmStr) + if err != nil { + klog.Errorf("KITEX: ttstream parse ttheader IntKey StreamTimeout failed, got: %s", tmStr) + return 0, false + } + return time.Duration(tmMs) * time.Millisecond, true +} diff --git a/pkg/remote/trans/ttstream/transport_server_test.go b/pkg/remote/trans/ttstream/transport_server_test.go new file mode 100644 index 0000000000..cefa6fec6d --- /dev/null +++ b/pkg/remote/trans/ttstream/transport_server_test.go @@ -0,0 +1,65 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "testing" + "time" + + "github.com/cloudwego/gopkg/protocol/ttheader" + + "github.com/cloudwego/kitex/internal/test" +) + +func Test_parseStreamTimeout(t *testing.T) { + // frame without Stream Timeout + fr := &Frame{ + streamFrame: streamFrame{ + meta: IntHeader{}, + }, + typ: headerFrameType, + } + tm, ok := parseStreamTimeout(fr) + test.Assert(t, !ok) + test.Assert(t, tm == 0) + + // frame with valid Stream Timeout + fr = &Frame{ + streamFrame: streamFrame{ + meta: IntHeader{ + ttheader.StreamTimeout: "1000", + }, + }, + typ: headerFrameType, + } + tm, ok = parseStreamTimeout(fr) + test.Assert(t, ok) + test.Assert(t, tm == 1*time.Second, tm) + + // frame with invalid Stream Timeout + fr = &Frame{ + streamFrame: streamFrame{ + meta: IntHeader{ + ttheader.StreamTimeout: "invalid", + }, + }, + typ: headerFrameType, + } + tm, ok = parseStreamTimeout(fr) + test.Assert(t, !ok) + test.Assert(t, tm == 0) +} diff --git a/pkg/remote/trans/ttstream/transport_timeout_test.go b/pkg/remote/trans/ttstream/transport_timeout_test.go new file mode 100644 index 0000000000..cdb0f14078 --- /dev/null +++ b/pkg/remote/trans/ttstream/transport_timeout_test.go @@ -0,0 +1,738 @@ +//go:build !windows + +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "context" + "errors" + "io" + "sync" + "testing" + "time" + + "github.com/cloudwego/netpoll" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/kerrors" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/streaming" +) + +func TestStreamTimeout(t *testing.T) { + cliSvcName := "clientSideService" + t.Run("ServerStreaming", func(t *testing.T) { + method := "ServerStreaming" + t.Run("configure timeout and finish in time", func(t *testing.T) { + cfd, sfd := netpoll.GetSysFdPairs() + cconn, err := netpoll.NewFDConnection(cfd) + test.Assert(t, err == nil, err) + sconn, err := netpoll.NewFDConnection(sfd) + test.Assert(t, err == nil, err) + + intHeader := make(IntHeader) + strHeader := make(streaming.Header) + ctrans := newClientTransport(cconn, nil) + defer func() { + ctrans.Close(nil) + ctrans.WaitClosed() + }() + start := time.Now() + tm := 200 * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), tm) + defer cancel() + injectStreamTimeout(ctx, intHeader) + cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: method}) + cs.rpcInfo = rpcinfo.NewRPCInfo( + rpcinfo.NewEndpointInfo(cliSvcName, method, nil, nil), nil, nil, nil, nil) + cs.setStreamTimeout(tm) + err = ctrans.WriteStream(ctx, cs, intHeader, strHeader) + test.Assert(t, err == nil, err) + strans := newServerTransport(sconn) + defer func() { + strans.Close(nil) + strans.WaitClosed() + }() + ss, err := strans.ReadStream(context.Background()) + test.Assert(t, err == nil, err) + + var wg sync.WaitGroup + wg.Add(1) + // client + go func() { + defer wg.Done() + req := new(testRequest) + req.B = "hello" + cErr := cs.SendMsg(ctx, req) + test.Assert(t, cErr == nil, cErr) + + cErr = cs.CloseSend(ctx) + test.Assert(t, cErr == nil, cErr) + + for { + res := new(testResponse) + cErr = cs.RecvMsg(ctx, res) + if cErr == io.EOF { + break + } + test.Assert(t, cErr == nil, cErr) + } + }() + + // server + checkServerSideStreamTimeout(t, ss, start, tm) + + req := new(testRequest) + sErr := ss.RecvMsg(ss.ctx, req) + test.Assert(t, sErr == nil, sErr) + + for i := 0; i < 10; i++ { + res := new(testResponse) + res.B = req.B + sErr = ss.SendMsg(ss.ctx, res) + test.Assert(t, sErr == nil, sErr) + } + err = ss.CloseSend(nil) + test.Assert(t, err == nil, err) + wg.Wait() + }) + t.Run("configure timeout and server-side continue sending", func(t *testing.T) { + cfd, sfd := netpoll.GetSysFdPairs() + cconn, err := netpoll.NewFDConnection(cfd) + test.Assert(t, err == nil, err) + sconn, err := netpoll.NewFDConnection(sfd) + test.Assert(t, err == nil, err) + + intHeader := make(IntHeader) + strHeader := make(streaming.Header) + ctrans := newClientTransport(cconn, nil) + defer func() { + ctrans.Close(nil) + ctrans.WaitClosed() + }() + start := time.Now() + tm := 50 * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), tm) + defer cancel() + injectStreamTimeout(ctx, intHeader) + cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: method}) + cs.rpcInfo = rpcinfo.NewRPCInfo( + rpcinfo.NewEndpointInfo(cliSvcName, method, nil, nil), nil, nil, nil, nil) + cs.setStreamTimeout(tm) + err = ctrans.WriteStream(ctx, cs, intHeader, strHeader) + test.Assert(t, err == nil, err) + strans := newServerTransport(sconn) + defer func() { + strans.Close(nil) + strans.WaitClosed() + }() + ss, err := strans.ReadStream(context.Background()) + test.Assert(t, err == nil, err) + + var wg sync.WaitGroup + wg.Add(1) + // client + go func() { + defer wg.Done() + req := new(testRequest) + req.B = "hello" + cErr := cs.SendMsg(ctx, req) + test.Assert(t, cErr == nil, cErr) + + cErr = cs.CloseSend(ctx) + test.Assert(t, cErr == nil, cErr) + + // Try to receive, should timeout eventually + for { + res := new(testResponse) + cErr = cs.RecvMsg(ctx, res) + if cErr != nil { + t.Logf("client-side Stream Recv err: %v", cErr) + test.Assert(t, errors.Is(cErr, kerrors.ErrStreamingTimeout), cErr) + break + } + } + }() + + // server + checkServerSideStreamTimeout(t, ss, start, tm) + + req := new(testRequest) + sErr := ss.RecvMsg(ss.ctx, req) + test.Assert(t, sErr == nil, sErr) + + // Server keeps sending slowly, causing timeout + for { + res := new(testResponse) + res.B = req.B + sErr = ss.SendMsg(ss.ctx, res) + if sErr != nil { + t.Logf("server-side Stream Send err: %v", sErr) + test.Assert(t, errors.Is(sErr, kerrors.ErrStreamingTimeout), sErr) + break + } + time.Sleep(30 * time.Millisecond) + } + wg.Wait() + }) + t.Run("configure timeout and server-side spend a lot of time send the first resp", func(t *testing.T) { + cfd, sfd := netpoll.GetSysFdPairs() + cconn, err := netpoll.NewFDConnection(cfd) + test.Assert(t, err == nil, err) + sconn, err := netpoll.NewFDConnection(sfd) + test.Assert(t, err == nil, err) + + intHeader := make(IntHeader) + strHeader := make(streaming.Header) + ctrans := newClientTransport(cconn, nil) + defer func() { + ctrans.Close(nil) + ctrans.WaitClosed() + }() + start := time.Now() + tm := 50 * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), tm) + defer cancel() + injectStreamTimeout(ctx, intHeader) + cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: method}) + cs.rpcInfo = rpcinfo.NewRPCInfo( + rpcinfo.NewEndpointInfo(cliSvcName, method, nil, nil), nil, nil, nil, nil) + cs.setStreamTimeout(tm) + err = ctrans.WriteStream(ctx, cs, intHeader, strHeader) + test.Assert(t, err == nil, err) + strans := newServerTransport(sconn) + defer func() { + strans.Close(nil) + strans.WaitClosed() + }() + ss, err := strans.ReadStream(context.Background()) + test.Assert(t, err == nil, err) + + var wg sync.WaitGroup + wg.Add(1) + // client + go func() { + defer wg.Done() + req := new(testRequest) + req.B = "hello" + cErr := cs.SendMsg(ctx, req) + test.Assert(t, cErr == nil, cErr) + + cErr = cs.CloseSend(ctx) + test.Assert(t, cErr == nil, cErr) + + // Try to receive first response, should timeout + res := new(testResponse) + cErr = cs.RecvMsg(ctx, res) + test.Assert(t, cErr != nil) + t.Logf("client-side Stream Recv err: %v", cErr) + test.Assert(t, errors.Is(cErr, kerrors.ErrStreamingTimeout), cErr) + }() + + // server + checkServerSideStreamTimeout(t, ss, start, tm) + + req := new(testRequest) + sErr := ss.RecvMsg(ss.ctx, req) + test.Assert(t, sErr == nil, sErr) + + // Server delays before sending first response + time.Sleep(150 * time.Millisecond) + + res := new(testResponse) + res.B = req.B + sErr = ss.SendMsg(ss.ctx, res) + // Send will fail because client already timeout + test.Assert(t, sErr != nil) + t.Logf("server-side Stream Send err: %v", sErr) + test.Assert(t, errors.Is(sErr, kerrors.ErrStreamingTimeout), sErr) + wg.Wait() + }) + }) + t.Run("ClientStreaming", func(t *testing.T) { + method := "ClientStreaming" + t.Run("configure timeout and finish in time", func(t *testing.T) { + cfd, sfd := netpoll.GetSysFdPairs() + cconn, err := netpoll.NewFDConnection(cfd) + test.Assert(t, err == nil, err) + sconn, err := netpoll.NewFDConnection(sfd) + test.Assert(t, err == nil, err) + + intHeader := make(IntHeader) + strHeader := make(streaming.Header) + ctrans := newClientTransport(cconn, nil) + defer func() { + ctrans.Close(nil) + ctrans.WaitClosed() + }() + start := time.Now() + tm := 200 * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), tm) + defer cancel() + injectStreamTimeout(ctx, intHeader) + cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: method}) + cs.rpcInfo = rpcinfo.NewRPCInfo( + rpcinfo.NewEndpointInfo(cliSvcName, method, nil, nil), nil, nil, nil, nil) + cs.setStreamTimeout(tm) + err = ctrans.WriteStream(ctx, cs, intHeader, strHeader) + test.Assert(t, err == nil, err) + strans := newServerTransport(sconn) + defer func() { + strans.Close(nil) + strans.WaitClosed() + }() + ss, err := strans.ReadStream(context.Background()) + test.Assert(t, err == nil, err) + + var wg sync.WaitGroup + wg.Add(1) + // client + go func() { + defer wg.Done() + // Send 10 messages quickly + for i := 0; i < 10; i++ { + req := new(testRequest) + req.A = int32(i) + req.B = "hello" + cErr := cs.SendMsg(ctx, req) + test.Assert(t, cErr == nil, cErr) + } + + cErr := cs.CloseSend(ctx) + test.Assert(t, cErr == nil, cErr) + + // Receive final response + res := new(testResponse) + cErr = cs.RecvMsg(ctx, res) + test.Assert(t, cErr == nil, cErr) + }() + + // server + checkServerSideStreamTimeout(t, ss, start, tm) + + for { + req := new(testRequest) + sErr := ss.RecvMsg(ss.ctx, req) + if sErr == io.EOF { + break + } + test.Assert(t, sErr == nil, sErr) + } + + // Send final response + res := new(testResponse) + res.B = "done" + sErr := ss.SendMsg(ss.ctx, res) + test.Assert(t, sErr == nil, sErr) + + err = ss.CloseSend(nil) + test.Assert(t, err == nil, err) + wg.Wait() + }) + t.Run("configure timeout and client-side continue sending", func(t *testing.T) { + cfd, sfd := netpoll.GetSysFdPairs() + cconn, err := netpoll.NewFDConnection(cfd) + test.Assert(t, err == nil, err) + sconn, err := netpoll.NewFDConnection(sfd) + test.Assert(t, err == nil, err) + + intHeader := make(IntHeader) + strHeader := make(streaming.Header) + ctrans := newClientTransport(cconn, nil) + defer func() { + ctrans.Close(nil) + ctrans.WaitClosed() + }() + start := time.Now() + tm := 50 * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), tm) + defer cancel() + injectStreamTimeout(ctx, intHeader) + cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: method}) + cs.rpcInfo = rpcinfo.NewRPCInfo( + rpcinfo.NewEndpointInfo(cliSvcName, method, nil, nil), nil, nil, nil, nil) + cs.setStreamTimeout(tm) + err = ctrans.WriteStream(ctx, cs, intHeader, strHeader) + test.Assert(t, err == nil, err) + strans := newServerTransport(sconn) + defer func() { + strans.Close(nil) + strans.WaitClosed() + }() + ss, err := strans.ReadStream(context.Background()) + test.Assert(t, err == nil, err) + + var wg sync.WaitGroup + wg.Add(1) + // client + go func() { + defer wg.Done() + for i := 0; i < 10; i++ { + req := new(testRequest) + req.A = int32(i) + req.B = "hello" + cErr := cs.SendMsg(ctx, req) + if cErr != nil { + // Send will fail after timeout + t.Logf("client-side Stream Send err: %v", cErr) + test.Assert(t, errors.Is(cErr, kerrors.ErrStreamingTimeout), cErr) + break + } + time.Sleep(30 * time.Millisecond) // Slow sending + } + }() + + // server + checkServerSideStreamTimeout(t, ss, start, tm) + + // Try to receive, should timeout (cascaded from client) + for { + req := new(testRequest) + sErr := ss.RecvMsg(ss.ctx, req) + if sErr != nil { + t.Logf("server-side Stream Recv err: %v", sErr) + test.Assert(t, errors.Is(sErr, kerrors.ErrStreamingTimeout), sErr) + break + } + } + wg.Wait() + }) + t.Run("configure timeout and server-side spend a lot of time return the resp", func(t *testing.T) { + cfd, sfd := netpoll.GetSysFdPairs() + cconn, err := netpoll.NewFDConnection(cfd) + test.Assert(t, err == nil, err) + sconn, err := netpoll.NewFDConnection(sfd) + test.Assert(t, err == nil, err) + + intHeader := make(IntHeader) + strHeader := make(streaming.Header) + ctrans := newClientTransport(cconn, nil) + defer func() { + ctrans.Close(nil) + ctrans.WaitClosed() + }() + start := time.Now() + tm := 50 * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), tm) + defer cancel() + injectStreamTimeout(ctx, intHeader) + cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: method}) + cs.rpcInfo = rpcinfo.NewRPCInfo( + rpcinfo.NewEndpointInfo(cliSvcName, method, nil, nil), nil, nil, nil, nil) + cs.setStreamTimeout(tm) + err = ctrans.WriteStream(ctx, cs, intHeader, strHeader) + test.Assert(t, err == nil, err) + strans := newServerTransport(sconn) + defer func() { + strans.Close(nil) + strans.WaitClosed() + }() + ss, err := strans.ReadStream(context.Background()) + test.Assert(t, err == nil, err) + + var wg sync.WaitGroup + wg.Add(1) + // client + go func() { + defer wg.Done() + for i := 0; i < 10; i++ { + req := new(testRequest) + req.A = int32(i) + req.B = "hello" + cErr := cs.SendMsg(ctx, req) + test.Assert(t, cErr == nil, cErr) + } + + cErr := cs.CloseSend(ctx) + test.Assert(t, cErr == nil, cErr) + + // Wait for final response, should timeout + res := new(testResponse) + cErr = cs.RecvMsg(ctx, res) + test.Assert(t, cErr != nil) + t.Logf("client-side Stream Recv err: %v", cErr) + test.Assert(t, errors.Is(cErr, kerrors.ErrStreamingTimeout), cErr) + }() + + // server + checkServerSideStreamTimeout(t, ss, start, tm) + + // Receive all client messages + for { + req := new(testRequest) + sErr := ss.RecvMsg(ss.ctx, req) + if sErr == io.EOF { + break + } + test.Assert(t, sErr == nil, sErr) + } + + // Server delays a lot before sending final response + time.Sleep(100 * time.Millisecond) + + res := new(testResponse) + res.B = "done" + sErr := ss.SendMsg(ss.ctx, res) + if sErr != nil { + // Send will fail because client already timeout + t.Logf("server-side Stream Send err: %v", sErr) + test.Assert(t, errors.Is(sErr, kerrors.ErrStreamingTimeout), sErr) + } + wg.Wait() + }) + }) + t.Run("BidiStreaming", func(t *testing.T) { + method := "BidiStreaming" + t.Run("configure timeout and finish in time", func(t *testing.T) { + cfd, sfd := netpoll.GetSysFdPairs() + cconn, err := netpoll.NewFDConnection(cfd) + test.Assert(t, err == nil, err) + sconn, err := netpoll.NewFDConnection(sfd) + test.Assert(t, err == nil, err) + + intHeader := make(IntHeader) + strHeader := make(streaming.Header) + ctrans := newClientTransport(cconn, nil) + defer func() { + ctrans.Close(nil) + ctrans.WaitClosed() + }() + start := time.Now() + tm := 200 * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), tm) + defer cancel() + injectStreamTimeout(ctx, intHeader) + cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: method}) + cs.rpcInfo = rpcinfo.NewRPCInfo( + rpcinfo.NewEndpointInfo(cliSvcName, method, nil, nil), nil, nil, nil, nil) + cs.setStreamTimeout(tm) + err = ctrans.WriteStream(ctx, cs, intHeader, strHeader) + test.Assert(t, err == nil, err) + strans := newServerTransport(sconn) + defer func() { + strans.Close(nil) + strans.WaitClosed() + }() + ss, err := strans.ReadStream(context.Background()) + test.Assert(t, err == nil, err) + + var wg sync.WaitGroup + wg.Add(1) + // client + go func() { + defer wg.Done() + // Send and receive 10 messages + for i := 0; i < 10; i++ { + req := new(testRequest) + req.A = int32(i) + req.B = "hello" + cErr := cs.SendMsg(ctx, req) + test.Assert(t, cErr == nil, cErr) + + res := new(testResponse) + cErr = cs.RecvMsg(ctx, res) + test.Assert(t, cErr == nil, cErr) + test.DeepEqual(t, req.A, res.A) + } + + cErr := cs.CloseSend(ctx) + test.Assert(t, cErr == nil, cErr) + }() + + // server + checkServerSideStreamTimeout(t, ss, start, tm) + + for i := 0; i < 10; i++ { + req := new(testRequest) + sErr := ss.RecvMsg(ss.ctx, req) + test.Assert(t, sErr == nil, sErr) + + res := new(testResponse) + res.A = req.A + res.B = req.B + sErr = ss.SendMsg(ss.ctx, res) + test.Assert(t, sErr == nil, sErr) + } + + err = ss.CloseSend(nil) + test.Assert(t, err == nil, err) + wg.Wait() + }) + t.Run("configure timeout and server-side continue sending", func(t *testing.T) { + cfd, sfd := netpoll.GetSysFdPairs() + cconn, err := netpoll.NewFDConnection(cfd) + test.Assert(t, err == nil, err) + sconn, err := netpoll.NewFDConnection(sfd) + test.Assert(t, err == nil, err) + + intHeader := make(IntHeader) + strHeader := make(streaming.Header) + ctrans := newClientTransport(cconn, nil) + defer func() { + ctrans.Close(nil) + ctrans.WaitClosed() + }() + start := time.Now() + tm := 50 * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), tm) + defer cancel() + injectStreamTimeout(ctx, intHeader) + cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: method}) + cs.rpcInfo = rpcinfo.NewRPCInfo( + rpcinfo.NewEndpointInfo(cliSvcName, method, nil, nil), nil, nil, nil, nil) + cs.setStreamTimeout(tm) + err = ctrans.WriteStream(ctx, cs, intHeader, strHeader) + test.Assert(t, err == nil, err) + strans := newServerTransport(sconn) + defer func() { + strans.Close(nil) + strans.WaitClosed() + }() + ss, err := strans.ReadStream(context.Background()) + test.Assert(t, err == nil, err) + + var wg sync.WaitGroup + wg.Add(1) + // client + go func() { + defer wg.Done() + // Send one request + req := new(testRequest) + req.A = 1 + req.B = "hello" + cErr := cs.SendMsg(ctx, req) + test.Assert(t, cErr == nil, cErr) + + // Try to receive, should timeout + for { + res := new(testResponse) + cErr = cs.RecvMsg(ctx, res) + if cErr != nil { + t.Logf("client-side Stream Recv err: %v", cErr) + test.Assert(t, errors.Is(cErr, kerrors.ErrStreamingTimeout), cErr) + break + } + } + }() + + // server + checkServerSideStreamTimeout(t, ss, start, tm) + + req := new(testRequest) + sErr := ss.RecvMsg(ss.ctx, req) + test.Assert(t, sErr == nil, sErr) + + // Server keeps sending slowly + for i := 0; i < 10; i++ { + res := new(testResponse) + res.A = int32(i) + res.B = req.B + sErr = ss.SendMsg(ss.ctx, res) + if sErr != nil { + // Send will fail after client timeout + t.Logf("server-side Stream Send err: %v", sErr) + test.Assert(t, errors.Is(sErr, kerrors.ErrStreamingTimeout), sErr) + break + } + time.Sleep(30 * time.Millisecond) + } + wg.Wait() + }) + t.Run("configure timeout and client-side continue sending", func(t *testing.T) { + cfd, sfd := netpoll.GetSysFdPairs() + cconn, err := netpoll.NewFDConnection(cfd) + test.Assert(t, err == nil, err) + sconn, err := netpoll.NewFDConnection(sfd) + test.Assert(t, err == nil, err) + + intHeader := make(IntHeader) + strHeader := make(streaming.Header) + ctrans := newClientTransport(cconn, nil) + defer func() { + ctrans.Close(nil) + ctrans.WaitClosed() + }() + start := time.Now() + tm := 50 * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), tm) + defer cancel() + injectStreamTimeout(ctx, intHeader) + cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: method}) + cs.rpcInfo = rpcinfo.NewRPCInfo( + rpcinfo.NewEndpointInfo(cliSvcName, method, nil, nil), nil, nil, nil, nil) + cs.setStreamTimeout(tm) + err = ctrans.WriteStream(ctx, cs, intHeader, strHeader) + test.Assert(t, err == nil, err) + strans := newServerTransport(sconn) + defer func() { + strans.Close(nil) + strans.WaitClosed() + }() + ss, err := strans.ReadStream(context.Background()) + test.Assert(t, err == nil, err) + + var wg sync.WaitGroup + wg.Add(1) + // client + go func() { + defer wg.Done() + // Client keeps sending slowly + for i := 0; i < 10; i++ { + req := new(testRequest) + req.A = int32(i) + req.B = "hello" + cErr := cs.SendMsg(ctx, req) + if cErr != nil { + // Send will fail after timeout + t.Logf("client-side Stream Send err: %v", cErr) + test.Assert(t, errors.Is(cErr, kerrors.ErrStreamingTimeout), cErr) + break + } + time.Sleep(30 * time.Millisecond) + } + }() + + // server + checkServerSideStreamTimeout(t, ss, start, tm) + + // Try to receive, should timeout (cascaded from client) + for { + req := new(testRequest) + sErr := ss.RecvMsg(ss.ctx, req) + if sErr != nil { + t.Logf("server-side Stream Recv err: %v", sErr) + test.Assert(t, errors.Is(sErr, kerrors.ErrStreamingTimeout), sErr) + break + } + // Echo back if no timeout yet + res := new(testResponse) + res.A = req.A + res.B = req.B + ss.SendMsg(ss.ctx, res) + } + wg.Wait() + }) + }) +} + +func checkServerSideStreamTimeout(t *testing.T, ss *serverStream, start time.Time, tm time.Duration) { + test.Assert(t, ss.ctx.Done() != nil, ss.ctx) + ddl, ok := ss.ctx.Deadline() + test.Assert(t, ok, ss.ctx) + test.Assert(t, ddl.Sub(start) <= tm, start, ddl) +} diff --git a/pkg/rpcinfo/interface.go b/pkg/rpcinfo/interface.go index 4438d4fb13..5e54827be7 100644 --- a/pkg/rpcinfo/interface.go +++ b/pkg/rpcinfo/interface.go @@ -75,6 +75,8 @@ type TimeoutProvider interface { type StreamConfig interface { StreamRecvTimeout() time.Duration + StreamSendTimeout() time.Duration + StreamTimeout() time.Duration } // RPCConfig contains configuration for RPC. diff --git a/pkg/rpcinfo/mocks_test.go b/pkg/rpcinfo/mocks_test.go index c0033bbd7e..3416f29906 100644 --- a/pkg/rpcinfo/mocks_test.go +++ b/pkg/rpcinfo/mocks_test.go @@ -94,6 +94,14 @@ func (m *MockRPCConfig) StreamRecvTimeout() time.Duration { return time.Duration(0) } +func (m *MockRPCConfig) StreamSendTimeout() time.Duration { + return time.Duration(0) +} + +func (m *MockRPCConfig) StreamTimeout() time.Duration { + return time.Duration(0) +} + type MockRPCStats struct{} func (m *MockRPCStats) Record(context.Context, stats.Event, stats.Status, string) {} diff --git a/pkg/rpcinfo/mutable.go b/pkg/rpcinfo/mutable.go index ee456fe50d..ff034c029c 100644 --- a/pkg/rpcinfo/mutable.go +++ b/pkg/rpcinfo/mutable.go @@ -54,6 +54,8 @@ type MutableRPCConfig interface { SetPayloadCodec(codec serviceinfo.PayloadCodec) SetStreamRecvTimeout(timeout time.Duration) + SetStreamSendTimeout(timeout time.Duration) + SetStreamTimeout(timeout time.Duration) } // MutableRPCStats is used to change the information in the RPCStats. diff --git a/pkg/rpcinfo/rpcconfig.go b/pkg/rpcinfo/rpcconfig.go index 4159e8dae2..5a11e2ae4c 100644 --- a/pkg/rpcinfo/rpcconfig.go +++ b/pkg/rpcinfo/rpcconfig.go @@ -69,6 +69,8 @@ type rpcConfig struct { // stream config streamRecvTimeout time.Duration + streamSendTimeout time.Duration + streamTimeout time.Duration } func init() { @@ -204,6 +206,22 @@ func (r *rpcConfig) StreamRecvTimeout() time.Duration { return r.streamRecvTimeout } +func (r *rpcConfig) SetStreamSendTimeout(timeout time.Duration) { + r.streamSendTimeout = timeout +} + +func (r *rpcConfig) StreamSendTimeout() time.Duration { + return r.streamSendTimeout +} + +func (r *rpcConfig) SetStreamTimeout(timeout time.Duration) { + r.streamTimeout = timeout +} + +func (r *rpcConfig) StreamTimeout() time.Duration { + return r.streamTimeout +} + // Clone returns a copy of the current rpcConfig. func (r *rpcConfig) Clone() MutableRPCConfig { r2 := rpcConfigPool.Get().(*rpcConfig) diff --git a/pkg/rpcinfo/rpcconfig_test.go b/pkg/rpcinfo/rpcconfig_test.go index 622691d740..2fdb071c07 100644 --- a/pkg/rpcinfo/rpcconfig_test.go +++ b/pkg/rpcinfo/rpcconfig_test.go @@ -18,6 +18,7 @@ package rpcinfo_test import ( "testing" + "time" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -32,6 +33,9 @@ func TestRPCConfig(t *testing.T) { test.Assert(t, c.ReadWriteTimeout() != 0) test.Assert(t, c.IOBufferSize() != 0) test.Assert(t, c.TransportProtocol() == transport.PurePayload) + test.Assert(t, c.StreamTimeout() == 0) + test.Assert(t, c.StreamRecvTimeout() == 0) + test.Assert(t, c.StreamSendTimeout() == 0) } func TestSetTransportProtocol(t *testing.T) { @@ -139,3 +143,17 @@ func TestSetTransportProtocol(t *testing.T) { test.Assert(t, (c.TransportProtocol()&transport.TTHeader == transport.TTHeader) && (c.TransportProtocol()&transport.GRPC == transport.GRPC), c.TransportProtocol()) }) } + +func TestStreamConfig(t *testing.T) { + cfg := rpcinfo.NewRPCConfig() + c := rpcinfo.AsMutableRPCConfig(cfg) + stTm := 1 * time.Second + recvTm := 2 * time.Second + sendTm := 3 * time.Second + c.SetStreamTimeout(stTm) + c.SetStreamRecvTimeout(recvTm) + c.SetStreamSendTimeout(sendTm) + test.Assert(t, cfg.StreamTimeout() == stTm, cfg) + test.Assert(t, cfg.StreamRecvTimeout() == recvTm, cfg) + test.Assert(t, cfg.StreamSendTimeout() == sendTm, cfg) +} diff --git a/pkg/streaming/streamx.go b/pkg/streaming/streamx.go index 42f6ed6aff..eb3f1d8e1b 100644 --- a/pkg/streaming/streamx.go +++ b/pkg/streaming/streamx.go @@ -97,6 +97,9 @@ type ClientStream interface { CloseSend(ctx context.Context) error // Context the stream context.Context Context() context.Context + // Cancel immediately terminates the entire lifecycle of the Stream and notifies the peer of the cancel error. + // The passed-in err is also utilized for operations such as logging and reporting. + Cancel(err error) } // ServerStream define server stream APIs diff --git a/pkg/streaming/types/doc.go b/pkg/streaming/types/doc.go new file mode 100644 index 0000000000..da341a8d96 --- /dev/null +++ b/pkg/streaming/types/doc.go @@ -0,0 +1,43 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +Package types provides common type definitions for streaming to avoid circular dependencies. + +# Background + +This package was created to resolve circular dependency issues between packages. +The circular dependency path was: + - pkg/streaming imports pkg/kerrors + - pkg/kerrors imports pkg/remote/trans/nphttp2/status + - pkg/remote/trans/nphttp2/status need to use streaming-related types + +If common streaming types (such as TimeoutType) were defined in pkg/streaming, +it would create a circular dependency: pkg/streaming ↔ pkg/remote/trans/nphttp2/status. + +# Solution + +By extracting shared streaming type definitions into pkg/streaming/types: + - pkg/streaming can import pkg/streaming/types + - pkg/remote/trans/* packages can import pkg/streaming/types + - The circular dependency is broken + +# Convention + +All future common streaming type definitions should be placed in this package +to maintain clean dependency relationships and avoid circular dependencies. +*/ +package types diff --git a/pkg/streaming/types/timeout.go b/pkg/streaming/types/timeout.go new file mode 100644 index 0000000000..df32f4758d --- /dev/null +++ b/pkg/streaming/types/timeout.go @@ -0,0 +1,35 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package types + +// TimeoutType identifies specific timeout types including Stream, Recv and Send Timeout. +// TTHeader Streaming and gRPC Streaming both support these timeout types. +type TimeoutType uint8 + +const ( + StreamTimeout TimeoutType = iota + 1 + StreamRecvTimeout + StreamSendTimeout +) + +// IsStreamingTimeout judges whether tmType is pre-defined Streaming TimeoutType +func IsStreamingTimeout(tmType TimeoutType) bool { + if tmType < StreamTimeout || tmType > StreamSendTimeout { + return false + } + return true +} diff --git a/pkg/streaming/types/timeout_test.go b/pkg/streaming/types/timeout_test.go new file mode 100644 index 0000000000..7cc4823bcb --- /dev/null +++ b/pkg/streaming/types/timeout_test.go @@ -0,0 +1,60 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package types + +import ( + "math" + "strconv" + "testing" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestIsStreamingTimeout(t *testing.T) { + testcases := []struct { + tmType TimeoutType + expectRes bool + }{ + { + tmType: StreamTimeout, + expectRes: true, + }, + { + tmType: StreamRecvTimeout, + expectRes: true, + }, + { + tmType: StreamSendTimeout, + expectRes: true, + }, + { + tmType: TimeoutType(0), + expectRes: false, + }, + { + tmType: TimeoutType(math.MaxUint8), + expectRes: false, + }, + } + + for _, tc := range testcases { + t.Run(strconv.Itoa(int(tc.tmType)), func(t *testing.T) { + res := IsStreamingTimeout(tc.tmType) + test.Assert(t, res == tc.expectRes, res) + }) + } +}