diff --git a/config/config.go b/config/config.go index 1623ad23..b09f4b4a 100644 --- a/config/config.go +++ b/config/config.go @@ -106,20 +106,21 @@ type ( } S2SProxyConfig struct { - Inbound *ProxyConfig `yaml:"inbound"` - Outbound *ProxyConfig `yaml:"outbound"` - MuxTransports []MuxTransportConfig `yaml:"mux"` - HealthCheck *HealthCheckConfig `yaml:"healthCheck"` - NamespaceNameTranslation NamespaceNameTranslationConfig `yaml:"namespaceNameTranslation"` - Metrics *MetricsConfig `yaml:"metrics"` - ProfilingConfig ProfilingConfig `yaml:"profiling"` + Inbound *ProxyConfig `yaml:"inbound"` + Outbound *ProxyConfig `yaml:"outbound"` + MuxTransports []MuxTransportConfig `yaml:"mux"` + HealthCheck *HealthCheckConfig `yaml:"healthCheck"` + NamespaceNameTranslation NameTranslationConfig `yaml:"namespaceNameTranslation"` + ClusterNameTranslation NameTranslationConfig `yaml:"clusterNameTranslation"` + Metrics *MetricsConfig `yaml:"metrics"` + ProfilingConfig ProfilingConfig `yaml:"profiling"` } ProfilingConfig struct { PProfHTTPAddress string `yaml:"pprofAddress"` } - NamespaceNameTranslationConfig struct { + NameTranslationConfig struct { Mappings []NameMappingConfig `yaml:"mappings"` } @@ -283,7 +284,7 @@ func (mc *MockConfigProvider) GetS2SProxyConfig() S2SProxyConfig { } // ToMaps returns request and response mappings. -func (n NamespaceNameTranslationConfig) ToMaps(inBound bool) (map[string]string, map[string]string) { +func (n NameTranslationConfig) ToMaps(inBound bool) (map[string]string, map[string]string) { reqMap := make(map[string]string) respMap := make(map[string]string) if inBound { diff --git a/interceptor/access_control_test.go b/interceptor/access_control_test.go index 1e79b6f5..c3ddcb5c 100644 --- a/interceptor/access_control_test.go +++ b/interceptor/access_control_test.go @@ -160,7 +160,7 @@ func testNamespaceAccessControl(t *testing.T, objCases []objCase) { require.ErrorContains(t, err, c.expError) } else { require.NoError(t, err) - if c.containsNamespace { + if c.containsObjName { require.Equal(t, ts.expAllowed, allowed) } else { require.True(t, allowed) diff --git a/interceptor/cluster_translator_test.go b/interceptor/cluster_translator_test.go new file mode 100644 index 00000000..0db3ad24 --- /dev/null +++ b/interceptor/cluster_translator_test.go @@ -0,0 +1,65 @@ +package interceptor + +import ( + "testing" + + "go.temporal.io/api/namespace/v1" + "go.temporal.io/api/replication/v1" + "go.temporal.io/server/api/adminservice/v1" +) + +func generateClusterNameObjs() []objCase { + return []objCase{ + { + objName: "nil", + makeType: func(clusterName string) any { + return nil + }, + }, + { + objName: "DescribeClusterResponse", + makeType: func(clusterName string) any { + return &adminservice.DescribeClusterResponse{ + ServerVersion: "1.2.3", + ClusterId: "abc123", + ClusterName: clusterName, + HistoryShardCount: 1234, + } + }, + }, + { + objName: "GetNamespaceResponse", + containsObjName: false, + makeType: func(name string) any { + return &adminservice.GetNamespaceResponse{ + Info: &namespace.NamespaceInfo{ + Name: "namespace-name", + }, + Config: &namespace.NamespaceConfig{}, + ReplicationConfig: &replication.NamespaceReplicationConfig{ + ActiveClusterName: name, + Clusters: []*replication.ClusterReplicationConfig{ + { + ClusterName: name, + }, + { + + ClusterName: "some-other-name", + }, + }, + }, + ConfigVersion: 1, + FailoverVersion: 2, + FailoverHistory: []*replication.FailoverStatus{}, + IsGlobalNamespace: true, + } + + }, + expError: "", + }, + } +} + +func TestTranslateClusterName(t *testing.T) { + testTranslateObjects(t, generateClusterNameObjs()) +} diff --git a/interceptor/namespace_translator_test.go b/interceptor/namespace_translator_test.go index f2e475a3..65674280 100644 --- a/interceptor/namespace_translator_test.go +++ b/interceptor/namespace_translator_test.go @@ -43,32 +43,32 @@ type ( } objCase struct { - objName string - containsNamespace bool - makeType func(ns string) any - expError string + objName string + containsObjName bool + makeType func(name string) any + expError string } ) func generateNamespaceObjCases() []objCase { return []objCase{ { - objName: "Namespace field", - containsNamespace: true, + objName: "Namespace field", + containsObjName: true, makeType: func(ns string) any { return &StructWithNamespaceField{Namespace: ns} }, }, { - objName: "WorkflowNamespace field", - containsNamespace: true, + objName: "WorkflowNamespace field", + containsObjName: true, makeType: func(ns string) any { return &StructWithWorkflowNamespaceField{WorkflowNamespace: ns} }, }, { - objName: "Nested Namespace field", - containsNamespace: true, + objName: "Nested Namespace field", + containsObjName: true, makeType: func(ns string) any { return &StructWithNestedNamespaceField{ Other: "do not change", @@ -79,8 +79,8 @@ func generateNamespaceObjCases() []objCase { }, }, { - objName: "list of structs", - containsNamespace: true, + objName: "list of structs", + containsObjName: true, makeType: func(ns string) any { return &StructWithListOfNestedNamespaceField{ Other: "do not change", @@ -93,8 +93,8 @@ func generateNamespaceObjCases() []objCase { }, }, { - objName: "list of ptrs", - containsNamespace: true, + objName: "list of ptrs", + containsObjName: true, makeType: func(ns string) any { return &StructWithListOfNestedPtrs{ Other: "do not change", @@ -107,8 +107,8 @@ func generateNamespaceObjCases() []objCase { }, }, { - objName: "RespondWorkflowTaskCompletedRequest", - containsNamespace: true, + objName: "RespondWorkflowTaskCompletedRequest", + containsObjName: true, makeType: func(ns string) any { return &workflowservice.RespondWorkflowTaskCompletedRequest{ TaskToken: []byte{}, @@ -138,8 +138,8 @@ func generateNamespaceObjCases() []objCase { }, }, { - objName: "PollWorkflowTaskQueueResponse", - containsNamespace: true, + objName: "PollWorkflowTaskQueueResponse", + containsObjName: true, makeType: func(ns string) any { return &workflowservice.PollWorkflowTaskQueueResponse{ TaskToken: []byte{}, @@ -168,8 +168,8 @@ func generateNamespaceObjCases() []objCase { }, }, { - objName: "GetWorkflowExecutionRawHistoryV2Response", - containsNamespace: true, + objName: "GetWorkflowExecutionRawHistoryV2Response", + containsObjName: true, makeType: func(ns string) any { return &adminservice.GetWorkflowExecutionRawHistoryV2Response{ NextPageToken: []byte("some-token"), @@ -182,8 +182,8 @@ func generateNamespaceObjCases() []objCase { }, }, { - objName: "DescribeNamespaceResponse", - containsNamespace: true, + objName: "DescribeNamespaceResponse", + containsObjName: true, makeType: func(ns string) any { return workflowservice.DescribeNamespaceResponse{ NamespaceInfo: &namespace.NamespaceInfo{ @@ -194,8 +194,8 @@ func generateNamespaceObjCases() []objCase { expError: "", }, { - objName: "UpdateNamespaceResponse", - containsNamespace: true, + objName: "UpdateNamespaceResponse", + containsObjName: true, makeType: func(ns string) any { return workflowservice.UpdateNamespaceResponse{ NamespaceInfo: &namespace.NamespaceInfo{ @@ -217,8 +217,8 @@ func generateNamespaceObjCases() []objCase { expError: "", }, { - objName: "ListNamespacesResponse", - containsNamespace: true, + objName: "ListNamespacesResponse", + containsObjName: true, makeType: func(ns string) any { return &workflowservice.ListNamespacesResponse{ Namespaces: []*workflowservice.DescribeNamespaceResponse{ @@ -232,8 +232,8 @@ func generateNamespaceObjCases() []objCase { expError: "", }, { - objName: "StreamWorkflowReplicationMessagesResponse", - containsNamespace: true, + objName: "StreamWorkflowReplicationMessagesResponse", + containsObjName: true, makeType: func(ns string) any { return &adminservice.StreamWorkflowReplicationMessagesResponse{ Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ @@ -331,8 +331,8 @@ func generateNamespaceObjCases() []objCase { }, }, { - objName: "circular pointer", - containsNamespace: true, + objName: "circular pointer", + containsObjName: true, makeType: func(ns string) any { a := &StructWithCircularPointer{ Namespace: ns, @@ -473,14 +473,14 @@ func generateNamespaceReplicationMessages() []objCase { }, }, { - objName: "full type", - makeType: makeFullType, - containsNamespace: true, + objName: "full type", + makeType: makeFullType, + containsObjName: true, }, } } -func testTranslateNamespace(t *testing.T, objCases []objCase) { +func testTranslateObjects(t *testing.T, objCases []objCase) { testcases := []struct { testName string inputNSName string @@ -525,7 +525,7 @@ func testTranslateNamespace(t *testing.T, objCases []objCase) { require.ErrorContains(t, err, c.expError) } else { require.NoError(t, err) - if c.containsNamespace { + if c.containsObjName { require.Equal(t, expOutput, input) require.Equal(t, expChanged, changed) } else { @@ -575,9 +575,9 @@ func makeHistoryEventsBlob(ns string) *common.DataBlob { } func TestTranslateNamespaceName(t *testing.T) { - testTranslateNamespace(t, generateNamespaceObjCases()) + testTranslateObjects(t, generateNamespaceObjCases()) } func TestTranslateNamespaceReplicationMessages(t *testing.T) { - testTranslateNamespace(t, generateNamespaceReplicationMessages()) + testTranslateObjects(t, generateNamespaceReplicationMessages()) } diff --git a/interceptor/reflection.go b/interceptor/reflection.go index c562de47..9395a823 100644 --- a/interceptor/reflection.go +++ b/interceptor/reflection.go @@ -16,6 +16,7 @@ var ( "WorkflowNamespace": true, // PollActivityTaskQueueResponse "ParentWorkflowNamespace": true, // WorkflowExecutionStartedEventAttributes } + dataBlobFieldNames = map[string]bool{ "Events": true, // HistoryTaskAttributes "NewRunEvents": true, // HistoryTaskAttributes @@ -24,6 +25,13 @@ var ( "EventsBatches": true, // HistoryTaskAttributes "HistoryBatches": true, // GetWorkflowExecutionRawHistoryV2 } + + clusterNameFields = map[string]bool{ + "ClusterName": true, // DescribeCluster, ListClusters, ReplicationTasks, GetNamespace (Clusters) + "SourceCluster": true, // HistoryDLQKey + "TargetCluster": true, // HistoryDLQKey + "ActiveClusterName": true, // GetNamespace + } ) // matcher returns 2 values: @@ -64,47 +72,58 @@ func visitNamespace(obj any, match matcher) (bool, error) { } matched = matched || ok } else if dataBlobFieldNames[fieldType.Name] { - switch evt := vwp.Interface().(type) { - case []*common.DataBlob: - newEvts, changed, err := translateDataBlobs(match, evt...) - if err != nil { - return visit.Stop, err - } - if changed { - if err := visit.Assign(vwp, reflect.ValueOf(newEvts)); err != nil { - return visit.Stop, err - } - } - matched = matched || changed - case *common.DataBlob: - newEvt, changed, err := translateOneDataBlob(match, evt) - if err != nil { - return visit.Stop, err - } - if changed { - if err := visit.Assign(vwp, reflect.ValueOf(newEvt)); err != nil { - return visit.Stop, err - } - } - matched = matched || changed - default: - return visit.Continue, nil + changed, err := visitDataBlobs(vwp, match, visitNamespace) + if err != nil { + return visit.Stop, err } + matched = matched || changed + return visit.Continue, nil } else if namespaceFieldNames[fieldType.Name] { - name, ok := vwp.Interface().(string) - if !ok { - return visit.Continue, nil + changed, err := visitStringField(vwp, match) + if err != nil { + return visit.Stop, err } - newName, ok := match(name) - if !ok { - return visit.Continue, nil + matched = matched || changed + return visit.Continue, nil + } + + return visit.Continue, nil + }) + return matched, err +} + +// visitClusterName uses reflection to recursively visit all fields +// in the given object. When it finds matching string fields, it invokes +// the provided match function. +func visitClusterName(obj any, match matcher) (bool, error) { + var matched bool + + // The visitor function can return Skip, Stop, or Continue to control recursion. + err := visit.Values(obj, func(vwp visit.ValueWithParent) (visit.Action, error) { + // Grab name of this struct field from the parent. + if vwp.Parent == nil || vwp.Parent.Kind() != reflect.Struct { + return visit.Continue, nil + } + fieldType := vwp.Parent.Type().Field(int(vwp.Index.Int())) + if !fieldType.IsExported() { + // Ignore unexported fields, particularly private gRPC message fields. + return visit.Skip, nil + } + + if dataBlobFieldNames[fieldType.Name] { + changed, err := visitDataBlobs(vwp, match, visitClusterName) + if err != nil { + return visit.Stop, err } - if name != newName { - if err := visit.Assign(vwp, reflect.ValueOf(newName)); err != nil { - return visit.Stop, err - } + matched = matched || changed + return visit.Continue, nil + } else if clusterNameFields[fieldType.Name] { + changed, err := visitStringField(vwp, match) + if err != nil { + return visit.Stop, err } - matched = matched || ok + matched = matched || changed + return visit.Continue, nil } return visit.Continue, nil @@ -112,46 +131,90 @@ func visitNamespace(obj any, match matcher) (bool, error) { return matched, err } -func translateOneDataBlob(match matcher, blob *common.DataBlob) (*common.DataBlob, bool, error) { +func visitDataBlobs(vwp visit.ValueWithParent, match matcher, visitor visitor) (bool, error) { + switch evt := vwp.Interface().(type) { + case []*common.DataBlob: + newEvts, matched, err := translateDataBlobs(match, visitor, evt...) + if err != nil { + return matched, err + } + if matched { + if err := visit.Assign(vwp, reflect.ValueOf(newEvts)); err != nil { + return matched, err + } + } + return matched, nil + case *common.DataBlob: + newEvt, matched, err := translateOneDataBlob(match, visitor, evt) + if err != nil { + return matched, err + } + if matched { + if err := visit.Assign(vwp, reflect.ValueOf(newEvt)); err != nil { + return matched, err + } + } + return matched, nil + default: + return false, nil + } +} + +func visitStringField(vwp visit.ValueWithParent, match matcher) (bool, error) { + name, ok := vwp.Interface().(string) + if !ok { + return false, nil + } + newName, matched := match(name) + if !matched || name == newName { + return matched, nil + } + if err := visit.Assign(vwp, reflect.ValueOf(newName)); err != nil { + return matched, err + } + return matched, nil +} + +func translateOneDataBlob(match matcher, visit visitor, blob *common.DataBlob) (*common.DataBlob, bool, error) { if blob == nil || len(blob.Data) == 0 { return blob, false, nil } - blobs, changed, err := translateDataBlobs(match, blob) + blobs, matched, err := translateDataBlobs(match, visit, blob) if err != nil { - return nil, false, err + return nil, matched, err } if len(blobs) != 1 { - return nil, false, fmt.Errorf("failed to translate single data blob") + return nil, matched, fmt.Errorf("failed to translate single data blob") } - return blobs[0], changed, err + return blobs[0], matched, err } -func translateDataBlobs(match matcher, blobs ...*common.DataBlob) ([]*common.DataBlob, bool, error) { +func translateDataBlobs(match matcher, visit visitor, blobs ...*common.DataBlob) ([]*common.DataBlob, bool, error) { if len(blobs) == 0 { return blobs, false, nil } s := serialization.NewSerializer() - var anyChanged bool + var anyMatched bool for i, blob := range blobs { evt, err := s.DeserializeEvents(blob) if err != nil { - return blobs, anyChanged, err + return blobs, anyMatched, err } - changed, err := visitNamespace(evt, match) + matched, err := visit(evt, match) if err != nil { - return blobs, anyChanged, err + return blobs, anyMatched, err } - anyChanged = anyChanged || changed + anyMatched = anyMatched || matched newBlob, err := s.SerializeEvents(evt, blob.EncodingType) if err != nil { - return blobs, anyChanged, err + return blobs, anyMatched, err } blobs[i] = newBlob } - return blobs, anyChanged, nil + return blobs, anyMatched, nil } diff --git a/interceptor/translator.go b/interceptor/translator.go index 6e916c00..029d5069 100644 --- a/interceptor/translator.go +++ b/interceptor/translator.go @@ -21,6 +21,14 @@ func NewNamespaceNameTranslator(reqMap, respMap map[string]string) Translator { } } +func NewClusterNameTranslator(reqMap, respMap map[string]string) Translator { + return &translatorImpl{ + matchReq: createNameTranslator(reqMap), + matchResp: createNameTranslator(respMap), + visitor: visitClusterName, + } +} + func (n *translatorImpl) TranslateRequest(req any) (bool, error) { return n.visitor(req, n.matchReq) } diff --git a/proxy/proxy.go b/proxy/proxy.go index 05e0ab7c..c817b921 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -44,20 +44,32 @@ type ( } ) -func makeServerOptions(logger log.Logger, cfg config.ProxyConfig, isInbound bool, nameTranslations config.NamespaceNameTranslationConfig) ([]grpc.ServerOption, error) { +func makeServerOptions( + logger log.Logger, + cfg config.ProxyConfig, + isInbound bool, + namespaceTranslation config.NameTranslationConfig, + clusterTranslation config.NameTranslationConfig, +) ([]grpc.ServerOption, error) { unaryInterceptors := []grpc.UnaryServerInterceptor{} streamInterceptors := []grpc.StreamServerInterceptor{} var translators []interceptor.Translator - if len(nameTranslations.Mappings) > 0 { - // NamespaceNameTranslator needs to be called before namespace access control so that - // local name can be used in namespace allowed list. + if len(namespaceTranslation.Mappings) > 0 { + translators = append(translators, + interceptor.NewNamespaceNameTranslator(namespaceTranslation.ToMaps(isInbound)), + ) + } + + if len(clusterTranslation.Mappings) > 0 { translators = append(translators, - interceptor.NewNamespaceNameTranslator(nameTranslations.ToMaps(isInbound)), + interceptor.NewClusterNameTranslator(clusterTranslation.ToMaps(isInbound)), ) } if len(translators) > 0 { + // Translation needs to be called before namespace access control so that + // local name can be used in namespace allowed list. tr := interceptor.NewTranslationInterceptor(logger, translators) unaryInterceptors = append(unaryInterceptors, tr.Intercept) streamInterceptors = append(streamInterceptors, tr.InterceptStream) @@ -100,7 +112,7 @@ func (ps *ProxyServer) startServer( opts := ps.opts logger := ps.logger - serverOpts, err := makeServerOptions(logger, cfg, opts.IsInbound, opts.Config.NamespaceNameTranslation) + serverOpts, err := makeServerOptions(logger, cfg, opts.IsInbound, opts.Config.NamespaceNameTranslation, opts.Config.ClusterNameTranslation) if err != nil { return err }