Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pkg/remote/trans/ttstream/stream_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/cloudwego/gopkg/protocol/ttheader"

"github.com/cloudwego/kitex/pkg/klog"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/streaming"
"github.com/cloudwego/kitex/pkg/transmeta"
)
Expand Down Expand Up @@ -246,6 +247,9 @@ func (s *clientStream) onReadTrailerFrame(fr *Frame) error {
// todo: unify bizErr with Exception
// bizErr is independent of rpc exception handling
exception = bizErr
if setter, ok := s.rpcInfo.Invocation().(rpcinfo.InvocationSetter); ok {
setter.SetBizStatusErr(bizErr)
}
}
}

Expand Down
5 changes: 5 additions & 0 deletions pkg/remote/trans/ttstream/stream_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ func (s *serverStream) SendMsg(ctx context.Context, res any) error {
// after CloseSend stream cannot be access again
func (s *serverStream) CloseSend(exception error) error {
s.close(errBizHandlerReturnCancel)
if s.wheader != nil {
if err := s.sendHeader(); err != nil {
return err
}
}
return s.sendTrailer(exception)
}

Expand Down
149 changes: 148 additions & 1 deletion pkg/remote/trans/ttstream/stream_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,17 @@ import (
"time"

"github.com/cloudwego/kitex/internal/test"
"github.com/cloudwego/kitex/pkg/streaming"
)

func newTestServerStream() *serverStream {
return newTestServerStreamWithStreamWriter(mockStreamWriter{})
}

func newTestServerStreamWithStreamWriter(w streamWriter) *serverStream {
ctx, cancel := context.WithCancel(context.Background())
ctx, cancelFunc := newContextWithCancelReason(ctx, cancel)
srvSt := newServerStream(ctx, mockStreamWriter{}, streamFrame{})
srvSt := newServerStream(ctx, w, streamFrame{})
srvSt.cancelFunc = cancelFunc
return srvSt
}
Expand Down Expand Up @@ -90,3 +95,145 @@ func Test_serverStreamStateChange(t *testing.T) {
wg.Wait()
})
}

func Test_serverStreamReimburseHeaderFrame(t *testing.T) {
t.Run("CloseSend", func(t *testing.T) {
var frameNum int
srvSt := newTestServerStreamWithStreamWriter(mockStreamWriter{
writeFrameFunc: func(f *Frame) error {
switch f.typ {
case headerFrameType:
// first frame
test.Assert(t, frameNum == 0, f)
frameNum++
case trailerFrameType:
// second frame
test.Assert(t, frameNum == 1, f)
frameNum++
default:
t.Fatalf("should not send other frame, frame: %+v, frameNum: %d", f, frameNum)
}
return nil
},
})
err := srvSt.CloseSend(nil)
test.Assert(t, err == nil, err)
test.Assert(t, frameNum == 2, frameNum)
})
t.Run("Send -> CloseSend", func(t *testing.T) {
var frameNum int
srvSt := newTestServerStreamWithStreamWriter(mockStreamWriter{
writeFrameFunc: func(f *Frame) error {
switch f.typ {
case headerFrameType:
// first frame
test.Assert(t, frameNum == 0, f)
frameNum++
case dataFrameType:
// second frame
test.Assert(t, frameNum == 1, f)
frameNum++
case trailerFrameType:
// second frame
test.Assert(t, frameNum == 2, f)
frameNum++
default:
t.Fatalf("should not send other frame, frame: %+v, frameNum: %d", f, frameNum)
}
return nil
},
})
err := srvSt.SendMsg(context.Background(), new(testResponse))
test.Assert(t, err == nil, err)
err = srvSt.CloseSend(nil)
test.Assert(t, err == nil, err)
test.Assert(t, frameNum == 3, frameNum)
})
t.Run("Send -> Send -> CloseSend", func(t *testing.T) {
var frameNum int
srvSt := newTestServerStreamWithStreamWriter(mockStreamWriter{
writeFrameFunc: func(f *Frame) error {
switch f.typ {
case headerFrameType:
// first frame
test.Assert(t, frameNum == 0, f)
frameNum++
case dataFrameType:
// second frame
test.Assert(t, frameNum == 1 || frameNum == 2, f)
frameNum++
case trailerFrameType:
// second frame
test.Assert(t, frameNum == 3, f)
frameNum++
default:
t.Fatalf("should not send other frame, frame: %+v, frameNum: %d", f, frameNum)
}
return nil
},
})
err := srvSt.SendMsg(context.Background(), new(testResponse))
test.Assert(t, err == nil, err)
err = srvSt.SendMsg(context.Background(), new(testResponse))
test.Assert(t, err == nil, err)
err = srvSt.CloseSend(nil)
test.Assert(t, err == nil, err)
test.Assert(t, frameNum == 4, frameNum)
})
t.Run("SendHeader -> CloseSend", func(t *testing.T) {
var frameNum int
srvSt := newTestServerStreamWithStreamWriter(mockStreamWriter{
writeFrameFunc: func(f *Frame) error {
switch f.typ {
case headerFrameType:
// first frame
test.Assert(t, frameNum == 0, f)
frameNum++
case trailerFrameType:
// second frame
test.Assert(t, frameNum == 1, f)
frameNum++
default:
t.Fatalf("should not send other frame, frame: %+v, frameNum: %d", f, frameNum)
}
return nil
},
})
err := srvSt.SendHeader(streaming.Header{"testKey": "testVal"})
test.Assert(t, err == nil)
err = srvSt.CloseSend(nil)
test.Assert(t, err == nil, err)
test.Assert(t, frameNum == 2, frameNum)
})
t.Run("SendHeader -> Send -> CloseSend", func(t *testing.T) {
var frameNum int
srvSt := newTestServerStreamWithStreamWriter(mockStreamWriter{
writeFrameFunc: func(f *Frame) error {
switch f.typ {
case headerFrameType:
// first frame
test.Assert(t, frameNum == 0, f)
frameNum++
case dataFrameType:
// second frame
test.Assert(t, frameNum == 1, f)
frameNum++
case trailerFrameType:
// third frame
test.Assert(t, frameNum == 2, f)
frameNum++
default:
t.Fatalf("should not send other frame, frame: %+v, frameNum: %d", f, frameNum)
}
return nil
},
})
err := srvSt.SendHeader(streaming.Header{"testKey": "testVal"})
test.Assert(t, err == nil)
err = srvSt.SendMsg(context.Background(), new(testResponse))
test.Assert(t, err == nil, err)
err = srvSt.CloseSend(nil)
test.Assert(t, err == nil, err)
test.Assert(t, frameNum == 3, frameNum)
})
}
11 changes: 10 additions & 1 deletion pkg/remote/trans/ttstream/test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,22 @@ import (
"github.com/cloudwego/kitex/pkg/streaming"
)

type mockStreamWriter struct{}
type mockStreamWriter struct {
writeFrameFunc func(f *Frame) error
closeStreamFunc func(sid int32) error
}

func (m mockStreamWriter) WriteFrame(f *Frame) error {
if m.writeFrameFunc != nil {
return m.writeFrameFunc(f)
}
return nil
}

func (m mockStreamWriter) CloseStream(sid int32) error {
if m.closeStreamFunc != nil {
return m.closeStreamFunc(sid)
}
return nil
}

Expand Down