From 75d27e96bec230066b378e545f007246414b69c1 Mon Sep 17 00:00:00 2001 From: Fredrik Medley Date: Tue, 2 Dec 2025 09:01:11 +0100 Subject: [PATCH 1/9] Add generic gRPC stream forwarding To be able to pass Bazel's build event stream (BES) to the same DNS name, without having to add an extra L7 router in front of the bb-storage frontend, add a configuration to forward specific gRPC methods to other backends. No authorization is possible on the passed through messages because Buildbarn has no knowledge about the semantics of the forwarded messages. The gRPC reflection service has also been extended to forward requests that cannot be resolved locally. --- MODULE.bazel | 1 + go.mod | 1 + go.sum | 2 + pkg/grpc/BUILD.bazel | 10 ++ pkg/grpc/reflection_relay.go | 83 +++++++++ pkg/grpc/routing_stream_forwarder.go | 36 ++++ pkg/grpc/server.go | 32 +++- pkg/grpc/server_transport_stream_context.go | 18 ++ pkg/grpc/simple_stream_forwarder.go | 88 +++++++++ pkg/proto/configuration/grpc/grpc.pb.go | 187 +++++++++++++------- pkg/proto/configuration/grpc/grpc.proto | 22 +++ 11 files changed, 417 insertions(+), 63 deletions(-) create mode 100644 pkg/grpc/reflection_relay.go create mode 100644 pkg/grpc/routing_stream_forwarder.go create mode 100644 pkg/grpc/server_transport_stream_context.go create mode 100644 pkg/grpc/simple_stream_forwarder.go diff --git a/MODULE.bazel b/MODULE.bazel index 5dfbb32b..594f7ca8 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -57,6 +57,7 @@ use_repo( "com_github_gorilla_mux", "com_github_grpc_ecosystem_go_grpc_middleware", "com_github_grpc_ecosystem_go_grpc_prometheus", + "com_github_jhump_protoreflect_v2", "com_github_jmespath_go_jmespath", "com_github_klauspost_compress", "com_github_lazybeaver_xorshift", diff --git a/go.mod b/go.mod index a56ac164..c2411c17 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( github.com/gorilla/mux v1.8.1 github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 + github.com/jhump/protoreflect/v2 v2.0.0-beta.2 github.com/jmespath/go-jmespath v0.4.0 github.com/klauspost/compress v1.18.1 github.com/lazybeaver/xorshift v0.0.0-20170702203709-ce511d4823dd diff --git a/go.sum b/go.sum index a151f158..e48aaa30 100644 --- a/go.sum +++ b/go.sum @@ -198,6 +198,8 @@ github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 h1:Ovs26xHkKqVztRpIrF/92Bcuy github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4= +github.com/jhump/protoreflect/v2 v2.0.0-beta.2 h1:qZU+rEZUOYTz1Bnhi3xbwn+VxdXkLVeEpAeZzVXLY88= +github.com/jhump/protoreflect/v2 v2.0.0-beta.2/go.mod h1:4tnOYkB/mq7QTyS3YKtVtNrJv4Psqout8HA1U+hZtgM= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= diff --git a/pkg/grpc/BUILD.bazel b/pkg/grpc/BUILD.bazel index 600ea5ec..4035ad17 100644 --- a/pkg/grpc/BUILD.bazel +++ b/pkg/grpc/BUILD.bazel @@ -26,9 +26,13 @@ go_library( "peer_transport_credentials_linux.go", "proto_trace_attributes_extractor.go", "proxy_dialer.go", + "reflection_relay.go", "request_headers_authenticator.go", "request_metadata_tracing_interceptor.go", + "routing_stream_forwarder.go", "server.go", + "server_transport_stream_context.go", + "simple_stream_forwarder.go", "tls_client_certificate_authenticator.go", ], importpath = "github.com/buildbarn/bb-storage/pkg/grpc", @@ -49,6 +53,8 @@ go_library( "@bazel_remote_apis//build/bazel/remote/execution/v2:remote_execution_go_proto", "@com_github_grpc_ecosystem_go_grpc_middleware//:go-grpc-middleware", "@com_github_grpc_ecosystem_go_grpc_prometheus//:go-grpc-prometheus", + "@com_github_jhump_protoreflect_v2//grpcreflect", + "@com_github_jhump_protoreflect_v2//protoresolve", "@io_opentelemetry_go_contrib_instrumentation_google_golang_org_grpc_otelgrpc//:otelgrpc", "@io_opentelemetry_go_otel//attribute", "@io_opentelemetry_go_otel_trace//:trace", @@ -63,11 +69,15 @@ go_library( "@org_golang_google_grpc//metadata", "@org_golang_google_grpc//peer", "@org_golang_google_grpc//reflection", + "@org_golang_google_grpc//reflection/grpc_reflection_v1", + "@org_golang_google_grpc//reflection/grpc_reflection_v1alpha", "@org_golang_google_grpc//status", "@org_golang_google_grpc_security_advancedtls//:advancedtls", "@org_golang_google_protobuf//encoding/prototext", "@org_golang_google_protobuf//proto", "@org_golang_google_protobuf//reflect/protoreflect", + "@org_golang_google_protobuf//types/known/emptypb", + "@org_golang_x_sync//errgroup", "@org_golang_x_sync//semaphore", ] + select({ "@rules_go//go/platform:android": [ diff --git a/pkg/grpc/reflection_relay.go b/pkg/grpc/reflection_relay.go new file mode 100644 index 00000000..623690a2 --- /dev/null +++ b/pkg/grpc/reflection_relay.go @@ -0,0 +1,83 @@ +package grpc + +import ( + "context" + "maps" + "strings" + + "github.com/buildbarn/bb-storage/pkg/program" + grpcpb "github.com/buildbarn/bb-storage/pkg/proto/configuration/grpc" + "github.com/buildbarn/bb-storage/pkg/util" + "github.com/jhump/protoreflect/v2/grpcreflect" + "github.com/jhump/protoreflect/v2/protoresolve" + v1reflectiongrpc "google.golang.org/grpc/reflection/grpc_reflection_v1" + v1alphareflectiongrpc "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/reflection" + "google.golang.org/grpc/status" +) + +type combinedServiceInfoProvider struct { + server reflection.ServiceInfoProvider + extraServices map[string]grpc.ServiceInfo +} + +var _ reflection.ServiceInfoProvider = (*combinedServiceInfoProvider)(nil) + +// GetServiceInfo returns the currently available services, which might have +// changed since the creation of this reflection server. +func (p *combinedServiceInfoProvider) GetServiceInfo() map[string]grpc.ServiceInfo { + services := make(map[string]grpc.ServiceInfo) + maps.Copy(services, p.extraServices) + maps.Copy(services, p.server.GetServiceInfo()) + return services +} + +// registerReflection registers the google.golang.org/grpc/reflection/ service +// on a grpc.Server and calls remote backends in case for relayed services. The +// connections to the backend will run with the backendCtx. +func registerReflection(backendCtx context.Context, s *grpc.Server, serverRelayConfiguration []*grpcpb.ServerRelayConfiguration, group program.Group, grpcClientFactory ClientFactory) error { + // Accumulate all the service names. + relayServices := make(map[string]grpc.ServiceInfo) + for _, relay := range serverRelayConfiguration { + for _, serviceMethod := range relay.Methods { + if !strings.HasPrefix(serviceMethod, "/") { + return status.Errorf(codes.InvalidArgument, "Malformed service method name %q, should start with '/'", serviceMethod) + } + pos := strings.LastIndex(serviceMethod, "/") + if pos == -1 || pos == 0 { + return status.Errorf(codes.InvalidArgument, "Malformed name %q, expected '/' between service and method", serviceMethod) + } + serviceName := serviceMethod[1:pos] + // According to ServiceInfoProvider docs for ServerOptions.Services, + // the reflection service is only interested in the service names. + relayServices[serviceName] = grpc.ServiceInfo{} + } + } + + // Make a combined descriptor and extension resolver. + reflectionBackends := []protoresolve.Resolver{} + for relayIdx, relay := range serverRelayConfiguration { + grpcClient, err := grpcClientFactory.NewClientFromConfiguration(relay.Endpoint, group) + if err != nil { + return util.StatusWrapf(err, "Failed to create relay RPC client %d", relayIdx+1) + } + resolver := grpcreflect.NewClientAuto(backendCtx, grpcClient).AsResolver() + reflectionBackends = append(reflectionBackends, resolver) + } + combinedRemoteResolver := protoresolve.Combine(reflectionBackends...) + + serverOptions := reflection.ServerOptions{ + Services: &combinedServiceInfoProvider{ + server: s, + extraServices: relayServices, + }, + DescriptorResolver: combinedRemoteResolver, + ExtensionResolver: protoresolve.TypesFromDescriptorPool(combinedRemoteResolver), + } + v1reflectiongrpc.RegisterServerReflectionServer(s, reflection.NewServerV1(serverOptions)) + v1alphareflectiongrpc.RegisterServerReflectionServer(s, reflection.NewServer(serverOptions)) + return nil +} diff --git a/pkg/grpc/routing_stream_forwarder.go b/pkg/grpc/routing_stream_forwarder.go new file mode 100644 index 00000000..899c9080 --- /dev/null +++ b/pkg/grpc/routing_stream_forwarder.go @@ -0,0 +1,36 @@ +package grpc + +import ( + "fmt" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// RoutingStreamForwarder forwards gRPC streams to different backends depending +// on the method being invoked. +type RoutingStreamForwarder struct { + // RouteTable maps to the grpc.StreamHandler to be called. The key is the + // combined gRPC service and method name. + RouteTable map[string]grpc.StreamHandler +} + +// NewRoutingStreamForwarder creates a RoutingStreamForwarder which routes gRPC +// streams based on the invoked gRPC method name. +func NewRoutingStreamForwarder() *RoutingStreamForwarder { + return &RoutingStreamForwarder{ + RouteTable: make(map[string]grpc.StreamHandler), + } +} + +// HandleStream is the implementation of the grpc.StreamHandler interface to +// process a gRPC stream, forwarding it according to the RouteTable. +func (s *RoutingStreamForwarder) HandleStream(srv any, stream grpc.ServerStream) error { + method := MustStreamMethodFromContext(stream.Context()) + if streamHandler, ok := s.RouteTable[method]; ok { + return streamHandler(srv, stream) + } + errDesc := fmt.Sprintf("no route for method %v", method) + return status.Error(codes.Unimplemented, errDesc) +} diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 30f97086..9eafffa6 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -9,7 +9,7 @@ import ( configuration "github.com/buildbarn/bb-storage/pkg/proto/configuration/grpc" grpcpb "github.com/buildbarn/bb-storage/pkg/proto/configuration/grpc" "github.com/buildbarn/bb-storage/pkg/util" - "github.com/grpc-ecosystem/go-grpc-prometheus" + grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -18,7 +18,6 @@ import ( "google.golang.org/grpc/health" "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/keepalive" - "google.golang.org/grpc/reflection" "google.golang.org/grpc/status" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" @@ -147,6 +146,14 @@ func NewServersFromConfigurationAndServe(configurations []*configuration.ServerC })) } + if len(configuration.Relays) != 0 { + handler, err := newStreamRoutingFromConfiguration(configuration.Relays, grpcClientFactory, group) + if err != nil { + return util.StatusWrap(err, "Failed to create authenticator RPC client") + } + serverOptions = append(serverOptions, grpc.UnknownServiceHandler(handler)) + } + // Create server. s := grpc.NewServer(serverOptions...) stopFunc := s.Stop @@ -162,7 +169,9 @@ func NewServersFromConfigurationAndServe(configurations []*configuration.ServerC // Enable default services. grpc_prometheus.Register(s) - reflection.Register(s) + if err := registerReflection(context.Background(), s, configuration.Relays, group, grpcClientFactory); err != nil { + return util.StatusWrap(err, "Failed to create reflection service") + } h := health.NewServer() grpc_health_v1.RegisterHealthServer(s, h) // TODO: Construct an API for the caller to indicate @@ -208,3 +217,20 @@ func NewServersFromConfigurationAndServe(configurations []*configuration.ServerC } return nil } + +func newStreamRoutingFromConfiguration(serverRelayConfiguration []*grpcpb.ServerRelayConfiguration, grpcClientFactory ClientFactory, group program.Group) (grpc.StreamHandler, error) { + handler := NewRoutingStreamForwarder() + for _, relay := range serverRelayConfiguration { + grpcClient, err := grpcClientFactory.NewClientFromConfiguration(relay.GetEndpoint(), group) + if err != nil { + return nil, util.StatusWrap(err, "Failed to create authenticator RPC client") + } + for _, method := range relay.GetMethods() { + if _, ok := handler.RouteTable[method]; ok { + return nil, status.Errorf(codes.InvalidArgument, "Duplicated relay for %v", method) + } + handler.RouteTable[method] = NewSimpleStreamForwarder(grpcClient) + } + } + return handler.HandleStream, nil +} diff --git a/pkg/grpc/server_transport_stream_context.go b/pkg/grpc/server_transport_stream_context.go new file mode 100644 index 00000000..c3347cf7 --- /dev/null +++ b/pkg/grpc/server_transport_stream_context.go @@ -0,0 +1,18 @@ +package grpc + +import ( + "context" + + "google.golang.org/grpc" +) + +// MustStreamMethodFromContext returns the service and method name for the ongoing gRPC stream. +// It will panic if the given context has no grpc.ServerTransportStream associated with it +// (which implies it is not an RPC invocation context). +func MustStreamMethodFromContext(ctx context.Context) string { + transportStream := grpc.ServerTransportStreamFromContext(ctx) + if transportStream == nil { + panic("No grpc.ServerTransportStream in context") + } + return transportStream.Method() +} diff --git a/pkg/grpc/simple_stream_forwarder.go b/pkg/grpc/simple_stream_forwarder.go new file mode 100644 index 00000000..5016d299 --- /dev/null +++ b/pkg/grpc/simple_stream_forwarder.go @@ -0,0 +1,88 @@ +package grpc + +import ( + "io" + + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/emptypb" +) + +// NewSimpleStreamForwarder creates a grpc.StreamHandler that forwards gRPC +// calls to a grpc.ClientConnInterface backend. +func NewSimpleStreamForwarder(client grpc.ClientConnInterface) grpc.StreamHandler { + forwarder := &simpleStreamForwarder{ + backend: client, + } + return forwarder.HandleStream +} + +type simpleStreamForwarder struct { + backend grpc.ClientConnInterface +} + +// HandleStream creates a new stream to the backend. Requests from +// incomingStream are forwarded to the backend stream and responses from the +// backend stream are sent back in the incomingStream. +func (s *simpleStreamForwarder) HandleStream(srv any, incomingStream grpc.ServerStream) error { + method := MustStreamMethodFromContext(incomingStream.Context()) + desc := grpc.StreamDesc{ + // According to grpc.StreamDesc documentation, StreamName and Handler + // are only used when registering handlers on a server. + StreamName: "", + Handler: nil, + // Streaming behaviour is wanted, single message is treated the same on + // transport level, the application just closes the stream after the + // first message. + ServerStreams: true, + ClientStreams: true, + } + group, groupCtx := errgroup.WithContext(incomingStream.Context()) + outgoingStream, err := s.backend.NewStream(groupCtx, &desc, method) + if err != nil { + return err + } + go func() { + for { + msg := &emptypb.Empty{} + if err := incomingStream.RecvMsg(msg); err != nil { + if err == io.EOF { + // Let's to receive on outgoingStream, so don't cancel + // grouptCtx. + outgoingStream.CloseSend() + return + } + // Cancel groupCtx immediately. + group.Go(func() error { return err }) + return + } + if err := outgoingStream.SendMsg(msg); err != nil { + if err == io.EOF { + // The error will be returned by outgoingStream.RecvMsg(), + // no need to cancel groupCtx now. + return + } + // Cancel groupCtx immediately. + group.Go(func() error { return err }) + return + } + } + }() + group.Go(func() error { + for { + msg := &emptypb.Empty{} + if err := outgoingStream.RecvMsg(msg); err != nil { + if err == io.EOF { + return nil + } + return err + } + if err := incomingStream.SendMsg(msg); err != nil { + return err + } + } + }) + // group.Wait() may block a bit on incomingStream.SendMsg(), but that + // shouldn't be for too long. + return group.Wait() +} diff --git a/pkg/proto/configuration/grpc/grpc.pb.go b/pkg/proto/configuration/grpc/grpc.pb.go index d21c8a15..f66596eb 100644 --- a/pkg/proto/configuration/grpc/grpc.pb.go +++ b/pkg/proto/configuration/grpc/grpc.pb.go @@ -233,6 +233,7 @@ type ServerConfiguration struct { // *ServerConfiguration_Tls // *ServerConfiguration_Alts TransportSecurity isServerConfiguration_TransportSecurity `protobuf_oneof:"transport_security"` + Relays []*ServerRelayConfiguration `protobuf:"bytes,14,rep,name=relays,proto3" json:"relays,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -369,6 +370,13 @@ func (x *ServerConfiguration) GetAlts() *emptypb.Empty { return nil } +func (x *ServerConfiguration) GetRelays() []*ServerRelayConfiguration { + if x != nil { + return x.Relays + } + return nil +} + type isServerConfiguration_TransportSecurity interface { isServerConfiguration_TransportSecurity() } @@ -967,6 +975,58 @@ func (x *TracingMethodConfiguration) GetAttributesFromFirstResponseMessage() []s return nil } +type ServerRelayConfiguration struct { + state protoimpl.MessageState `protogen:"open.v1"` + Endpoint *ClientConfiguration `protobuf:"bytes,1,opt,name=endpoint,proto3" json:"endpoint,omitempty"` + Methods []string `protobuf:"bytes,2,rep,name=methods,proto3" json:"methods,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ServerRelayConfiguration) Reset() { + *x = ServerRelayConfiguration{} + mi := &file_github_com_buildbarn_bb_storage_pkg_proto_configuration_grpc_grpc_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ServerRelayConfiguration) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ServerRelayConfiguration) ProtoMessage() {} + +func (x *ServerRelayConfiguration) ProtoReflect() protoreflect.Message { + mi := &file_github_com_buildbarn_bb_storage_pkg_proto_configuration_grpc_grpc_proto_msgTypes[11] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ServerRelayConfiguration.ProtoReflect.Descriptor instead. +func (*ServerRelayConfiguration) Descriptor() ([]byte, []int) { + return file_github_com_buildbarn_bb_storage_pkg_proto_configuration_grpc_grpc_proto_rawDescGZIP(), []int{11} +} + +func (x *ServerRelayConfiguration) GetEndpoint() *ClientConfiguration { + if x != nil { + return x.Endpoint + } + return nil +} + +func (x *ServerRelayConfiguration) GetMethods() []string { + if x != nil { + return x.Methods + } + return nil +} + type ClientConfiguration_HeaderValues struct { state protoimpl.MessageState `protogen:"open.v1"` Header string `protobuf:"bytes,1,opt,name=header,proto3" json:"header,omitempty"` @@ -977,7 +1037,7 @@ type ClientConfiguration_HeaderValues struct { func (x *ClientConfiguration_HeaderValues) Reset() { *x = ClientConfiguration_HeaderValues{} - mi := &file_github_com_buildbarn_bb_storage_pkg_proto_configuration_grpc_grpc_proto_msgTypes[11] + mi := &file_github_com_buildbarn_bb_storage_pkg_proto_configuration_grpc_grpc_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -989,7 +1049,7 @@ func (x *ClientConfiguration_HeaderValues) String() string { func (*ClientConfiguration_HeaderValues) ProtoMessage() {} func (x *ClientConfiguration_HeaderValues) ProtoReflect() protoreflect.Message { - mi := &file_github_com_buildbarn_bb_storage_pkg_proto_configuration_grpc_grpc_proto_msgTypes[11] + mi := &file_github_com_buildbarn_bb_storage_pkg_proto_configuration_grpc_grpc_proto_msgTypes[12] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1046,7 +1106,7 @@ const file_github_com_buildbarn_bb_storage_pkg_proto_configuration_grpc_grpc_pro "\x1cClientKeepaliveConfiguration\x12-\n" + "\x04time\x18\x01 \x01(\v2\x19.google.protobuf.DurationR\x04time\x123\n" + "\atimeout\x18\x02 \x01(\v2\x19.google.protobuf.DurationR\atimeout\x122\n" + - "\x15permit_without_stream\x18\x03 \x01(\bR\x13permitWithoutStream\"\xbd\b\n" + + "\x15permit_without_stream\x18\x03 \x01(\bR\x13permitWithoutStream\"\x8d\t\n" + "\x13ServerConfiguration\x12)\n" + "\x10listen_addresses\x18\x01 \x03(\tR\x0flistenAddresses\x12!\n" + "\flisten_paths\x18\x02 \x03(\tR\vlistenPaths\x12g\n" + @@ -1061,7 +1121,8 @@ const file_github_com_buildbarn_bb_storage_pkg_proto_configuration_grpc_grpc_pro "\x14keepalive_parameters\x18\v \x01(\v27.buildbarn.configuration.grpc.ServerKeepaliveParametersR\x13keepaliveParameters\x12'\n" + "\x0fstop_gracefully\x18\f \x01(\bR\x0estopGracefully\x12D\n" + "\x03tls\x18\x03 \x01(\v20.buildbarn.configuration.tls.ServerConfigurationH\x00R\x03tls\x12,\n" + - "\x04alts\x18\r \x01(\v2\x16.google.protobuf.EmptyH\x00R\x04alts\x1at\n" + + "\x04alts\x18\r \x01(\v2\x16.google.protobuf.EmptyH\x00R\x04alts\x12N\n" + + "\x06relays\x18\x0e \x03(\v26.buildbarn.configuration.grpc.ServerRelayConfigurationR\x06relays\x1at\n" + "\fTracingEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12N\n" + "\x05value\x18\x02 \x01(\v28.buildbarn.configuration.grpc.TracingMethodConfigurationR\x05value:\x028\x01B\x14\n" + @@ -1101,7 +1162,10 @@ const file_github_com_buildbarn_bb_storage_pkg_proto_configuration_grpc_grpc_pro "\x18cache_replacement_policy\x18\x05 \x01(\x0e28.buildbarn.configuration.eviction.CacheReplacementPolicyR\x16cacheReplacementPolicy\"\xc2\x01\n" + "\x1aTracingMethodConfiguration\x12P\n" + "%attributes_from_first_request_message\x18\x01 \x03(\tR!attributesFromFirstRequestMessage\x12R\n" + - "&attributes_from_first_response_message\x18\x02 \x03(\tR\"attributesFromFirstResponseMessageB>ZZ buildbarn.configuration.tls.ClientConfiguration + 15, // 0: buildbarn.configuration.grpc.ClientConfiguration.tls:type_name -> buildbarn.configuration.tls.ClientConfiguration 1, // 1: buildbarn.configuration.grpc.ClientConfiguration.keepalive:type_name -> buildbarn.configuration.grpc.ClientKeepaliveConfiguration - 11, // 2: buildbarn.configuration.grpc.ClientConfiguration.add_metadata:type_name -> buildbarn.configuration.grpc.ClientConfiguration.HeaderValues - 15, // 3: buildbarn.configuration.grpc.ClientConfiguration.add_metadata_jmespath_expression:type_name -> buildbarn.configuration.jmespath.Expression - 16, // 4: buildbarn.configuration.grpc.ClientConfiguration.oauth2:type_name -> buildbarn.configuration.http.client.OAuth2Configuration - 12, // 5: buildbarn.configuration.grpc.ClientConfiguration.tracing:type_name -> buildbarn.configuration.grpc.ClientConfiguration.TracingEntry - 17, // 6: buildbarn.configuration.grpc.ClientConfiguration.default_service_config:type_name -> google.protobuf.Struct - 18, // 7: buildbarn.configuration.grpc.ClientKeepaliveConfiguration.time:type_name -> google.protobuf.Duration - 18, // 8: buildbarn.configuration.grpc.ClientKeepaliveConfiguration.timeout:type_name -> google.protobuf.Duration + 12, // 2: buildbarn.configuration.grpc.ClientConfiguration.add_metadata:type_name -> buildbarn.configuration.grpc.ClientConfiguration.HeaderValues + 16, // 3: buildbarn.configuration.grpc.ClientConfiguration.add_metadata_jmespath_expression:type_name -> buildbarn.configuration.jmespath.Expression + 17, // 4: buildbarn.configuration.grpc.ClientConfiguration.oauth2:type_name -> buildbarn.configuration.http.client.OAuth2Configuration + 13, // 5: buildbarn.configuration.grpc.ClientConfiguration.tracing:type_name -> buildbarn.configuration.grpc.ClientConfiguration.TracingEntry + 18, // 6: buildbarn.configuration.grpc.ClientConfiguration.default_service_config:type_name -> google.protobuf.Struct + 19, // 7: buildbarn.configuration.grpc.ClientKeepaliveConfiguration.time:type_name -> google.protobuf.Duration + 19, // 8: buildbarn.configuration.grpc.ClientKeepaliveConfiguration.timeout:type_name -> google.protobuf.Duration 5, // 9: buildbarn.configuration.grpc.ServerConfiguration.authentication_policy:type_name -> buildbarn.configuration.grpc.AuthenticationPolicy 3, // 10: buildbarn.configuration.grpc.ServerConfiguration.keepalive_enforcement_policy:type_name -> buildbarn.configuration.grpc.ServerKeepaliveEnforcementPolicy - 13, // 11: buildbarn.configuration.grpc.ServerConfiguration.tracing:type_name -> buildbarn.configuration.grpc.ServerConfiguration.TracingEntry + 14, // 11: buildbarn.configuration.grpc.ServerConfiguration.tracing:type_name -> buildbarn.configuration.grpc.ServerConfiguration.TracingEntry 4, // 12: buildbarn.configuration.grpc.ServerConfiguration.keepalive_parameters:type_name -> buildbarn.configuration.grpc.ServerKeepaliveParameters - 19, // 13: buildbarn.configuration.grpc.ServerConfiguration.tls:type_name -> buildbarn.configuration.tls.ServerConfiguration - 20, // 14: buildbarn.configuration.grpc.ServerConfiguration.alts:type_name -> google.protobuf.Empty - 18, // 15: buildbarn.configuration.grpc.ServerKeepaliveEnforcementPolicy.min_time:type_name -> google.protobuf.Duration - 18, // 16: buildbarn.configuration.grpc.ServerKeepaliveParameters.max_connection_idle:type_name -> google.protobuf.Duration - 18, // 17: buildbarn.configuration.grpc.ServerKeepaliveParameters.max_connection_age:type_name -> google.protobuf.Duration - 18, // 18: buildbarn.configuration.grpc.ServerKeepaliveParameters.max_connection_age_grace:type_name -> google.protobuf.Duration - 18, // 19: buildbarn.configuration.grpc.ServerKeepaliveParameters.time:type_name -> google.protobuf.Duration - 18, // 20: buildbarn.configuration.grpc.ServerKeepaliveParameters.timeout:type_name -> google.protobuf.Duration - 21, // 21: buildbarn.configuration.grpc.AuthenticationPolicy.allow:type_name -> buildbarn.auth.AuthenticationMetadata - 6, // 22: buildbarn.configuration.grpc.AuthenticationPolicy.any:type_name -> buildbarn.configuration.grpc.AnyAuthenticationPolicy - 7, // 23: buildbarn.configuration.grpc.AuthenticationPolicy.all:type_name -> buildbarn.configuration.grpc.AllAuthenticationPolicy - 22, // 24: buildbarn.configuration.grpc.AuthenticationPolicy.tls_client_certificate:type_name -> buildbarn.configuration.x509.ClientCertificateVerifierConfiguration - 23, // 25: buildbarn.configuration.grpc.AuthenticationPolicy.jwt:type_name -> buildbarn.configuration.jwt.AuthorizationHeaderParserConfiguration - 15, // 26: buildbarn.configuration.grpc.AuthenticationPolicy.peer_credentials_jmespath_expression:type_name -> buildbarn.configuration.jmespath.Expression - 9, // 27: buildbarn.configuration.grpc.AuthenticationPolicy.remote:type_name -> buildbarn.configuration.grpc.RemoteAuthenticationPolicy - 5, // 28: buildbarn.configuration.grpc.AnyAuthenticationPolicy.policies:type_name -> buildbarn.configuration.grpc.AuthenticationPolicy - 5, // 29: buildbarn.configuration.grpc.AllAuthenticationPolicy.policies:type_name -> buildbarn.configuration.grpc.AuthenticationPolicy - 15, // 30: buildbarn.configuration.grpc.TLSClientCertificateAuthenticationPolicy.validation_jmespath_expression:type_name -> buildbarn.configuration.jmespath.Expression - 15, // 31: buildbarn.configuration.grpc.TLSClientCertificateAuthenticationPolicy.metadata_extraction_jmespath_expression:type_name -> buildbarn.configuration.jmespath.Expression - 0, // 32: buildbarn.configuration.grpc.RemoteAuthenticationPolicy.endpoint:type_name -> buildbarn.configuration.grpc.ClientConfiguration - 24, // 33: buildbarn.configuration.grpc.RemoteAuthenticationPolicy.scope:type_name -> google.protobuf.Value - 25, // 34: buildbarn.configuration.grpc.RemoteAuthenticationPolicy.cache_replacement_policy:type_name -> buildbarn.configuration.eviction.CacheReplacementPolicy - 10, // 35: buildbarn.configuration.grpc.ClientConfiguration.TracingEntry.value:type_name -> buildbarn.configuration.grpc.TracingMethodConfiguration - 10, // 36: buildbarn.configuration.grpc.ServerConfiguration.TracingEntry.value:type_name -> buildbarn.configuration.grpc.TracingMethodConfiguration - 37, // [37:37] is the sub-list for method output_type - 37, // [37:37] is the sub-list for method input_type - 37, // [37:37] is the sub-list for extension type_name - 37, // [37:37] is the sub-list for extension extendee - 0, // [0:37] is the sub-list for field type_name + 20, // 13: buildbarn.configuration.grpc.ServerConfiguration.tls:type_name -> buildbarn.configuration.tls.ServerConfiguration + 21, // 14: buildbarn.configuration.grpc.ServerConfiguration.alts:type_name -> google.protobuf.Empty + 11, // 15: buildbarn.configuration.grpc.ServerConfiguration.relays:type_name -> buildbarn.configuration.grpc.ServerRelayConfiguration + 19, // 16: buildbarn.configuration.grpc.ServerKeepaliveEnforcementPolicy.min_time:type_name -> google.protobuf.Duration + 19, // 17: buildbarn.configuration.grpc.ServerKeepaliveParameters.max_connection_idle:type_name -> google.protobuf.Duration + 19, // 18: buildbarn.configuration.grpc.ServerKeepaliveParameters.max_connection_age:type_name -> google.protobuf.Duration + 19, // 19: buildbarn.configuration.grpc.ServerKeepaliveParameters.max_connection_age_grace:type_name -> google.protobuf.Duration + 19, // 20: buildbarn.configuration.grpc.ServerKeepaliveParameters.time:type_name -> google.protobuf.Duration + 19, // 21: buildbarn.configuration.grpc.ServerKeepaliveParameters.timeout:type_name -> google.protobuf.Duration + 22, // 22: buildbarn.configuration.grpc.AuthenticationPolicy.allow:type_name -> buildbarn.auth.AuthenticationMetadata + 6, // 23: buildbarn.configuration.grpc.AuthenticationPolicy.any:type_name -> buildbarn.configuration.grpc.AnyAuthenticationPolicy + 7, // 24: buildbarn.configuration.grpc.AuthenticationPolicy.all:type_name -> buildbarn.configuration.grpc.AllAuthenticationPolicy + 23, // 25: buildbarn.configuration.grpc.AuthenticationPolicy.tls_client_certificate:type_name -> buildbarn.configuration.x509.ClientCertificateVerifierConfiguration + 24, // 26: buildbarn.configuration.grpc.AuthenticationPolicy.jwt:type_name -> buildbarn.configuration.jwt.AuthorizationHeaderParserConfiguration + 16, // 27: buildbarn.configuration.grpc.AuthenticationPolicy.peer_credentials_jmespath_expression:type_name -> buildbarn.configuration.jmespath.Expression + 9, // 28: buildbarn.configuration.grpc.AuthenticationPolicy.remote:type_name -> buildbarn.configuration.grpc.RemoteAuthenticationPolicy + 5, // 29: buildbarn.configuration.grpc.AnyAuthenticationPolicy.policies:type_name -> buildbarn.configuration.grpc.AuthenticationPolicy + 5, // 30: buildbarn.configuration.grpc.AllAuthenticationPolicy.policies:type_name -> buildbarn.configuration.grpc.AuthenticationPolicy + 16, // 31: buildbarn.configuration.grpc.TLSClientCertificateAuthenticationPolicy.validation_jmespath_expression:type_name -> buildbarn.configuration.jmespath.Expression + 16, // 32: buildbarn.configuration.grpc.TLSClientCertificateAuthenticationPolicy.metadata_extraction_jmespath_expression:type_name -> buildbarn.configuration.jmespath.Expression + 0, // 33: buildbarn.configuration.grpc.RemoteAuthenticationPolicy.endpoint:type_name -> buildbarn.configuration.grpc.ClientConfiguration + 25, // 34: buildbarn.configuration.grpc.RemoteAuthenticationPolicy.scope:type_name -> google.protobuf.Value + 26, // 35: buildbarn.configuration.grpc.RemoteAuthenticationPolicy.cache_replacement_policy:type_name -> buildbarn.configuration.eviction.CacheReplacementPolicy + 0, // 36: buildbarn.configuration.grpc.ServerRelayConfiguration.endpoint:type_name -> buildbarn.configuration.grpc.ClientConfiguration + 10, // 37: buildbarn.configuration.grpc.ClientConfiguration.TracingEntry.value:type_name -> buildbarn.configuration.grpc.TracingMethodConfiguration + 10, // 38: buildbarn.configuration.grpc.ServerConfiguration.TracingEntry.value:type_name -> buildbarn.configuration.grpc.TracingMethodConfiguration + 39, // [39:39] is the sub-list for method output_type + 39, // [39:39] is the sub-list for method input_type + 39, // [39:39] is the sub-list for extension type_name + 39, // [39:39] is the sub-list for extension extendee + 0, // [0:39] is the sub-list for field type_name } func init() { file_github_com_buildbarn_bb_storage_pkg_proto_configuration_grpc_grpc_proto_init() } @@ -1214,7 +1281,7 @@ func file_github_com_buildbarn_bb_storage_pkg_proto_configuration_grpc_grpc_prot GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_github_com_buildbarn_bb_storage_pkg_proto_configuration_grpc_grpc_proto_rawDesc), len(file_github_com_buildbarn_bb_storage_pkg_proto_configuration_grpc_grpc_proto_rawDesc)), NumEnums: 0, - NumMessages: 14, + NumMessages: 15, NumExtensions: 0, NumServices: 0, }, diff --git a/pkg/proto/configuration/grpc/grpc.proto b/pkg/proto/configuration/grpc/grpc.proto index 083355bf..ca915613 100644 --- a/pkg/proto/configuration/grpc/grpc.proto +++ b/pkg/proto/configuration/grpc/grpc.proto @@ -240,6 +240,16 @@ message ServerConfiguration { // https://docs.cloud.google.com/docs/security/encryption-in-transit/application-layer-transport-security google.protobuf.Empty alts = 13; } + + // Forward calls to certain named gRPC services to a different endpoint. + // + // One use case is to let the user connect to the same DNS name for extra + // services without having to use separate DNS names or setup another gRPC + // proxy. build.bazel.remote.execution.v2.Execution is a candidate for this, + // but note that build.bazel.remote.execution.v2.Capabilities might still need + // information from the scheduler. Another use case is Bazel's Build Event + // Streaming. + repeated ServerRelayConfiguration relays = 14; } message ServerKeepaliveEnforcementPolicy { @@ -501,3 +511,15 @@ message TracingMethodConfiguration { // 'attributes_from_first_request_message'. repeated string attributes_from_first_response_message = 2; } + +message ServerRelayConfiguration { + // The remote gRPC server to forward the gRPC calls to. + ClientConfiguration endpoint = 1; + + // The full gRPC service and method name to relay. + // Examples of valid names include: + // + // /build.bazel.remote.execution.v2.Execution/Execute + // /com.google.devtools.build.v1.PublishBuildEvent/PublishBuildEvent + repeated string methods = 2; +} From e1e9af38ae0d725daa82cb78a3887b9e6e1bf715 Mon Sep 17 00:00:00 2001 From: Fredrik Medley Date: Tue, 9 Dec 2025 15:23:06 +0100 Subject: [PATCH 2/9] Add tests and minor adjustments --- internal/mock/BUILD.bazel | 1 + pkg/grpc/BUILD.bazel | 2 + pkg/grpc/routing_stream_forwarder.go | 5 +- pkg/grpc/routing_stream_forwarder_test.go | 56 ++++ pkg/grpc/simple_stream_forwarder.go | 2 +- pkg/grpc/simple_stream_forwarder_test.go | 358 ++++++++++++++++++++++ pkg/testutil/testutil.go | 10 + 7 files changed, 429 insertions(+), 5 deletions(-) create mode 100644 pkg/grpc/routing_stream_forwarder_test.go create mode 100644 pkg/grpc/simple_stream_forwarder_test.go diff --git a/internal/mock/BUILD.bazel b/internal/mock/BUILD.bazel index e5f0a54d..255808e2 100644 --- a/internal/mock/BUILD.bazel +++ b/internal/mock/BUILD.bazel @@ -260,6 +260,7 @@ gomock( "ClientConnInterface", "ClientStream", "ServerStream", + "ServerTransportStream", "StreamHandler", "Streamer", "UnaryHandler", diff --git a/pkg/grpc/BUILD.bazel b/pkg/grpc/BUILD.bazel index 4035ad17..6c08c186 100644 --- a/pkg/grpc/BUILD.bazel +++ b/pkg/grpc/BUILD.bazel @@ -117,6 +117,8 @@ go_test( "proto_trace_attributes_extractor_test.go", "request_headers_authenticator_test.go", "request_metadata_tracing_interceptor_test.go", + "routing_stream_forwarder_test.go", + "simple_stream_forwarder_test.go", ] + select({ "@rules_go//go/platform:android": [ "peer_transport_credentials_test.go", diff --git a/pkg/grpc/routing_stream_forwarder.go b/pkg/grpc/routing_stream_forwarder.go index 899c9080..d597356d 100644 --- a/pkg/grpc/routing_stream_forwarder.go +++ b/pkg/grpc/routing_stream_forwarder.go @@ -1,8 +1,6 @@ package grpc import ( - "fmt" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -31,6 +29,5 @@ func (s *RoutingStreamForwarder) HandleStream(srv any, stream grpc.ServerStream) if streamHandler, ok := s.RouteTable[method]; ok { return streamHandler(srv, stream) } - errDesc := fmt.Sprintf("no route for method %v", method) - return status.Error(codes.Unimplemented, errDesc) + return status.Errorf(codes.Unimplemented, "No route for method %v", method) } diff --git a/pkg/grpc/routing_stream_forwarder_test.go b/pkg/grpc/routing_stream_forwarder_test.go new file mode 100644 index 00000000..e6910869 --- /dev/null +++ b/pkg/grpc/routing_stream_forwarder_test.go @@ -0,0 +1,56 @@ +package grpc_test + +import ( + "context" + "errors" + "testing" + + "github.com/buildbarn/bb-storage/internal/mock" + bb_grpc "github.com/buildbarn/bb-storage/pkg/grpc" + "github.com/buildbarn/bb-storage/pkg/testutil" + "github.com/stretchr/testify/require" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "go.uber.org/mock/gomock" +) + +func TestRoutingStreamForwarder(t *testing.T) { + ctrl, ctx := gomock.WithContext(context.Background(), t) + + someSrv := "server" + serverTransportStream := mock.NewMockServerTransportStream(ctrl) + streamCtx := grpc.NewContextWithServerTransportStream(ctx, serverTransportStream) + incomingStream := mock.NewMockServerStream(ctrl) + incomingStream.EXPECT().Context().Return(streamCtx).AnyTimes() + + // The test assumes that the incomingStream is forwarded straight through + // the RoutingStreamForwarder, even if the implementation is allowed to do + // some wrapping. + forwarder := bb_grpc.NewRoutingStreamForwarder() + forwarder.RouteTable["/serviceA/method1"] = func(srv any, stream grpc.ServerStream) error { + require.Equal(t, srv, someSrv) + require.Equal(t, stream, incomingStream) + return errors.New("A1") + } + forwarder.RouteTable["generic-service-method-name"] = func(srv any, stream grpc.ServerStream) error { + require.Equal(t, srv, someSrv) + require.Equal(t, stream, incomingStream) + return errors.New("generic") + } + + serverTransportStream.EXPECT().Method().Return("/serviceA/method1") + require.Error(t, forwarder.HandleStream(someSrv, incomingStream), "A1") + + serverTransportStream.EXPECT().Method().Return("generic-service-method-name") + require.Error(t, forwarder.HandleStream(someSrv, incomingStream), "generic") + + serverTransportStream.EXPECT().Method().Return("/non-existing-service/bad-method") + testutil.RequireEqualStatus( + t, + status.Error(codes.Unimplemented, "No route for method /non-existing-service/bad-method"), + forwarder.HandleStream(someSrv, incomingStream), + ) +} diff --git a/pkg/grpc/simple_stream_forwarder.go b/pkg/grpc/simple_stream_forwarder.go index 5016d299..50d419ee 100644 --- a/pkg/grpc/simple_stream_forwarder.go +++ b/pkg/grpc/simple_stream_forwarder.go @@ -43,13 +43,13 @@ func (s *simpleStreamForwarder) HandleStream(srv any, incomingStream grpc.Server return err } go func() { + defer outgoingStream.CloseSend() for { msg := &emptypb.Empty{} if err := incomingStream.RecvMsg(msg); err != nil { if err == io.EOF { // Let's to receive on outgoingStream, so don't cancel // grouptCtx. - outgoingStream.CloseSend() return } // Cancel groupCtx immediately. diff --git a/pkg/grpc/simple_stream_forwarder_test.go b/pkg/grpc/simple_stream_forwarder_test.go new file mode 100644 index 00000000..e4c7c6dc --- /dev/null +++ b/pkg/grpc/simple_stream_forwarder_test.go @@ -0,0 +1,358 @@ +package grpc_test + +import ( + "context" + "errors" + "io" + "testing" + "testing/synctest" + + "github.com/buildbarn/bb-storage/internal/mock" + bb_grpc "github.com/buildbarn/bb-storage/pkg/grpc" + "github.com/buildbarn/bb-storage/pkg/testutil" + "github.com/stretchr/testify/require" + + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" + + "go.uber.org/mock/gomock" +) + +// simpleStreamForwarderStandardFixture contains channels to communicate with +// the RecvMsg and SendMsg mocks in the incoming and backend streams. +type simpleStreamForwarderStandardFixture struct { + forwarder grpc.StreamHandler + incomingStream *mock.MockServerStream + + backendNewStreamErrorChan chan<- error + backendNewStreamCtx <-chan context.Context + + // IncomingRecvErrorChan provides the return value for + // incomingStream.RecvMsg(). + IncomingRecvErrorChan chan<- error + // IncomingRecvValueChan provides the returned proto message for + // incomingStream.RecvMsg() if the error was nil. + IncomingRecvValueChan chan<- *structpb.Value + // BackendSendErrorChan provides the return value for + // backendStream.SendMsg(). + BackendSendErrorChan chan<- error + // BackendSendValueChan receives the proto message provided in the call to + // backendStream.SendMsg(). + BackendSendValueChan <-chan *structpb.Value + // BackendSendCloseChan receives an entry when backendStream.CloseSend() is + // called. + BackendSendCloseChan <-chan struct{} + // BackendRecvErrorChan provides the return value for + // backendStream.RecvMsg(). + BackendRecvErrorChan chan<- error + // BackendRecvValueChan provides the returned proto message for + // backendStream.RecvMsg() if the error was nil. + BackendRecvValueChan chan<- *structpb.Value + // IncomingSendErrorChan provides the return value for + // incomingStream.SendMsg(). + IncomingSendErrorChan chan<- error + // IncomingSendValueChan receives the proto message provided in the call to + // incomingStream.SendMsg(). + IncomingSendValueChan <-chan *structpb.Value +} + +func newSimpleStreamForwarderStandardFixture(ctx context.Context, ctrl *gomock.Controller, t *testing.T) *simpleStreamForwarderStandardFixture { + backend := mock.NewMockClientConnInterface(ctrl) + forwarder := bb_grpc.NewSimpleStreamForwarder(backend) + serverTransportStream := mock.NewMockServerTransportStream(ctrl) + serverTransportStream.EXPECT().Method().Return("/buildbarn.buildqueuestate.BuildQueueState/ListWorkers").AnyTimes() + streamCtx := grpc.NewContextWithServerTransportStream(ctx, serverTransportStream) + incomingStream := mock.NewMockServerStream(ctrl) + incomingStream.EXPECT().Context().Return(streamCtx).AnyTimes() + backendStream := mock.NewMockClientStream(ctrl) + + backendNewStreamErrorChan := make(chan error, 10) + backendNewStreamCtx := make(chan context.Context, 10) + + backend.EXPECT().NewStream( + gomock.Any(), + gomock.Any(), + "/buildbarn.buildqueuestate.BuildQueueState/ListWorkers", + ).DoAndReturn(func(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + backendNewStreamCtx <- ctx + return backendStream, <-backendNewStreamErrorChan + }).AnyTimes() + + incomingRecvErrorChan := make(chan error, 10) + incomingRecvValueChan := make(chan *structpb.Value, 10) + backendSendErrorChan := make(chan error, 10) + backendSendValueChan := make(chan *structpb.Value, 10) + backendSendCloseChan := make(chan struct{}, 10) + backendRecvErrorChan := make(chan error, 10) + backendRecvValueChan := make(chan *structpb.Value, 10) + incomingSendErrorChan := make(chan error, 10) + incomingSendValueChan := make(chan *structpb.Value, 10) + + incomingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(func(msg any) error { + if err := <-incomingRecvErrorChan; err != nil { + return err + } + value := <-incomingRecvValueChan + bytes, err := proto.Marshal(value) + require.NoError(t, err) + require.NoError(t, proto.Unmarshal(bytes, msg.(proto.Message))) + return nil + }).AnyTimes() + backendStream.EXPECT().SendMsg(gomock.Any()).DoAndReturn(func(msg any) error { + if err := <-backendSendErrorChan; err != nil { + return err + } + bytes, err := proto.Marshal(msg.(proto.Message)) + require.NoError(t, err) + value := new(structpb.Value) + require.NoError(t, proto.Unmarshal(bytes, value)) + backendSendValueChan <- value + return nil + }).AnyTimes() + backendStream.EXPECT().CloseSend().DoAndReturn(func() error { + backendSendCloseChan <- struct{}{} + return nil + }).AnyTimes() + backendStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(func(msg any) error { + if err := <-backendRecvErrorChan; err != nil { + return err + } + value := <-backendRecvValueChan + bytes, err := proto.Marshal(value) + require.NoError(t, err) + require.NoError(t, proto.Unmarshal(bytes, msg.(proto.Message))) + return nil + }).AnyTimes() + incomingStream.EXPECT().SendMsg(gomock.Any()).DoAndReturn(func(msg any) error { + if err := <-incomingSendErrorChan; err != nil { + return err + } + bytes, err := proto.Marshal(msg.(proto.Message)) + require.NoError(t, err) + value := new(structpb.Value) + require.NoError(t, proto.Unmarshal(bytes, value)) + incomingSendValueChan <- value + return nil + }).AnyTimes() + + return &simpleStreamForwarderStandardFixture{ + forwarder: forwarder, + incomingStream: incomingStream, + + backendNewStreamErrorChan: backendNewStreamErrorChan, + backendNewStreamCtx: backendNewStreamCtx, + + IncomingRecvErrorChan: incomingRecvErrorChan, + IncomingRecvValueChan: incomingRecvValueChan, + BackendSendErrorChan: backendSendErrorChan, + BackendSendValueChan: backendSendValueChan, + BackendSendCloseChan: backendSendCloseChan, + BackendRecvErrorChan: backendRecvErrorChan, + BackendRecvValueChan: backendRecvValueChan, + IncomingSendErrorChan: incomingSendErrorChan, + IncomingSendValueChan: incomingSendValueChan, + } +} + +func (f *simpleStreamForwarderStandardFixture) call(newStreamErr error) (context.Context, <-chan error) { + callResult := make(chan error, 1) + go func() { + defer close(callResult) + f.backendNewStreamErrorChan <- newStreamErr + callResult <- f.forwarder(nil, f.incomingStream) + }() + return <-f.backendNewStreamCtx, callResult +} + +func (f *simpleStreamForwarderStandardFixture) verifyEmptyChannels(t *testing.T) { + require.Len(t, f.backendNewStreamErrorChan, 0, "backendNewStreamErrorChan") + require.Len(t, f.backendNewStreamCtx, 0, "backendNewStreamCtx") + require.Len(t, f.IncomingRecvErrorChan, 0, "IncomingRecvErrorChan") + require.Len(t, f.IncomingRecvValueChan, 0, "IncomingRecvValueChan") + require.Len(t, f.BackendSendErrorChan, 0, "BackendSendErrorChan") + require.Len(t, f.BackendSendValueChan, 0, "BackendSendValueChan") + require.Len(t, f.BackendSendCloseChan, 0, "BackendSendCloseChan") + require.Len(t, f.BackendRecvErrorChan, 0, "BackendRecvErrorChan") + require.Len(t, f.BackendRecvValueChan, 0, "BackendRecvValueChan") + require.Len(t, f.IncomingSendErrorChan, 0, "IncomingSendErrorChan") + require.Len(t, f.IncomingSendValueChan, 0, "IncomingSendValueChan") +} + +func TestSimpleStreamForwarderRequestSuccess(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + ctrl, ctx := gomock.WithContext(context.Background(), t) + fixture := newSimpleStreamForwarderStandardFixture(ctx, ctrl, t) + backendCtx, forwardResultChan := fixture.call(nil) + + fixture.IncomingRecvErrorChan <- nil + fixture.IncomingRecvValueChan <- structpb.NewStringValue("beep") + fixture.BackendSendErrorChan <- nil + testutil.RequireEqualProto(t, structpb.NewStringValue("beep"), <-fixture.BackendSendValueChan) + fixture.IncomingRecvErrorChan <- nil + fixture.IncomingRecvValueChan <- structpb.NewStringValue("boop") + fixture.BackendSendErrorChan <- nil + testutil.RequireEqualProto(t, structpb.NewStringValue("boop"), <-fixture.BackendSendValueChan) + + // Should still be forwarding requests to the backend. + synctest.Wait() + require.Len(t, fixture.BackendSendCloseChan, 0) + fixture.IncomingRecvErrorChan <- io.EOF + <-fixture.BackendSendCloseChan + + // Should still be receiving responses. + synctest.Wait() + testutil.VerifyChannelIsBlocking(t, backendCtx.Done()) + require.Len(t, forwardResultChan, 0) + fixture.BackendRecvErrorChan <- io.EOF + <-backendCtx.Done() + require.NoError(t, <-forwardResultChan) + + fixture.verifyEmptyChannels(t) + }) +} + +func TestSimpleStreamForwarderRequestRecvError(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + ctrl, ctx := gomock.WithContext(context.Background(), t) + fixture := newSimpleStreamForwarderStandardFixture(ctx, ctrl, t) + backendCtx, forwardResultChan := fixture.call(nil) + + fixture.IncomingRecvErrorChan <- errors.New("incoming recv") + // Optional to call backend.SendClose(), but if called it should be done + // by now. + synctest.Wait() + select { + case <-fixture.BackendSendCloseChan: + default: + } + + // In error state, so the backend context should be canceled. + <-backendCtx.Done() + // Emulate that backend.RecvMsg() to returns. + fixture.BackendRecvErrorChan <- context.Canceled + + require.EqualError(t, <-forwardResultChan, "incoming recv") + + fixture.verifyEmptyChannels(t) + }) +} + +func TestSimpleStreamForwarderRequestSendError(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + ctrl, ctx := gomock.WithContext(context.Background(), t) + fixture := newSimpleStreamForwarderStandardFixture(ctx, ctrl, t) + backendCtx, forwardResultChan := fixture.call(nil) + + fixture.IncomingRecvErrorChan <- nil + fixture.IncomingRecvValueChan <- structpb.NewStringValue("beep") + fixture.BackendSendErrorChan <- errors.New("backend send") + // Optional to call backend.SendClose(), but if called it should be done + // by now. + synctest.Wait() + select { + case <-fixture.BackendSendCloseChan: + default: + } + + // In error state, so the backend context should be canceled. + <-backendCtx.Done() + // Emulate that backend.RecvMsg() to returns. + fixture.BackendRecvErrorChan <- context.Canceled + + require.EqualError(t, <-forwardResultChan, "backend send") + + fixture.verifyEmptyChannels(t) + }) +} + +func TestSimpleStreamForwarderResponseSuccess(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + ctrl, ctx := gomock.WithContext(context.Background(), t) + fixture := newSimpleStreamForwarderStandardFixture(ctx, ctrl, t) + backendCtx, forwardResultChan := fixture.call(nil) + + fixture.BackendRecvErrorChan <- nil + fixture.BackendRecvValueChan <- structpb.NewStringValue("beep") + fixture.IncomingSendErrorChan <- nil + testutil.RequireEqualProto(t, structpb.NewStringValue("beep"), <-fixture.IncomingSendValueChan) + fixture.BackendRecvErrorChan <- nil + fixture.BackendRecvValueChan <- structpb.NewStringValue("boop") + fixture.IncomingSendErrorChan <- nil + testutil.RequireEqualProto(t, structpb.NewStringValue("boop"), <-fixture.IncomingSendValueChan) + + // Should still be forwarding requests to the backend. + synctest.Wait() + require.Len(t, fixture.BackendSendCloseChan, 0) + // Should still be receiving responses. + testutil.VerifyChannelIsBlocking(t, backendCtx.Done()) + require.Len(t, forwardResultChan, 0) + fixture.BackendRecvErrorChan <- io.EOF + // Will not receive any more from the backend, so its context should + // be canceled. + <-backendCtx.Done() + require.NoError(t, <-forwardResultChan) + fixture.IncomingRecvErrorChan <- context.Canceled + <-fixture.BackendSendCloseChan + + fixture.verifyEmptyChannels(t) + }) +} + +func TestSimpleStreamForwarderResponseRecvError(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + ctrl, ctx := gomock.WithContext(context.Background(), t) + fixture := newSimpleStreamForwarderStandardFixture(ctx, ctrl, t) + backendCtx, forwardResultChan := fixture.call(nil) + + fixture.BackendRecvErrorChan <- errors.New("backend recv") + // Optional to call backend.SendClose(), but if called it should be done + // by now. + synctest.Wait() + select { + case <-fixture.BackendSendCloseChan: + default: + } + + // In error state, so the backend context should be canceled. + <-backendCtx.Done() + require.EqualError(t, <-forwardResultChan, "backend recv") + + // Emulate that incoming.RecvMsg() returns now when the whole stream + // handler has returned. + fixture.IncomingRecvErrorChan <- context.Canceled + + fixture.verifyEmptyChannels(t) + }) +} + +func TestSimpleStreamForwarderResponseSendError(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + ctrl, ctx := gomock.WithContext(context.Background(), t) + fixture := newSimpleStreamForwarderStandardFixture(ctx, ctrl, t) + backendCtx, forwardResultChan := fixture.call(nil) + + fixture.BackendRecvErrorChan <- nil + fixture.BackendRecvValueChan <- structpb.NewStringValue("beep") + fixture.IncomingSendErrorChan <- errors.New("incoming send") + // Optional to call backend.SendClose(), but if called it should be done + // by now. + synctest.Wait() + select { + case <-fixture.BackendSendCloseChan: + default: + } + + // In error state, so the backend context should be canceled. + <-backendCtx.Done() + + // The forwarder should return, even if the backend.RecvMsg() is slow. + require.EqualError(t, <-forwardResultChan, "incoming send") + + // Emulate that incoming.RecvMsg() returns now when the whole stream + // handler has returned. + fixture.IncomingRecvErrorChan <- context.Canceled + + fixture.verifyEmptyChannels(t) + }) +} diff --git a/pkg/testutil/testutil.go b/pkg/testutil/testutil.go index 6fc3172b..dbe01af7 100644 --- a/pkg/testutil/testutil.go +++ b/pkg/testutil/testutil.go @@ -138,3 +138,13 @@ func mustMarshalToString(t *testing.T, proto proto.Message) string { } return string(s) } + +// VerifyChannelIsBlocking checks that no value can be received from the +// channel. If a value is ready or the channel has been closed, the test fails. +func VerifyChannelIsBlocking[Value any](t *testing.T, channel <-chan Value) { + select { + case <-channel: + t.Error("Channel is not blocking") + default: + } +} From 1011bb36da031b4901d9f51ddb0a35dd3b6827f7 Mon Sep 17 00:00:00 2001 From: Fredrik Medley Date: Wed, 10 Dec 2025 14:41:34 +0100 Subject: [PATCH 3/9] Fix review comments apart from the tests --- pkg/grpc/BUILD.bazel | 8 +- pkg/grpc/forwarding_stream_handler.go | 91 +++++++++++++++++++ ...t.go => forwarding_stream_handler_test.go} | 31 +------ pkg/grpc/reflection_relay.go | 15 +-- pkg/grpc/routing_stream_forwarder.go | 33 ------- pkg/grpc/routing_stream_forwarder_test.go | 56 ------------ pkg/grpc/routing_stream_handler.go | 36 ++++++++ pkg/grpc/routing_stream_handler_test.go | 87 ++++++++++++++++++ pkg/grpc/server.go | 23 ++--- pkg/grpc/simple_stream_forwarder.go | 88 ------------------ pkg/proto/configuration/grpc/grpc.pb.go | 12 +-- pkg/proto/configuration/grpc/grpc.proto | 8 +- 12 files changed, 243 insertions(+), 245 deletions(-) create mode 100644 pkg/grpc/forwarding_stream_handler.go rename pkg/grpc/{simple_stream_forwarder_test.go => forwarding_stream_handler_test.go} (94%) delete mode 100644 pkg/grpc/routing_stream_forwarder.go delete mode 100644 pkg/grpc/routing_stream_forwarder_test.go create mode 100644 pkg/grpc/routing_stream_handler.go create mode 100644 pkg/grpc/routing_stream_handler_test.go delete mode 100644 pkg/grpc/simple_stream_forwarder.go diff --git a/pkg/grpc/BUILD.bazel b/pkg/grpc/BUILD.bazel index 6c08c186..593f8791 100644 --- a/pkg/grpc/BUILD.bazel +++ b/pkg/grpc/BUILD.bazel @@ -13,6 +13,7 @@ go_library( "client_factory.go", "deduplicating_client_factory.go", "deny_authenticator.go", + "forwarding_stream_handler.go", "jmespath_extractor.go", "lazy_client_dialer.go", "metadata_adding_interceptor.go", @@ -29,10 +30,9 @@ go_library( "reflection_relay.go", "request_headers_authenticator.go", "request_metadata_tracing_interceptor.go", - "routing_stream_forwarder.go", + "routing_stream_handler.go", "server.go", "server_transport_stream_context.go", - "simple_stream_forwarder.go", "tls_client_certificate_authenticator.go", ], importpath = "github.com/buildbarn/bb-storage/pkg/grpc", @@ -108,6 +108,7 @@ go_test( "authenticating_interceptor_test.go", "deduplicating_client_factory_test.go", "deny_authenticator_test.go", + "forwarding_stream_handler_test.go", "jmespath_extractor_test.go", "lazy_client_dialer_test.go", "metadata_adding_interceptor_test.go", @@ -117,8 +118,7 @@ go_test( "proto_trace_attributes_extractor_test.go", "request_headers_authenticator_test.go", "request_metadata_tracing_interceptor_test.go", - "routing_stream_forwarder_test.go", - "simple_stream_forwarder_test.go", + "routing_stream_handler_test.go", ] + select({ "@rules_go//go/platform:android": [ "peer_transport_credentials_test.go", diff --git a/pkg/grpc/forwarding_stream_handler.go b/pkg/grpc/forwarding_stream_handler.go new file mode 100644 index 00000000..99720a4a --- /dev/null +++ b/pkg/grpc/forwarding_stream_handler.go @@ -0,0 +1,91 @@ +package grpc + +import ( + "io" + + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/emptypb" +) + +// NewForwardingStreamHandler creates a grpc.StreamHandler that forwards gRPC +// calls to a grpc.ClientConnInterface backend. +func NewForwardingStreamHandler(client grpc.ClientConnInterface) grpc.StreamHandler { + forwarder := &forwardingStreamHandler{ + backend: client, + } + return forwarder.HandleStream +} + +type forwardingStreamHandler struct { + backend grpc.ClientConnInterface +} + +// HandleStream creates a new stream to the backend. Requests from +// incomingStream are forwarded to the backend stream and responses from the +// backend stream are sent back in the incomingStream. +func (s *forwardingStreamHandler) HandleStream(srv any, incomingStream grpc.ServerStream) error { + method := MustStreamMethodFromContext(incomingStream.Context()) + desc := grpc.StreamDesc{ + // According to grpc.StreamDesc documentation, StreamName and Handler + // are only used when registering handlers on a server. + StreamName: "", + Handler: nil, + // Streaming behaviour is wanted, single message is treated the same on + // transport level, the application just closes the stream after the + // first message. + ServerStreams: true, + ClientStreams: true, + } + group, groupCtx := errgroup.WithContext(incomingStream.Context()) + group.Go(func() error { + // groupCtx is guaranteed to be canceled before returning from this method, so outgoingStream will not leak resources. + outgoingStream, err := s.backend.NewStream(groupCtx, &desc, method) + if err != nil { + return err + } + // Avoid group.Go because incomingStream.RecvMsg might block returning + // an error from the outgoingStream and getting the context for + // incomingStream canceled. + go func() { + for { + msg := &emptypb.Empty{} + if err := incomingStream.RecvMsg(msg); err != nil { + if err == io.EOF { + // Let's continue to receive on outgoingStream, so don't + // cancel grouptCtx. + outgoingStream.CloseSend() + return + } + // Cancel groupCtx immediately. + group.Go(func() error { return err }) + return + } + if err := outgoingStream.SendMsg(msg); err != nil { + if err == io.EOF { + // The error will be returned by outgoingStream.RecvMsg(), + // no need to cancel groupCtx now. + return + } + // Cancel groupCtx immediately. + group.Go(func() error { return err }) + return + } + } + }() + + for { + msg := &emptypb.Empty{} + if err := outgoingStream.RecvMsg(msg); err != nil { + if err == io.EOF { + return nil + } + return err + } + if err := incomingStream.SendMsg(msg); err != nil { + return err + } + } + }) + return group.Wait() +} diff --git a/pkg/grpc/simple_stream_forwarder_test.go b/pkg/grpc/forwarding_stream_handler_test.go similarity index 94% rename from pkg/grpc/simple_stream_forwarder_test.go rename to pkg/grpc/forwarding_stream_handler_test.go index e4c7c6dc..055f8ec1 100644 --- a/pkg/grpc/simple_stream_forwarder_test.go +++ b/pkg/grpc/forwarding_stream_handler_test.go @@ -59,7 +59,7 @@ type simpleStreamForwarderStandardFixture struct { func newSimpleStreamForwarderStandardFixture(ctx context.Context, ctrl *gomock.Controller, t *testing.T) *simpleStreamForwarderStandardFixture { backend := mock.NewMockClientConnInterface(ctrl) - forwarder := bb_grpc.NewSimpleStreamForwarder(backend) + forwarder := bb_grpc.NewForwardingStreamHandler(backend) serverTransportStream := mock.NewMockServerTransportStream(ctrl) serverTransportStream.EXPECT().Method().Return("/buildbarn.buildqueuestate.BuildQueueState/ListWorkers").AnyTimes() streamCtx := grpc.NewContextWithServerTransportStream(ctx, serverTransportStream) @@ -219,13 +219,6 @@ func TestSimpleStreamForwarderRequestRecvError(t *testing.T) { backendCtx, forwardResultChan := fixture.call(nil) fixture.IncomingRecvErrorChan <- errors.New("incoming recv") - // Optional to call backend.SendClose(), but if called it should be done - // by now. - synctest.Wait() - select { - case <-fixture.BackendSendCloseChan: - default: - } // In error state, so the backend context should be canceled. <-backendCtx.Done() @@ -247,13 +240,6 @@ func TestSimpleStreamForwarderRequestSendError(t *testing.T) { fixture.IncomingRecvErrorChan <- nil fixture.IncomingRecvValueChan <- structpb.NewStringValue("beep") fixture.BackendSendErrorChan <- errors.New("backend send") - // Optional to call backend.SendClose(), but if called it should be done - // by now. - synctest.Wait() - select { - case <-fixture.BackendSendCloseChan: - default: - } // In error state, so the backend context should be canceled. <-backendCtx.Done() @@ -293,7 +279,6 @@ func TestSimpleStreamForwarderResponseSuccess(t *testing.T) { <-backendCtx.Done() require.NoError(t, <-forwardResultChan) fixture.IncomingRecvErrorChan <- context.Canceled - <-fixture.BackendSendCloseChan fixture.verifyEmptyChannels(t) }) @@ -306,13 +291,6 @@ func TestSimpleStreamForwarderResponseRecvError(t *testing.T) { backendCtx, forwardResultChan := fixture.call(nil) fixture.BackendRecvErrorChan <- errors.New("backend recv") - // Optional to call backend.SendClose(), but if called it should be done - // by now. - synctest.Wait() - select { - case <-fixture.BackendSendCloseChan: - default: - } // In error state, so the backend context should be canceled. <-backendCtx.Done() @@ -335,13 +313,6 @@ func TestSimpleStreamForwarderResponseSendError(t *testing.T) { fixture.BackendRecvErrorChan <- nil fixture.BackendRecvValueChan <- structpb.NewStringValue("beep") fixture.IncomingSendErrorChan <- errors.New("incoming send") - // Optional to call backend.SendClose(), but if called it should be done - // by now. - synctest.Wait() - select { - case <-fixture.BackendSendCloseChan: - default: - } // In error state, so the backend context should be canceled. <-backendCtx.Done() diff --git a/pkg/grpc/reflection_relay.go b/pkg/grpc/reflection_relay.go index 623690a2..b1eceec5 100644 --- a/pkg/grpc/reflection_relay.go +++ b/pkg/grpc/reflection_relay.go @@ -3,7 +3,6 @@ package grpc import ( "context" "maps" - "strings" "github.com/buildbarn/bb-storage/pkg/program" grpcpb "github.com/buildbarn/bb-storage/pkg/proto/configuration/grpc" @@ -14,9 +13,7 @@ import ( v1alphareflectiongrpc "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" "google.golang.org/grpc" - "google.golang.org/grpc/codes" "google.golang.org/grpc/reflection" - "google.golang.org/grpc/status" ) type combinedServiceInfoProvider struct { @@ -42,18 +39,10 @@ func registerReflection(backendCtx context.Context, s *grpc.Server, serverRelayC // Accumulate all the service names. relayServices := make(map[string]grpc.ServiceInfo) for _, relay := range serverRelayConfiguration { - for _, serviceMethod := range relay.Methods { - if !strings.HasPrefix(serviceMethod, "/") { - return status.Errorf(codes.InvalidArgument, "Malformed service method name %q, should start with '/'", serviceMethod) - } - pos := strings.LastIndex(serviceMethod, "/") - if pos == -1 || pos == 0 { - return status.Errorf(codes.InvalidArgument, "Malformed name %q, expected '/' between service and method", serviceMethod) - } - serviceName := serviceMethod[1:pos] + for _, service := range relay.GetServices() { // According to ServiceInfoProvider docs for ServerOptions.Services, // the reflection service is only interested in the service names. - relayServices[serviceName] = grpc.ServiceInfo{} + relayServices[service] = grpc.ServiceInfo{} } } diff --git a/pkg/grpc/routing_stream_forwarder.go b/pkg/grpc/routing_stream_forwarder.go deleted file mode 100644 index d597356d..00000000 --- a/pkg/grpc/routing_stream_forwarder.go +++ /dev/null @@ -1,33 +0,0 @@ -package grpc - -import ( - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -// RoutingStreamForwarder forwards gRPC streams to different backends depending -// on the method being invoked. -type RoutingStreamForwarder struct { - // RouteTable maps to the grpc.StreamHandler to be called. The key is the - // combined gRPC service and method name. - RouteTable map[string]grpc.StreamHandler -} - -// NewRoutingStreamForwarder creates a RoutingStreamForwarder which routes gRPC -// streams based on the invoked gRPC method name. -func NewRoutingStreamForwarder() *RoutingStreamForwarder { - return &RoutingStreamForwarder{ - RouteTable: make(map[string]grpc.StreamHandler), - } -} - -// HandleStream is the implementation of the grpc.StreamHandler interface to -// process a gRPC stream, forwarding it according to the RouteTable. -func (s *RoutingStreamForwarder) HandleStream(srv any, stream grpc.ServerStream) error { - method := MustStreamMethodFromContext(stream.Context()) - if streamHandler, ok := s.RouteTable[method]; ok { - return streamHandler(srv, stream) - } - return status.Errorf(codes.Unimplemented, "No route for method %v", method) -} diff --git a/pkg/grpc/routing_stream_forwarder_test.go b/pkg/grpc/routing_stream_forwarder_test.go deleted file mode 100644 index e6910869..00000000 --- a/pkg/grpc/routing_stream_forwarder_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package grpc_test - -import ( - "context" - "errors" - "testing" - - "github.com/buildbarn/bb-storage/internal/mock" - bb_grpc "github.com/buildbarn/bb-storage/pkg/grpc" - "github.com/buildbarn/bb-storage/pkg/testutil" - "github.com/stretchr/testify/require" - - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - "go.uber.org/mock/gomock" -) - -func TestRoutingStreamForwarder(t *testing.T) { - ctrl, ctx := gomock.WithContext(context.Background(), t) - - someSrv := "server" - serverTransportStream := mock.NewMockServerTransportStream(ctrl) - streamCtx := grpc.NewContextWithServerTransportStream(ctx, serverTransportStream) - incomingStream := mock.NewMockServerStream(ctrl) - incomingStream.EXPECT().Context().Return(streamCtx).AnyTimes() - - // The test assumes that the incomingStream is forwarded straight through - // the RoutingStreamForwarder, even if the implementation is allowed to do - // some wrapping. - forwarder := bb_grpc.NewRoutingStreamForwarder() - forwarder.RouteTable["/serviceA/method1"] = func(srv any, stream grpc.ServerStream) error { - require.Equal(t, srv, someSrv) - require.Equal(t, stream, incomingStream) - return errors.New("A1") - } - forwarder.RouteTable["generic-service-method-name"] = func(srv any, stream grpc.ServerStream) error { - require.Equal(t, srv, someSrv) - require.Equal(t, stream, incomingStream) - return errors.New("generic") - } - - serverTransportStream.EXPECT().Method().Return("/serviceA/method1") - require.Error(t, forwarder.HandleStream(someSrv, incomingStream), "A1") - - serverTransportStream.EXPECT().Method().Return("generic-service-method-name") - require.Error(t, forwarder.HandleStream(someSrv, incomingStream), "generic") - - serverTransportStream.EXPECT().Method().Return("/non-existing-service/bad-method") - testutil.RequireEqualStatus( - t, - status.Error(codes.Unimplemented, "No route for method /non-existing-service/bad-method"), - forwarder.HandleStream(someSrv, incomingStream), - ) -} diff --git a/pkg/grpc/routing_stream_handler.go b/pkg/grpc/routing_stream_handler.go new file mode 100644 index 00000000..759cf2f0 --- /dev/null +++ b/pkg/grpc/routing_stream_handler.go @@ -0,0 +1,36 @@ +package grpc + +import ( + "strings" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// NewRoutingStreamHandler creates a RoutingStreamForwarder which routes gRPC +// streams based on the invoked gRPC method name. The keys in the routeTable map +// are gRPC service names, for example: +// +// build.bazel.remote.execution.v2.Execution +// com.google.devtools.build.v1.PublishBuildEvent +func NewRoutingStreamHandler(routeTable map[string]grpc.StreamHandler) grpc.StreamHandler { + return func(srv any, stream grpc.ServerStream) error { + serviceMethod := MustStreamMethodFromContext(stream.Context()) + // Service and method name parsing based on grpc.Server.handleStream(). + startIdx := 0 + if serviceMethod != "" && serviceMethod[0] == '/' { + startIdx = 1 + } + endIdx := strings.LastIndex(serviceMethod, "/") + if endIdx <= startIdx { + return status.Errorf(codes.InvalidArgument, "Malformed method name %v", serviceMethod) + } + service := serviceMethod[startIdx:endIdx] + + if streamHandler, ok := routeTable[service]; ok { + return streamHandler(srv, stream) + } + return status.Errorf(codes.Unimplemented, "No route for service %v", service) + } +} diff --git a/pkg/grpc/routing_stream_handler_test.go b/pkg/grpc/routing_stream_handler_test.go new file mode 100644 index 00000000..652f0b49 --- /dev/null +++ b/pkg/grpc/routing_stream_handler_test.go @@ -0,0 +1,87 @@ +package grpc_test + +import ( + "context" + "errors" + "testing" + + "github.com/buildbarn/bb-storage/internal/mock" + bb_grpc "github.com/buildbarn/bb-storage/pkg/grpc" + "github.com/buildbarn/bb-storage/pkg/testutil" + "github.com/stretchr/testify/require" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "go.uber.org/mock/gomock" +) + +func TestRoutingStreamForwarder(t *testing.T) { + ctrl, ctx := gomock.WithContext(context.Background(), t) + + someSrv := "server" + serverTransportStream := mock.NewMockServerTransportStream(ctrl) + streamCtx := grpc.NewContextWithServerTransportStream(ctx, serverTransportStream) + incomingStream := mock.NewMockServerStream(ctrl) + incomingStream.EXPECT().Context().Return(streamCtx).AnyTimes() + + // The test assumes that the incomingStream is forwarded straight through + // the RoutingStreamForwarder, even if the implementation is allowed to do + // some wrapping. + streamHandler := mock.NewMockStreamHandler(ctrl) + + forwarder := bb_grpc.NewRoutingStreamHandler(map[string]grpc.StreamHandler{ + "serviceA": streamHandler.Call, + "/serviceB": streamHandler.Call, + }) + + serverTransportStream.EXPECT().Method().Return("/serviceA/method1") + streamHandler.EXPECT().Call(someSrv, incomingStream).Return(errors.New("called")) + require.Error(t, forwarder(someSrv, incomingStream), "called") + + serverTransportStream.EXPECT().Method().Return("/serviceB/method2") + testutil.RequireEqualStatus( + t, + status.Error(codes.Unimplemented, "No route for service serviceB"), + forwarder(someSrv, incomingStream), + ) + + serverTransportStream.EXPECT().Method().Return("/non.existing/service/bad-method") + testutil.RequireEqualStatus( + t, + status.Error(codes.Unimplemented, "No route for service non.existing/service"), + forwarder(someSrv, incomingStream), + ) + serverTransportStream.EXPECT().Method().Return("non.existing/service/bad-method") + testutil.RequireEqualStatus( + t, + status.Error(codes.Unimplemented, "No route for service non.existing/service"), + forwarder(someSrv, incomingStream), + ) + + serverTransportStream.EXPECT().Method().Return("/service.only") + testutil.RequireEqualStatus( + t, + status.Error(codes.InvalidArgument, "Malformed method name /service.only"), + forwarder(someSrv, incomingStream), + ) + serverTransportStream.EXPECT().Method().Return("service.only") + testutil.RequireEqualStatus( + t, + status.Error(codes.InvalidArgument, "Malformed method name service.only"), + forwarder(someSrv, incomingStream), + ) + serverTransportStream.EXPECT().Method().Return("/") + testutil.RequireEqualStatus( + t, + status.Error(codes.InvalidArgument, "Malformed method name /"), + forwarder(someSrv, incomingStream), + ) + serverTransportStream.EXPECT().Method().Return("") + testutil.RequireEqualStatus( + t, + status.Error(codes.InvalidArgument, "Malformed method name "), + forwarder(someSrv, incomingStream), + ) +} diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 9eafffa6..95db8896 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -147,9 +147,9 @@ func NewServersFromConfigurationAndServe(configurations []*configuration.ServerC } if len(configuration.Relays) != 0 { - handler, err := newStreamRoutingFromConfiguration(configuration.Relays, grpcClientFactory, group) + handler, err := newRoutingStreamHandlerFromConfiguration(configuration.Relays, grpcClientFactory, group) if err != nil { - return util.StatusWrap(err, "Failed to create authenticator RPC client") + return err } serverOptions = append(serverOptions, grpc.UnknownServiceHandler(handler)) } @@ -218,19 +218,20 @@ func NewServersFromConfigurationAndServe(configurations []*configuration.ServerC return nil } -func newStreamRoutingFromConfiguration(serverRelayConfiguration []*grpcpb.ServerRelayConfiguration, grpcClientFactory ClientFactory, group program.Group) (grpc.StreamHandler, error) { - handler := NewRoutingStreamForwarder() - for _, relay := range serverRelayConfiguration { +func newRoutingStreamHandlerFromConfiguration(serverRelayConfiguration []*grpcpb.ServerRelayConfiguration, grpcClientFactory ClientFactory, group program.Group) (grpc.StreamHandler, error) { + routeTable := make(map[string]grpc.StreamHandler) + for i, relay := range serverRelayConfiguration { grpcClient, err := grpcClientFactory.NewClientFromConfiguration(relay.GetEndpoint(), group) if err != nil { - return nil, util.StatusWrap(err, "Failed to create authenticator RPC client") + return nil, util.StatusWrapf(err, "Failed to create gRPC relay RPC client at index %d", i) } - for _, method := range relay.GetMethods() { - if _, ok := handler.RouteTable[method]; ok { - return nil, status.Errorf(codes.InvalidArgument, "Duplicated relay for %v", method) + handler := NewForwardingStreamHandler(grpcClient) + for _, service := range relay.GetServices() { + if _, ok := routeTable[service]; ok { + return nil, status.Errorf(codes.InvalidArgument, "Duplicated gRPC relay for %v", service) } - handler.RouteTable[method] = NewSimpleStreamForwarder(grpcClient) + routeTable[service] = handler } } - return handler.HandleStream, nil + return NewRoutingStreamHandler(routeTable), nil } diff --git a/pkg/grpc/simple_stream_forwarder.go b/pkg/grpc/simple_stream_forwarder.go deleted file mode 100644 index 50d419ee..00000000 --- a/pkg/grpc/simple_stream_forwarder.go +++ /dev/null @@ -1,88 +0,0 @@ -package grpc - -import ( - "io" - - "golang.org/x/sync/errgroup" - "google.golang.org/grpc" - "google.golang.org/protobuf/types/known/emptypb" -) - -// NewSimpleStreamForwarder creates a grpc.StreamHandler that forwards gRPC -// calls to a grpc.ClientConnInterface backend. -func NewSimpleStreamForwarder(client grpc.ClientConnInterface) grpc.StreamHandler { - forwarder := &simpleStreamForwarder{ - backend: client, - } - return forwarder.HandleStream -} - -type simpleStreamForwarder struct { - backend grpc.ClientConnInterface -} - -// HandleStream creates a new stream to the backend. Requests from -// incomingStream are forwarded to the backend stream and responses from the -// backend stream are sent back in the incomingStream. -func (s *simpleStreamForwarder) HandleStream(srv any, incomingStream grpc.ServerStream) error { - method := MustStreamMethodFromContext(incomingStream.Context()) - desc := grpc.StreamDesc{ - // According to grpc.StreamDesc documentation, StreamName and Handler - // are only used when registering handlers on a server. - StreamName: "", - Handler: nil, - // Streaming behaviour is wanted, single message is treated the same on - // transport level, the application just closes the stream after the - // first message. - ServerStreams: true, - ClientStreams: true, - } - group, groupCtx := errgroup.WithContext(incomingStream.Context()) - outgoingStream, err := s.backend.NewStream(groupCtx, &desc, method) - if err != nil { - return err - } - go func() { - defer outgoingStream.CloseSend() - for { - msg := &emptypb.Empty{} - if err := incomingStream.RecvMsg(msg); err != nil { - if err == io.EOF { - // Let's to receive on outgoingStream, so don't cancel - // grouptCtx. - return - } - // Cancel groupCtx immediately. - group.Go(func() error { return err }) - return - } - if err := outgoingStream.SendMsg(msg); err != nil { - if err == io.EOF { - // The error will be returned by outgoingStream.RecvMsg(), - // no need to cancel groupCtx now. - return - } - // Cancel groupCtx immediately. - group.Go(func() error { return err }) - return - } - } - }() - group.Go(func() error { - for { - msg := &emptypb.Empty{} - if err := outgoingStream.RecvMsg(msg); err != nil { - if err == io.EOF { - return nil - } - return err - } - if err := incomingStream.SendMsg(msg); err != nil { - return err - } - } - }) - // group.Wait() may block a bit on incomingStream.SendMsg(), but that - // shouldn't be for too long. - return group.Wait() -} diff --git a/pkg/proto/configuration/grpc/grpc.pb.go b/pkg/proto/configuration/grpc/grpc.pb.go index f66596eb..b79d1a79 100644 --- a/pkg/proto/configuration/grpc/grpc.pb.go +++ b/pkg/proto/configuration/grpc/grpc.pb.go @@ -978,7 +978,7 @@ func (x *TracingMethodConfiguration) GetAttributesFromFirstResponseMessage() []s type ServerRelayConfiguration struct { state protoimpl.MessageState `protogen:"open.v1"` Endpoint *ClientConfiguration `protobuf:"bytes,1,opt,name=endpoint,proto3" json:"endpoint,omitempty"` - Methods []string `protobuf:"bytes,2,rep,name=methods,proto3" json:"methods,omitempty"` + Services []string `protobuf:"bytes,2,rep,name=services,proto3" json:"services,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -1020,9 +1020,9 @@ func (x *ServerRelayConfiguration) GetEndpoint() *ClientConfiguration { return nil } -func (x *ServerRelayConfiguration) GetMethods() []string { +func (x *ServerRelayConfiguration) GetServices() []string { if x != nil { - return x.Methods + return x.Services } return nil } @@ -1162,10 +1162,10 @@ const file_github_com_buildbarn_bb_storage_pkg_proto_configuration_grpc_grpc_pro "\x18cache_replacement_policy\x18\x05 \x01(\x0e28.buildbarn.configuration.eviction.CacheReplacementPolicyR\x16cacheReplacementPolicy\"\xc2\x01\n" + "\x1aTracingMethodConfiguration\x12P\n" + "%attributes_from_first_request_message\x18\x01 \x03(\tR!attributesFromFirstRequestMessage\x12R\n" + - "&attributes_from_first_response_message\x18\x02 \x03(\tR\"attributesFromFirstResponseMessage\"\x83\x01\n" + + "&attributes_from_first_response_message\x18\x02 \x03(\tR\"attributesFromFirstResponseMessage\"\x85\x01\n" + "\x18ServerRelayConfiguration\x12M\n" + - "\bendpoint\x18\x01 \x01(\v21.buildbarn.configuration.grpc.ClientConfigurationR\bendpoint\x12\x18\n" + - "\amethods\x18\x02 \x03(\tR\amethodsB>ZZ Date: Wed, 10 Dec 2025 20:59:33 +0100 Subject: [PATCH 4/9] Fix tests with gomock --- pkg/grpc/forwarding_stream_handler_test.go | 522 ++++++++++----------- pkg/testutil/testutil.go | 1 + 2 files changed, 237 insertions(+), 286 deletions(-) diff --git a/pkg/grpc/forwarding_stream_handler_test.go b/pkg/grpc/forwarding_stream_handler_test.go index 055f8ec1..0f300d77 100644 --- a/pkg/grpc/forwarding_stream_handler_test.go +++ b/pkg/grpc/forwarding_stream_handler_test.go @@ -19,311 +19,261 @@ import ( "go.uber.org/mock/gomock" ) -// simpleStreamForwarderStandardFixture contains channels to communicate with -// the RecvMsg and SendMsg mocks in the incoming and backend streams. -type simpleStreamForwarderStandardFixture struct { - forwarder grpc.StreamHandler - incomingStream *mock.MockServerStream - - backendNewStreamErrorChan chan<- error - backendNewStreamCtx <-chan context.Context - - // IncomingRecvErrorChan provides the return value for - // incomingStream.RecvMsg(). - IncomingRecvErrorChan chan<- error - // IncomingRecvValueChan provides the returned proto message for - // incomingStream.RecvMsg() if the error was nil. - IncomingRecvValueChan chan<- *structpb.Value - // BackendSendErrorChan provides the return value for - // backendStream.SendMsg(). - BackendSendErrorChan chan<- error - // BackendSendValueChan receives the proto message provided in the call to - // backendStream.SendMsg(). - BackendSendValueChan <-chan *structpb.Value - // BackendSendCloseChan receives an entry when backendStream.CloseSend() is - // called. - BackendSendCloseChan <-chan struct{} - // BackendRecvErrorChan provides the return value for - // backendStream.RecvMsg(). - BackendRecvErrorChan chan<- error - // BackendRecvValueChan provides the returned proto message for - // backendStream.RecvMsg() if the error was nil. - BackendRecvValueChan chan<- *structpb.Value - // IncomingSendErrorChan provides the return value for - // incomingStream.SendMsg(). - IncomingSendErrorChan chan<- error - // IncomingSendValueChan receives the proto message provided in the call to - // incomingStream.SendMsg(). - IncomingSendValueChan <-chan *structpb.Value +type eqProtoStringValueMatcher struct { + gomock.Matcher } -func newSimpleStreamForwarderStandardFixture(ctx context.Context, ctrl *gomock.Controller, t *testing.T) *simpleStreamForwarderStandardFixture { - backend := mock.NewMockClientConnInterface(ctrl) - forwarder := bb_grpc.NewForwardingStreamHandler(backend) - serverTransportStream := mock.NewMockServerTransportStream(ctrl) - serverTransportStream.EXPECT().Method().Return("/buildbarn.buildqueuestate.BuildQueueState/ListWorkers").AnyTimes() - streamCtx := grpc.NewContextWithServerTransportStream(ctx, serverTransportStream) - incomingStream := mock.NewMockServerStream(ctrl) - incomingStream.EXPECT().Context().Return(streamCtx).AnyTimes() - backendStream := mock.NewMockClientStream(ctrl) - - backendNewStreamErrorChan := make(chan error, 10) - backendNewStreamCtx := make(chan context.Context, 10) - - backend.EXPECT().NewStream( - gomock.Any(), - gomock.Any(), - "/buildbarn.buildqueuestate.BuildQueueState/ListWorkers", - ).DoAndReturn(func(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { - backendNewStreamCtx <- ctx - return backendStream, <-backendNewStreamErrorChan - }).AnyTimes() - - incomingRecvErrorChan := make(chan error, 10) - incomingRecvValueChan := make(chan *structpb.Value, 10) - backendSendErrorChan := make(chan error, 10) - backendSendValueChan := make(chan *structpb.Value, 10) - backendSendCloseChan := make(chan struct{}, 10) - backendRecvErrorChan := make(chan error, 10) - backendRecvValueChan := make(chan *structpb.Value, 10) - incomingSendErrorChan := make(chan error, 10) - incomingSendValueChan := make(chan *structpb.Value, 10) - - incomingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(func(msg any) error { - if err := <-incomingRecvErrorChan; err != nil { - return err - } - value := <-incomingRecvValueChan - bytes, err := proto.Marshal(value) - require.NoError(t, err) - require.NoError(t, proto.Unmarshal(bytes, msg.(proto.Message))) - return nil - }).AnyTimes() - backendStream.EXPECT().SendMsg(gomock.Any()).DoAndReturn(func(msg any) error { - if err := <-backendSendErrorChan; err != nil { - return err - } - bytes, err := proto.Marshal(msg.(proto.Message)) - require.NoError(t, err) - value := new(structpb.Value) - require.NoError(t, proto.Unmarshal(bytes, value)) - backendSendValueChan <- value - return nil - }).AnyTimes() - backendStream.EXPECT().CloseSend().DoAndReturn(func() error { - backendSendCloseChan <- struct{}{} - return nil - }).AnyTimes() - backendStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(func(msg any) error { - if err := <-backendRecvErrorChan; err != nil { - return err - } - value := <-backendRecvValueChan - bytes, err := proto.Marshal(value) - require.NoError(t, err) - require.NoError(t, proto.Unmarshal(bytes, msg.(proto.Message))) - return nil - }).AnyTimes() - incomingStream.EXPECT().SendMsg(gomock.Any()).DoAndReturn(func(msg any) error { - if err := <-incomingSendErrorChan; err != nil { - return err - } - bytes, err := proto.Marshal(msg.(proto.Message)) - require.NoError(t, err) - value := new(structpb.Value) - require.NoError(t, proto.Unmarshal(bytes, value)) - incomingSendValueChan <- value - return nil - }).AnyTimes() - - return &simpleStreamForwarderStandardFixture{ - forwarder: forwarder, - incomingStream: incomingStream, - - backendNewStreamErrorChan: backendNewStreamErrorChan, - backendNewStreamCtx: backendNewStreamCtx, - - IncomingRecvErrorChan: incomingRecvErrorChan, - IncomingRecvValueChan: incomingRecvValueChan, - BackendSendErrorChan: backendSendErrorChan, - BackendSendValueChan: backendSendValueChan, - BackendSendCloseChan: backendSendCloseChan, - BackendRecvErrorChan: backendRecvErrorChan, - BackendRecvValueChan: backendRecvValueChan, - IncomingSendErrorChan: incomingSendErrorChan, - IncomingSendValueChan: incomingSendValueChan, +// newEqProtoStringValueMatcher is a gomock matcher for proto equality after +// converting the proto.Message to structpb.Value. +func newEqProtoStringValueMatcher(t *testing.T, v string) gomock.Matcher { + proto := structpb.NewStringValue(v) + return &eqProtoStringValueMatcher{ + Matcher: testutil.EqProto(t, proto), } } -func (f *simpleStreamForwarderStandardFixture) call(newStreamErr error) (context.Context, <-chan error) { - callResult := make(chan error, 1) - go func() { - defer close(callResult) - f.backendNewStreamErrorChan <- newStreamErr - callResult <- f.forwarder(nil, f.incomingStream) - }() - return <-f.backendNewStreamCtx, callResult +func (m *eqProtoStringValueMatcher) Matches(other interface{}) bool { + otherProto, ok := other.(proto.Message) + if !ok { + return false + } + bytes, err := proto.Marshal(otherProto) + if err != nil { + return false + } + value := new(structpb.Value) + if proto.Unmarshal(bytes, value) != nil { + return false + } + return m.Matcher.Matches(value) } -func (f *simpleStreamForwarderStandardFixture) verifyEmptyChannels(t *testing.T) { - require.Len(t, f.backendNewStreamErrorChan, 0, "backendNewStreamErrorChan") - require.Len(t, f.backendNewStreamCtx, 0, "backendNewStreamCtx") - require.Len(t, f.IncomingRecvErrorChan, 0, "IncomingRecvErrorChan") - require.Len(t, f.IncomingRecvValueChan, 0, "IncomingRecvValueChan") - require.Len(t, f.BackendSendErrorChan, 0, "BackendSendErrorChan") - require.Len(t, f.BackendSendValueChan, 0, "BackendSendValueChan") - require.Len(t, f.BackendSendCloseChan, 0, "BackendSendCloseChan") - require.Len(t, f.BackendRecvErrorChan, 0, "BackendRecvErrorChan") - require.Len(t, f.BackendRecvValueChan, 0, "BackendRecvValueChan") - require.Len(t, f.IncomingSendErrorChan, 0, "IncomingSendErrorChan") - require.Len(t, f.IncomingSendValueChan, 0, "IncomingSendValueChan") +func newForwardingStreamRecvMsgStub(v string) func(msg any) error { + src := structpb.NewStringValue(v) + bytes, err := proto.Marshal(src) + return func(dst any) error { + if err != nil { + return err + } + return proto.Unmarshal(bytes, dst.(proto.Message)) + } } -func TestSimpleStreamForwarderRequestSuccess(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - ctrl, ctx := gomock.WithContext(context.Background(), t) - fixture := newSimpleStreamForwarderStandardFixture(ctx, ctrl, t) - backendCtx, forwardResultChan := fixture.call(nil) - - fixture.IncomingRecvErrorChan <- nil - fixture.IncomingRecvValueChan <- structpb.NewStringValue("beep") - fixture.BackendSendErrorChan <- nil - testutil.RequireEqualProto(t, structpb.NewStringValue("beep"), <-fixture.BackendSendValueChan) - fixture.IncomingRecvErrorChan <- nil - fixture.IncomingRecvValueChan <- structpb.NewStringValue("boop") - fixture.BackendSendErrorChan <- nil - testutil.RequireEqualProto(t, structpb.NewStringValue("boop"), <-fixture.BackendSendValueChan) - - // Should still be forwarding requests to the backend. - synctest.Wait() - require.Len(t, fixture.BackendSendCloseChan, 0) - fixture.IncomingRecvErrorChan <- io.EOF - <-fixture.BackendSendCloseChan - - // Should still be receiving responses. - synctest.Wait() - testutil.VerifyChannelIsBlocking(t, backendCtx.Done()) - require.Len(t, forwardResultChan, 0) - fixture.BackendRecvErrorChan <- io.EOF - <-backendCtx.Done() - require.NoError(t, <-forwardResultChan) +func TestSimpleStreamForwarder(t *testing.T) { + ctrl, _ := gomock.WithContext(context.Background(), t) - fixture.verifyEmptyChannels(t) + backend := mock.NewMockClientConnInterface(ctrl) + forwarder := bb_grpc.NewForwardingStreamHandler(backend) + serverTransportStream := mock.NewMockServerTransportStream(ctrl) + serverTransportStream.EXPECT().Method().Return("/serviceA/method1").AnyTimes() + incomingStream := mock.NewMockServerStream(ctrl) + outgoingStream := mock.NewMockClientStream(ctrl) + + t.Run("RequestSuccess", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + var outgoingStreamCtx context.Context + outgoingRecvBarrier := make(chan struct{}) + incomingStreamCtx := grpc.NewContextWithServerTransportStream(context.Background(), serverTransportStream) + + newStreamCall := backend.EXPECT().NewStream(gomock.Any(), gomock.Any(), "/serviceA/method1").DoAndReturn( + func(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + outgoingStreamCtx = ctx + return outgoingStream, nil + }, + ) + + gomock.InOrder( + incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes(), + newStreamCall, + incomingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn( + newForwardingStreamRecvMsgStub("beep")), + outgoingStream.EXPECT().SendMsg(newEqProtoStringValueMatcher(t, "beep")).Return(nil), + incomingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn( + newForwardingStreamRecvMsgStub("boop")), + outgoingStream.EXPECT().SendMsg(newEqProtoStringValueMatcher(t, "boop")).Return(nil), + incomingStream.EXPECT().RecvMsg(gomock.Any()).Return(io.EOF), + outgoingStream.EXPECT().CloseSend().DoAndReturn(func() error { + close(outgoingRecvBarrier) + return nil + }), + ) + gomock.InOrder( + newStreamCall, + outgoingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(func(msg any) error { + <-outgoingRecvBarrier + testutil.VerifyChannelIsBlocking(t, outgoingStreamCtx.Done()) + return io.EOF + }), + ) + + require.NoError(t, forwarder(nil, incomingStream)) + <-outgoingStreamCtx.Done() + }) }) -} - -func TestSimpleStreamForwarderRequestRecvError(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - ctrl, ctx := gomock.WithContext(context.Background(), t) - fixture := newSimpleStreamForwarderStandardFixture(ctx, ctrl, t) - backendCtx, forwardResultChan := fixture.call(nil) - - fixture.IncomingRecvErrorChan <- errors.New("incoming recv") - - // In error state, so the backend context should be canceled. - <-backendCtx.Done() - // Emulate that backend.RecvMsg() to returns. - fixture.BackendRecvErrorChan <- context.Canceled - require.EqualError(t, <-forwardResultChan, "incoming recv") - - fixture.verifyEmptyChannels(t) + t.Run("RequestRecvError", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + var outgoingStreamCtx context.Context + incomingStreamCtx := grpc.NewContextWithServerTransportStream(context.Background(), serverTransportStream) + + newStreamCall := backend.EXPECT().NewStream(gomock.Any(), gomock.Any(), "/serviceA/method1").DoAndReturn( + func(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + outgoingStreamCtx = ctx + return outgoingStream, nil + }, + ) + gomock.InOrder( + incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes(), + newStreamCall, + incomingStream.EXPECT().RecvMsg(gomock.Any()).Return(errors.New("incoming recv")), + ) + gomock.InOrder( + newStreamCall, + outgoingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(func(msg any) error { + // When incomingStream.RecvMsg returns, the backend context + // should be canceled due to the error. + <-outgoingStreamCtx.Done() + return context.Canceled + }), + ) + + require.EqualError(t, forwarder(nil, incomingStream), "incoming recv") + }) }) -} - -func TestSimpleStreamForwarderRequestSendError(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - ctrl, ctx := gomock.WithContext(context.Background(), t) - fixture := newSimpleStreamForwarderStandardFixture(ctx, ctrl, t) - backendCtx, forwardResultChan := fixture.call(nil) - - fixture.IncomingRecvErrorChan <- nil - fixture.IncomingRecvValueChan <- structpb.NewStringValue("beep") - fixture.BackendSendErrorChan <- errors.New("backend send") - // In error state, so the backend context should be canceled. - <-backendCtx.Done() - // Emulate that backend.RecvMsg() to returns. - fixture.BackendRecvErrorChan <- context.Canceled - - require.EqualError(t, <-forwardResultChan, "backend send") - - fixture.verifyEmptyChannels(t) + t.Run("RequestSendError", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + var outgoingStreamCtx context.Context + incomingStreamCtx := grpc.NewContextWithServerTransportStream(context.Background(), serverTransportStream) + + newStreamCall := backend.EXPECT().NewStream(gomock.Any(), gomock.Any(), "/serviceA/method1").DoAndReturn( + func(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + outgoingStreamCtx = ctx + return outgoingStream, nil + }, + ) + gomock.InOrder( + incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes(), + newStreamCall, + incomingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn( + newForwardingStreamRecvMsgStub("beep")), + outgoingStream.EXPECT().SendMsg(newEqProtoStringValueMatcher(t, "beep")).Return(errors.New("outgoing send")), + ) + gomock.InOrder( + newStreamCall, + outgoingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(func(msg any) error { + // When outgoingStream.SendMsg returns, the outgoing context + // should be canceled due to the error. + <-outgoingStreamCtx.Done() + return context.Canceled + }), + ) + + require.EqualError(t, forwarder(nil, incomingStream), "outgoing send") + }) }) -} -func TestSimpleStreamForwarderResponseSuccess(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - ctrl, ctx := gomock.WithContext(context.Background(), t) - fixture := newSimpleStreamForwarderStandardFixture(ctx, ctrl, t) - backendCtx, forwardResultChan := fixture.call(nil) - - fixture.BackendRecvErrorChan <- nil - fixture.BackendRecvValueChan <- structpb.NewStringValue("beep") - fixture.IncomingSendErrorChan <- nil - testutil.RequireEqualProto(t, structpb.NewStringValue("beep"), <-fixture.IncomingSendValueChan) - fixture.BackendRecvErrorChan <- nil - fixture.BackendRecvValueChan <- structpb.NewStringValue("boop") - fixture.IncomingSendErrorChan <- nil - testutil.RequireEqualProto(t, structpb.NewStringValue("boop"), <-fixture.IncomingSendValueChan) - - // Should still be forwarding requests to the backend. - synctest.Wait() - require.Len(t, fixture.BackendSendCloseChan, 0) - // Should still be receiving responses. - testutil.VerifyChannelIsBlocking(t, backendCtx.Done()) - require.Len(t, forwardResultChan, 0) - fixture.BackendRecvErrorChan <- io.EOF - // Will not receive any more from the backend, so its context should - // be canceled. - <-backendCtx.Done() - require.NoError(t, <-forwardResultChan) - fixture.IncomingRecvErrorChan <- context.Canceled - - fixture.verifyEmptyChannels(t) + t.Run("ResponseSuccess", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + var outgoingStreamCtx context.Context + incomingRecvBarrier := make(chan struct{}) + incomingStreamCtx := grpc.NewContextWithServerTransportStream(context.Background(), serverTransportStream) + + newStreamCall := backend.EXPECT().NewStream(gomock.Any(), gomock.Any(), "/serviceA/method1").DoAndReturn( + func(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + outgoingStreamCtx = ctx + return outgoingStream, nil + }, + ) + gomock.InOrder( + incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes(), + newStreamCall, + outgoingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn( + newForwardingStreamRecvMsgStub("beep")), + incomingStream.EXPECT().SendMsg(newEqProtoStringValueMatcher(t, "beep")).Return(nil), + outgoingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn( + newForwardingStreamRecvMsgStub("boop")), + incomingStream.EXPECT().SendMsg(newEqProtoStringValueMatcher(t, "boop")).Return(nil), + outgoingStream.EXPECT().RecvMsg(gomock.Any()).Return(io.EOF), + ) + gomock.InOrder( + newStreamCall, + incomingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(func(msg any) error { + <-incomingRecvBarrier + return context.Canceled + }), + ) + + require.NoError(t, forwarder(nil, incomingStream)) + <-outgoingStreamCtx.Done() + + // incomingStream.Recv() is still blocking. + close(incomingRecvBarrier) + }) }) -} - -func TestSimpleStreamForwarderResponseRecvError(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - ctrl, ctx := gomock.WithContext(context.Background(), t) - fixture := newSimpleStreamForwarderStandardFixture(ctx, ctrl, t) - backendCtx, forwardResultChan := fixture.call(nil) - - fixture.BackendRecvErrorChan <- errors.New("backend recv") - - // In error state, so the backend context should be canceled. - <-backendCtx.Done() - require.EqualError(t, <-forwardResultChan, "backend recv") - // Emulate that incoming.RecvMsg() returns now when the whole stream - // handler has returned. - fixture.IncomingRecvErrorChan <- context.Canceled - - fixture.verifyEmptyChannels(t) + t.Run("ResponseRecvError", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + var outgoingStreamCtx context.Context + incomingRecvBarrier := make(chan struct{}) + incomingStreamCtx := grpc.NewContextWithServerTransportStream(context.Background(), serverTransportStream) + + newStreamCall := backend.EXPECT().NewStream(gomock.Any(), gomock.Any(), "/serviceA/method1").DoAndReturn( + func(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + outgoingStreamCtx = ctx + return outgoingStream, nil + }, + ) + gomock.InOrder( + incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes(), + newStreamCall, + outgoingStream.EXPECT().RecvMsg(gomock.Any()).Return(errors.New("outgoing recv")), + ) + gomock.InOrder( + newStreamCall, + incomingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(func(msg any) error { + <-incomingRecvBarrier + return context.Canceled + }), + ) + + require.EqualError(t, forwarder(nil, incomingStream), "outgoing recv") + <-outgoingStreamCtx.Done() + + // incomingStream.Recv() is still blocking. + close(incomingRecvBarrier) + }) }) -} - -func TestSimpleStreamForwarderResponseSendError(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - ctrl, ctx := gomock.WithContext(context.Background(), t) - fixture := newSimpleStreamForwarderStandardFixture(ctx, ctrl, t) - backendCtx, forwardResultChan := fixture.call(nil) - - fixture.BackendRecvErrorChan <- nil - fixture.BackendRecvValueChan <- structpb.NewStringValue("beep") - fixture.IncomingSendErrorChan <- errors.New("incoming send") - - // In error state, so the backend context should be canceled. - <-backendCtx.Done() - - // The forwarder should return, even if the backend.RecvMsg() is slow. - require.EqualError(t, <-forwardResultChan, "incoming send") - - // Emulate that incoming.RecvMsg() returns now when the whole stream - // handler has returned. - fixture.IncomingRecvErrorChan <- context.Canceled - fixture.verifyEmptyChannels(t) + t.Run("ResponseSendError", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + var outgoingStreamCtx context.Context + incomingStreamCtx := grpc.NewContextWithServerTransportStream(context.Background(), serverTransportStream) + + newStreamCall := backend.EXPECT().NewStream(gomock.Any(), gomock.Any(), "/serviceA/method1").DoAndReturn( + func(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + outgoingStreamCtx = ctx + return outgoingStream, nil + }, + ) + gomock.InOrder( + incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes(), + newStreamCall, + outgoingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn( + newForwardingStreamRecvMsgStub("beep")), + incomingStream.EXPECT().SendMsg(newEqProtoStringValueMatcher(t, "beep")).Return(errors.New("incoming send")), + ) + gomock.InOrder( + newStreamCall, + incomingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(func(msg any) error { + // When incomingStream.SendMsg returns, the outgoing context + // should be canceled due to the error. + <-outgoingStreamCtx.Done() + return context.Canceled + }), + ) + + require.EqualError(t, forwarder(nil, incomingStream), "incoming send") + }) }) } diff --git a/pkg/testutil/testutil.go b/pkg/testutil/testutil.go index dbe01af7..bb2bfe3f 100644 --- a/pkg/testutil/testutil.go +++ b/pkg/testutil/testutil.go @@ -144,6 +144,7 @@ func mustMarshalToString(t *testing.T, proto proto.Message) string { func VerifyChannelIsBlocking[Value any](t *testing.T, channel <-chan Value) { select { case <-channel: + t.Helper() t.Error("Channel is not blocking") default: } From 68d8e54c8a39a7591c28e5f650d21599f4ee9398 Mon Sep 17 00:00:00 2001 From: Fredrik Medley Date: Thu, 11 Dec 2025 12:01:28 +0100 Subject: [PATCH 5/9] Minor test adjustments --- pkg/grpc/forwarding_stream_handler_test.go | 49 +++++++++++++--------- pkg/testutil/testutil.go | 11 ----- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/pkg/grpc/forwarding_stream_handler_test.go b/pkg/grpc/forwarding_stream_handler_test.go index 0f300d77..42927f93 100644 --- a/pkg/grpc/forwarding_stream_handler_test.go +++ b/pkg/grpc/forwarding_stream_handler_test.go @@ -82,14 +82,12 @@ func TestSimpleStreamForwarder(t *testing.T) { }, ) + incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes() gomock.InOrder( - incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes(), newStreamCall, - incomingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn( - newForwardingStreamRecvMsgStub("beep")), + incomingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(newForwardingStreamRecvMsgStub("beep")), outgoingStream.EXPECT().SendMsg(newEqProtoStringValueMatcher(t, "beep")).Return(nil), - incomingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn( - newForwardingStreamRecvMsgStub("boop")), + incomingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(newForwardingStreamRecvMsgStub("boop")), outgoingStream.EXPECT().SendMsg(newEqProtoStringValueMatcher(t, "boop")).Return(nil), incomingStream.EXPECT().RecvMsg(gomock.Any()).Return(io.EOF), outgoingStream.EXPECT().CloseSend().DoAndReturn(func() error { @@ -101,7 +99,8 @@ func TestSimpleStreamForwarder(t *testing.T) { newStreamCall, outgoingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(func(msg any) error { <-outgoingRecvBarrier - testutil.VerifyChannelIsBlocking(t, outgoingStreamCtx.Done()) + synctest.Wait() + require.NoError(t, outgoingStreamCtx.Err()) return io.EOF }), ) @@ -122,8 +121,9 @@ func TestSimpleStreamForwarder(t *testing.T) { return outgoingStream, nil }, ) + + incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes() gomock.InOrder( - incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes(), newStreamCall, incomingStream.EXPECT().RecvMsg(gomock.Any()).Return(errors.New("incoming recv")), ) @@ -152,11 +152,11 @@ func TestSimpleStreamForwarder(t *testing.T) { return outgoingStream, nil }, ) + + incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes() gomock.InOrder( - incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes(), newStreamCall, - incomingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn( - newForwardingStreamRecvMsgStub("beep")), + incomingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(newForwardingStreamRecvMsgStub("beep")), outgoingStream.EXPECT().SendMsg(newEqProtoStringValueMatcher(t, "beep")).Return(errors.New("outgoing send")), ) gomock.InOrder( @@ -185,14 +185,13 @@ func TestSimpleStreamForwarder(t *testing.T) { return outgoingStream, nil }, ) + + incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes() gomock.InOrder( - incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes(), newStreamCall, - outgoingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn( - newForwardingStreamRecvMsgStub("beep")), + outgoingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(newForwardingStreamRecvMsgStub("beep")), incomingStream.EXPECT().SendMsg(newEqProtoStringValueMatcher(t, "beep")).Return(nil), - outgoingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn( - newForwardingStreamRecvMsgStub("boop")), + outgoingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(newForwardingStreamRecvMsgStub("boop")), incomingStream.EXPECT().SendMsg(newEqProtoStringValueMatcher(t, "boop")).Return(nil), outgoingStream.EXPECT().RecvMsg(gomock.Any()).Return(io.EOF), ) @@ -224,8 +223,9 @@ func TestSimpleStreamForwarder(t *testing.T) { return outgoingStream, nil }, ) + + incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes() gomock.InOrder( - incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes(), newStreamCall, outgoingStream.EXPECT().RecvMsg(gomock.Any()).Return(errors.New("outgoing recv")), ) @@ -256,11 +256,11 @@ func TestSimpleStreamForwarder(t *testing.T) { return outgoingStream, nil }, ) + + incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes() gomock.InOrder( - incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes(), newStreamCall, - outgoingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn( - newForwardingStreamRecvMsgStub("beep")), + outgoingStream.EXPECT().RecvMsg(gomock.Any()).DoAndReturn(newForwardingStreamRecvMsgStub("beep")), incomingStream.EXPECT().SendMsg(newEqProtoStringValueMatcher(t, "beep")).Return(errors.New("incoming send")), ) gomock.InOrder( @@ -276,4 +276,15 @@ func TestSimpleStreamForwarder(t *testing.T) { require.EqualError(t, forwarder(nil, incomingStream), "incoming send") }) }) + + t.Run("NewStreamError", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + incomingStreamCtx := grpc.NewContextWithServerTransportStream(context.Background(), serverTransportStream) + + incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes() + backend.EXPECT().NewStream(gomock.Any(), gomock.Any(), "/serviceA/method1").Return(nil, errors.New("no stream")) + + require.EqualError(t, forwarder(nil, incomingStream), "no stream") + }) + }) } diff --git a/pkg/testutil/testutil.go b/pkg/testutil/testutil.go index bb2bfe3f..6fc3172b 100644 --- a/pkg/testutil/testutil.go +++ b/pkg/testutil/testutil.go @@ -138,14 +138,3 @@ func mustMarshalToString(t *testing.T, proto proto.Message) string { } return string(s) } - -// VerifyChannelIsBlocking checks that no value can be received from the -// channel. If a value is ready or the channel has been closed, the test fails. -func VerifyChannelIsBlocking[Value any](t *testing.T, channel <-chan Value) { - select { - case <-channel: - t.Helper() - t.Error("Channel is not blocking") - default: - } -} From e8598d6cc2ad2b99c4e48fd4cedf73008cb49baf Mon Sep 17 00:00:00 2001 From: Fredrik Medley Date: Mon, 15 Dec 2025 22:41:01 +0100 Subject: [PATCH 6/9] Review feedback --- pkg/grpc/BUILD.bazel | 2 -- pkg/grpc/forwarding_stream_handler.go | 3 ++- pkg/grpc/forwarding_stream_handler_test.go | 15 +++++++++++++-- pkg/grpc/reflection_relay.go | 6 ++---- pkg/grpc/routing_stream_handler.go | 3 ++- pkg/grpc/server_transport_stream_context.go | 18 ------------------ 6 files changed, 19 insertions(+), 28 deletions(-) delete mode 100644 pkg/grpc/server_transport_stream_context.go diff --git a/pkg/grpc/BUILD.bazel b/pkg/grpc/BUILD.bazel index 593f8791..fe7c95aa 100644 --- a/pkg/grpc/BUILD.bazel +++ b/pkg/grpc/BUILD.bazel @@ -32,7 +32,6 @@ go_library( "request_metadata_tracing_interceptor.go", "routing_stream_handler.go", "server.go", - "server_transport_stream_context.go", "tls_client_certificate_authenticator.go", ], importpath = "github.com/buildbarn/bb-storage/pkg/grpc", @@ -70,7 +69,6 @@ go_library( "@org_golang_google_grpc//peer", "@org_golang_google_grpc//reflection", "@org_golang_google_grpc//reflection/grpc_reflection_v1", - "@org_golang_google_grpc//reflection/grpc_reflection_v1alpha", "@org_golang_google_grpc//status", "@org_golang_google_grpc_security_advancedtls//:advancedtls", "@org_golang_google_protobuf//encoding/prototext", diff --git a/pkg/grpc/forwarding_stream_handler.go b/pkg/grpc/forwarding_stream_handler.go index 99720a4a..4f0364c3 100644 --- a/pkg/grpc/forwarding_stream_handler.go +++ b/pkg/grpc/forwarding_stream_handler.go @@ -25,7 +25,8 @@ type forwardingStreamHandler struct { // incomingStream are forwarded to the backend stream and responses from the // backend stream are sent back in the incomingStream. func (s *forwardingStreamHandler) HandleStream(srv any, incomingStream grpc.ServerStream) error { - method := MustStreamMethodFromContext(incomingStream.Context()) + // All gRPC invocations has a grpc.ServerTransportStream context. + method := grpc.ServerTransportStreamFromContext(incomingStream.Context()).Method() desc := grpc.StreamDesc{ // According to grpc.StreamDesc documentation, StreamName and Handler // are only used when registering handlers on a server. diff --git a/pkg/grpc/forwarding_stream_handler_test.go b/pkg/grpc/forwarding_stream_handler_test.go index 42927f93..c5eb2096 100644 --- a/pkg/grpc/forwarding_stream_handler_test.go +++ b/pkg/grpc/forwarding_stream_handler_test.go @@ -66,14 +66,14 @@ func TestSimpleStreamForwarder(t *testing.T) { forwarder := bb_grpc.NewForwardingStreamHandler(backend) serverTransportStream := mock.NewMockServerTransportStream(ctrl) serverTransportStream.EXPECT().Method().Return("/serviceA/method1").AnyTimes() - incomingStream := mock.NewMockServerStream(ctrl) - outgoingStream := mock.NewMockClientStream(ctrl) t.Run("RequestSuccess", func(t *testing.T) { synctest.Test(t, func(t *testing.T) { var outgoingStreamCtx context.Context outgoingRecvBarrier := make(chan struct{}) incomingStreamCtx := grpc.NewContextWithServerTransportStream(context.Background(), serverTransportStream) + incomingStream := mock.NewMockServerStream(ctrl) + outgoingStream := mock.NewMockClientStream(ctrl) newStreamCall := backend.EXPECT().NewStream(gomock.Any(), gomock.Any(), "/serviceA/method1").DoAndReturn( func(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { @@ -114,6 +114,8 @@ func TestSimpleStreamForwarder(t *testing.T) { synctest.Test(t, func(t *testing.T) { var outgoingStreamCtx context.Context incomingStreamCtx := grpc.NewContextWithServerTransportStream(context.Background(), serverTransportStream) + incomingStream := mock.NewMockServerStream(ctrl) + outgoingStream := mock.NewMockClientStream(ctrl) newStreamCall := backend.EXPECT().NewStream(gomock.Any(), gomock.Any(), "/serviceA/method1").DoAndReturn( func(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { @@ -145,6 +147,8 @@ func TestSimpleStreamForwarder(t *testing.T) { synctest.Test(t, func(t *testing.T) { var outgoingStreamCtx context.Context incomingStreamCtx := grpc.NewContextWithServerTransportStream(context.Background(), serverTransportStream) + incomingStream := mock.NewMockServerStream(ctrl) + outgoingStream := mock.NewMockClientStream(ctrl) newStreamCall := backend.EXPECT().NewStream(gomock.Any(), gomock.Any(), "/serviceA/method1").DoAndReturn( func(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { @@ -178,6 +182,8 @@ func TestSimpleStreamForwarder(t *testing.T) { var outgoingStreamCtx context.Context incomingRecvBarrier := make(chan struct{}) incomingStreamCtx := grpc.NewContextWithServerTransportStream(context.Background(), serverTransportStream) + incomingStream := mock.NewMockServerStream(ctrl) + outgoingStream := mock.NewMockClientStream(ctrl) newStreamCall := backend.EXPECT().NewStream(gomock.Any(), gomock.Any(), "/serviceA/method1").DoAndReturn( func(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { @@ -216,6 +222,8 @@ func TestSimpleStreamForwarder(t *testing.T) { var outgoingStreamCtx context.Context incomingRecvBarrier := make(chan struct{}) incomingStreamCtx := grpc.NewContextWithServerTransportStream(context.Background(), serverTransportStream) + incomingStream := mock.NewMockServerStream(ctrl) + outgoingStream := mock.NewMockClientStream(ctrl) newStreamCall := backend.EXPECT().NewStream(gomock.Any(), gomock.Any(), "/serviceA/method1").DoAndReturn( func(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { @@ -249,6 +257,8 @@ func TestSimpleStreamForwarder(t *testing.T) { synctest.Test(t, func(t *testing.T) { var outgoingStreamCtx context.Context incomingStreamCtx := grpc.NewContextWithServerTransportStream(context.Background(), serverTransportStream) + incomingStream := mock.NewMockServerStream(ctrl) + outgoingStream := mock.NewMockClientStream(ctrl) newStreamCall := backend.EXPECT().NewStream(gomock.Any(), gomock.Any(), "/serviceA/method1").DoAndReturn( func(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { @@ -280,6 +290,7 @@ func TestSimpleStreamForwarder(t *testing.T) { t.Run("NewStreamError", func(t *testing.T) { synctest.Test(t, func(t *testing.T) { incomingStreamCtx := grpc.NewContextWithServerTransportStream(context.Background(), serverTransportStream) + incomingStream := mock.NewMockServerStream(ctrl) incomingStream.EXPECT().Context().Return(incomingStreamCtx).AnyTimes() backend.EXPECT().NewStream(gomock.Any(), gomock.Any(), "/serviceA/method1").Return(nil, errors.New("no stream")) diff --git a/pkg/grpc/reflection_relay.go b/pkg/grpc/reflection_relay.go index b1eceec5..345870e6 100644 --- a/pkg/grpc/reflection_relay.go +++ b/pkg/grpc/reflection_relay.go @@ -9,8 +9,7 @@ import ( "github.com/buildbarn/bb-storage/pkg/util" "github.com/jhump/protoreflect/v2/grpcreflect" "github.com/jhump/protoreflect/v2/protoresolve" - v1reflectiongrpc "google.golang.org/grpc/reflection/grpc_reflection_v1" - v1alphareflectiongrpc "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" + "google.golang.org/grpc/reflection/grpc_reflection_v1" "google.golang.org/grpc" "google.golang.org/grpc/reflection" @@ -66,7 +65,6 @@ func registerReflection(backendCtx context.Context, s *grpc.Server, serverRelayC DescriptorResolver: combinedRemoteResolver, ExtensionResolver: protoresolve.TypesFromDescriptorPool(combinedRemoteResolver), } - v1reflectiongrpc.RegisterServerReflectionServer(s, reflection.NewServerV1(serverOptions)) - v1alphareflectiongrpc.RegisterServerReflectionServer(s, reflection.NewServer(serverOptions)) + grpc_reflection_v1.RegisterServerReflectionServer(s, reflection.NewServerV1(serverOptions)) return nil } diff --git a/pkg/grpc/routing_stream_handler.go b/pkg/grpc/routing_stream_handler.go index 759cf2f0..42107f1d 100644 --- a/pkg/grpc/routing_stream_handler.go +++ b/pkg/grpc/routing_stream_handler.go @@ -16,7 +16,8 @@ import ( // com.google.devtools.build.v1.PublishBuildEvent func NewRoutingStreamHandler(routeTable map[string]grpc.StreamHandler) grpc.StreamHandler { return func(srv any, stream grpc.ServerStream) error { - serviceMethod := MustStreamMethodFromContext(stream.Context()) + // All gRPC invocations has a grpc.ServerTransportStream context. + serviceMethod := grpc.ServerTransportStreamFromContext(stream.Context()).Method() // Service and method name parsing based on grpc.Server.handleStream(). startIdx := 0 if serviceMethod != "" && serviceMethod[0] == '/' { diff --git a/pkg/grpc/server_transport_stream_context.go b/pkg/grpc/server_transport_stream_context.go deleted file mode 100644 index c3347cf7..00000000 --- a/pkg/grpc/server_transport_stream_context.go +++ /dev/null @@ -1,18 +0,0 @@ -package grpc - -import ( - "context" - - "google.golang.org/grpc" -) - -// MustStreamMethodFromContext returns the service and method name for the ongoing gRPC stream. -// It will panic if the given context has no grpc.ServerTransportStream associated with it -// (which implies it is not an RPC invocation context). -func MustStreamMethodFromContext(ctx context.Context) string { - transportStream := grpc.ServerTransportStreamFromContext(ctx) - if transportStream == nil { - panic("No grpc.ServerTransportStream in context") - } - return transportStream.Method() -} From 080d955b5013f087c767376be424602f005eae56 Mon Sep 17 00:00:00 2001 From: Fredrik Medley Date: Tue, 13 Jan 2026 16:19:53 +0100 Subject: [PATCH 7/9] Review feedback 2 --- pkg/grpc/reflection_relay.go | 32 +++++++++++++++--------------- pkg/grpc/routing_stream_handler.go | 13 +++++------- pkg/grpc/server.go | 27 +++++++++++++++---------- 3 files changed, 38 insertions(+), 34 deletions(-) diff --git a/pkg/grpc/reflection_relay.go b/pkg/grpc/reflection_relay.go index 345870e6..614b7800 100644 --- a/pkg/grpc/reflection_relay.go +++ b/pkg/grpc/reflection_relay.go @@ -4,9 +4,7 @@ import ( "context" "maps" - "github.com/buildbarn/bb-storage/pkg/program" grpcpb "github.com/buildbarn/bb-storage/pkg/proto/configuration/grpc" - "github.com/buildbarn/bb-storage/pkg/util" "github.com/jhump/protoreflect/v2/grpcreflect" "github.com/jhump/protoreflect/v2/protoresolve" "google.golang.org/grpc/reflection/grpc_reflection_v1" @@ -25,20 +23,26 @@ var _ reflection.ServiceInfoProvider = (*combinedServiceInfoProvider)(nil) // GetServiceInfo returns the currently available services, which might have // changed since the creation of this reflection server. func (p *combinedServiceInfoProvider) GetServiceInfo() map[string]grpc.ServiceInfo { - services := make(map[string]grpc.ServiceInfo) + serverServiceInfo := p.server.GetServiceInfo() + services := make(map[string]grpc.ServiceInfo, len(p.extraServices)+len(serverServiceInfo)) maps.Copy(services, p.extraServices) - maps.Copy(services, p.server.GetServiceInfo()) + maps.Copy(services, serverServiceInfo) return services } -// registerReflection registers the google.golang.org/grpc/reflection/ service -// on a grpc.Server and calls remote backends in case for relayed services. The -// connections to the backend will run with the backendCtx. -func registerReflection(backendCtx context.Context, s *grpc.Server, serverRelayConfiguration []*grpcpb.ServerRelayConfiguration, group program.Group, grpcClientFactory ClientFactory) error { +type serverRelayConfigWithGrpcClient struct { + config *grpcpb.ServerRelayConfiguration + grpcClient grpc.ClientConnInterface +} + +// registerReflectionServer registers the google.golang.org/grpc/reflection/ +// service on a grpc.Server and calls remote backends in case for relayed +// services. The connections to the backend will run with the backendCtx. +func registerReflectionServer(backendCtx context.Context, s *grpc.Server, serverRelayConfigurations []serverRelayConfigWithGrpcClient) error { // Accumulate all the service names. relayServices := make(map[string]grpc.ServiceInfo) - for _, relay := range serverRelayConfiguration { - for _, service := range relay.GetServices() { + for _, relay := range serverRelayConfigurations { + for _, service := range relay.config.GetServices() { // According to ServiceInfoProvider docs for ServerOptions.Services, // the reflection service is only interested in the service names. relayServices[service] = grpc.ServiceInfo{} @@ -47,12 +51,8 @@ func registerReflection(backendCtx context.Context, s *grpc.Server, serverRelayC // Make a combined descriptor and extension resolver. reflectionBackends := []protoresolve.Resolver{} - for relayIdx, relay := range serverRelayConfiguration { - grpcClient, err := grpcClientFactory.NewClientFromConfiguration(relay.Endpoint, group) - if err != nil { - return util.StatusWrapf(err, "Failed to create relay RPC client %d", relayIdx+1) - } - resolver := grpcreflect.NewClientAuto(backendCtx, grpcClient).AsResolver() + for _, relay := range serverRelayConfigurations { + resolver := grpcreflect.NewClientAuto(backendCtx, relay.grpcClient).AsResolver() reflectionBackends = append(reflectionBackends, resolver) } combinedRemoteResolver := protoresolve.Combine(reflectionBackends...) diff --git a/pkg/grpc/routing_stream_handler.go b/pkg/grpc/routing_stream_handler.go index 42107f1d..8c995acc 100644 --- a/pkg/grpc/routing_stream_handler.go +++ b/pkg/grpc/routing_stream_handler.go @@ -17,17 +17,14 @@ import ( func NewRoutingStreamHandler(routeTable map[string]grpc.StreamHandler) grpc.StreamHandler { return func(srv any, stream grpc.ServerStream) error { // All gRPC invocations has a grpc.ServerTransportStream context. - serviceMethod := grpc.ServerTransportStreamFromContext(stream.Context()).Method() + orgServiceMethod := grpc.ServerTransportStreamFromContext(stream.Context()).Method() // Service and method name parsing based on grpc.Server.handleStream(). - startIdx := 0 - if serviceMethod != "" && serviceMethod[0] == '/' { - startIdx = 1 - } + serviceMethod := strings.TrimPrefix(orgServiceMethod, "/") endIdx := strings.LastIndex(serviceMethod, "/") - if endIdx <= startIdx { - return status.Errorf(codes.InvalidArgument, "Malformed method name %v", serviceMethod) + if endIdx == -1 { + return status.Errorf(codes.InvalidArgument, "Malformed method name %v", orgServiceMethod) } - service := serviceMethod[startIdx:endIdx] + service := serviceMethod[:endIdx] if streamHandler, ok := routeTable[service]; ok { return streamHandler(srv, stream) diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 95db8896..62b235f9 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -146,8 +146,19 @@ func NewServersFromConfigurationAndServe(configurations []*configuration.ServerC })) } + relayConfigWithGrpcClients := make([]serverRelayConfigWithGrpcClient, len(configuration.Relays)) + for relayIdx, relay := range configuration.Relays { + grpcClient, err := grpcClientFactory.NewClientFromConfiguration(relay.Endpoint, group) + if err != nil { + return util.StatusWrapf(err, "Failed to create relay RPC client %d", relayIdx+1) + } + relayConfigWithGrpcClients[relayIdx] = serverRelayConfigWithGrpcClient{ + config: relay, + grpcClient: grpcClient, + } + } if len(configuration.Relays) != 0 { - handler, err := newRoutingStreamHandlerFromConfiguration(configuration.Relays, grpcClientFactory, group) + handler, err := newRoutingStreamHandlerFromConfiguration(relayConfigWithGrpcClients) if err != nil { return err } @@ -169,7 +180,7 @@ func NewServersFromConfigurationAndServe(configurations []*configuration.ServerC // Enable default services. grpc_prometheus.Register(s) - if err := registerReflection(context.Background(), s, configuration.Relays, group, grpcClientFactory); err != nil { + if err := registerReflectionServer(context.Background(), s, relayConfigWithGrpcClients); err != nil { return util.StatusWrap(err, "Failed to create reflection service") } h := health.NewServer() @@ -218,15 +229,11 @@ func NewServersFromConfigurationAndServe(configurations []*configuration.ServerC return nil } -func newRoutingStreamHandlerFromConfiguration(serverRelayConfiguration []*grpcpb.ServerRelayConfiguration, grpcClientFactory ClientFactory, group program.Group) (grpc.StreamHandler, error) { +func newRoutingStreamHandlerFromConfiguration(serverRelayConfigurations []serverRelayConfigWithGrpcClient) (grpc.StreamHandler, error) { routeTable := make(map[string]grpc.StreamHandler) - for i, relay := range serverRelayConfiguration { - grpcClient, err := grpcClientFactory.NewClientFromConfiguration(relay.GetEndpoint(), group) - if err != nil { - return nil, util.StatusWrapf(err, "Failed to create gRPC relay RPC client at index %d", i) - } - handler := NewForwardingStreamHandler(grpcClient) - for _, service := range relay.GetServices() { + for _, relay := range serverRelayConfigurations { + handler := NewForwardingStreamHandler(relay.grpcClient) + for _, service := range relay.config.GetServices() { if _, ok := routeTable[service]; ok { return nil, status.Errorf(codes.InvalidArgument, "Duplicated gRPC relay for %v", service) } From 9c87967982e32991ef6c9dcfba66d037a8c80993 Mon Sep 17 00:00:00 2001 From: Fredrik Medley Date: Wed, 14 Jan 2026 00:07:48 +0100 Subject: [PATCH 8/9] Review feedback 3 --- pkg/grpc/reflection_relay.go | 9 ++------- pkg/grpc/server.go | 9 ++++++++- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pkg/grpc/reflection_relay.go b/pkg/grpc/reflection_relay.go index 614b7800..83c1a470 100644 --- a/pkg/grpc/reflection_relay.go +++ b/pkg/grpc/reflection_relay.go @@ -7,10 +7,10 @@ import ( grpcpb "github.com/buildbarn/bb-storage/pkg/proto/configuration/grpc" "github.com/jhump/protoreflect/v2/grpcreflect" "github.com/jhump/protoreflect/v2/protoresolve" - "google.golang.org/grpc/reflection/grpc_reflection_v1" "google.golang.org/grpc" "google.golang.org/grpc/reflection" + "google.golang.org/grpc/reflection/grpc_reflection_v1" ) type combinedServiceInfoProvider struct { @@ -30,11 +30,6 @@ func (p *combinedServiceInfoProvider) GetServiceInfo() map[string]grpc.ServiceIn return services } -type serverRelayConfigWithGrpcClient struct { - config *grpcpb.ServerRelayConfiguration - grpcClient grpc.ClientConnInterface -} - // registerReflectionServer registers the google.golang.org/grpc/reflection/ // service on a grpc.Server and calls remote backends in case for relayed // services. The connections to the backend will run with the backendCtx. @@ -42,7 +37,7 @@ func registerReflectionServer(backendCtx context.Context, s *grpc.Server, server // Accumulate all the service names. relayServices := make(map[string]grpc.ServiceInfo) for _, relay := range serverRelayConfigurations { - for _, service := range relay.config.GetServices() { + for _, service := range relay.config.Services { // According to ServiceInfoProvider docs for ServerOptions.Services, // the reflection service is only interested in the service names. relayServices[service] = grpc.ServiceInfo{} diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 62b235f9..e54cbb47 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -30,6 +30,13 @@ func init() { util.DecimalExponentialBuckets(-3, 6, 2))) } +type serverRelayConfigWithGrpcClient struct { + // config is never nil. + config *grpcpb.ServerRelayConfiguration + // grpcClient is a client created according to config.Endpoint. + grpcClient grpc.ClientConnInterface +} + // NewServersFromConfigurationAndServe creates a series of gRPC servers // based on a configuration stored in a list of Protobuf messages. It // then lets all of these gRPC servers listen on the network addresses @@ -233,7 +240,7 @@ func newRoutingStreamHandlerFromConfiguration(serverRelayConfigurations []server routeTable := make(map[string]grpc.StreamHandler) for _, relay := range serverRelayConfigurations { handler := NewForwardingStreamHandler(relay.grpcClient) - for _, service := range relay.config.GetServices() { + for _, service := range relay.config.Services { if _, ok := routeTable[service]; ok { return nil, status.Errorf(codes.InvalidArgument, "Duplicated gRPC relay for %v", service) } From c10dc4bab6bad9b2a995bc88ab0753168867c309 Mon Sep 17 00:00:00 2001 From: Fredrik Medley Date: Wed, 14 Jan 2026 00:19:31 +0100 Subject: [PATCH 9/9] Minor adjustments --- pkg/grpc/forwarding_stream_handler.go | 2 +- pkg/grpc/reflection_relay.go | 1 - pkg/grpc/routing_stream_handler.go | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pkg/grpc/forwarding_stream_handler.go b/pkg/grpc/forwarding_stream_handler.go index 4f0364c3..be95d913 100644 --- a/pkg/grpc/forwarding_stream_handler.go +++ b/pkg/grpc/forwarding_stream_handler.go @@ -26,7 +26,7 @@ type forwardingStreamHandler struct { // backend stream are sent back in the incomingStream. func (s *forwardingStreamHandler) HandleStream(srv any, incomingStream grpc.ServerStream) error { // All gRPC invocations has a grpc.ServerTransportStream context. - method := grpc.ServerTransportStreamFromContext(incomingStream.Context()).Method() + method, _ := grpc.Method(incomingStream.Context()) desc := grpc.StreamDesc{ // According to grpc.StreamDesc documentation, StreamName and Handler // are only used when registering handlers on a server. diff --git a/pkg/grpc/reflection_relay.go b/pkg/grpc/reflection_relay.go index 83c1a470..143f996c 100644 --- a/pkg/grpc/reflection_relay.go +++ b/pkg/grpc/reflection_relay.go @@ -4,7 +4,6 @@ import ( "context" "maps" - grpcpb "github.com/buildbarn/bb-storage/pkg/proto/configuration/grpc" "github.com/jhump/protoreflect/v2/grpcreflect" "github.com/jhump/protoreflect/v2/protoresolve" diff --git a/pkg/grpc/routing_stream_handler.go b/pkg/grpc/routing_stream_handler.go index 8c995acc..1b897611 100644 --- a/pkg/grpc/routing_stream_handler.go +++ b/pkg/grpc/routing_stream_handler.go @@ -17,7 +17,7 @@ import ( func NewRoutingStreamHandler(routeTable map[string]grpc.StreamHandler) grpc.StreamHandler { return func(srv any, stream grpc.ServerStream) error { // All gRPC invocations has a grpc.ServerTransportStream context. - orgServiceMethod := grpc.ServerTransportStreamFromContext(stream.Context()).Method() + orgServiceMethod, _ := grpc.Method(stream.Context()) // Service and method name parsing based on grpc.Server.handleStream(). serviceMethod := strings.TrimPrefix(orgServiceMethod, "/") endIdx := strings.LastIndex(serviceMethod, "/")