diff --git a/sonora/asgi.py b/sonora/asgi.py index ad1837c..57e7483 100644 --- a/sonora/asgi.py +++ b/sonora/asgi.py @@ -211,11 +211,11 @@ async def _do_unary_response( else: message_data = b"" - trailers = [(b"grpc-status", str(context.code.value[0]).encode())] + trailers = [("grpc-status", str(context.code.value[0]))] if context.details: trailers.append( - (b"grpc-message", quote(context.details.encode("utf8")).encode("ascii")) + ("grpc-message", quote(context.details.encode("utf8"))) ) if context._trailing_metadata: diff --git a/tests/conftest.py b/tests/conftest.py index 2f29874..d3d6767 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -85,6 +85,10 @@ def HelloStreamMetadata(self, request, context): class AsyncGreeter(helloworld_pb2_grpc.GreeterServicer): async def SayHello(self, request, context): + if request.HasField("response_status"): + context.set_code(request.response_status.code) + context.set_details(request.response_status.message) + return helloworld_pb2.HelloReply(message=FORMAT_STRING.format(request=request)) async def SayHelloSlowly(self, request, context): @@ -93,6 +97,10 @@ async def SayHelloSlowly(self, request, context): for char in message: yield helloworld_pb2.HelloReply(message=char) + if request.HasField("response_status"): + context.set_code(request.response_status.code) + context.set_details(request.response_status.message) + async def Abort(self, request, context): await context.abort(grpc.StatusCode.ABORTED, "test aborting") diff --git a/tests/protos/tests/helloworld.proto b/tests/protos/tests/helloworld.proto index b4eb348..aca1ca0 100644 --- a/tests/protos/tests/helloworld.proto +++ b/tests/protos/tests/helloworld.proto @@ -23,9 +23,19 @@ service Greeter { rpc StreamTimeout(TimeoutRequest) returns (stream google.protobuf.Empty) {} } +// A protobuf representation for grpc status. This is used by test +// clients to specify a status that the server should attempt to return. +message EchoStatus { + int32 code = 1; + string message = 2; +} + // The request message containing the user's name. message HelloRequest { string name = 1; + + // Status to set at the end of the RPC. + EchoStatus response_status = 2; } // The response message containing the greetings diff --git a/tests/test_asgi_helloworld.py b/tests/test_asgi_helloworld.py index 39e045b..58a1151 100644 --- a/tests/test_asgi_helloworld.py +++ b/tests/test_asgi_helloworld.py @@ -96,6 +96,71 @@ async def test_helloworld_unary_metadata_binary(asgi_greeter): assert dict(trailing_metadata)["trailing-metadata-key-bin"] == repr(b"\0\1\2\3") +@pytest.mark.asyncio +async def test_unary_trailing_status_no_message(asgi_greeter): + for name in ("you", "world"): + request = helloworld_pb2.HelloRequest(name=name) + call = asgi_greeter.SayHello(request) + _response = await call + + trailers = dict(await call.trailing_metadata()) + assert "grpc-status" in trailers + assert trailers["grpc-status"] == "0" + assert "grpc-message" not in trailers + + +@pytest.mark.asyncio +async def test_unary_trailing_status_message(asgi_greeter): + print(helloworld_pb2.__file__) + for name in ("you", "world"): + request = helloworld_pb2.HelloRequest( + response_status=helloworld_pb2.EchoStatus( + code=grpc.StatusCode.OK.value[0], + message="OK", + ) + ) + call = asgi_greeter.SayHello(request) + _response = await call + + trailers = dict(await call.trailing_metadata()) + assert "grpc-status" in trailers + assert trailers["grpc-status"] == "0" + assert trailers["grpc-message"] == "OK" + + +@pytest.mark.asyncio +async def test_streaming_trailing_status_no_message(asgi_greeter): + for name in ("you", "world"): + request = helloworld_pb2.HelloRequest(name=name) + call = asgi_greeter.SayHelloSlowly(request) + async for _response in call: + pass + + trailers = dict(await call.trailing_metadata()) + assert "grpc-status" in trailers + assert trailers["grpc-status"] == "0" + assert "grpc-message" not in trailers + + +@pytest.mark.asyncio +async def test_streaming_trailing_status_message(asgi_greeter): + for name in ("you", "world"): + request = helloworld_pb2.HelloRequest( + response_status=helloworld_pb2.EchoStatus( + code=grpc.StatusCode.OK.value[0], + message="OK", + ) + ) + call = asgi_greeter.SayHelloSlowly(request) + async for _response in call: + pass + + trailers = dict(await call.trailing_metadata()) + assert "grpc-status" in trailers + assert trailers["grpc-status"] == "0" + assert trailers["grpc-message"] == "OK" + + @pytest.mark.asyncio async def test_helloworld_stream_metadata_ascii(asgi_greeter): request = helloworld_pb2.HelloRequest(name="metadata-key")