From 8d035d6bb6e390c00e931076b70c28a0589c369c Mon Sep 17 00:00:00 2001 From: Paul Glass Date: Wed, 18 Jun 2025 16:45:08 -0500 Subject: [PATCH] Implement multi-namespace search attribute translation --- interceptor/access_control.go | 3 +- interceptor/namespace_translator_test.go | 17 ++- interceptor/reflection.go | 107 +++++++++++++----- interceptor/reflection_test.go | 4 +- interceptor/search_attribute_translator.go | 27 ++--- .../search_attribute_translator_test.go | 21 +++- interceptor/translation_interceptor.go | 4 +- interceptor/translator.go | 18 ++- proxy/proxy.go | 3 - 9 files changed, 139 insertions(+), 65 deletions(-) diff --git a/interceptor/access_control.go b/interceptor/access_control.go index c29d6f7c..76e96c82 100644 --- a/interceptor/access_control.go +++ b/interceptor/access_control.go @@ -52,7 +52,8 @@ func createNamespaceAccessControl(access *auth.AccessControl) stringMatcher { } func isNamespaceAccessAllowed(obj any, access *auth.AccessControl) (bool, error) { - notAllowed, err := visitNamespace(obj, createNamespaceAccessControl(access)) + v := NewNamespaceVisitor(createNamespaceAccessControl(access)) + notAllowed, err := v.Visit(obj) if err != nil { return false, err } diff --git a/interceptor/namespace_translator_test.go b/interceptor/namespace_translator_test.go index 240b81c9..64154e6a 100644 --- a/interceptor/namespace_translator_test.go +++ b/interceptor/namespace_translator_test.go @@ -51,11 +51,19 @@ type ( ) func TestTranslateNamespaceName(t *testing.T) { - testTranslateObj(t, visitNamespace, generateNamespaceObjCases(), require.Equal) + testTranslateObj(t, generateNamespaceObjCases(), require.Equal, + func(m map[string]string) Visitor { + return NewNamespaceVisitor(createStringMatcher(m)) + }, + ) } func TestTranslateNamespaceReplicationMessages(t *testing.T) { - testTranslateObj(t, visitNamespace, generateNamespaceReplicationMessages(), require.EqualExportedValues) + testTranslateObj(t, generateNamespaceReplicationMessages(), require.EqualExportedValues, + func(m map[string]string) Visitor { + return NewNamespaceVisitor(createStringMatcher(m)) + }, + ) } func generateNamespaceObjCases() []objCase { @@ -497,9 +505,9 @@ func generateNamespaceReplicationMessages() []objCase { // handle pointer cycles. func testTranslateObj( t *testing.T, - visitor visitor, objCases []objCase, equalityAssertion func(t require.TestingT, exp, actual any, extra ...any), + createVisitor func(map[string]string) Visitor, ) { testcases := []struct { testName string @@ -540,7 +548,8 @@ func testTranslateObj( expOutput := c.makeType(ts.outputName) expChanged := ts.inputName != ts.outputName - changed, err := visitor(input, createStringMatcher(ts.mapping)) + visitor := createVisitor(ts.mapping) + changed, err := visitor.Visit(input) if len(c.expError) != 0 { require.ErrorContains(t, err, c.expError) } else { diff --git a/interceptor/reflection.go b/interceptor/reflection.go index dd7c3528..4ddc2fb2 100644 --- a/interceptor/reflection.go +++ b/interceptor/reflection.go @@ -10,6 +10,10 @@ import ( "go.temporal.io/server/common/persistence/serialization" ) +const ( + namespaceIDFieldName = "NamespaceId" +) + var ( serializer = serialization.NewSerializer() @@ -40,19 +44,47 @@ var ( } ) -// stringMatcher returns 2 values: -// 1. new name. If there is no change, new name equals to input name -// 2. whether or not the input name matches the defined rule(s). -type stringMatcher func(name string) (string, bool) +type ( + // Visitor will visits an object's fields recursively. It returns an + // implementation-specific bool and error, which typicall indicate if it + // matched anything and if it encountered an unrecoverable error. + Visitor interface { + Visit(any) (bool, error) + } + + // visitNamespace uses reflection to recursively visit all fields in the + // given object. When it finds namespace string fields, it invokes the match + // function. + nsVisitor struct { + match stringMatcher + } + + // saVisitor uses reflection to recursively visit search attribute fields in the given object. + // It translates search attribute fields according to per-namespace search attribute mappings. + // + // This is not concurrent safe. You must create a separate struct each time. + saVisitor struct { + getNamespaceSAMatcher getSAMatcher -// visitor visits each field in obj matching the matcher. -// It returns whether anything was matched and any error it encountered. -type visitor func(obj any, match stringMatcher) (bool, error) + // currentNamespaceId is internal-state to remember the namespace id set in some parent + // field as the visitor descends recursively into child fields. + currentNamespaceId string + } + + // stringMatcher returns 2 values: + // 1. new name. If there is no change, new name equals to input name + // 2. whether or not the input name matches the defined rule(s). + stringMatcher func(name string) (string, bool) + + // getSAMatcher returns a string matcher for a given namespace's search attribute mapping + getSAMatcher func(nsId string) stringMatcher +) + +func NewNamespaceVisitor(match stringMatcher) Visitor { + return &nsVisitor{match: match} +} -// visitNamespace uses reflection to recursively visit all fields -// in the given object. When it finds namespace string fields, it invokes -// the provided match function. -func visitNamespace(obj any, match stringMatcher) (bool, error) { +func (v *nsVisitor) Visit(obj any) (bool, error) { var matched bool // The visitor function can return Skip, Stop, or Continue to control recursion. @@ -65,7 +97,7 @@ func visitNamespace(obj any, match stringMatcher) (bool, error) { if info, ok := vwp.Interface().(*namespace.NamespaceInfo); ok && info != nil { // Handle NamespaceInfo.Name in any message. - newName, ok := match(info.Name) + newName, ok := v.match(info.Name) if !ok { return visit.Continue, nil } @@ -74,7 +106,7 @@ func visitNamespace(obj any, match stringMatcher) (bool, error) { } matched = matched || ok } else if dataBlobFieldNames[fieldType.Name] { - changed, err := visitDataBlobs(vwp, match, visitNamespace) + changed, err := visitDataBlobs(vwp, v) matched = matched || changed if err != nil { return visit.Stop, err @@ -84,7 +116,7 @@ func visitNamespace(obj any, match stringMatcher) (bool, error) { if !ok { return visit.Continue, nil } - newName, ok := match(name) + newName, ok := v.match(name) if !ok { return visit.Continue, nil } @@ -101,10 +133,11 @@ func visitNamespace(obj any, match stringMatcher) (bool, error) { return matched, err } -// visitSearchAttributes uses reflection to recursively visit all fields -// in the given object. When it finds namespace string fields, it invokes -// the provided match function. -func visitSearchAttributes(obj any, match stringMatcher) (bool, error) { +func MakeSearchAttributeVisitor(getNsSearchAttr getSAMatcher) saVisitor { + return saVisitor{getNamespaceSAMatcher: getNsSearchAttr} +} + +func (v *saVisitor) Visit(obj any) (bool, error) { var matched bool // The visitor function can return Skip, Stop, or Continue to control recursion. @@ -115,13 +148,24 @@ func visitSearchAttributes(obj any, match stringMatcher) (bool, error) { return action, nil } + nsId := discoverNamespaceId(vwp) + if nsId != "" { + v.currentNamespaceId = nsId + } + if dataBlobFieldNames[fieldType.Name] { - changed, err := visitDataBlobs(vwp, match, visitSearchAttributes) + changed, err := visitDataBlobs(vwp, v) matched = matched || changed if err != nil { return visit.Stop, err } } else if searchAttributeFieldNames[fieldType.Name] { + // Get the per-namespace search attribute mapping + match := v.getNamespaceSAMatcher(v.currentNamespaceId) + if match == nil { + return visit.Continue, nil + } + // This could be *common.SearchAttributes, or it could be map[string]*common.Payload (indexed fields) var changed bool switch attrs := vwp.Interface().(type) { @@ -146,6 +190,17 @@ func visitSearchAttributes(obj any, match stringMatcher) (bool, error) { return matched, err } +func discoverNamespaceId(vwp visit.ValueWithParent) string { + parent := vwp.Parent + if parent.Kind() == reflect.Struct { + typ, ok := parent.Type().FieldByName(namespaceIDFieldName) + if ok && typ.Type.Kind() == reflect.String { + return parent.FieldByName(namespaceIDFieldName).String() + } + } + return "" +} + func translateIndexedFields(fields map[string]*common.Payload, match stringMatcher) (map[string]*common.Payload, bool) { if fields == nil { return fields, false @@ -176,10 +231,10 @@ func getParentFieldType(vwp visit.ValueWithParent) (result reflect.StructField, return fieldType, action } -func visitDataBlobs(vwp visit.ValueWithParent, match stringMatcher, visitor visitor) (bool, error) { +func visitDataBlobs(vwp visit.ValueWithParent, v Visitor) (bool, error) { switch evt := vwp.Interface().(type) { case []*common.DataBlob: - newEvts, matched, err := translateDataBlobs(match, visitor, evt...) + newEvts, matched, err := translateDataBlobs(v, evt...) if err != nil { return matched, err } @@ -190,7 +245,7 @@ func visitDataBlobs(vwp visit.ValueWithParent, match stringMatcher, visitor visi } return matched, nil case *common.DataBlob: - newEvt, matched, err := translateOneDataBlob(match, visitor, evt) + newEvt, matched, err := translateOneDataBlob(v, evt) if err != nil { return matched, err } @@ -205,10 +260,10 @@ func visitDataBlobs(vwp visit.ValueWithParent, match stringMatcher, visitor visi } } -func translateDataBlobs(match stringMatcher, visitor visitor, blobs ...*common.DataBlob) ([]*common.DataBlob, bool, error) { +func translateDataBlobs(visitor Visitor, blobs ...*common.DataBlob) ([]*common.DataBlob, bool, error) { var anyChanged bool for i, blob := range blobs { - newBlob, changed, err := translateOneDataBlob(match, visitor, blob) + newBlob, changed, err := translateOneDataBlob(visitor, blob) anyChanged = anyChanged || changed if err != nil { return blobs, anyChanged, err @@ -218,7 +273,7 @@ func translateDataBlobs(match stringMatcher, visitor visitor, blobs ...*common.D return blobs, anyChanged, nil } -func translateOneDataBlob(match stringMatcher, visitor visitor, blob *common.DataBlob) (*common.DataBlob, bool, error) { +func translateOneDataBlob(visitor Visitor, blob *common.DataBlob) (*common.DataBlob, bool, error) { if blob == nil || len(blob.Data) == 0 { return blob, false, nil @@ -228,7 +283,7 @@ func translateOneDataBlob(match stringMatcher, visitor visitor, blob *common.Dat return blob, false, err } - changed, err := visitor(evt, match) + changed, err := visitor.Visit(evt) if err != nil || !changed { return blob, changed, err } diff --git a/interceptor/reflection_test.go b/interceptor/reflection_test.go index 6e879ad3..a70c221f 100644 --- a/interceptor/reflection_test.go +++ b/interceptor/reflection_test.go @@ -26,14 +26,14 @@ func BenchmarkVisitNamespace(b *testing.B) { for _, c := range cases { b.Run(c.objName, func(b *testing.B) { for _, variant := range variants { - translator := createStringMatcher(variant.mapping) + visitor := NewNamespaceVisitor(createStringMatcher(variant.mapping)) b.Run(variant.testName, func(b *testing.B) { for i := 0; i < b.N; i++ { b.StopTimer() input := c.makeType(variant.inputNSName) b.StartTimer() - visitNamespace(input, translator) + visitor.Visit(input) } }) } diff --git a/interceptor/search_attribute_translator.go b/interceptor/search_attribute_translator.go index ef25543b..a6bf2431 100644 --- a/interceptor/search_attribute_translator.go +++ b/interceptor/search_attribute_translator.go @@ -3,6 +3,7 @@ package interceptor import ( "strings" + "go.temporal.io/server/api/adminservice/v1" "go.temporal.io/server/common/api" ) @@ -31,27 +32,27 @@ func (s *saTranslator) MatchMethod(m string) bool { } func (s *saTranslator) TranslateRequest(req any) (bool, error) { - return visitSearchAttributes(req, s.getNamespaceReqMatcher("")) + v := MakeSearchAttributeVisitor(s.getNamespaceReqMatcher) + return v.Visit(req) } -func (s *saTranslator) TranslateResponse(resp any) (bool, error) { - return visitSearchAttributes(resp, s.getNamespaceRespMatcher("")) +func (s *saTranslator) TranslateResponse(req, resp any) (bool, error) { + // Detect namespace id in GetWorkflowExecutionRawHistoryV2Request. + // Use that namespace id to translate search attributes in the response type. + v := MakeSearchAttributeVisitor(s.getNamespaceRespMatcher) + switch val := req.(type) { + case *adminservice.GetWorkflowExecutionRawHistoryV2Request: + v.currentNamespaceId = val.NamespaceId + } + return v.Visit(resp) } func (s *saTranslator) getNamespaceReqMatcher(namespaceId string) stringMatcher { - // Placeholder: Just return the first one (only support one namespace mapping) - for _, matcher := range s.reqMap { - return matcher - } - return createStringMatcher(nil) + return s.reqMap[namespaceId] } func (s *saTranslator) getNamespaceRespMatcher(namespaceId string) stringMatcher { - // Placeholder: Just return the first one (only support one namespace mappping) - for _, matcher := range s.respMap { - return matcher - } - return createStringMatcher(nil) + return s.respMap[namespaceId] } func createStringMatchers(nsMappings map[string]map[string]string) map[string]stringMatcher { diff --git a/interceptor/search_attribute_translator_test.go b/interceptor/search_attribute_translator_test.go index 8d747dfe..68f6b06b 100644 --- a/interceptor/search_attribute_translator_test.go +++ b/interceptor/search_attribute_translator_test.go @@ -16,10 +16,23 @@ import ( ) func TestTranslateSearchAttribute(t *testing.T) { - testTranslateObj(t, visitSearchAttributes, generateSearchAttributeObjs(), require.EqualExportedValues) + namespaceId := "ns-1234" + testTranslateObj(t, generateSearchAttributeObjs(namespaceId), require.EqualExportedValues, + func(mapping map[string]string) Visitor { + v := MakeSearchAttributeVisitor( + func(nsId string) stringMatcher { + if nsId != namespaceId { + return nil + } + return createStringMatcher(mapping) + }, + ) + return &v + }, + ) } -func generateSearchAttributeObjs() []objCase { +func generateSearchAttributeObjs(nsId string) []objCase { return []objCase{ { objName: "HistoryTaskAttributes", @@ -32,7 +45,7 @@ func generateSearchAttributeObjs() []objCase { { Attributes: &replicationspb.ReplicationTask_HistoryTaskAttributes{ HistoryTaskAttributes: &replicationspb.HistoryTaskAttributes{ - NamespaceId: "some-ns-id", + NamespaceId: nsId, WorkflowId: "some-wf-id", RunId: "some-run-id", Events: makeHistoryEventsBlobWithSearchAttribute(name), @@ -62,7 +75,7 @@ func generateSearchAttributeObjs() []objCase { SyncWorkflowStateMutationAttributes: &replicationspb.SyncWorkflowStateMutationAttributes{ StateMutation: &persistence.WorkflowMutableStateMutation{ ExecutionInfo: &persistence.WorkflowExecutionInfo{ - NamespaceId: "some-ns", + NamespaceId: nsId, WorkflowId: "some-wf", SearchAttributes: makeTestIndexedFieldMap(name), Memo: map[string]*common.Payload{ diff --git a/interceptor/translation_interceptor.go b/interceptor/translation_interceptor.go index 2ffa74d6..a769702f 100644 --- a/interceptor/translation_interceptor.go +++ b/interceptor/translation_interceptor.go @@ -55,7 +55,7 @@ func (i *TranslationInterceptor) Intercept( for _, tr := range i.translators { if tr.MatchMethod(info.FullMethod) { - changed, trErr := tr.TranslateResponse(resp) + changed, trErr := tr.TranslateResponse(req, resp) logTranslateResult(i.logger, changed, trErr, methodName+"Response", resp) } } @@ -98,7 +98,7 @@ func (w *streamTranslator) RecvMsg(m any) error { func (w *streamTranslator) SendMsg(m any) error { w.logger.Debug("Intercept SendMsg", tag.NewStringTag("type", fmt.Sprintf("%T", m)), tag.NewAnyTag("message", m)) for _, tr := range w.translators { - changed, trErr := tr.TranslateResponse(m) + changed, trErr := tr.TranslateResponse(nil, m) logTranslateResult(w.logger, changed, trErr, "SendMsg", m) } return w.ServerStream.SendMsg(m) diff --git a/interceptor/translator.go b/interceptor/translator.go index 814b16d0..62feafda 100644 --- a/interceptor/translator.go +++ b/interceptor/translator.go @@ -4,23 +4,21 @@ type ( Translator interface { MatchMethod(string) bool TranslateRequest(any) (bool, error) - TranslateResponse(any) (bool, error) + TranslateResponse(any, any) (bool, error) } translatorImpl struct { matchMethod func(string) bool - matchReq stringMatcher - matchResp stringMatcher - visitor visitor + reqVisitor Visitor + respVisitor Visitor } ) func NewNamespaceNameTranslator(reqMap, respMap map[string]string) Translator { return &translatorImpl{ matchMethod: func(string) bool { return true }, - matchReq: createStringMatcher(reqMap), - matchResp: createStringMatcher(respMap), - visitor: visitNamespace, + reqVisitor: &nsVisitor{match: createStringMatcher(reqMap)}, + respVisitor: &nsVisitor{match: createStringMatcher(respMap)}, } } @@ -29,11 +27,11 @@ func (n *translatorImpl) MatchMethod(m string) bool { } func (n *translatorImpl) TranslateRequest(req any) (bool, error) { - return n.visitor(req, n.matchReq) + return n.reqVisitor.Visit(req) } -func (n *translatorImpl) TranslateResponse(resp any) (bool, error) { - return n.visitor(resp, n.matchResp) +func (n *translatorImpl) TranslateResponse(_, resp any) (bool, error) { + return n.respVisitor.Visit(resp) } func createStringMatcher(mapping map[string]string) stringMatcher { diff --git a/proxy/proxy.go b/proxy/proxy.go index be4c9d80..faefcfe5 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -62,9 +62,6 @@ func makeServerOptions( if tln := proxyOpts.Config.SearchAttributeTranslation; tln.IsEnabled() { logger.Info("search attribute translation enabled", tag.NewAnyTag("mappings", tln.NamespaceMappings)) - if len(tln.NamespaceMappings) > 1 { - panic("multiple namespace search attribute mappings are not supported") - } translators = append(translators, interceptor.NewSearchAttributeTranslator(tln.ToMaps(proxyOpts.IsInbound))) }