diff --git a/.github/workflows/pull-request.yml b/.github/workflows/pull-request.yml index dda69249..8a81cf33 100644 --- a/.github/workflows/pull-request.yml +++ b/.github/workflows/pull-request.yml @@ -59,7 +59,7 @@ jobs: cache: ${{ github.ref == 'refs/heads/main' }} # only update the cache in main. - name: Run go build - run: go build ./... + run: make bins - name: Run go unittest run: make test @@ -67,7 +67,7 @@ jobs: - name: Install helm uses: azure/setup-helm@v4.3.0 with: - version: v3.17.3 + version: v3.19.4 - name: Install helm-unittest run: helm plugin install https://github.com/helm-unittest/helm-unittest.git diff --git a/Makefile b/Makefile index 4dbc3f41..be35b5a2 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ GO_GET_TOOL = go get -tool -modfile=$(TOOLS_MOD_FILE) # Disable cgo by default. CGO_ENABLED ?= 0 -TEST_ARG ?= -race -timeout=5m -tags test_dep +TEST_ARG ?= -race -timeout=15m -tags test_dep -count=1 BENCH_ARG ?= -benchtime=5000x ALL_SRC := $(shell find . -name "*.go") diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 41b31daf..1212854e 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -63,12 +63,17 @@ func buildCLIOptions() *cli.App { return app } -func startPProfHTTPServer(logger log.Logger, c config.ProfilingConfig) { +func startPProfHTTPServer(logger log.Logger, c config.ProfilingConfig, proxyInstance *proxy.Proxy) { addr := c.PProfHTTPAddress if len(addr) == 0 { return } + // Add debug endpoint handler + http.HandleFunc("/debug/connections", func(w http.ResponseWriter, r *http.Request) { + proxy.HandleDebugInfo(w, r, proxyInstance, logger) + }) + go func() { logger.Info("Start pprof http server", tag.NewStringTag("address", addr)) if err := http.ListenAndServe(addr, nil); err != nil { @@ -101,7 +106,7 @@ func startProxy(c *cli.Context) error { } cfg := proxyParams.ConfigProvider.GetS2SProxyConfig() - startPProfHTTPServer(proxyParams.Logger, cfg.ProfilingConfig) + startPProfHTTPServer(proxyParams.Logger, cfg.ProfilingConfig, proxyParams.Proxy) if err := proxyParams.Proxy.Start(); err != nil { return err diff --git a/common/intra_headers.go b/common/intra_headers.go new file mode 100644 index 00000000..7bc2731f --- /dev/null +++ b/common/intra_headers.go @@ -0,0 +1,37 @@ +package common + +import ( + "context" + + "google.golang.org/grpc/metadata" +) + +const ( + // Intra-proxy identification and tracing headers + IntraProxyHeaderKey = "x-s2s-intra-proxy" + IntraProxyHeaderValue = "1" + IntraProxyOriginProxyIDHeader = "x-s2s-origin-proxy-id" + IntraProxyHopCountHeader = "x-s2s-hop-count" + IntraProxyTraceIDHeader = "x-s2s-trace-id" +) + +// IsIntraProxy checks incoming context metadata for intra-proxy marker. +func IsIntraProxy(ctx context.Context) bool { + if md, ok := metadata.FromIncomingContext(ctx); ok { + if vals := md.Get(IntraProxyHeaderKey); len(vals) > 0 && vals[0] == IntraProxyHeaderValue { + return true + } + } + return false +} + +// WithIntraProxyHeaders returns a new outgoing context with intra-proxy headers set. +func WithIntraProxyHeaders(ctx context.Context, headers map[string]string) context.Context { + md, _ := metadata.FromOutgoingContext(ctx) + md = md.Copy() + md.Set(IntraProxyHeaderKey, IntraProxyHeaderValue) + for k, v := range headers { + md.Set(k, v) + } + return metadata.NewOutgoingContext(ctx, md) +} diff --git a/config/cluster_conn_config.go b/config/cluster_conn_config.go index 636e83f4..2bcf1388 100644 --- a/config/cluster_conn_config.go +++ b/config/cluster_conn_config.go @@ -16,6 +16,7 @@ type ( OutboundHealthCheck HealthCheckConfig `yaml:"outboundHealthCheck"` InboundHealthCheck HealthCheckConfig `yaml:"inboundHealthCheck"` ShardCountConfig ShardCountConfig `yaml:"shardCount"` + MemberlistConfig *MemberlistConfig `yaml:"memberlist"` } StringTranslator struct { Mappings []StringMapping `yaml:"mappings"` diff --git a/config/config.go b/config/config.go index 53e85fbc..d1b28116 100644 --- a/config/config.go +++ b/config/config.go @@ -40,6 +40,7 @@ type ShardCountMode string const ( ShardCountDefault ShardCountMode = "" ShardCountLCM ShardCountMode = "lcm" + ShardCountRouting ShardCountMode = "routing" ) type HealthCheckProtocol string @@ -153,7 +154,11 @@ type ( // TODO: Soon to be deprecated! Create an item in ClusterConnections instead HealthCheck *HealthCheckConfig `yaml:"healthCheck"` // TODO: Soon to be deprecated! Create an item in ClusterConnections instead - OutboundHealthCheck *HealthCheckConfig `yaml:"outboundHealthCheck"` + OutboundHealthCheck *HealthCheckConfig `yaml:"outboundHealthCheck"` + // TODO: Soon to be deprecated! Create an item in ClusterConnections instead + ShardCountConfig ShardCountConfig `yaml:"shardCount"` + // TODO: Soon to be deprecated! Create an item in ClusterConnections instead + MemberlistConfig *MemberlistConfig `yaml:"memberlist"` NamespaceNameTranslation NameTranslationConfig `yaml:"namespaceNameTranslation"` SearchAttributeTranslation SATranslationConfig `yaml:"searchAttributeTranslation"` Metrics *MetricsConfig `yaml:"metrics"` @@ -217,6 +222,31 @@ type ( LoggingConfig struct { ThrottleMaxRPS float64 `yaml:"throttleMaxRPS"` } + + MemberlistConfig struct { + // Enable distributed shard management using memberlist + Enabled bool `yaml:"enabled"` + // Node name for this proxy instance in the cluster + NodeName string `yaml:"nodeName"` + // Bind address for memberlist cluster communication + BindAddr string `yaml:"bindAddr"` + // Bind port for memberlist cluster communication + BindPort int `yaml:"bindPort"` + // List of existing cluster members to join + JoinAddrs []string `yaml:"joinAddrs"` + // Shard assignment strategy (deprecated - now uses actual ownership tracking) + ShardStrategy string `yaml:"shardStrategy"` + // Map of node names to their proxy service addresses for forwarding + ProxyAddresses map[string]string `yaml:"proxyAddresses"` + // Use TCP-only transport (disables UDP) for restricted networks + TCPOnly bool `yaml:"tcpOnly"` + // Disable TCP pings when using TCP-only mode + DisableTCPPings bool `yaml:"disableTCPPings"` + // Probe timeout for memberlist health checks + ProbeTimeoutMs int `yaml:"probeTimeoutMs"` + // Probe interval for memberlist health checks + ProbeIntervalMs int `yaml:"probeIntervalMs"` + } ) func FromServerTLSConfig(cfg ServerTLSConfig) encryption.TLSConfig { diff --git a/config/converter.go b/config/converter.go index 0a07f50a..93412caf 100644 --- a/config/converter.go +++ b/config/converter.go @@ -41,6 +41,8 @@ func ToClusterConnConfig(config S2SProxyConfig) S2SProxyConfig { SearchAttributeTranslation: config.SearchAttributeTranslation, OutboundHealthCheck: flattenNilHealthCheck(config.OutboundHealthCheck), InboundHealthCheck: flattenNilHealthCheck(config.HealthCheck), + ShardCountConfig: config.ShardCountConfig, + MemberlistConfig: config.MemberlistConfig, }, }, Metrics: config.Metrics, diff --git a/develop/config/cluster-a-mux-client-proxy-1.yaml b/develop/config/cluster-a-mux-client-proxy-1.yaml new file mode 100644 index 00000000..4855af00 --- /dev/null +++ b/develop/config/cluster-a-mux-client-proxy-1.yaml @@ -0,0 +1,23 @@ +inbound: + name: "a-inbound-server" + server: + type: "mux" + mux: "muxed" + client: + tcp: + serverAddress: "localhost:7233" +outbound: + name: "a-outbound-server" + server: + tcp: + listenAddress: "0.0.0.0:6133" + client: + type: "mux" + mux: "muxed" +mux: + - name: "muxed" + mode: "client" + client: + serverAddress: "localhost:7003" +profiling: + pprofAddress: "localhost:6060" \ No newline at end of file diff --git a/develop/config/cluster-a-mux-client-proxy-2.yaml b/develop/config/cluster-a-mux-client-proxy-2.yaml new file mode 100644 index 00000000..85aa9bf6 --- /dev/null +++ b/develop/config/cluster-a-mux-client-proxy-2.yaml @@ -0,0 +1,23 @@ +inbound: + name: "a-inbound-server" + server: + type: "mux" + mux: "muxed" + client: + tcp: + serverAddress: "localhost:7233" +outbound: + name: "a-outbound-server" + server: + tcp: + listenAddress: "0.0.0.0:6233" + client: + type: "mux" + mux: "muxed" +mux: + - name: "muxed" + mode: "client" + client: + serverAddress: "localhost:7003" +profiling: + pprofAddress: "localhost:6061" \ No newline at end of file diff --git a/develop/config/cluster-b-mux-server-proxy-1.yaml b/develop/config/cluster-b-mux-server-proxy-1.yaml new file mode 100644 index 00000000..e204309d --- /dev/null +++ b/develop/config/cluster-b-mux-server-proxy-1.yaml @@ -0,0 +1,46 @@ +inbound: + name: "b-inbound-server" + server: + type: "mux" + mux: "muxed" + client: + tcp: + serverAddress: "localhost:8233" +outbound: + name: "b-outbound-server" + server: + tcp: + listenAddress: "0.0.0.0:6333" + client: + type: "mux" + mux: "muxed" +mux: + - name: "muxed" + mode: "server" + server: + listenAddress: "0.0.0.0:6334" +# shardCount: +# mode: "lcm" +# localShardCount: 3 +# remoteShardCount: 2 +shardCount: + mode: "routing" + localShardCount: 3 + remoteShardCount: 2 +profiling: + pprofAddress: "localhost:6070" +memberlist: + enabled: true + nodeName: "proxy-node-b-1" + bindAddr: "127.0.0.1" + bindPort: 6335 + # joinAddrs: + # - "localhost:6435" + proxyAddresses: + "proxy-node-b-1": "localhost:6333" + "proxy-node-b-2": "localhost:6433" + # # TCP-only configuration for restricted networks + tcpOnly: true # Use TCP transport only, disable UDP + # disableTCPPings: true # Disable TCP pings for faster convergence + # probeTimeoutMs: 1000 # Longer timeout for network latency + # probeIntervalMs: 2000 # Less frequent probes to reduce network noise \ No newline at end of file diff --git a/develop/config/cluster-b-mux-server-proxy-2.yaml b/develop/config/cluster-b-mux-server-proxy-2.yaml new file mode 100644 index 00000000..e37b5006 --- /dev/null +++ b/develop/config/cluster-b-mux-server-proxy-2.yaml @@ -0,0 +1,46 @@ +inbound: + name: "b-inbound-server" + server: + type: "mux" + mux: "muxed" + client: + tcp: + serverAddress: "localhost:8233" +outbound: + name: "b-outbound-server" + server: + tcp: + listenAddress: "0.0.0.0:6433" + client: + type: "mux" + mux: "muxed" +mux: + - name: "muxed" + mode: "server" + server: + listenAddress: "0.0.0.0:6434" +# shardCount: +# mode: "lcm" +# localShardCount: 3 +# remoteShardCount: 2 +shardCount: + mode: "routing" + localShardCount: 3 + remoteShardCount: 2 +profiling: + pprofAddress: "localhost:6071" +memberlist: + enabled: true + nodeName: "proxy-node-b-2" + bindAddr: "127.0.0.1" + bindPort: 6435 + joinAddrs: + - "localhost:6335" + proxyAddresses: + "proxy-node-b-1": "localhost:6333" + "proxy-node-b-2": "localhost:6433" + # # TCP-only configuration for restricted networks + # tcpOnly: true # Use TCP transport only, disable UDP + # disableTCPPings: true # Disable TCP pings for faster convergence + # probeTimeoutMs: 1000 # Longer timeout for network latency + # probeIntervalMs: 2000 # Less frequent probes to reduce network noise \ No newline at end of file diff --git a/develop/config/dynamic-config.yaml b/develop/config/dynamic-config.yaml index 95de71bc..f95073c9 100644 --- a/develop/config/dynamic-config.yaml +++ b/develop/config/dynamic-config.yaml @@ -20,10 +20,14 @@ history.ReplicationEnableUpdateWithNewTaskMerge: history.enableWorkflowExecutionTimeoutTimer: - value: true history.EnableReplicationTaskTieredProcessing: - - value: true + - value: false history.persistenceMaxQPS: - value: 100000 constraints: {} frontend.persistenceMaxQPS: - value: 100000 - constraints: {} \ No newline at end of file + constraints: {} +history.shardUpdateMinInterval: + - value: 1s +history.ReplicationStreamSendEmptyTaskDuration: + - value: 10s \ No newline at end of file diff --git a/go.mod b/go.mod index 95719fa8..25be7132 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/gogo/status v1.1.1 github.com/golang/mock v1.7.0-rc.1 github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.1.0 + github.com/hashicorp/memberlist v0.5.1 github.com/hashicorp/yamux v0.1.2 github.com/keilerkonzept/visit v1.1.1 github.com/pkg/errors v0.9.1 @@ -40,6 +41,7 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.51.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.51.0 // indirect github.com/apache/thrift v0.21.0 // indirect + github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da // indirect github.com/aws/aws-sdk-go v1.55.6 // indirect github.com/benbjohnson/clock v1.3.5 // indirect github.com/beorn7/perks v1.0.1 // indirect @@ -66,6 +68,7 @@ require ( github.com/golang-jwt/jwt/v4 v4.5.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v1.0.0 // indirect + github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/google/uuid v1.6.0 // indirect @@ -76,6 +79,12 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect + github.com/hashicorp/errwrap v1.0.0 // indirect + github.com/hashicorp/go-immutable-radix v1.0.0 // indirect + github.com/hashicorp/go-msgpack/v2 v2.1.1 // indirect + github.com/hashicorp/go-multierror v1.0.0 // indirect + github.com/hashicorp/go-sockaddr v1.0.0 // indirect + github.com/hashicorp/golang-lru v0.5.0 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/iancoleman/strcase v0.3.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -89,6 +98,7 @@ require ( github.com/lib/pq v1.10.9 // indirect github.com/mailru/easyjson v0.9.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/miekg/dns v1.1.26 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect @@ -106,6 +116,7 @@ require ( github.com/robfig/cron v1.2.0 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/sony/gobreaker v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect diff --git a/go.sum b/go.sum index 602451ec..487fdb62 100644 --- a/go.sum +++ b/go.sum @@ -39,6 +39,8 @@ github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3 github.com/apache/thrift v0.16.0/go.mod h1:PHK3hniurgQaNMZYaCLEqXKsYK8upmhPbmdP2FXSqgU= github.com/apache/thrift v0.21.0 h1:tdPmh/ptjE1IJnhbhrcl2++TauVjy242rkV/UzJChnE= github.com/apache/thrift v0.21.0/go.mod h1:W1H8aR/QRtYNvrPeFXBtobyRkd0/YVhTc6i07XIAgDw= +github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da h1:8GUt8eRujhVEGZFFEjBj46YV4rDjvGrNxb0KMWYkL2I= +github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/aws/aws-sdk-go v1.55.6 h1:cSg4pvZ3m8dgYcgqB97MrcdjUmZ1BeMYKUxMMB89IPk= github.com/aws/aws-sdk-go v1.55.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= github.com/benbjohnson/clock v0.0.0-20160125162948-a620c1cc9866/go.mod h1:UMqtWQTnOe4byzwe7Zhwh8f8s+36uszN51sJrSIZlTE= @@ -141,6 +143,8 @@ github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6 github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c h1:964Od4U6p2jUkFxvCydnIczKteheJEzHRToSGK3Bnlw= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -171,8 +175,24 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 h1:5ZPtiqj0JL5oKWmcsq4VMaAW5uk github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3/go.mod h1:ndYquD05frm2vACXE1nsccT4oJzjhw2arTS2cpUD1PI= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-immutable-radix v1.0.0 h1:AKDB1HM5PWEA7i4nhcpwOrO2byshxBjXVn/J/3+z5/0= +github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= +github.com/hashicorp/go-msgpack/v2 v2.1.1 h1:xQEY9yB2wnHitoSzk/B9UjXWRQ67QKu5AOm8aFp8N3I= +github.com/hashicorp/go-msgpack/v2 v2.1.1/go.mod h1:upybraOAblm4S7rx0+jeNy+CWWhzywQsSRV5033mMu4= +github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o= +github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= +github.com/hashicorp/go-sockaddr v1.0.0 h1:GeH6tui99pF4NJgfnhp+L6+FfobzVW3Ah46sLo0ICXs= +github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU= +github.com/hashicorp/go-uuid v1.0.0 h1:RS8zrF7PhGwyNPOtxSClXXj9HA8feRnJzgnI1RJCSnM= +github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/golang-lru v0.5.0 h1:CL2msUPvZTLb5O648aiLNJw3hnBxN2+1Jq8rCOH9wdo= +github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/hashicorp/memberlist v0.5.1 h1:mk5dRuzeDNis2bi6LLoQIXfMH7JQvAzt3mQD0vNZZUo= +github.com/hashicorp/memberlist v0.5.1/go.mod h1:zGDXV6AqbDTKTM6yxW0I4+JtFzZAJVoIPvss4hV8F24= github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8= github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns= github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= @@ -221,6 +241,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/miekg/dns v1.1.26 h1:gPxPSwALAeHJSjarOs00QjVdV9QoBvc1D2ujQUr5BzU= +github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= @@ -235,6 +257,8 @@ github.com/olivere/elastic/v7 v7.0.32/go.mod h1:c7PVmLe3Fxq77PIfY/bZmxY/TAamBhCz github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= +github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c h1:Lgl0gzECD8GnQ5QCWA8o6BtfL6mDH5rQgM4/fX3avOs= +github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pborman/uuid v1.2.1 h1:+ZZIw58t/ozdjRaXh/3awHfmWRbzYxJoAdNJxe/3pvw= github.com/pborman/uuid v1.2.1/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -270,6 +294,8 @@ github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWN github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/samuel/go-thrift v0.0.0-20190219015601-e8b6b52668fe/go.mod h1:Vrkh1pnjV9Bl8c3P9zH0/D4NlOHWP5d4/hF4YTULaec= +github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529 h1:nn5Wsu0esKSJiIVhscUtVbo7ada43DJhG55ua/hjS5I= +github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/sirupsen/logrus v1.0.2-0.20170726183946-abee6f9b0679/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= @@ -377,6 +403,7 @@ go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= @@ -414,6 +441,7 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= @@ -442,6 +470,8 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190922100055-0a153f010e69/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -465,6 +495,7 @@ golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= @@ -483,6 +514,7 @@ golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3 golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190907020128-2ca718005c18/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= diff --git a/interceptor/translation_interceptor.go b/interceptor/translation_interceptor.go index 56cf9622..3fc2b836 100644 --- a/interceptor/translation_interceptor.go +++ b/interceptor/translation_interceptor.go @@ -10,6 +10,7 @@ import ( "go.temporal.io/server/common/log/tag" "google.golang.org/grpc" + "github.com/temporalio/s2s-proxy/common" "github.com/temporalio/s2s-proxy/metrics" ) @@ -75,6 +76,16 @@ func (i *TranslationInterceptor) InterceptStream( info *grpc.StreamServerInfo, handler grpc.StreamHandler, ) error { + + i.logger.Debug("InterceptStream", tag.NewAnyTag("method", info.FullMethod)) + // Skip translation for intra-proxy streams + if common.IsIntraProxy(ss.Context()) { + err := handler(srv, ss) + if err != nil { + i.logger.Error("grpc handler with error: %v", tag.Error(err)) + } + return err + } return handler(srv, newStreamTranslator(ss, i.logger, i.translators)) } diff --git a/metrics/prometheus_defs.go b/metrics/prometheus_defs.go index 45bd3ba1..f84c6b97 100644 --- a/metrics/prometheus_defs.go +++ b/metrics/prometheus_defs.go @@ -36,8 +36,9 @@ var ( GRPCServerStarted = DefaultCounterVec("grpc_server_started", "Emits when the grpc server is started", "service_name") GRPCServerStopped = DefaultCounterVec("grpc_server_stopped", "Emits when the grpc server is stopped", "service_name", "error") - GRPCOutboundClientMetrics = GetStandardGRPCClientInterceptor("outbound") - GRPCInboundClientMetrics = GetStandardGRPCClientInterceptor("inbound") + GRPCOutboundClientMetrics = GetStandardGRPCClientInterceptor("outbound") + GRPCInboundClientMetrics = GetStandardGRPCClientInterceptor("inbound") + GRPCIntraProxyClientMetrics = GetStandardGRPCClientInterceptor("intra_proxy") // /transport/mux @@ -86,6 +87,8 @@ func GetGRPCClientMetrics(directionLabel string) *grpcprom.ClientMetrics { return GRPCOutboundClientMetrics case "inbound": return GRPCInboundClientMetrics + case "intra_proxy": + return GRPCIntraProxyClientMetrics } panic("unknown direction label: " + directionLabel) } diff --git a/proxy/admin_stream_transfer.go b/proxy/admin_stream_transfer.go index 03d47608..ceae61c6 100644 --- a/proxy/admin_stream_transfer.go +++ b/proxy/admin_stream_transfer.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "strings" "sync" "time" @@ -58,6 +59,7 @@ type StreamForwarder struct { targetClusterShardID history.ClusterShardID metricLabelValues []string logger log.Logger + streamID string sourceStreamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient shutdownChan channel.ShutdownOnce @@ -72,7 +74,10 @@ func newStreamForwarder( metricLabelValues []string, logger log.Logger, ) *StreamForwarder { + streamID := BuildForwarderStreamID(sourceClusterShardID, targetClusterShardID) + logger = log.With(logger, tag.NewStringTag("streamID", streamID)) return &StreamForwarder{ + streamID: streamID, adminClient: adminClient, targetStreamServer: targetStreamServer, targetMetadata: targetMetadata, @@ -87,6 +92,11 @@ func newStreamForwarder( // It sets up bidirectional forwarding with proper shutdown handling. // Returns the stream duration. func (f *StreamForwarder) Run() error { + f.logger = log.With(f.logger, + tag.NewStringTag("role", "forwarder"), + tag.NewStringTag("streamID", f.streamID), + ) + // simply forwarding target metadata outgoingContext := metadata.NewOutgoingContext(f.targetStreamServer.Context(), f.targetMetadata) outgoingContext, cancel := context.WithCancel(outgoingContext) @@ -105,6 +115,13 @@ func (f *StreamForwarder) Run() error { defer metrics.AdminServiceStreamsClosedCount.WithLabelValues(f.metricLabelValues...).Inc() streamStartTime := time.Now() + // Register the forwarder stream here + streamTracker := GetGlobalStreamTracker() + sourceShard := ClusterShardIDtoString(f.sourceClusterShardID) + targetShard := ClusterShardIDtoString(f.targetClusterShardID) + streamTracker.RegisterStream(f.streamID, "StreamWorkflowReplicationMessages", "forwarder", sourceShard, targetShard, StreamRoleForwarder) + defer streamTracker.UnregisterStream(f.streamID) + // When one side of the stream dies, we want to tell the other side to hang up // (see https://stackoverflow.com/questions/68218469/how-to-un-wedge-go-grpc-bidi-streaming-server-from-the-blocking-recv-call) // One call to StreamWorkflowReplicationMessages establishes a one-way channel through the proxy from one server to another. @@ -135,8 +152,10 @@ func (f *StreamForwarder) Run() error { } func (f *StreamForwarder) forwardReplicationMessages(wg *sync.WaitGroup) { + f.logger.Info("proxyStreamForwarder forwardReplicationMessages started") + defer f.logger.Info("proxyStreamForwarder forwardReplicationMessages finished") + defer func() { - f.logger.Debug("Shutdown sourceStreamClient.Recv loop.") f.shutdownChan.Shutdown() wg.Done() }() @@ -167,6 +186,15 @@ func (f *StreamForwarder) forwardReplicationMessages(wg *sync.WaitGroup) { case *adminservice.StreamWorkflowReplicationMessagesResponse_Messages: f.logger.Debug("forwarding ReplicationMessages", tag.NewInt64("exclusive", attr.Messages.GetExclusiveHighWatermark())) + msg := make([]string, 0, len(attr.Messages.ReplicationTasks)) + for i, task := range attr.Messages.ReplicationTasks { + msg = append(msg, fmt.Sprintf("[%d]: %v", i, task.SourceTaskId)) + } + f.logger.Info(fmt.Sprintf("forwarding ReplicationMessages: exclusive %v, tasks: %v", attr.Messages.ExclusiveHighWatermark, strings.Join(msg, ", "))) + + streamTracker := GetGlobalStreamTracker() + streamTracker.UpdateStreamReplicationMessages(f.streamID, attr.Messages.ExclusiveHighWatermark) + if err = f.targetStreamServer.Send(resp); err != nil { if err != io.EOF { f.logger.Error("targetStreamServer.Send encountered error", tag.Error(err)) @@ -188,7 +216,8 @@ func (f *StreamForwarder) forwardReplicationMessages(wg *sync.WaitGroup) { func (f *StreamForwarder) forwardAcks(wg *sync.WaitGroup) { defer func() { - f.logger.Debug("Shutdown targetStreamServer.Recv loop.") + f.logger.Info("StreamForwarder forwardAck started") + defer f.logger.Info("proxyStreamForwarder forwardAck finished") f.shutdownChan.Shutdown() var err error closeSent := make(chan struct{}) @@ -235,7 +264,15 @@ func (f *StreamForwarder) forwardAcks(wg *sync.WaitGroup) { switch attr := req.GetAttributes().(type) { case *adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState: - f.logger.Debug("forwarding SyncReplicationState", tag.NewInt64("inclusive", attr.SyncReplicationState.GetInclusiveLowWatermark())) + f.logger.Info(fmt.Sprintf("forwarding SyncReplicationState: inclusive %v, attr: %v", attr.SyncReplicationState.InclusiveLowWatermark, attr)) + + var watermarkTime *time.Time + if attr.SyncReplicationState.InclusiveLowWatermarkTime != nil { + t := attr.SyncReplicationState.InclusiveLowWatermarkTime.AsTime() + watermarkTime = &t + } + streamTracker := GetGlobalStreamTracker() + streamTracker.UpdateStreamSyncReplicationState(f.streamID, attr.SyncReplicationState.InclusiveLowWatermark, watermarkTime) if err = f.sourceStreamClient.Send(req); err != nil { if err != io.EOF { f.logger.Error("sourceStreamClient.Send encountered error", tag.Error(err)) diff --git a/proxy/adminservice.go b/proxy/adminservice.go index 29b27a4f..6cf491c7 100644 --- a/proxy/adminservice.go +++ b/proxy/adminservice.go @@ -4,11 +4,13 @@ import ( "context" "fmt" "strconv" + "sync" "go.temporal.io/api/serviceerror" "go.temporal.io/server/api/adminservice/v1" "go.temporal.io/server/client/history" servercommon "go.temporal.io/server/common" + "go.temporal.io/server/common/channel" "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" @@ -21,19 +23,28 @@ import ( type ( LCMParameters struct { - LCM int32 `yaml:"lcm"` - TargetShardCount int32 `yaml:"targetShardCount"` + LCM int32 + TargetShardCount int32 + } + + RoutingParameters struct { + OverrideShardCount int32 + RoutingLocalShardCount int32 + DirectionLabel string } adminServiceProxyServer struct { adminservice.UnimplementedAdminServiceServer - adminClient adminservice.AdminServiceClient - logger log.Logger - apiOverrides *config.APIOverridesConfig - metricLabelValues []string - reportStreamValue func(idx int32, value int32) - shardCountConfig config.ShardCountConfig - lcmParameters LCMParameters + shardManager ShardManager + adminClient adminservice.AdminServiceClient + adminClientReverse adminservice.AdminServiceClient + logger log.Logger + apiOverrides *config.APIOverridesConfig + metricLabelValues []string + reportStreamValue func(idx int32, value int32) + shardCountConfig config.ShardCountConfig + lcmParameters LCMParameters + routingParameters RoutingParameters } ) @@ -41,25 +52,31 @@ type ( func NewAdminServiceProxyServer( serviceName string, adminClient adminservice.AdminServiceClient, + adminClientReverse adminservice.AdminServiceClient, apiOverrides *config.APIOverridesConfig, metricLabelValues []string, reportStreamValue func(idx int32, value int32), shardCountConfig config.ShardCountConfig, lcmParameters LCMParameters, + routingParameters RoutingParameters, logger log.Logger, + shardManager ShardManager, ) adminservice.AdminServiceServer { // The AdminServiceStreams will duplicate the same output for an underlying connection issue hundreds of times. // Limit their output to three times per minute logger = log.NewThrottledLogger(log.With(logger, common.ServiceTag(serviceName)), func() float64 { return 3.0 / 60.0 }) return &adminServiceProxyServer{ - adminClient: adminClient, - logger: logger, - apiOverrides: apiOverrides, - metricLabelValues: metricLabelValues, - reportStreamValue: reportStreamValue, - shardCountConfig: shardCountConfig, - lcmParameters: lcmParameters, + shardManager: shardManager, + adminClient: adminClient, + adminClientReverse: adminClientReverse, + logger: logger, + apiOverrides: apiOverrides, + metricLabelValues: metricLabelValues, + reportStreamValue: reportStreamValue, + shardCountConfig: shardCountConfig, + lcmParameters: lcmParameters, + routingParameters: routingParameters, } } @@ -103,10 +120,15 @@ func (s *adminServiceProxyServer) DescribeCluster(ctx context.Context, in0 *admi return resp, err } - if s.shardCountConfig.Mode == config.ShardCountLCM { + switch s.shardCountConfig.Mode { + case config.ShardCountLCM: // Present a fake number of shards. In LCM mode, we present the least // common multiple of both cluster shard counts. resp.HistoryShardCount = s.lcmParameters.LCM + case config.ShardCountRouting: + if s.routingParameters.OverrideShardCount > 0 { + resp.HistoryShardCount = s.routingParameters.OverrideShardCount + } } if s.apiOverrides != nil && s.apiOverrides.AdminService.DescribeCluster != nil { @@ -116,6 +138,8 @@ func (s *adminServiceProxyServer) DescribeCluster(ctx context.Context, in0 *admi } } + s.logger.Info("DescribeCluster response", tag.NewStringTag("response", fmt.Sprintf("%v", resp))) + return resp, err } @@ -243,26 +267,33 @@ func ClusterShardIDtoString(sd history.ClusterShardID) string { return fmt.Sprintf("(id: %d, shard: %d)", sd.ClusterID, sd.ShardID) } +func ClusterShardIDtoShortString(sd history.ClusterShardID) string { + return fmt.Sprintf("%d:%d", sd.ClusterID, sd.ShardID) +} + // StreamWorkflowReplicationMessages establishes an HTTP/2 stream. gRPC passes us a stream that represents the initiating server, // and we can freely Send and Recv on that "server". Because this is a proxy, we also establish a bidirectional // stream using our configured adminClient. When we Recv on the initiator, we Send to the client. // When we Recv on the client, we Send to the initiator func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( - targetStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + streamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, ) (retError error) { defer log.CapturePanic(s.logger, &retError) - targetMetadata, ok := metadata.FromIncomingContext(targetStreamServer.Context()) + targetMetadata, ok := metadata.FromIncomingContext(streamServer.Context()) if !ok { return serviceerror.NewInvalidArgument("missing cluster & shard ID metadata") } targetClusterShardID, sourceClusterShardID, err := history.DecodeClusterShardMD( - headers.NewGRPCHeaderGetter(targetStreamServer.Context()), + headers.NewGRPCHeaderGetter(streamServer.Context()), ) if err != nil { return err } + // Detect intra-proxy streams early for logging/behavior toggles + isIntraProxy := common.IsIntraProxy(streamServer.Context()) + logger := log.With(s.logger, tag.NewStringTag("source", ClusterShardIDtoString(sourceClusterShardID)), tag.NewStringTag("target", ClusterShardIDtoString(targetClusterShardID))) @@ -303,9 +334,17 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( targetMetadata.Set(history.MetadataKeyServerShardID, strconv.Itoa(int(newSourceShardID.ShardID))) } + if isIntraProxy { + return s.streamIntraProxyRouting(logger, streamServer, sourceClusterShardID, targetClusterShardID) + } + + if s.shardCountConfig.Mode == config.ShardCountRouting { + return s.streamRouting(logger, streamServer, sourceClusterShardID, targetClusterShardID) + } + forwarder := newStreamForwarder( s.adminClient, - targetStreamServer, + streamServer, targetMetadata, sourceClusterShardID, targetClusterShardID, @@ -321,6 +360,99 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( return nil } +func (s *adminServiceProxyServer) streamIntraProxyRouting( + logger log.Logger, + streamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + sourceShardID history.ClusterShardID, + targetShardID history.ClusterShardID, +) error { + logger.Info("streamIntraProxyRouting started") + defer logger.Info("streamIntraProxyRouting finished") + + // Determine remote peer identity from intra-proxy headers + peerNodeName := "" + if md, ok := metadata.FromIncomingContext(streamServer.Context()); ok { + vals := md.Get(common.IntraProxyOriginProxyIDHeader) + if len(vals) > 0 { + peerNodeName = vals[0] + } + } + + // Only allow intra-proxy when at least one shard is local to this proxy instance + isLocalSource := s.shardManager.IsLocalShard(sourceShardID) + isLocalTarget := s.shardManager.IsLocalShard(targetShardID) + if isLocalTarget || !isLocalSource { + logger.Info("Skipping intra-proxy between two local shards or two remote shards. Client may use outdated shard info.", + tag.NewBoolTag("isLocalSource", isLocalSource), + tag.NewBoolTag("isLocalTarget", isLocalTarget), + ) + return nil + } + + // Sender: handle ACKs coming from peer and forward to original owner + sender := &intraProxyStreamSender{ + logger: logger, + shardManager: s.shardManager, + peerNodeName: peerNodeName, + sourceShardID: sourceShardID, + targetShardID: targetShardID, + } + + shutdownChan := channel.NewShutdownOnce() + go func() { + if err := sender.Run(streamServer, shutdownChan); err != nil { + logger.Error("intraProxyStreamSender.Run error", tag.Error(err)) + } + }() + <-shutdownChan.Channel() + return nil +} + +func (s *adminServiceProxyServer) streamRouting( + logger log.Logger, + streamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + sourceShardID history.ClusterShardID, + targetShardID history.ClusterShardID, +) error { + logger.Info("streamRouting started") + defer logger.Info("streamRouting stopped") + + // client: stream receiver + // server: stream sender + proxyStreamSender := &proxyStreamSender{ + logger: logger, + shardManager: s.shardManager, + sourceShardID: sourceShardID, + targetShardID: targetShardID, + directionLabel: s.routingParameters.DirectionLabel, + } + + proxyStreamReceiver := &proxyStreamReceiver{ + logger: s.logger, + shardManager: s.shardManager, + adminClient: s.adminClientReverse, + localShardCount: s.routingParameters.RoutingLocalShardCount, + sourceShardID: targetShardID, // reverse direction + targetShardID: sourceShardID, // reverse direction + directionLabel: s.routingParameters.DirectionLabel, + } + + shutdownChan := channel.NewShutdownOnce() + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + proxyStreamSender.Run(streamServer, shutdownChan) + }() + go func() { + defer wg.Done() + proxyStreamReceiver.Run(shutdownChan) + }() + wg.Wait() + + return nil +} + func mapShardIDUnique(sourceShardCount, targetShardCount, sourceShardID int32) int32 { targetShardID := servercommon.MapShardID(sourceShardCount, targetShardCount, sourceShardID) if len(targetShardID) != 1 { diff --git a/proxy/adminservice_test.go b/proxy/adminservice_test.go index 456b4cc2..e964ba95 100644 --- a/proxy/adminservice_test.go +++ b/proxy/adminservice_test.go @@ -45,7 +45,8 @@ type adminProxyServerInput struct { func (s *adminserviceSuite) newAdminServiceProxyServer(in adminProxyServerInput, observer *ReplicationStreamObserver) adminservice.AdminServiceServer { return NewAdminServiceProxyServer("test-service-name", s.adminClientMock, - in.apiOverrides, in.metricLabels, observer.ReportStreamValue, config.ShardCountConfig{}, LCMParameters{}, log.NewTestLogger()) + s.adminClientMock, + in.apiOverrides, in.metricLabels, observer.ReportStreamValue, config.ShardCountConfig{}, LCMParameters{}, RoutingParameters{}, log.NewTestLogger(), nil) } func (s *adminserviceSuite) TestAddOrUpdateRemoteCluster() { @@ -234,7 +235,7 @@ func (s *adminserviceSuite) TestAPIOverrides_FailoverVersionIncrement() { s.adminClientMock.EXPECT().DescribeCluster(ctx, gomock.Any()).Return(c.mockResp, nil) resp, err := server.DescribeCluster(ctx, req) s.NoError(err) - s.Equal(c.expResp, resp) + s.Equal(c.expResp.FailoverVersionIncrement, resp.FailoverVersionIncrement) s.Equal("[]", observer.PrintActiveStreams()) }) } diff --git a/proxy/cluster_connection.go b/proxy/cluster_connection.go index 5ee388e3..8fe3f9e8 100644 --- a/proxy/cluster_connection.go +++ b/proxy/cluster_connection.go @@ -59,6 +59,7 @@ type ( inboundClient closableClientConn inboundObserver *ReplicationStreamObserver outboundObserver *ReplicationStreamObserver + shardManager ShardManager logger log.Logger } // contextAwareServer represents a startable gRPC server used to provide the Temporal interface on some connection. @@ -92,8 +93,11 @@ type ( nsTranslations collect.StaticBiMap[string, string] saTranslations config.SearchAttributeTranslation shardCountConfig config.ShardCountConfig - targetShardCount int32 logger log.Logger + + clusterConnection *ClusterConnection + lcmParameters LCMParameters + routingParameters RoutingParameters } ) @@ -128,6 +132,42 @@ func NewClusterConnection(lifetime context.Context, connConfig config.ClusterCon return nil, err } + cc.shardManager = NewShardManager(connConfig.MemberlistConfig, connConfig.ShardCountConfig, connConfig.LocalServer.Connection.TcpClient.TLSConfig, logger) + + getLCMParameters := func(shardCountConfig config.ShardCountConfig, inverse bool) LCMParameters { + if shardCountConfig.Mode != config.ShardCountLCM { + return LCMParameters{} + } + lcm := common.LCM(shardCountConfig.LocalShardCount, shardCountConfig.RemoteShardCount) + if inverse { + return LCMParameters{ + LCM: lcm, + TargetShardCount: shardCountConfig.LocalShardCount, + } + } + return LCMParameters{ + LCM: lcm, + TargetShardCount: shardCountConfig.RemoteShardCount, + } + } + getRoutingParameters := func(shardCountConfig config.ShardCountConfig, inverse bool, directionLabel string) RoutingParameters { + if shardCountConfig.Mode != config.ShardCountRouting { + return RoutingParameters{} + } + if inverse { + return RoutingParameters{ + OverrideShardCount: shardCountConfig.RemoteShardCount, + RoutingLocalShardCount: shardCountConfig.LocalShardCount, + DirectionLabel: directionLabel, + } + } + return RoutingParameters{ + OverrideShardCount: shardCountConfig.LocalShardCount, + RoutingLocalShardCount: shardCountConfig.RemoteShardCount, + DirectionLabel: directionLabel, + } + } + cc.inboundServer, cc.inboundObserver, err = createServer(lifetime, serverConfiguration{ name: sanitizedConnectionName, clusterDefinition: connConfig.RemoteServer, @@ -137,12 +177,15 @@ func NewClusterConnection(lifetime context.Context, connConfig config.ClusterCon nsTranslations: nsTranslations.Inverse(), saTranslations: saTranslations.Inverse(), shardCountConfig: connConfig.ShardCountConfig, - targetShardCount: connConfig.ShardCountConfig.LocalShardCount, logger: cc.logger, + clusterConnection: cc, + lcmParameters: getLCMParameters(connConfig.ShardCountConfig, true), + routingParameters: getRoutingParameters(connConfig.ShardCountConfig, true, "inbound"), }) if err != nil { return nil, err } + cc.outboundServer, cc.outboundObserver, err = createServer(lifetime, serverConfiguration{ name: sanitizedConnectionName, clusterDefinition: connConfig.LocalServer, @@ -152,12 +195,15 @@ func NewClusterConnection(lifetime context.Context, connConfig config.ClusterCon nsTranslations: nsTranslations, saTranslations: saTranslations, shardCountConfig: connConfig.ShardCountConfig, - targetShardCount: connConfig.ShardCountConfig.RemoteShardCount, logger: cc.logger, + clusterConnection: cc, + lcmParameters: getLCMParameters(connConfig.ShardCountConfig, false), + routingParameters: getRoutingParameters(connConfig.ShardCountConfig, false, "outbound"), }) if err != nil { return nil, err } + return cc, nil } @@ -239,11 +285,18 @@ func buildTLSTCPClient(lifetime context.Context, serverAddress string, tlsCfg en } func (c *ClusterConnection) Start() { + if c.shardManager != nil { + err := c.shardManager.Start(c.lifetime) + if err != nil { + c.logger.Error("Failed to start shard manager", tag.Error(err)) + } + } c.inboundServer.Start() c.inboundObserver.Start(c.lifetime, c.inboundServer.Name(), "inbound") c.outboundServer.Start() c.outboundObserver.Start(c.lifetime, c.outboundServer.Name(), "outbound") } + func (c *ClusterConnection) Describe() string { return fmt.Sprintf("[ClusterConnection connects outbound server %s to outbound client %s, inbound server %s to inbound client %s]", c.outboundServer.Describe(), c.outboundClient.Describe(), c.inboundServer.Describe(), c.inboundClient.Describe()) @@ -252,6 +305,7 @@ func (c *ClusterConnection) Describe() string { func (c *ClusterConnection) AcceptingInboundTraffic() bool { return c.inboundClient.CanMakeCalls() && c.inboundServer.CanAcceptConnections() } + func (c *ClusterConnection) AcceptingOutboundTraffic() bool { return c.outboundClient.CanMakeCalls() && c.outboundServer.CanAcceptConnections() } @@ -265,16 +319,19 @@ func buildProxyServer(c serverConfiguration, tlsConfig encryption.TLSConfig, obs } server := grpc.NewServer(serverOpts...) - var lcmParameters LCMParameters - if c.shardCountConfig.Mode == config.ShardCountLCM { - lcmParameters = LCMParameters{ - LCM: common.LCM(c.shardCountConfig.LocalShardCount, c.shardCountConfig.RemoteShardCount), - TargetShardCount: c.targetShardCount, - } - } - - adminServiceImpl := NewAdminServiceProxyServer(fmt.Sprintf("%sAdminService", c.directionLabel), adminservice.NewAdminServiceClient(c.client), - c.clusterDefinition.APIOverrides, []string{c.directionLabel}, observeFn, c.shardCountConfig, lcmParameters, c.logger) + adminServiceImpl := NewAdminServiceProxyServer( + fmt.Sprintf("%sAdminService", c.directionLabel), + adminservice.NewAdminServiceClient(c.client), + adminservice.NewAdminServiceClient(c.managedClient), + c.clusterDefinition.APIOverrides, + []string{c.directionLabel}, + observeFn, + c.shardCountConfig, + c.lcmParameters, + c.routingParameters, + c.logger, + c.clusterConnection.shardManager, + ) var accessControl *auth.AccessControl if c.clusterDefinition.ACLPolicy != nil { accessControl = auth.NewAccesControl(c.clusterDefinition.ACLPolicy.AllowedNamespaces) diff --git a/proxy/debug.go b/proxy/debug.go new file mode 100644 index 00000000..4af633e1 --- /dev/null +++ b/proxy/debug.go @@ -0,0 +1,207 @@ +package proxy + +import ( + "encoding/json" + "net/http" + "time" + + "go.temporal.io/server/client/history" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" + + "github.com/temporalio/s2s-proxy/transport/mux" + "github.com/temporalio/s2s-proxy/transport/mux/session" +) + +type ( + + // ProxyIDEntry is a preview of a ring buffer entry + ProxyIDEntry struct { + ProxyID int64 `json:"proxy_id"` + SourceShard string `json:"source_shard"` + SourceTask int64 `json:"source_task"` + } + + // SenderDebugInfo captures proxy-stream-sender internals for debugging + SenderDebugInfo struct { + RingStartProxyID int64 `json:"ring_start_proxy_id"` + RingSize int `json:"ring_size"` + RingMaxSize int `json:"ring_max_size"` + RingCapacity int `json:"ring_capacity"` + RingHead int `json:"ring_head"` + NextProxyTaskID int64 `json:"next_proxy_task_id"` + PrevAckBySource map[string]int64 `json:"prev_ack_by_source"` + LastHighBySource map[string]int64 `json:"last_high_by_source"` + LastProxyHighBySource map[string]int64 `json:"last_proxy_high_by_source"` + EntriesPreview []ProxyIDEntry `json:"entries_preview"` + } + + // ReceiverDebugInfo captures proxy-stream-receiver ack aggregation state + ReceiverDebugInfo struct { + AckByTarget map[string]int64 `json:"ack_by_target"` + LastAggregatedMin int64 `json:"last_aggregated_min"` + LastExclusiveHighOriginal int64 `json:"last_exclusive_high_original"` + } + + // StreamInfo represents information about an active gRPC stream + StreamInfo struct { + ID string `json:"id"` + Method string `json:"method"` + Direction string `json:"direction"` + Role string `json:"role"` + ClientShard string `json:"client_shard"` + ServerShard string `json:"server_shard"` + StartTime time.Time `json:"start_time"` + LastSeen time.Time `json:"last_seen"` + TotalDuration string `json:"total_duration"` + IdleDuration string `json:"idle_duration"` + LastSyncWatermark *int64 `json:"last_sync_watermark"` + LastSyncWatermarkTime *time.Time `json:"last_sync_watermark_time"` + LastExclusiveHighWatermark *int64 `json:"last_exclusive_high_watermark"` + LastTaskIDs []int64 `json:"last_task_ids"` + SenderDebug *SenderDebugInfo `json:"sender_debug"` + ReceiverDebug *ReceiverDebugInfo `json:"receiver_debug"` + } + + // ShardDebugInfo contains debug information about shard distribution + ShardDebugInfo struct { + Enabled bool `json:"enabled"` + NodeName string `json:"node_name"` + LocalShards map[string]history.ClusterShardID `json:"local_shards"` // key: "clusterID:shardID" + LocalShardCount int `json:"local_shard_count"` + ClusterNodes []string `json:"cluster_nodes"` + ClusterSize int `json:"cluster_size"` + RemoteShards map[string]string `json:"remote_shards"` // shard_id -> node_name + RemoteShardCounts map[string]int `json:"remote_shard_counts"` // node_name -> shard_count + } + + // ChannelDebugInfo holds debug information about channels + ChannelDebugInfo struct { + RemoteSendChannels map[string]int `json:"remote_send_channels"` // shard ID -> buffer size + LocalAckChannels map[string]int `json:"local_ack_channels"` // shard ID -> buffer size + TotalSendChannels int `json:"total_send_channels"` + TotalAckChannels int `json:"total_ack_channels"` + } + + // MuxConnectionInfo holds debug information about a mux connection + MuxConnectionInfo struct { + ID string `json:"id"` + LocalAddr string `json:"local_addr"` + RemoteAddr string `json:"remote_addr"` + State string `json:"state"` + IsClosed bool `json:"is_closed"` + } + + // MuxConnectionsDebugInfo holds debug information about mux connections for a cluster connection + MuxConnectionsDebugInfo struct { + ConnectionName string `json:"connection_name"` + Direction string `json:"direction"` + Address string `json:"address"` + Connections []MuxConnectionInfo `json:"connections"` + ConnectionCount int `json:"connection_count"` + } + + DebugResponse struct { + Timestamp time.Time `json:"timestamp"` + ActiveStreams []StreamInfo `json:"active_streams"` + StreamCount int `json:"stream_count"` + ShardInfos []ShardDebugInfo `json:"shard_infos"` + ChannelInfos []ChannelDebugInfo `json:"channel_infos"` + MuxConnections []MuxConnectionsDebugInfo `json:"mux_connections"` + } +) + +func HandleDebugInfo(w http.ResponseWriter, r *http.Request, proxyInstance *Proxy, logger log.Logger) { + w.Header().Set("Content-Type", "application/json") + + var activeStreams []StreamInfo + var streamCount int + var shardInfos []ShardDebugInfo + var channelInfos []ChannelDebugInfo + var muxConnections []MuxConnectionsDebugInfo + + // Get active streams information + streamTracker := GetGlobalStreamTracker() + activeStreams = streamTracker.GetActiveStreams() + streamCount = streamTracker.GetStreamCount() + for _, clusterConnection := range proxyInstance.clusterConnections { + if clusterConnection.shardManager != nil { + shardInfos = append(shardInfos, clusterConnection.shardManager.GetShardInfos()...) + channelInfos = append(channelInfos, clusterConnection.shardManager.GetChannelInfo()) + } + + // Collect mux connection info from inbound and outbound servers + muxConnections = append(muxConnections, getMuxConnectionsInfo(clusterConnection.inboundServer, "inbound")...) + muxConnections = append(muxConnections, getMuxConnectionsInfo(clusterConnection.outboundServer, "outbound")...) + } + + response := DebugResponse{ + Timestamp: time.Now(), + ActiveStreams: activeStreams, + StreamCount: streamCount, + ShardInfos: shardInfos, + ChannelInfos: channelInfos, + MuxConnections: muxConnections, + } + + if err := json.NewEncoder(w).Encode(response); err != nil { + logger.Error("Failed to encode debug response", tag.Error(err)) + http.Error(w, "Internal server error", http.StatusInternalServerError) + } +} + +func getMuxConnectionsInfo(server contextAwareServer, direction string) []MuxConnectionsDebugInfo { + muxMgr, ok := server.(mux.MultiMuxManager) + if !ok { + return nil + } + + connections := muxMgr.GetMuxConnections() + if len(connections) == 0 { + return nil + } + + var connInfos []MuxConnectionInfo + for id, muxSession := range connections { + localAddr, remoteAddr := muxSession.GetConnectionInfo() + state := muxSession.State() + stateStr := "unknown" + if state != nil { + switch state.State { + case session.Connected: + stateStr = "connected" + case session.Closed: + stateStr = "closed" + case session.Error: + stateStr = "error" + } + } + + localAddrStr := "" + if localAddr != nil { + localAddrStr = localAddr.String() + } + remoteAddrStr := "" + if remoteAddr != nil { + remoteAddrStr = remoteAddr.String() + } + + connInfos = append(connInfos, MuxConnectionInfo{ + ID: id, + LocalAddr: localAddrStr, + RemoteAddr: remoteAddrStr, + State: stateStr, + IsClosed: muxSession.IsClosed(), + }) + } + + return []MuxConnectionsDebugInfo{ + { + ConnectionName: muxMgr.Name(), + Direction: direction, + Address: muxMgr.Address(), + Connections: connInfos, + ConnectionCount: len(connInfos), + }, + } +} diff --git a/proxy/intra_proxy_router.go b/proxy/intra_proxy_router.go new file mode 100644 index 00000000..24e0655b --- /dev/null +++ b/proxy/intra_proxy_router.go @@ -0,0 +1,905 @@ +package proxy + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "sync" + "time" + + "go.temporal.io/server/api/adminservice/v1" + replicationv1 "go.temporal.io/server/api/replication/v1" + "go.temporal.io/server/client/history" + "go.temporal.io/server/common/channel" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" + + "github.com/temporalio/s2s-proxy/common" + "github.com/temporalio/s2s-proxy/encryption" + "github.com/temporalio/s2s-proxy/metrics" + "github.com/temporalio/s2s-proxy/transport/grpcutil" +) + +// intraProxyManager maintains long-lived intra-proxy streams to peer proxies and +// provides simple send helpers (e.g., forwarding ACKs). +type intraProxyManager struct { + logger log.Logger + streamsMu sync.RWMutex + shardManager ShardManager + notifyCh chan struct{} + // Group state by remote peer for unified lifecycle ops + peers map[string]*peerState +} + +type peerState struct { + conn *grpc.ClientConn + receivers map[peerStreamKey]*intraProxyStreamReceiver + senders map[peerStreamKey]*intraProxyStreamSender + recvShutdown map[peerStreamKey]channel.ShutdownOnce +} + +type peerStreamKey struct { + targetShard history.ClusterShardID + sourceShard history.ClusterShardID +} + +func newIntraProxyManager(logger log.Logger, shardManager ShardManager) *intraProxyManager { + return &intraProxyManager{ + logger: logger, + shardManager: shardManager, + peers: make(map[string]*peerState), + notifyCh: make(chan struct{}), + } +} + +// intraProxyStreamSender registers server stream and forwards upstream ACKs to shard owners (local or remote). +// Replication messages are sent by intraProxyManager.sendMessages using the registered server stream. +type intraProxyStreamSender struct { + logger log.Logger + shardManager ShardManager + peerNodeName string + targetShardID history.ClusterShardID + sourceShardID history.ClusterShardID + streamID string + sourceStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer +} + +func (s *intraProxyStreamSender) Run( + sourceStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + shutdownChan channel.ShutdownOnce, +) error { + s.streamID = BuildIntraProxySenderStreamID(s.peerNodeName, s.sourceShardID, s.targetShardID) + s.logger = log.With(s.logger, tag.NewStringTag("streamID", s.streamID)) + + s.logger.Info("intraProxyStreamSender Run") + defer s.logger.Info("intraProxyStreamSender Run finished") + + // Register server-side intra-proxy stream in tracker + st := GetGlobalStreamTracker() + st.RegisterStream(s.streamID, "StreamWorkflowReplicationMessages", "intra-proxy", ClusterShardIDtoString(s.sourceShardID), ClusterShardIDtoString(s.targetShardID), StreamRoleForwarder) + defer st.UnregisterStream(s.streamID) + + s.sourceStreamServer = sourceStreamServer + + // register this sender so sendMessages can use it + s.shardManager.GetIntraProxyManager().RegisterSender(s.peerNodeName, s.targetShardID, s.sourceShardID, s) + defer s.shardManager.GetIntraProxyManager().UnregisterSender(s.peerNodeName, s.targetShardID, s.sourceShardID) + + // Send pending watermarks to late-registering shards + // When a sender is registered, check if there's an active receiver for the source shard + // that has a pending watermark, and send it immediately to the peer + if receiver, ok := s.shardManager.GetActiveReceiver(s.sourceShardID); ok { + if lastWatermark := receiver.GetLastWatermark(); lastWatermark != nil && lastWatermark.ExclusiveHighWatermark > 0 { + s.logger.Info("Sending pending watermark to peer on sender registration", + tag.NewInt64("exclusive_high", lastWatermark.ExclusiveHighWatermark), + tag.NewStringTag("peer", s.peerNodeName)) + resp := &adminservice.StreamWorkflowReplicationMessagesResponse{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ + Messages: &replicationv1.WorkflowReplicationMessages{ + ExclusiveHighWatermark: lastWatermark.ExclusiveHighWatermark, + Priority: lastWatermark.Priority, + }, + }, + } + if err := s.sendReplicationMessages(resp); err != nil { + s.logger.Warn("Failed to send pending watermark to peer on sender registration", tag.Error(err)) + } + } + } + + // recv ACKs from peer and route to original source shard owner + return s.recvAck(shutdownChan) +} + +// recvAck reads ACKs from the peer and routes them to the source shard owner. +func (s *intraProxyStreamSender) recvAck(shutdownChan channel.ShutdownOnce) error { + s.logger.Info("intraProxyStreamSender recvAck") + defer func() { + s.logger.Info("intraProxyStreamSender recvAck finished") + shutdownChan.Shutdown() + }() + + for !shutdownChan.IsShutdown() { + req, err := s.sourceStreamServer.Recv() + if err == io.EOF { + s.logger.Info("intraProxyStreamSender recvAck encountered EOF") + return nil + } + if err != nil { + s.logger.Error("intraProxyStreamSender recvAck encountered error", tag.Error(err)) + return err + } + if attr, ok := req.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState); ok && attr.SyncReplicationState != nil { + ack := attr.SyncReplicationState.InclusiveLowWatermark + + s.logger.Info("Sender received upstream ACK", tag.NewInt64("inclusive_low", ack)) + + // Update server-side intra-proxy stream tracker with sync watermark + st := GetGlobalStreamTracker() + st.UpdateStreamSyncReplicationState(s.streamID, ack, nil) + st.UpdateStream(s.streamID) + + routedAck := &RoutedAck{ + TargetShard: s.targetShardID, + Req: &adminservice.StreamWorkflowReplicationMessagesRequest{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState{ + SyncReplicationState: &replicationv1.SyncReplicationState{InclusiveLowWatermark: ack}, + }, + }, + } + + s.logger.Info("Sender forwarding ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(s.sourceShardID)), tag.NewInt64("ack", ack)) + // FIXME: should retry. If not succeed, return and shutdown the stream + sent := s.shardManager.DeliverAckToShardOwner(s.sourceShardID, routedAck, shutdownChan, s.logger, ack, false) + if !sent { + s.logger.Error("Sender failed to forward ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(s.sourceShardID)), tag.NewInt64("ack", ack)) + return fmt.Errorf("failed to forward ACK to source shard") + } + } + } + return nil +} + +// sendReplicationMessages sends replication messages to the peer via the server stream. +func (s *intraProxyStreamSender) sendReplicationMessages(resp *adminservice.StreamWorkflowReplicationMessagesResponse) error { + s.logger.Info("intraProxyStreamSender sendReplicationMessages started") + defer s.logger.Info("intraProxyStreamSender sendReplicationMessages finished") + + // Update server-side intra-proxy tracker for outgoing messages + if msgs, ok := resp.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesResponse_Messages); ok && msgs.Messages != nil { + st := GetGlobalStreamTracker() + ids := make([]int64, 0, len(msgs.Messages.ReplicationTasks)) + for _, t := range msgs.Messages.ReplicationTasks { + ids = append(ids, t.SourceTaskId) + } + st.UpdateStreamLastTaskIDs(s.streamID, ids) + st.UpdateStreamReplicationMessages(s.streamID, msgs.Messages.ExclusiveHighWatermark) + st.UpdateStream(s.streamID) + } + if err := s.sourceStreamServer.Send(resp); err != nil { + return err + } + return nil +} + +// intraProxyStreamReceiver ensures a client stream to peer exists and sends aggregated ACKs upstream. +type intraProxyStreamReceiver struct { + logger log.Logger + shardManager ShardManager + intraMgr *intraProxyManager + peerNodeName string + targetShardID history.ClusterShardID + sourceShardID history.ClusterShardID + streamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient + streamID string + shutdown channel.ShutdownOnce + cancel context.CancelFunc + // lastWatermark tracks the last watermark received from source shard for late-registering target shards + lastWatermarkMu sync.RWMutex + lastWatermark *replicationv1.WorkflowReplicationMessages +} + +// Run opens the client stream with metadata, registers tracking, and starts receiver goroutines. +func (r *intraProxyStreamReceiver) Run(ctx context.Context, shardManager ShardManager, conn *grpc.ClientConn) error { + r.streamID = BuildIntraProxyReceiverStreamID(r.peerNodeName, r.sourceShardID, r.targetShardID) + r.logger = log.With(r.logger, tag.NewStringTag("streamID", r.streamID)) + + r.logger.Info("intraProxyStreamReceiver Run") + // Build metadata according to receiver pattern: client=targetShard, server=sourceShard + md := metadata.New(map[string]string{}) + md.Set(history.MetadataKeyClientClusterID, fmt.Sprintf("%d", r.targetShardID.ClusterID)) + md.Set(history.MetadataKeyClientShardID, fmt.Sprintf("%d", r.targetShardID.ShardID)) + md.Set(history.MetadataKeyServerClusterID, fmt.Sprintf("%d", r.sourceShardID.ClusterID)) + md.Set(history.MetadataKeyServerShardID, fmt.Sprintf("%d", r.sourceShardID.ShardID)) + ctx = metadata.NewOutgoingContext(ctx, md) + ctx = common.WithIntraProxyHeaders(ctx, map[string]string{ + common.IntraProxyOriginProxyIDHeader: shardManager.GetShardInfo().NodeName, + }) + + // Ensure we can cancel Recv() by canceling the context when tearing down + ctx, cancel := context.WithCancel(ctx) + r.cancel = cancel + + client := adminservice.NewAdminServiceClient(conn) + streamClient, err := client.StreamWorkflowReplicationMessages(ctx) + if err != nil { + if r.cancel != nil { + r.cancel() + } + return err + } + r.streamClient = streamClient + + r.shardManager.RegisterActiveReceiver(r.sourceShardID, r) + defer r.shardManager.UnregisterActiveReceiver(r.sourceShardID) + + // Register client-side intra-proxy stream in tracker + st := GetGlobalStreamTracker() + st.RegisterStream(r.streamID, "StreamWorkflowReplicationMessages", "intra-proxy", ClusterShardIDtoString(r.sourceShardID), ClusterShardIDtoString(r.targetShardID), StreamRoleForwarder) + defer st.UnregisterStream(r.streamID) + + // Start replication receiver loop + return r.recvReplicationMessages() +} + +// recvReplicationMessages receives replication messages and forwards to local shard owner. +func (r *intraProxyStreamReceiver) recvReplicationMessages() error { + r.logger.Info("intraProxyStreamReceiver recvReplicationMessages started") + defer r.logger.Info("intraProxyStreamReceiver recvReplicationMessages finished") + + shutdown := r.shutdown + defer shutdown.Shutdown() + backoff := 10 * time.Millisecond + for !shutdown.IsShutdown() { + resp, err := r.streamClient.Recv() + if err == io.EOF { + r.logger.Info("recvReplicationMessages encountered EOF") + return nil + } + if err != nil { + r.logger.Error("intra-proxy stream Recv error", tag.Error(err)) + return err + } + if msgs, ok := resp.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesResponse_Messages); ok && msgs.Messages != nil { + // Capture watermark value immediately to avoid data race with sender + exclusiveHighWatermark := msgs.Messages.ExclusiveHighWatermark + priority := msgs.Messages.Priority + + // Update client-side intra-proxy tracker for received messages + st := GetGlobalStreamTracker() + ids := make([]int64, 0, len(msgs.Messages.ReplicationTasks)) + for _, t := range msgs.Messages.ReplicationTasks { + ids = append(ids, t.SourceTaskId) + } + st.UpdateStreamLastTaskIDs(r.streamID, ids) + st.UpdateStreamReplicationMessages(r.streamID, exclusiveHighWatermark) + st.UpdateStream(r.streamID) + + // Track last watermark for late-registering shards + r.lastWatermarkMu.Lock() + r.lastWatermark = &replicationv1.WorkflowReplicationMessages{ + ExclusiveHighWatermark: exclusiveHighWatermark, + Priority: priority, + } + r.lastWatermarkMu.Unlock() + + r.logger.Info(fmt.Sprintf("Receiver received ReplicationTasks: exclusive_high=%d ids=%v", exclusiveHighWatermark, ids)) + + msg := RoutedMessage{SourceShard: r.sourceShardID, Resp: resp} + sent := false + logged := false + for !sent { + if ch, ok := r.shardManager.GetRemoteSendChan(r.targetShardID); ok { + func() { + defer func() { + if panicErr := recover(); panicErr != nil { + r.logger.Warn("Failed to send to local target shard (channel closed)", + tag.NewStringTag("targetShard", ClusterShardIDtoString(r.targetShardID))) + } + }() + select { + case ch <- msg: + sent = true + r.logger.Info("Receiver sent ReplicationTasks to local target shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(r.targetShardID)), tag.NewInt64("exclusive_high", exclusiveHighWatermark)) + case <-shutdown.Channel(): + // Will be handled outside the func + } + }() + if shutdown.IsShutdown() { + return nil + } + } else { + if !logged { + r.logger.Warn("No local send channel yet for target shard; waiting", + tag.NewStringTag("targetShard", ClusterShardIDtoString(r.targetShardID))) + logged = true + } + time.Sleep(backoff) + if backoff < time.Second { + backoff *= 2 + } + } + } + backoff = 10 * time.Millisecond + } + } + return nil +} + +// sendAck sends an ACK upstream via the client stream and updates tracker. +func (r *intraProxyStreamReceiver) sendAck(req *adminservice.StreamWorkflowReplicationMessagesRequest) error { + r.logger.Info("intraProxyStreamReceiver sendAck started") + defer r.logger.Info("intraProxyStreamReceiver sendAck finished") + + if err := r.streamClient.Send(req); err != nil { + return err + } + if attr, ok := req.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState); ok && attr.SyncReplicationState != nil { + st := GetGlobalStreamTracker() + st.UpdateStreamSyncReplicationState(r.streamID, attr.SyncReplicationState.InclusiveLowWatermark, nil) + st.UpdateStream(r.streamID) + } + return nil +} + +// GetTargetShardID returns the target shard ID for this receiver +func (r *intraProxyStreamReceiver) GetTargetShardID() history.ClusterShardID { + return r.targetShardID +} + +// GetSourceShardID returns the source shard ID for this receiver +func (r *intraProxyStreamReceiver) GetSourceShardID() history.ClusterShardID { + return r.sourceShardID +} + +// GetLastWatermark returns the last watermark received from the source shard +func (r *intraProxyStreamReceiver) GetLastWatermark() *replicationv1.WorkflowReplicationMessages { + r.lastWatermarkMu.RLock() + defer r.lastWatermarkMu.RUnlock() + return r.lastWatermark +} + +// NotifyNewTargetShard notifies the receiver about a newly registered target shard +func (r *intraProxyStreamReceiver) NotifyNewTargetShard(targetShardID history.ClusterShardID) { + r.sendPendingWatermarkToShard(targetShardID) +} + +// sendPendingWatermarkToShard sends the last known watermark to a newly registered target shard +// This ensures late-registering shards receive watermarks that were sent before they registered +func (r *intraProxyStreamReceiver) sendPendingWatermarkToShard(targetShardID history.ClusterShardID) { + r.lastWatermarkMu.RLock() + lastWatermark := r.lastWatermark + r.lastWatermarkMu.RUnlock() + + if lastWatermark == nil || lastWatermark.ExclusiveHighWatermark == 0 { + // No pending watermark to send + return + } + + r.logger.Info("Sending pending watermark to newly registered shard", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID)), + tag.NewInt64("exclusive_high", lastWatermark.ExclusiveHighWatermark)) + + msg := RoutedMessage{ + SourceShard: r.sourceShardID, + Resp: &adminservice.StreamWorkflowReplicationMessagesResponse{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ + Messages: &replicationv1.WorkflowReplicationMessages{ + ExclusiveHighWatermark: lastWatermark.ExclusiveHighWatermark, + Priority: lastWatermark.Priority, + }, + }, + }, + } + + // Try to send to local shard first + if sendChan, exists := r.shardManager.GetRemoteSendChan(targetShardID); exists { + clonedResp := proto.Clone(msg.Resp).(*adminservice.StreamWorkflowReplicationMessagesResponse) + clonedMsg := RoutedMessage{ + SourceShard: msg.SourceShard, + Resp: clonedResp, + } + select { + case sendChan <- clonedMsg: + r.logger.Info("Sent pending watermark to local shard", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + default: + r.logger.Warn("Failed to send pending watermark to local shard (channel full)", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + } + return + } + + // If not local, try to send to remote shard + if r.shardManager != nil { + shutdownChan := channel.NewShutdownOnce() + clonedResp := proto.Clone(msg.Resp).(*adminservice.StreamWorkflowReplicationMessagesResponse) + clonedMsg := RoutedMessage{ + SourceShard: msg.SourceShard, + Resp: clonedResp, + } + if r.shardManager.DeliverMessagesToShardOwner(targetShardID, &clonedMsg, shutdownChan, r.logger) { + r.logger.Info("Sent pending watermark to remote shard", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + } else { + r.logger.Warn("Failed to send pending watermark to remote shard", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + } + } +} + +func (m *intraProxyManager) RegisterSender( + peerNodeName string, + targetShard history.ClusterShardID, + sourceShard history.ClusterShardID, + sender *intraProxyStreamSender, +) { + // Cross-cluster only + if targetShard.ClusterID == sourceShard.ClusterID { + return + } + key := peerStreamKey{targetShard: targetShard, sourceShard: sourceShard} + m.logger.Info("RegisterSender", tag.NewStringTag("peerNodeName", peerNodeName), tag.NewStringTag("key", fmt.Sprintf("%v", key)), tag.NewStringTag("sender", sender.streamID)) + m.streamsMu.Lock() + ps := m.peers[peerNodeName] + if ps == nil { + ps = &peerState{receivers: make(map[peerStreamKey]*intraProxyStreamReceiver), senders: make(map[peerStreamKey]*intraProxyStreamSender), recvShutdown: make(map[peerStreamKey]channel.ShutdownOnce)} + m.peers[peerNodeName] = ps + } + if ps.senders == nil { + ps.senders = make(map[peerStreamKey]*intraProxyStreamSender) + } + ps.senders[key] = sender + m.streamsMu.Unlock() +} + +func (m *intraProxyManager) UnregisterSender( + peerNodeName string, + targetShard history.ClusterShardID, + sourceShard history.ClusterShardID, +) { + key := peerStreamKey{targetShard: targetShard, sourceShard: sourceShard} + m.logger.Info("UnregisterSender", tag.NewStringTag("peerNodeName", peerNodeName), tag.NewStringTag("key", fmt.Sprintf("%v", key))) + m.streamsMu.Lock() + if ps := m.peers[peerNodeName]; ps != nil && ps.senders != nil { + delete(ps.senders, key) + } + m.streamsMu.Unlock() +} + +// EnsureReceiverForPeerShard ensures a client stream and an ACK aggregator exist for the given peer/shard pair. +func (m *intraProxyManager) EnsureReceiverForPeerShard(peerNodeName string, targetShard history.ClusterShardID, sourceShard history.ClusterShardID) { + logger := log.With(m.logger, + tag.NewStringTag("peerNodeName", peerNodeName), + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShard)), + tag.NewStringTag("sourceShard", ClusterShardIDtoString(sourceShard))) + logger.Info("EnsureReceiverForPeerShard") + + // Cross-cluster only + if targetShard.ClusterID == sourceShard.ClusterID { + return + } + // Do not create intra-proxy streams to self instance + if peerNodeName == m.shardManager.GetNodeName() { + return + } + // Require at least one shard to be local to this instance + isLocalTargetShard := m.shardManager.IsLocalShard(targetShard) + isLocalSourceShard := m.shardManager.IsLocalShard(sourceShard) + if !isLocalTargetShard && !isLocalSourceShard { + logger.Info("EnsureReceiverForPeerShard skipping because neither shard is local", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShard)), tag.NewStringTag("sourceShard", ClusterShardIDtoString(sourceShard)), tag.NewBoolTag("isLocalTargetShard", isLocalTargetShard), tag.NewBoolTag("isLocalSourceShard", isLocalSourceShard)) + return + } + // Consolidated path: ensure stream and background loops + err := m.ensureStream(context.Background(), logger, peerNodeName, targetShard, sourceShard) + if err != nil { + logger.Error("failed to ensureStream", tag.Error(err)) + } +} + +// ensurePeer ensures a per-peer state with a shared gRPC connection exists. +func (m *intraProxyManager) ensurePeer( + ctx context.Context, + peerNodeName string, +) (*peerState, error) { + logger := log.With(m.logger, tag.NewStringTag("peerNodeName", peerNodeName)) + logger.Info("ensurePeer started") + defer logger.Info("ensurePeer finished") + + m.streamsMu.RLock() + if ps, ok := m.peers[peerNodeName]; ok && ps != nil && ps.conn != nil { + m.streamsMu.RUnlock() + logger.Info("ensurePeer found existing peer with connection") + return ps, nil + } + m.streamsMu.RUnlock() + + logger.Info("ensurePeer creating new peer connection") + + // Build TLS from this proxy's outbound client TLS config if available + tlsCfg := m.shardManager.GetIntraProxyTLSConfig() + var parsedTLSCfg *tls.Config + if tlsCfg.IsEnabled() { + logger.Info("ensurePeer TLS enabled, building TLS config") + var err error + parsedTLSCfg, err = encryption.GetClientTLSConfig(tlsCfg) + if err != nil { + logger.Error("ensurePeer failed to create TLS config", tag.Error(err)) + return nil, fmt.Errorf("config error when creating tls config: %w", err) + } + } else { + logger.Info("ensurePeer TLS disabled") + } + dialOpts := grpcutil.MakeDialOptions(parsedTLSCfg, metrics.GetGRPCClientMetrics("intra_proxy")) + + proxyAddresses, ok := m.shardManager.GetProxyAddress(peerNodeName) + if !ok { + logger.Error("ensurePeer proxy address not found") + return nil, fmt.Errorf("proxy address not found") + } + logger.Info("ensurePeer dialing peer", tag.NewStringTag("proxyAddresses", proxyAddresses)) + + cc, err := grpc.NewClient(proxyAddresses, dialOpts...) + if err != nil { + logger.Error("ensurePeer failed to dial peer", tag.Error(err)) + return nil, err + } + logger.Info("ensurePeer successfully dialed peer") + + m.streamsMu.Lock() + ps := m.peers[peerNodeName] + if ps == nil { + logger.Info("ensurePeer creating new peer state") + ps = &peerState{conn: cc, receivers: make(map[peerStreamKey]*intraProxyStreamReceiver), senders: make(map[peerStreamKey]*intraProxyStreamSender), recvShutdown: make(map[peerStreamKey]channel.ShutdownOnce)} + m.peers[peerNodeName] = ps + } else { + logger.Info("ensurePeer updating existing peer state with new connection") + old := ps.conn + ps.conn = cc + if old != nil { + logger.Info("ensurePeer closing old connection") + _ = old.Close() + } + if ps.receivers == nil { + ps.receivers = make(map[peerStreamKey]*intraProxyStreamReceiver) + } + if ps.senders == nil { + ps.senders = make(map[peerStreamKey]*intraProxyStreamSender) + } + if ps.recvShutdown == nil { + ps.recvShutdown = make(map[peerStreamKey]channel.ShutdownOnce) + } + } + m.streamsMu.Unlock() + return ps, nil +} + +// ensureStream dials a peer proxy outbound server and opens a replication stream. +func (m *intraProxyManager) ensureStream( + ctx context.Context, + logger log.Logger, + peerNodeName string, + targetShard history.ClusterShardID, + sourceShard history.ClusterShardID, +) error { + logger.Info("ensureStream") + key := peerStreamKey{targetShard: targetShard, sourceShard: sourceShard} + + // Fast path: already exists + m.streamsMu.RLock() + if ps, ok := m.peers[peerNodeName]; ok && ps != nil { + if r, ok2 := ps.receivers[key]; ok2 && r != nil && r.streamClient != nil { + m.streamsMu.RUnlock() + logger.Info("ensureStream reused") + return nil + } + } + m.streamsMu.RUnlock() + + // Reuse shared connection per peer + ps, err := m.ensurePeer(ctx, peerNodeName) + if err != nil { + logger.Error("Failed to ensure peer", tag.Error(err)) + return err + } + + // Create receiver and register tracking + recv := &intraProxyStreamReceiver{ + logger: log.With(m.logger, + tag.NewStringTag("peerNodeName", peerNodeName), + tag.NewStringTag("targetShardID", ClusterShardIDtoString(targetShard)), + tag.NewStringTag("sourceShardID", ClusterShardIDtoString(sourceShard))), + shardManager: m.shardManager, + intraMgr: m, + peerNodeName: peerNodeName, + targetShardID: targetShard, + sourceShardID: sourceShard, + } + // initialize shutdown handle and register it for lifecycle management + recv.shutdown = channel.NewShutdownOnce() + m.streamsMu.Lock() + ps.receivers[key] = recv + ps.recvShutdown[key] = recv.shutdown + m.streamsMu.Unlock() + m.logger.Info("intraProxyStreamReceiver added", tag.NewStringTag("peerNodeName", peerNodeName), tag.NewStringTag("key", fmt.Sprintf("%v", key)), tag.NewStringTag("receiver", recv.streamID)) + + // Let the receiver open stream, register tracking, and start goroutines + go func() { + if err := recv.Run(ctx, m.shardManager, ps.conn); err != nil { + m.logger.Error("intraProxyStreamReceiver.Run error", tag.Error(err)) + } + // remove the receiver from the peer state + m.streamsMu.Lock() + delete(ps.receivers, key) + delete(ps.recvShutdown, key) + m.streamsMu.Unlock() + }() + return nil +} + +// sendAck forwards an ACK to the specified peer stream (creates it on demand). +func (m *intraProxyManager) sendAck( + ctx context.Context, + peerNodeName string, + clientShard history.ClusterShardID, + serverShard history.ClusterShardID, + req *adminservice.StreamWorkflowReplicationMessagesRequest, +) error { + key := peerStreamKey{targetShard: clientShard, sourceShard: serverShard} + m.streamsMu.RLock() + defer m.streamsMu.RUnlock() + if ps, ok := m.peers[peerNodeName]; ok && ps != nil { + if r, ok2 := ps.receivers[key]; ok2 && r != nil && r.streamClient != nil { + if err := r.sendAck(req); err != nil { + m.logger.Error("Failed to send intra-proxy ACK", tag.Error(err)) + return err + } + return nil + } + } + return fmt.Errorf("peer not found") +} + +// sendReplicationMessages sends replication messages to the peer via the server stream. +func (m *intraProxyManager) sendReplicationMessages( + ctx context.Context, + peerNodeName string, + targetShard history.ClusterShardID, + sourceShard history.ClusterShardID, + resp *adminservice.StreamWorkflowReplicationMessagesResponse, +) error { + key := peerStreamKey{targetShard: targetShard, sourceShard: sourceShard} + logger := log.With(m.logger, tag.NewStringTag("task-target-shard", ClusterShardIDtoString(targetShard)), tag.NewStringTag("task-source-shard", ClusterShardIDtoString(sourceShard))) + logger.Info("sendReplicationMessages") + defer logger.Info("sendReplicationMessages finished") + + // Try server stream first with short retry/backoff to await registration + deadline := time.Now().Add(2 * time.Second) + backoff := 10 * time.Millisecond + for { + var sender *intraProxyStreamSender + m.streamsMu.RLock() + ps, ok := m.peers[peerNodeName] + if ok && ps != nil && ps.senders != nil { + logger.Info("sendReplicationMessages senders for node", tag.NewStringTag("node", peerNodeName), tag.NewStringTag("senders", fmt.Sprintf("%v", ps.senders))) + if s, ok2 := ps.senders[key]; ok2 && s != nil { + sender = s + } + } + m.streamsMu.RUnlock() + logger.Info("sendReplicationMessages sender", tag.NewStringTag("sender", fmt.Sprintf("%v", sender))) + + if sender != nil { + if err := sender.sendReplicationMessages(resp); err != nil { + logger.Error("Failed to send intra-proxy replication messages via server stream", tag.Error(err)) + return err + } + return nil + } + + if time.Now().After(deadline) { + break + } + time.Sleep(backoff) + if backoff < 200*time.Millisecond { + backoff *= 2 + } + } + + return fmt.Errorf("failed to send replication messages") +} + +// closePeerLocked shuts down and removes all resources for a peer. Caller must hold m.streamsMu. +func (m *intraProxyManager) closePeerLocked(peer string, ps *peerState) { + // Shutdown receivers and unregister client-side tracker entries + for key, shut := range ps.recvShutdown { + if shut != nil { + shut.Shutdown() + } + st := GetGlobalStreamTracker() + cliID := BuildIntraProxyReceiverStreamID(peer, key.targetShard, key.sourceShard) + st.UnregisterStream(cliID) + delete(ps.recvShutdown, key) + } + // Close client streams (receiver cleanup is handled by its own goroutine) + for key := range ps.receivers { + m.logger.Info("intraProxyStreamReceiver deleted", tag.NewStringTag("peerNodeName", peer), tag.NewStringTag("key", fmt.Sprintf("%v", key)), tag.NewStringTag("receiver", ps.receivers[key].streamID)) + delete(ps.receivers, key) + } + // Unregister server-side tracker entries + for key := range ps.senders { + st := GetGlobalStreamTracker() + srvID := BuildIntraProxySenderStreamID(peer, key.targetShard, key.sourceShard) + st.UnregisterStream(srvID) + delete(ps.senders, key) + } + if ps.conn != nil { + _ = ps.conn.Close() + ps.conn = nil + } + delete(m.peers, peer) +} + +// closePeerShardLocked shuts down and removes resources for a specific peer/shard pair. Caller must hold m.streamsMu. +func (m *intraProxyManager) closePeerShardLocked(peer string, ps *peerState, key peerStreamKey) { + m.logger.Info("closePeerShardLocked", tag.NewStringTag("peer", peer), tag.NewStringTag("clientShard", ClusterShardIDtoString(key.targetShard)), tag.NewStringTag("serverShard", ClusterShardIDtoString(key.sourceShard))) + if shut, ok := ps.recvShutdown[key]; ok && shut != nil { + shut.Shutdown() + st := GetGlobalStreamTracker() + cliID := BuildIntraProxyReceiverStreamID(peer, key.targetShard, key.sourceShard) + st.UnregisterStream(cliID) + delete(ps.recvShutdown, key) + } + if r, ok := ps.receivers[key]; ok { + // cancel stream context and attempt to close client send side + if r.cancel != nil { + r.cancel() + } + if r.streamClient != nil { + _ = r.streamClient.CloseSend() + } + m.logger.Info("intraProxyStreamReceiver deleted", tag.NewStringTag("peerNodeName", peer), tag.NewStringTag("key", fmt.Sprintf("%v", key)), tag.NewStringTag("receiver", r.streamID)) + delete(ps.receivers, key) + } + st := GetGlobalStreamTracker() + srvID := BuildIntraProxySenderStreamID(peer, key.targetShard, key.sourceShard) + st.UnregisterStream(srvID) + delete(ps.senders, key) +} + +// ClosePeer closes and removes all resources for a specific peer. +func (m *intraProxyManager) ClosePeer(peer string) { + m.streamsMu.Lock() + defer m.streamsMu.Unlock() + if ps, ok := m.peers[peer]; ok { + m.closePeerLocked(peer, ps) + } +} + +// ClosePeerShard closes resources for a specific peer/shard pair. +func (m *intraProxyManager) ClosePeerShard(peer string, clientShard, serverShard history.ClusterShardID) { + key := peerStreamKey{targetShard: clientShard, sourceShard: serverShard} + m.streamsMu.Lock() + defer m.streamsMu.Unlock() + if ps, ok := m.peers[peer]; ok { + m.closePeerShardLocked(peer, ps, key) + } +} + +func (m *intraProxyManager) Start() { + m.logger.Info("intraProxyManager starting") + defer m.logger.Info("intraProxyManager started") + go func() { + for { + // timer + timer := time.NewTimer(1 * time.Second) + select { + case <-timer.C: + m.ReconcilePeerStreams("") + case <-m.notifyCh: + m.ReconcilePeerStreams("") + } + } + }() +} + +func (m *intraProxyManager) Notify() { + select { + case m.notifyCh <- struct{}{}: + default: + } +} + +// ReconcilePeerStreams ensures receivers exist for desired (local shard, remote shard) pairs +// for a given peer and closes any sender/receiver not in the desired set. +// This mirrors the Temporal StreamReceiverMonitor approach. +func (m *intraProxyManager) ReconcilePeerStreams(peerNodeName string) { + m.logger.Info("ReconcilePeerStreams started", tag.NewStringTag("peerNodeName", peerNodeName)) + defer m.logger.Info("ReconcilePeerStreams done", tag.NewStringTag("peerNodeName", peerNodeName)) + + localShards := m.shardManager.GetLocalShards() + remoteShards, err := m.shardManager.GetRemoteShardsForPeer(peerNodeName) + if err != nil { + m.logger.Error("Failed to get remote shards for peer", tag.Error(err)) + return + } + m.logger.Info("ReconcilePeerStreams remote and local shards", + tag.NewStringTag("peerNodeName", peerNodeName), + tag.NewStringTag("remoteShards", fmt.Sprintf("%v", remoteShards)), + tag.NewStringTag("localShards", fmt.Sprintf("%v", localShards)), + ) + + // Build desiredReceivers receiver set of cross-cluster pairs + desiredReceivers := make(map[peerStreamKey]string) + for _, l := range localShards { + for peer, shards := range remoteShards { + for _, r := range shards.Shards { + if l.ClusterID == r.ID.ClusterID { + continue + } + desiredReceivers[peerStreamKey{targetShard: l, sourceShard: r.ID}] = peer + } + } + } + + // Build desiredSenders set: inverted direction of desiredReceivers + // Senders exist when remote shard is the target and local shard is the source + desiredSenders := make(map[peerStreamKey]string) + for _, l := range localShards { + for peer, shards := range remoteShards { + for _, r := range shards.Shards { + if l.ClusterID == r.ID.ClusterID { + continue + } + desiredSenders[peerStreamKey{targetShard: r.ID, sourceShard: l}] = peer + } + } + } + + m.logger.Info("ReconcilePeerStreams desired receivers and senders", tag.NewStringTag("desiredReceivers", fmt.Sprintf("%v", desiredReceivers)), tag.NewStringTag("desiredSenders", fmt.Sprintf("%v", desiredSenders))) + + // Ensure all desired receivers exist + for key := range desiredReceivers { + m.EnsureReceiverForPeerShard(desiredReceivers[key], key.targetShard, key.sourceShard) + } + + // Prune anything not desired + check := func(peer string, ps *peerState) { + // Collect keys to close for receivers + var receiversToClose []peerStreamKey + for key := range ps.receivers { + if _, ok2 := desiredReceivers[key]; !ok2 { + receiversToClose = append(receiversToClose, key) + } + } + for _, key := range receiversToClose { + m.closePeerShardLocked(peer, ps, key) + } + // Collect keys to close for senders + var sendersToClose []peerStreamKey + for key := range ps.senders { + if _, ok2 := desiredSenders[key]; !ok2 { + sendersToClose = append(sendersToClose, key) + } + } + for _, key := range sendersToClose { + m.closePeerShardLocked(peer, ps, key) + } + } + + m.streamsMu.Lock() + if peerNodeName != "" { + if ps, ok := m.peers[peerNodeName]; ok && ps != nil { + check(peerNodeName, ps) + } + } else { + for peer, ps := range m.peers { + check(peer, ps) + } + } + m.streamsMu.Unlock() +} diff --git a/proxy/proxy.go b/proxy/proxy.go index bba7ee9c..8541ef4e 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -7,6 +7,8 @@ import ( "net/http" "strings" + "go.temporal.io/server/api/adminservice/v1" + "go.temporal.io/server/client/history" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" @@ -20,6 +22,19 @@ type ( // Needs some config revision before uncommenting: //accountId string } + + // RoutedAck wraps an ACK with the target shard it originated from + RoutedAck struct { + TargetShard history.ClusterShardID + Req *adminservice.StreamWorkflowReplicationMessagesRequest + } + + // RoutedMessage wraps a replication response with originating client shard info + RoutedMessage struct { + SourceShard history.ClusterShardID + Resp *adminservice.StreamWorkflowReplicationMessagesResponse + } + Proxy struct { lifetime context.Context cancel context.CancelFunc @@ -54,13 +69,15 @@ func NewProxy(configProvider config.ConfigProvider, logger log.Logger) *Proxy { if s2sConfig.Metrics != nil { proxy.metricsConfig = s2sConfig.Metrics } + for _, clusterCfg := range s2sConfig.ClusterConnections { cc, err := NewClusterConnection(ctx, clusterCfg, logger) if err != nil { logger.Fatal("Incorrectly configured Mux cluster connection", tag.Error(err), tag.NewStringTag("name", clusterCfg.Name)) continue } - proxy.clusterConnections[migrationId{clusterCfg.Name}] = cc + migrationId := migrationId{clusterCfg.Name} + proxy.clusterConnections[migrationId] = cc } // TODO: correctly host multiple health checks if len(s2sConfig.ClusterConnections) > 0 && s2sConfig.ClusterConnections[0].InboundHealthCheck.ListenAddress != "" { diff --git a/proxy/proxy_streams.go b/proxy/proxy_streams.go new file mode 100644 index 00000000..8ae63597 --- /dev/null +++ b/proxy/proxy_streams.go @@ -0,0 +1,1061 @@ +package proxy + +import ( + "context" + "fmt" + "io" + "sync" + "time" + + "go.temporal.io/server/api/adminservice/v1" + replicationv1 "go.temporal.io/server/api/replication/v1" + "go.temporal.io/server/client/history" + servercommon "go.temporal.io/server/common" + "go.temporal.io/server/common/channel" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" +) + +// proxyIDMapping stores the original source shard and task for a given proxy task ID +// Entries are kept in strictly increasing proxyID order. +type proxyIDMapping struct { + sourceShard history.ClusterShardID + sourceTask int64 +} + +// proxyIDRingBuffer is a dynamically growing ring buffer keyed by monotonically increasing proxy IDs. +// It supports O(1) append and O(k) pop up to a given watermark, while preserving insertion order. +type proxyIDRingBuffer struct { + entries []proxyIDMapping + head int + size int + maxSize int // Maximum size ever reached + startProxyID int64 // proxyID of the current head element when size > 0 +} + +func newProxyIDRingBuffer(capacity int) *proxyIDRingBuffer { + if capacity < 1 { + capacity = 1 + } + return &proxyIDRingBuffer{entries: make([]proxyIDMapping, capacity)} +} + +// ensureCapacity grows the buffer if it is full, preserving order. +func (b *proxyIDRingBuffer) ensureCapacity() { + if b.size < len(b.entries) { + return + } + newCap := len(b.entries) * 2 + if newCap == 0 { + newCap = 1 + } + newEntries := make([]proxyIDMapping, newCap) + // copy existing elements in order starting from head + for i := 0; i < b.size; i++ { + idx := (b.head + i) % len(b.entries) + newEntries[i] = b.entries[idx] + } + b.entries = newEntries + b.head = 0 +} + +// Append appends a mapping for the given proxyID. ProxyIDs must be strictly increasing and contiguous. +func (b *proxyIDRingBuffer) Append(proxyID int64, sourceShard history.ClusterShardID, sourceTask int64) { + b.ensureCapacity() + if b.size == 0 { + b.startProxyID = proxyID + } else { + // Maintain contiguity: next proxyID must be startProxyID + size + expected := b.startProxyID + int64(b.size) + if proxyID != expected { + // If contiguity is violated, grow holes by inserting empty mappings until aligned. + // In practice proxyID is always increasing by 1, so this branch should not trigger. + for expected < proxyID { + b.ensureCapacity() + pos := (b.head + b.size) % len(b.entries) + b.entries[pos] = proxyIDMapping{sourceShard: history.ClusterShardID{}, sourceTask: 0} + b.size++ + if b.size > b.maxSize { + b.maxSize = b.size + } + expected++ + } + } + } + pos := (b.head + b.size) % len(b.entries) + b.entries[pos] = proxyIDMapping{sourceShard: sourceShard, sourceTask: sourceTask} + b.size++ + if b.size > b.maxSize { + b.maxSize = b.size + } +} + +// AggregateUpTo computes the per-shard aggregation up to watermark without removing entries. +// Returns (aggregation, count) where count is the number of entries covered. +func (b *proxyIDRingBuffer) AggregateUpTo(watermark int64) (map[history.ClusterShardID]int64, int) { + result := make(map[history.ClusterShardID]int64) + if b.size == 0 { + return result, 0 + } + if watermark < b.startProxyID { + return result, 0 + } + count64 := watermark - b.startProxyID + 1 + if count64 <= 0 { + return result, 0 + } + count := int(count64) + if count > b.size { + count = b.size + } + for i := 0; i < count; i++ { + idx := (b.head + i) % len(b.entries) + m := b.entries[idx] + if m.sourceShard.ClusterID == 0 && m.sourceShard.ShardID == 0 { + continue + } + if current, ok := result[m.sourceShard]; !ok || m.sourceTask > current { + result[m.sourceShard] = m.sourceTask + } + } + return result, count +} + +// Discard advances the head by count entries, effectively removing them. +func (b *proxyIDRingBuffer) Discard(count int) { + if count <= 0 { + return + } + if count > b.size { + count = b.size + } + b.head = (b.head + count) % len(b.entries) + b.size -= count + b.startProxyID += int64(count) +} + +// proxyStreamSender is responsible for sending replication messages to the next hop +// (another proxy or a target server) and receiving ACKs back. +// This is scaffolding only – the concrete behavior will be wired in later. +type proxyStreamSender struct { + logger log.Logger + shardManager ShardManager + targetShardID history.ClusterShardID + sourceShardID history.ClusterShardID + directionLabel string + streamID string + streamTracker *StreamTracker + // sendMsgChan carries replication messages to be sent to the remote side. + sendMsgChan chan RoutedMessage + + mu sync.Mutex + nextProxyTaskID int64 + idRing *proxyIDRingBuffer + // prevAckBySource tracks the last ack level sent per original source shard + prevAckBySource map[history.ClusterShardID]int64 + // keepalive state + lastMsgSendTime time.Time + lastSentWatermark int64 +} + +// buildSenderDebugSnapshot returns a snapshot of the sender's ring buffer and related state +func (s *proxyStreamSender) buildSenderDebugSnapshot(maxEntries int) *SenderDebugInfo { + s.mu.Lock() + defer s.mu.Unlock() + + info := &SenderDebugInfo{ + PrevAckBySource: make(map[string]int64), + } + + info.NextProxyTaskID = s.nextProxyTaskID + + for k, v := range s.prevAckBySource { + info.PrevAckBySource[ClusterShardIDtoString(k)] = v + } + + if s.idRing != nil { + info.RingStartProxyID = s.idRing.startProxyID + info.RingSize = s.idRing.size + info.RingMaxSize = s.idRing.maxSize + info.RingCapacity = len(s.idRing.entries) + info.RingHead = s.idRing.head + + // Build entries preview + if maxEntries <= 0 { + maxEntries = 20 + } + limit := s.idRing.size + if limit > maxEntries { + limit = maxEntries + } + info.EntriesPreview = make([]ProxyIDEntry, 0, limit) + for i := 0; i < limit; i++ { + idx := (s.idRing.head + i) % len(s.idRing.entries) + e := s.idRing.entries[idx] + info.EntriesPreview = append(info.EntriesPreview, ProxyIDEntry{ + ProxyID: s.idRing.startProxyID + int64(i), + SourceShard: ClusterShardIDtoString(e.sourceShard), + SourceTask: e.sourceTask, + }) + } + } + + return info +} + +func (s *proxyStreamSender) Run( + sourceStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + shutdownChan channel.ShutdownOnce, +) { + s.streamID = BuildSenderStreamID(s.sourceShardID, s.targetShardID) + s.logger = log.With(s.logger, tag.NewStringTag("streamID", s.streamID), tag.NewStringTag("role", "sender")) + + s.logger.Info("proxyStreamSender Run") + defer s.logger.Info("proxyStreamSender Run finished") + + s.streamTracker = GetGlobalStreamTracker() + s.streamTracker.RegisterStream( + s.streamID, + "StreamWorkflowReplicationMessages", + s.directionLabel, + ClusterShardIDtoString(s.sourceShardID), + ClusterShardIDtoString(s.targetShardID), + StreamRoleSender, + ) + defer s.streamTracker.UnregisterStream(s.streamID) + + // lazy init maps + s.mu.Lock() + if s.idRing == nil { + s.idRing = newProxyIDRingBuffer(1024) + } + if s.prevAckBySource == nil { + s.prevAckBySource = make(map[history.ClusterShardID]int64) + } + s.mu.Unlock() + + // Register remote send channel for this shard so receiver can forward tasks locally + s.sendMsgChan = make(chan RoutedMessage, 100) + + s.shardManager.SetRemoteSendChan(s.targetShardID, s.sendMsgChan) + defer s.shardManager.RemoveRemoteSendChan(s.targetShardID, s.sendMsgChan) + + registeredAt := s.shardManager.RegisterShard(s.targetShardID) + defer s.shardManager.UnregisterShard(s.targetShardID, registeredAt) + + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + err := s.sendReplicationMessages(sourceStreamServer, shutdownChan) + if err != nil { + s.logger.Error("proxyStreamSender sendReplicationMessages error", tag.Error(err)) + } + }() + go func() { + defer wg.Done() + err := s.recvAck(sourceStreamServer, shutdownChan) + if err != nil { + s.logger.Error("proxyStreamSender recvAck error", tag.Error(err)) + } + }() + // Wait for shutdown signal (triggered by receiver or stream errors) + <-shutdownChan.Channel() + // Ensure send loop exits promptly + close(s.sendMsgChan) + // Do not block waiting for ack goroutine; it will terminate when stream ends +} + +// recvAck receives ACKs from the remote side and forwards them to the provided +// channel for aggregation/routing. Non-blocking shutdown is coordinated via +// shutdownChan. This is a placeholder implementation. +func (s *proxyStreamSender) recvAck( + sourceStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + shutdownChan channel.ShutdownOnce, +) error { + s.logger.Info("proxyStreamSender recvAck started") + defer func() { + s.logger.Info("proxyStreamSender recvAck finished") + shutdownChan.Shutdown() + }() + for !shutdownChan.IsShutdown() { + req, err := sourceStreamServer.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + + // Unmap proxy task IDs back to original source shard/task and ACK by source shard + if attr, ok := req.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState); ok && attr.SyncReplicationState != nil { + proxyAckWatermark := attr.SyncReplicationState.InclusiveLowWatermark + + // track sync watermark + s.streamTracker.UpdateStreamSyncReplicationState(s.streamID, proxyAckWatermark, nil) + s.streamTracker.UpdateStream(s.streamID) + + s.mu.Lock() + shardToAck, pendingDiscard := s.idRing.AggregateUpTo(proxyAckWatermark) + s.mu.Unlock() + + s.logger.Info("Sender received upstream ACK", tag.NewInt64("inclusive_low", proxyAckWatermark), tag.NewStringTag("shardToAck", fmt.Sprintf("%v", shardToAck)), tag.NewInt("pendingDiscard", pendingDiscard)) + + if len(shardToAck) > 0 { + sent := make(map[history.ClusterShardID]bool, len(shardToAck)) + logged := make(map[history.ClusterShardID]bool, len(shardToAck)) + numRemaining := len(shardToAck) + backoff := 10 * time.Millisecond + for numRemaining > 0 { + select { + case <-shutdownChan.Channel(): + return nil + default: + } + progress := false + for srcShard, originalAck := range shardToAck { + if sent[srcShard] { + continue + } + routedAck := &RoutedAck{ + TargetShard: s.targetShardID, + Req: &adminservice.StreamWorkflowReplicationMessagesRequest{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState{ + SyncReplicationState: &replicationv1.SyncReplicationState{ + InclusiveLowWatermark: originalAck, + InclusiveLowWatermarkTime: attr.SyncReplicationState.InclusiveLowWatermarkTime, + }, + }, + }, + } + + s.logger.Info("Sender forwarding ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(srcShard)), tag.NewInt64("ack", originalAck)) + + if s.shardManager.DeliverAckToShardOwner(srcShard, routedAck, shutdownChan, s.logger, originalAck, true) { + sent[srcShard] = true + numRemaining-- + progress = true + // record last ack per source shard after forwarding + s.mu.Lock() + s.prevAckBySource[srcShard] = originalAck + s.mu.Unlock() + } else if !logged[srcShard] { + s.logger.Warn("No local ack channel for source shard; retrying until available", tag.NewStringTag("shard", ClusterShardIDtoString(srcShard))) + logged[srcShard] = true + } + } + if !progress { + time.Sleep(backoff) + if backoff < time.Second { + backoff *= 2 + } + } else if backoff > 10*time.Millisecond { + backoff = 10 * time.Millisecond + } + } + + // TODO: ack to idle shards using prevAckBySource + + } else { + // No new shards to ACK: send previous ack levels per source shard (if known) + s.mu.Lock() + pendingPrev := make(map[history.ClusterShardID]int64, len(s.prevAckBySource)) + for srcShard, prev := range s.prevAckBySource { + pendingPrev[srcShard] = prev + } + s.mu.Unlock() + + sent := make(map[history.ClusterShardID]bool, len(pendingPrev)) + logged := make(map[history.ClusterShardID]bool, len(pendingPrev)) + numRemaining := len(pendingPrev) + backoff := 10 * time.Millisecond + for numRemaining > 0 { + select { + case <-shutdownChan.Channel(): + return nil + default: + } + progress := false + for srcShard, prev := range pendingPrev { + if sent[srcShard] { + continue + } + routedAck := &RoutedAck{ + TargetShard: s.targetShardID, + Req: &adminservice.StreamWorkflowReplicationMessagesRequest{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState{ + SyncReplicationState: &replicationv1.SyncReplicationState{ + InclusiveLowWatermark: prev, + InclusiveLowWatermarkTime: attr.SyncReplicationState.InclusiveLowWatermarkTime, + }, + }, + }, + } + // Log fallback ACK for this source shard + s.logger.Info("Sender forwarding fallback ACK to source shard", tag.NewStringTag("sourceShard", ClusterShardIDtoString(srcShard)), tag.NewInt64("ack", prev)) + if s.shardManager.DeliverAckToShardOwner(srcShard, routedAck, shutdownChan, s.logger, prev, true) { + sent[srcShard] = true + numRemaining-- + progress = true + } else if !logged[srcShard] { + s.logger.Warn("No local ack channel for source shard; retrying until available", tag.NewStringTag("shard", ClusterShardIDtoString(srcShard))) + logged[srcShard] = true + } + } + if !progress { + time.Sleep(backoff) + if backoff < time.Second { + backoff *= 2 + } + } else if backoff > 10*time.Millisecond { + backoff = 10 * time.Millisecond + } + } + } + + // Only after forwarding ACKs, discard the entries from the ring buffer + if pendingDiscard > 0 { + s.mu.Lock() + s.idRing.Discard(pendingDiscard) + s.mu.Unlock() + } + + // Update debug snapshot after ack processing + s.streamTracker.UpdateStreamSenderDebug(s.streamID, s.buildSenderDebugSnapshot(20)) + } + } + return nil +} + +// sendReplicationMessages sends replication messages read from sendMsgChan to +// the remote side. This is a placeholder implementation. +func (s *proxyStreamSender) sendReplicationMessages( + sourceStreamServer adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + shutdownChan channel.ShutdownOnce, +) error { + s.logger.Info("proxyStreamSender sendReplicationMessages started") + defer func() { + s.logger.Info("proxyStreamSender sendReplicationMessages finished") + shutdownChan.Shutdown() + }() + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for !shutdownChan.IsShutdown() { + if s.sendMsgChan == nil { + return nil + } + select { + case routed, ok := <-s.sendMsgChan: + if !ok { + return nil + } + s.logger.Info(fmt.Sprintf("Sender received ReplicationTasks: routed.Resp=%p", routed.Resp), tag.NewStringTag("routed", fmt.Sprintf("%v", routed))) + resp := routed.Resp + m, ok := resp.Attributes.(*adminservice.StreamWorkflowReplicationMessagesResponse_Messages) + if !ok || m.Messages == nil { + return nil + } + + sourceTaskIds := make([]int64, 0, len(m.Messages.ReplicationTasks)) + for _, t := range m.Messages.ReplicationTasks { + sourceTaskIds = append(sourceTaskIds, t.SourceTaskId) + } + s.logger.Info(fmt.Sprintf("Sender received ReplicationTasks: exclusive_high=%d ids=%v", m.Messages.ExclusiveHighWatermark, sourceTaskIds)) + + // rewrite task ids + s.mu.Lock() + var originalIDs []int64 + var proxyIDs []int64 + // capture original exclusive high watermark before rewriting + originalHigh := m.Messages.ExclusiveHighWatermark + s.logger.Info(fmt.Sprintf("Sender received ReplicationTasks: exclusive_high=%d original_high=%d", m.Messages.ExclusiveHighWatermark, originalHigh)) + // Ensure exclusive high watermark is in proxy task ID space + var proxyExclusiveHigh int64 + if len(m.Messages.ReplicationTasks) > 0 { + for _, t := range m.Messages.ReplicationTasks { + // allocate proxy task id + s.nextProxyTaskID++ + proxyID := s.nextProxyTaskID + // remember original + original := t.SourceTaskId + s.idRing.Append(proxyID, routed.SourceShard, original) + // rewrite id + t.SourceTaskId = proxyID + if t.RawTaskInfo != nil { + t.RawTaskInfo.TaskId = proxyID + } + originalIDs = append(originalIDs, original) + proxyIDs = append(proxyIDs, proxyID) + } + proxyExclusiveHigh = m.Messages.ReplicationTasks[len(m.Messages.ReplicationTasks)-1].SourceTaskId + 1 + m.Messages.ExclusiveHighWatermark = proxyExclusiveHigh + } else { + // No tasks in this batch: allocate a synthetic proxy task id mapping + s.nextProxyTaskID++ + proxyHigh := s.nextProxyTaskID + s.idRing.Append(proxyHigh, routed.SourceShard, originalHigh) + originalIDs = append(originalIDs, originalHigh) + proxyIDs = append(proxyIDs, proxyHigh) + proxyExclusiveHigh = proxyHigh + m.Messages.ExclusiveHighWatermark = proxyExclusiveHigh + s.logger.Info(fmt.Sprintf("Sender received ReplicationTasks: exclusive_high=%d original_high=%d proxy_high=%d original", proxyExclusiveHigh, originalHigh, proxyHigh)) + } + s.mu.Unlock() + // Log mapping from original -> proxy IDs (use captured value to avoid data race) + s.logger.Info(fmt.Sprintf("Sender sending ReplicationTasks from shard %s: original=%v proxy=%v", ClusterShardIDtoString(routed.SourceShard), originalIDs, proxyIDs), tag.NewInt64("exclusive_high", proxyExclusiveHigh)) + + if err := sourceStreamServer.Send(resp); err != nil { + return err + } + s.logger.Info("Sender sent ReplicationTasks", tag.NewStringTag("sourceShard", ClusterShardIDtoString(routed.SourceShard)), tag.NewInt64("exclusive_high", proxyExclusiveHigh)) + + // Update keepalive state + s.mu.Lock() + s.lastMsgSendTime = time.Now() + s.lastSentWatermark = m.Messages.ExclusiveHighWatermark + s.mu.Unlock() + + s.streamTracker.UpdateStreamLastTaskIDs(s.streamID, sourceTaskIds) + s.streamTracker.UpdateStreamReplicationMessages(s.streamID, m.Messages.ExclusiveHighWatermark) + s.streamTracker.UpdateStreamSenderDebug(s.streamID, s.buildSenderDebugSnapshot(20)) + s.streamTracker.UpdateStream(s.streamID) + case <-ticker.C: + // Send keepalive if idle for 1 second + s.mu.Lock() + shouldSendKeepalive := s.lastSentWatermark > 0 && time.Since(s.lastMsgSendTime) >= 1*time.Second + watermark := s.lastSentWatermark + s.mu.Unlock() + + if shouldSendKeepalive { + keepaliveResp := &adminservice.StreamWorkflowReplicationMessagesResponse{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ + Messages: &replicationv1.WorkflowReplicationMessages{ + ReplicationTasks: []*replicationv1.ReplicationTask{}, + ExclusiveHighWatermark: watermark, + }, + }, + } + s.logger.Info("Sender sending keepalive message", tag.NewInt64("watermark", watermark)) + if err := sourceStreamServer.Send(keepaliveResp); err != nil { + return err + } + s.mu.Lock() + s.lastMsgSendTime = time.Now() + s.mu.Unlock() + } + case <-shutdownChan.Channel(): + return nil + } + } + return nil +} + +// proxyStreamReceiver receives replication messages from a local/remote server and +// produces ACKs destined for the original sender. +type proxyStreamReceiver struct { + logger log.Logger + shardManager ShardManager + adminClient adminservice.AdminServiceClient + localShardCount int32 + targetShardID history.ClusterShardID + sourceShardID history.ClusterShardID + directionLabel string + ackChan chan RoutedAck + // ack aggregation across target shards + ackByTarget map[history.ClusterShardID]int64 + lastSentMin int64 + // lastExclusiveHighOriginal tracks last exclusive high watermark seen from source (original id space) + lastExclusiveHighOriginal int64 + streamID string + streamTracker *StreamTracker + // keepalive state + ackMu sync.RWMutex + lastAckSendTime time.Time + lastSentAck *adminservice.StreamWorkflowReplicationMessagesRequest + // lastWatermark tracks the last watermark received from source shard for late-registering target shards + lastWatermarkMu sync.RWMutex + lastWatermark *replicationv1.WorkflowReplicationMessages +} + +// buildReceiverDebugSnapshot builds receiver ACK aggregation state for debugging +func (r *proxyStreamReceiver) buildReceiverDebugSnapshot() *ReceiverDebugInfo { + r.ackMu.RLock() + defer r.ackMu.RUnlock() + info := &ReceiverDebugInfo{ + AckByTarget: make(map[string]int64), + } + for k, v := range r.ackByTarget { + info.AckByTarget[ClusterShardIDtoString(k)] = v + } + info.LastAggregatedMin = r.lastSentMin + info.LastExclusiveHighOriginal = r.lastExclusiveHighOriginal + return info +} + +func (r *proxyStreamReceiver) Run( + shutdownChan channel.ShutdownOnce, +) { + // Terminate any previous local receiver for this shard + if r.shardManager != nil { + r.shardManager.TerminatePreviousLocalReceiver(r.sourceShardID, r.logger) + } + + r.streamID = BuildReceiverStreamID(r.sourceShardID, r.targetShardID) + r.logger = log.With(r.logger, + tag.NewStringTag("streamID", r.streamID), + tag.NewStringTag("source", ClusterShardIDtoString(r.sourceShardID)), + tag.NewStringTag("target", ClusterShardIDtoString(r.targetShardID)), + tag.NewStringTag("role", "receiver"), + ) + r.logger.Info("proxyStreamReceiver Run") + defer r.logger.Info("proxyStreamReceiver Run finished") + + // Build metadata for local server stream + md := metadata.New(map[string]string{}) + md.Set(history.MetadataKeyClientClusterID, fmt.Sprintf("%d", r.targetShardID.ClusterID)) + md.Set(history.MetadataKeyClientShardID, fmt.Sprintf("%d", r.targetShardID.ShardID)) + md.Set(history.MetadataKeyServerClusterID, fmt.Sprintf("%d", r.sourceShardID.ClusterID)) + md.Set(history.MetadataKeyServerShardID, fmt.Sprintf("%d", r.sourceShardID.ShardID)) + + outgoingContext := metadata.NewOutgoingContext(context.Background(), md) + outgoingContext, cancel := context.WithCancel(outgoingContext) + defer cancel() + + r.logger.Info("proxyStreamReceiver outgoingContext created") + + // Open stream receiver -> local server's stream sender for clientShardID + var sourceStreamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient + var err error + sourceStreamClient, err = r.adminClient.StreamWorkflowReplicationMessages(outgoingContext) + if err != nil { + r.logger.Error("adminClient.StreamWorkflowReplicationMessages error", tag.Error(err)) + return + } + + r.logger.Info("proxyStreamReceiver sourceStreamClient created") + + // Setup ack channel and cancel func bookkeeping + r.ackChan = make(chan RoutedAck, 100) + if r.shardManager != nil { + r.shardManager.SetLocalAckChan(r.sourceShardID, r.ackChan) + r.shardManager.SetLocalReceiverCancelFunc(r.sourceShardID, cancel) + // Register receiver for watermark propagation to late-registering shards + r.shardManager.RegisterActiveReceiver(r.sourceShardID, r) + defer func() { + r.shardManager.RemoveLocalAckChan(r.sourceShardID, r.ackChan) + r.shardManager.RemoveLocalReceiverCancelFunc(r.sourceShardID) + r.shardManager.UnregisterActiveReceiver(r.sourceShardID) + }() + } + + // init aggregation state + r.ackByTarget = make(map[history.ClusterShardID]int64) + r.lastSentMin = 0 + + // Register a new local stream for tracking (short id, include role) + r.streamTracker = GetGlobalStreamTracker() + r.streamTracker.RegisterStream( + r.streamID, + "StreamWorkflowReplicationMessages", + r.directionLabel, + ClusterShardIDtoString(r.sourceShardID), + ClusterShardIDtoString(r.targetShardID), + StreamRoleReceiver, + ) + defer r.streamTracker.UnregisterStream(r.streamID) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer func() { + shutdownChan.Shutdown() + wg.Done() + }() + _ = r.recvReplicationMessages(sourceStreamClient, shutdownChan) + }() + + go func() { + defer func() { + shutdownChan.Shutdown() + _ = sourceStreamClient.CloseSend() + wg.Done() + }() + _ = r.sendAck(sourceStreamClient, shutdownChan) + }() + + wg.Wait() +} + +// recvReplicationMessages receives from local server and routes to target shard owners. +func (r *proxyStreamReceiver) recvReplicationMessages( + sourceStreamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient, + shutdownChan channel.ShutdownOnce, +) error { + r.logger.Info("proxyStreamReceiver recvReplicationMessages started") + defer r.logger.Info("proxyStreamReceiver recvReplicationMessages finished") + + for !shutdownChan.IsShutdown() { + resp, err := sourceStreamClient.Recv() + if err == io.EOF { + r.logger.Info("sourceStreamClient.Recv encountered EOF", tag.Error(err)) + return nil + } + if err != nil { + r.logger.Error("sourceStreamClient.Recv encountered error", tag.Error(err)) + return err + } + + if attr, ok := resp.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesResponse_Messages); ok && attr.Messages != nil { + // Group by recalculated target shard using namespace/workflow hash + tasksByTargetShard := make(map[history.ClusterShardID][]*replicationv1.ReplicationTask) + ids := make([]int64, 0, len(attr.Messages.ReplicationTasks)) + for _, task := range attr.Messages.ReplicationTasks { + if task.RawTaskInfo != nil && task.RawTaskInfo.NamespaceId != "" && task.RawTaskInfo.WorkflowId != "" { + targetShard := servercommon.WorkflowIDToHistoryShard(task.RawTaskInfo.NamespaceId, task.RawTaskInfo.WorkflowId, r.localShardCount) + targetClusterShard := history.ClusterShardID{ClusterID: r.targetShardID.ClusterID, ShardID: targetShard} + tasksByTargetShard[targetClusterShard] = append(tasksByTargetShard[targetClusterShard], task) + ids = append(ids, task.SourceTaskId) + } + } + + // Log every replication task id received at receiver + r.logger.Info(fmt.Sprintf("Receiver received ReplicationTasks: exclusive_high=%d ids=%v", attr.Messages.ExclusiveHighWatermark, ids)) + + // record last source exclusive high watermark (original id space) + r.ackMu.Lock() + r.lastExclusiveHighOriginal = attr.Messages.ExclusiveHighWatermark + r.ackMu.Unlock() + + // update tracker for incoming messages + if r.streamTracker != nil && r.streamID != "" { + r.streamTracker.UpdateStreamLastTaskIDs(r.streamID, ids) + r.streamTracker.UpdateStreamReplicationMessages(r.streamID, attr.Messages.ExclusiveHighWatermark) + r.streamTracker.UpdateStreamReceiverDebug(r.streamID, r.buildReceiverDebugSnapshot()) + r.streamTracker.UpdateStream(r.streamID) + } + + // If replication tasks are empty, still log the empty batch and send watermark + if len(attr.Messages.ReplicationTasks) == 0 { + r.logger.Info("Receiver received empty replication batch", tag.NewInt64("exclusive_high", attr.Messages.ExclusiveHighWatermark)) + + // Track last watermark for late-registering shards + r.lastWatermarkMu.Lock() + r.lastWatermark = &replicationv1.WorkflowReplicationMessages{ + ExclusiveHighWatermark: attr.Messages.ExclusiveHighWatermark, + Priority: attr.Messages.Priority, + } + r.lastWatermarkMu.Unlock() + + msg := RoutedMessage{ + SourceShard: r.sourceShardID, + Resp: &adminservice.StreamWorkflowReplicationMessagesResponse{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ + Messages: &replicationv1.WorkflowReplicationMessages{ + ExclusiveHighWatermark: attr.Messages.ExclusiveHighWatermark, + Priority: attr.Messages.Priority, + }, + }, + }, + } + localShardsToSend := r.shardManager.GetRemoteSendChansByCluster(r.targetShardID.ClusterID) + r.logger.Info("Going to broadcast high watermark to local shards", tag.NewStringTag("localShardsToSend", fmt.Sprintf("%v", localShardsToSend))) + for targetShardID, sendChan := range localShardsToSend { + // Clone the message for each recipient to prevent shared mutation + clonedResp := proto.Clone(msg.Resp).(*adminservice.StreamWorkflowReplicationMessagesResponse) + clonedMsg := RoutedMessage{ + SourceShard: msg.SourceShard, + Resp: clonedResp, + } + r.logger.Info(fmt.Sprintf("Sending high watermark to target shard, msg.Resp=%p", clonedMsg.Resp), tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID)), tag.NewInt64("exclusive_high", attr.Messages.ExclusiveHighWatermark), tag.NewStringTag("msg", fmt.Sprintf("%v", clonedMsg))) + // Use non-blocking send with recover to handle closed channels + func() { + defer func() { + if panicErr := recover(); panicErr != nil { + // Channel was closed while we were trying to send + r.logger.Warn("Failed to send high watermark to target shard (channel closed)", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID)), + tag.NewInt64("exclusive_high", attr.Messages.ExclusiveHighWatermark)) + } + }() + select { + case sendChan <- clonedMsg: + // Message sent successfully + default: + // Channel is full or closed, log and skip + r.logger.Warn("Failed to send high watermark to target shard (channel full)", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID)), + tag.NewInt64("exclusive_high", attr.Messages.ExclusiveHighWatermark)) + } + }() + } + // send to all remote shards on other nodes as well + remoteShards, err := r.shardManager.GetRemoteShardsForPeer("") + if err != nil { + r.logger.Error("Failed to get remote shards", tag.Error(err)) + return err + } + r.logger.Info("Going to broadcast high watermark to remote shards", tag.NewStringTag("remoteShards", fmt.Sprintf("%v", remoteShards))) + for _, shards := range remoteShards { + for _, shard := range shards.Shards { + if shard.ID.ClusterID != r.targetShardID.ClusterID { + continue + } + // Clone the message for each remote recipient + clonedResp := proto.Clone(msg.Resp).(*adminservice.StreamWorkflowReplicationMessagesResponse) + clonedMsg := RoutedMessage{ + SourceShard: msg.SourceShard, + Resp: clonedResp, + } + if !r.shardManager.DeliverMessagesToShardOwner(shard.ID, &clonedMsg, shutdownChan, r.logger) { + r.logger.Warn("Failed to send ReplicationTasks to remote shard", tag.NewStringTag("shard", ClusterShardIDtoString(shard.ID))) + } + } + } + continue + } + + // Retry across the whole target set until all sends succeed (or shutdown) + sentByTarget := make(map[history.ClusterShardID]bool, len(tasksByTargetShard)) + loggedByTarget := make(map[history.ClusterShardID]bool, len(tasksByTargetShard)) + for targetShardID := range tasksByTargetShard { + sentByTarget[targetShardID] = false + } + r.logger.Info("Going to broadcast ReplicationTasks to target shards", tag.NewStringTag("tasksByTargetShard", fmt.Sprintf("%v", tasksByTargetShard))) + numRemaining := len(tasksByTargetShard) + backoff := 10 * time.Millisecond + for numRemaining > 0 { + select { + case <-shutdownChan.Channel(): + return nil + default: + } + progress := false + for targetShardID, tasks := range tasksByTargetShard { + if sentByTarget[targetShardID] { + continue + } + msg := RoutedMessage{ + SourceShard: r.sourceShardID, + Resp: &adminservice.StreamWorkflowReplicationMessagesResponse{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ + Messages: &replicationv1.WorkflowReplicationMessages{ + ReplicationTasks: tasks, + ExclusiveHighWatermark: tasks[len(tasks)-1].RawTaskInfo.TaskId + 1, + Priority: attr.Messages.Priority, + }, + }, + }, + } + if r.shardManager.DeliverMessagesToShardOwner(targetShardID, &msg, shutdownChan, r.logger) { + sentByTarget[targetShardID] = true + numRemaining-- + progress = true + } else { + if !loggedByTarget[targetShardID] { + r.logger.Warn("No send channel found for target shard; retrying until available", tag.NewStringTag("task-target-shard", ClusterShardIDtoString(targetShardID))) + loggedByTarget[targetShardID] = true + } + } + } + if !progress { + time.Sleep(backoff) + if backoff < time.Second { + backoff *= 2 + } + } else if backoff > 10*time.Millisecond { + backoff = 10 * time.Millisecond + } + } + } + } + return nil +} + +// GetTargetShardID returns the target shard ID for this receiver +func (r *proxyStreamReceiver) GetTargetShardID() history.ClusterShardID { + return r.targetShardID +} + +// GetSourceShardID returns the source shard ID for this receiver +func (r *proxyStreamReceiver) GetSourceShardID() history.ClusterShardID { + return r.sourceShardID +} + +// GetLastWatermark returns the last watermark received from the source shard +func (r *proxyStreamReceiver) GetLastWatermark() *replicationv1.WorkflowReplicationMessages { + r.lastWatermarkMu.RLock() + defer r.lastWatermarkMu.RUnlock() + return r.lastWatermark +} + +// NotifyNewTargetShard notifies the receiver about a newly registered target shard +func (r *proxyStreamReceiver) NotifyNewTargetShard(targetShardID history.ClusterShardID) { + r.sendPendingWatermarkToShard(targetShardID) +} + +// sendPendingWatermarkToShard sends the last known watermark to a newly registered target shard +// This ensures late-registering shards receive watermarks that were sent before they registered +func (r *proxyStreamReceiver) sendPendingWatermarkToShard(targetShardID history.ClusterShardID) { + r.lastWatermarkMu.RLock() + lastWatermark := r.lastWatermark + r.lastWatermarkMu.RUnlock() + + if lastWatermark == nil || lastWatermark.ExclusiveHighWatermark == 0 { + // No pending watermark to send + return + } + + r.logger.Info("Sending pending watermark to newly registered shard", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID)), + tag.NewInt64("exclusive_high", lastWatermark.ExclusiveHighWatermark)) + + msg := RoutedMessage{ + SourceShard: r.sourceShardID, + Resp: &adminservice.StreamWorkflowReplicationMessagesResponse{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ + Messages: &replicationv1.WorkflowReplicationMessages{ + ExclusiveHighWatermark: lastWatermark.ExclusiveHighWatermark, + Priority: lastWatermark.Priority, + }, + }, + }, + } + + // Try to send to local shard first + if sendChan, exists := r.shardManager.GetRemoteSendChan(targetShardID); exists { + clonedResp := proto.Clone(msg.Resp).(*adminservice.StreamWorkflowReplicationMessagesResponse) + clonedMsg := RoutedMessage{ + SourceShard: msg.SourceShard, + Resp: clonedResp, + } + select { + case sendChan <- clonedMsg: + r.logger.Info("Sent pending watermark to local shard", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + default: + r.logger.Warn("Failed to send pending watermark to local shard (channel full)", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + } + return + } + + // If not local, try to send to remote shard + if r.shardManager != nil { + shutdownChan := channel.NewShutdownOnce() + clonedResp := proto.Clone(msg.Resp).(*adminservice.StreamWorkflowReplicationMessagesResponse) + clonedMsg := RoutedMessage{ + SourceShard: msg.SourceShard, + Resp: clonedResp, + } + if r.shardManager.DeliverMessagesToShardOwner(targetShardID, &clonedMsg, shutdownChan, r.logger) { + r.logger.Info("Sent pending watermark to remote shard", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + } else { + r.logger.Warn("Failed to send pending watermark to remote shard", + tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShardID))) + } + } +} + +// sendAck forwards ACKs from local ack channel upstream to the local server. +func (r *proxyStreamReceiver) sendAck( + sourceStreamClient adminservice.AdminService_StreamWorkflowReplicationMessagesClient, + shutdownChan channel.ShutdownOnce, +) error { + r.logger.Info("proxyStreamReceiver sendAck started") + defer r.logger.Info("proxyStreamReceiver sendAck finished") + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for !shutdownChan.IsShutdown() { + select { + case routed := <-r.ackChan: + // Update per-target watermark + if attr, ok := routed.Req.GetAttributes().(*adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState); ok && attr.SyncReplicationState != nil { + r.logger.Info("Receiver received upstream ACK", tag.NewInt64("inclusive_low", attr.SyncReplicationState.InclusiveLowWatermark), tag.NewStringTag("targetShard", ClusterShardIDtoString(routed.TargetShard))) + r.ackMu.Lock() + r.ackByTarget[routed.TargetShard] = attr.SyncReplicationState.InclusiveLowWatermark + // Compute minimal watermark across targets + min := int64(0) + first := true + for _, wm := range r.ackByTarget { + if first || wm < min { + min = wm + first = false + } + } + lastSentMin := r.lastSentMin + lastExclusiveHighOriginal := r.lastExclusiveHighOriginal + r.ackMu.Unlock() + if !first && min >= lastSentMin { + // Clamp ACK to last known exclusive high watermark from source + if lastExclusiveHighOriginal > 0 && min > lastExclusiveHighOriginal { + r.logger.Warn("Aggregated ACK exceeds last source high watermark; clamping", + tag.NewInt64("ack_min", min), + tag.NewInt64("source_exclusive_high", lastExclusiveHighOriginal)) + min = lastExclusiveHighOriginal + } + // Send aggregated minimal ack upstream + aggregated := &adminservice.StreamWorkflowReplicationMessagesRequest{ + Attributes: &adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState{ + SyncReplicationState: &replicationv1.SyncReplicationState{ + InclusiveLowWatermark: min, + }, + }, + } + r.logger.Info("Receiver sending aggregated ACK upstream", tag.NewInt64("inclusive_low", min)) + if err := sourceStreamClient.Send(aggregated); err != nil { + if err != io.EOF { + r.logger.Error("sourceStreamClient.Send encountered error", tag.Error(err)) + } else { + r.logger.Info("sourceStreamClient.Send encountered EOF", tag.Error(err)) + } + return err + } + // Track sync watermark for receiver stream + if r.streamTracker != nil && r.streamID != "" { + r.streamTracker.UpdateStreamSyncReplicationState(r.streamID, min, nil) + r.streamTracker.UpdateStream(r.streamID) + // Update receiver debug snapshot when we send an aggregated ACK + r.streamTracker.UpdateStreamReceiverDebug(r.streamID, r.buildReceiverDebugSnapshot()) + } + r.lastSentMin = min + + // Update keepalive state + r.ackMu.Lock() + r.lastAckSendTime = time.Now() + r.lastSentAck = aggregated + r.ackMu.Unlock() + } + } + case <-ticker.C: + // Send keepalive if idle for 1 second + r.ackMu.RLock() + shouldSendKeepalive := r.lastSentAck != nil && time.Since(r.lastAckSendTime) >= 1*time.Second + lastAck := r.lastSentAck + r.ackMu.RUnlock() + + if shouldSendKeepalive { + r.logger.Info("Receiver sending keepalive ACK") + if err := sourceStreamClient.Send(lastAck); err != nil { + if err != io.EOF { + r.logger.Error("sourceStreamClient.Send keepalive encountered error", tag.Error(err)) + } + return err + } + r.ackMu.Lock() + r.lastAckSendTime = time.Now() + r.ackMu.Unlock() + } + case <-shutdownChan.Channel(): + return nil + } + } + return nil +} diff --git a/proxy/shard_manager.go b/proxy/shard_manager.go new file mode 100644 index 00000000..d6c27c90 --- /dev/null +++ b/proxy/shard_manager.go @@ -0,0 +1,1345 @@ +package proxy + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "sync" + "time" + + "github.com/hashicorp/memberlist" + replicationv1 "go.temporal.io/server/api/replication/v1" + "go.temporal.io/server/client/history" + "go.temporal.io/server/common/channel" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" + + "github.com/temporalio/s2s-proxy/config" + "github.com/temporalio/s2s-proxy/encryption" +) + +type ( + // ActiveReceiver is an interface for receivers that can be notified of new target shards + ActiveReceiver interface { + GetTargetShardID() history.ClusterShardID + GetSourceShardID() history.ClusterShardID + NotifyNewTargetShard(targetShardID history.ClusterShardID) + GetLastWatermark() *replicationv1.WorkflowReplicationMessages + } + + // ShardManager manages distributed shard ownership across proxy instances + ShardManager interface { + // Start initializes the memberlist cluster and starts the manager + Start(lifetime context.Context) error + // Stop shuts down the manager and leaves the cluster + Stop() + // RegisterShard registers a clientShardID as owned by this proxy instance and returns the registration timestamp + RegisterShard(clientShardID history.ClusterShardID) time.Time + // UnregisterShard removes a clientShardID from this proxy's ownership only if the timestamp matches + UnregisterShard(clientShardID history.ClusterShardID, expectedRegisteredAt time.Time) + // GetProxyAddress returns the proxy service address for the given node name + GetProxyAddress(nodeName string) (string, bool) + // IsLocalShard checks if this proxy instance owns the given shard + IsLocalShard(clientShardID history.ClusterShardID) bool + // GetNodeName returns the name of this proxy instance + GetNodeName() string + // GetMemberNodes returns all active proxy nodes in the cluster + GetMemberNodes() []string + // GetLocalShards returns all shards currently handled by this proxy instance, keyed by short id + GetLocalShards() map[string]history.ClusterShardID + // GetRemoteShardsForPeer returns all shards owned by the specified peer node, keyed by short id + GetRemoteShardsForPeer(peerNodeName string) (map[string]NodeShardState, error) + // GetShardInfo returns debug information about shard distribution + GetShardInfo() ShardDebugInfo + // GetShardInfos returns debug information about shard distribution as a slice + GetShardInfos() []ShardDebugInfo + // GetChannelInfo returns debug information about active channels + GetChannelInfo() ChannelDebugInfo + // GetShardOwner returns the node name that owns the given shard + GetShardOwner(shard history.ClusterShardID) (string, bool) + // TerminatePreviousLocalReceiver checks if there is a previous local receiver for this shard and terminates it if needed + TerminatePreviousLocalReceiver(shardID history.ClusterShardID, logger log.Logger) + // GetIntraProxyManager returns the intra-proxy manager if it exists + GetIntraProxyManager() *intraProxyManager + // GetIntraProxyTLSConfig returns the TLS config for intra-proxy connections + GetIntraProxyTLSConfig() encryption.TLSConfig + // DeliverAckToShardOwner routes an ACK request to the appropriate shard owner (local or remote) + DeliverAckToShardOwner(srcShard history.ClusterShardID, routedAck *RoutedAck, shutdownChan channel.ShutdownOnce, logger log.Logger, ack int64, allowForward bool) bool + // DeliverMessagesToShardOwner routes replication messages to the appropriate shard owner (local or remote) + DeliverMessagesToShardOwner(targetShard history.ClusterShardID, routedMsg *RoutedMessage, shutdownChan channel.ShutdownOnce, logger log.Logger) bool + // SetOnPeerJoin registers a callback invoked when a new peer joins + SetOnPeerJoin(handler func(nodeName string)) + // SetOnPeerLeave registers a callback invoked when a peer leaves. + SetOnPeerLeave(handler func(nodeName string)) + // New: notify when local shard set changes + SetOnLocalShardChange(handler func(shard history.ClusterShardID, added bool)) + // New: notify when remote shard set changes for a peer + SetOnRemoteShardChange(handler func(peer string, shard history.ClusterShardID, added bool)) + // RegisterActiveReceiver registers an active receiver for watermark propagation + RegisterActiveReceiver(sourceShardID history.ClusterShardID, receiver ActiveReceiver) + // UnregisterActiveReceiver removes an active receiver + UnregisterActiveReceiver(sourceShardID history.ClusterShardID) + // GetActiveReceiver returns the active receiver for the given source shard + GetActiveReceiver(sourceShardID history.ClusterShardID) (ActiveReceiver, bool) + // SetRemoteSendChan registers a send channel for a specific shard ID + SetRemoteSendChan(shardID history.ClusterShardID, sendChan chan RoutedMessage) + // GetRemoteSendChan retrieves the send channel for a specific shard ID + GetRemoteSendChan(shardID history.ClusterShardID) (chan RoutedMessage, bool) + // GetAllRemoteSendChans returns a map of all remote send channels + GetAllRemoteSendChans() map[history.ClusterShardID]chan RoutedMessage + // GetRemoteSendChansByCluster returns a copy of remote send channels filtered by clusterID + GetRemoteSendChansByCluster(clusterID int32) map[history.ClusterShardID]chan RoutedMessage + // RemoveRemoteSendChan removes the send channel for a specific shard ID only if it matches the provided channel + RemoveRemoteSendChan(shardID history.ClusterShardID, expectedChan chan RoutedMessage) + // SetLocalAckChan registers an ack channel for a specific shard ID + SetLocalAckChan(shardID history.ClusterShardID, ackChan chan RoutedAck) + // GetLocalAckChan retrieves the ack channel for a specific shard ID + GetLocalAckChan(shardID history.ClusterShardID) (chan RoutedAck, bool) + // GetAllLocalAckChans returns a map of all local ack channels + GetAllLocalAckChans() map[history.ClusterShardID]chan RoutedAck + // RemoveLocalAckChan removes the ack channel for a specific shard ID only if it matches the provided channel + RemoveLocalAckChan(shardID history.ClusterShardID, expectedChan chan RoutedAck) + // ForceRemoveLocalAckChan unconditionally removes the ack channel for a specific shard ID + ForceRemoveLocalAckChan(shardID history.ClusterShardID) + // SetLocalReceiverCancelFunc registers a cancel function for a local receiver for a specific shard ID + SetLocalReceiverCancelFunc(shardID history.ClusterShardID, cancelFunc context.CancelFunc) + // GetLocalReceiverCancelFunc retrieves the cancel function for a local receiver for a specific shard ID + GetLocalReceiverCancelFunc(shardID history.ClusterShardID) (context.CancelFunc, bool) + // RemoveLocalReceiverCancelFunc unconditionally removes the cancel function for a local receiver for a specific shard ID + RemoveLocalReceiverCancelFunc(shardID history.ClusterShardID) + } + + shardManagerImpl struct { + memberlistConfig *config.MemberlistConfig + logger log.Logger + ml *memberlist.Memberlist + delegate *shardDelegate + mutex sync.RWMutex + mlMutex sync.RWMutex // Protects memberlist operations (Members, NumMembers, UpdateNode, etc.) + localAddr string + started bool + onPeerJoin func(nodeName string) + onPeerLeave func(nodeName string) + // New callbacks + onLocalShardChange func(shard history.ClusterShardID, added bool) + onRemoteShardChange func(peer string, shard history.ClusterShardID, added bool) + // Local shards owned by this node, keyed by short id + localShards map[string]ShardInfo + intraMgr *intraProxyManager + intraProxyTLSConfig encryption.TLSConfig + // Join retry control + stopJoinRetry chan struct{} + joinWg sync.WaitGroup + joinLoopRunning bool + // activeReceivers tracks active receiver instances by source shard for watermark propagation + activeReceivers map[history.ClusterShardID]ActiveReceiver + activeReceiversMu sync.RWMutex + // remoteSendChannels maps shard IDs to send channels for replication message routing + remoteSendChannels map[history.ClusterShardID]chan RoutedMessage + remoteSendChannelsMu sync.RWMutex + // localAckChannels maps shard IDs to ack channels for local acknowledgment handling + localAckChannels map[history.ClusterShardID]chan RoutedAck + localAckChannelsMu sync.RWMutex + // localReceiverCancelFuncs maps shard IDs to context cancel functions for local receiver termination + localReceiverCancelFuncs map[history.ClusterShardID]context.CancelFunc + localReceiverCancelFuncsMu sync.RWMutex + // remoteNodeStates stores remote node shard states (from MergeRemoteState) + // keyed by node name, includes the meta (shard state) information + remoteNodeStates map[string]NodeShardState + remoteNodeStatesMu sync.RWMutex + } + + // shardDelegate implements memberlist.Delegate for shard state management + shardDelegate struct { + manager *shardManagerImpl + logger log.Logger + } + + // ShardInfo describes a local shard and its creation time + ShardInfo struct { + ID history.ClusterShardID `json:"id"` + Created time.Time `json:"created"` + } + + // ShardMessage represents shard ownership changes broadcast to cluster + ShardMessage struct { + Type string `json:"type"` // "register" or "unregister" + NodeName string `json:"node"` + ClientShard history.ClusterShardID `json:"shard"` + Timestamp time.Time `json:"timestamp"` + } + + // NodeShardState represents all shards owned by a node + NodeShardState struct { + NodeName string `json:"node"` + Shards map[string]ShardInfo `json:"shards"` + Updated time.Time `json:"updated"` + } + + // memberSnapshot is a thread-safe copy of memberlist node data + memberSnapshot struct { + Name string + Meta []byte + } +) + +// getMembersSnapshot returns a thread-safe snapshot of remote node states. +// Uses the remoteNodeStates map instead of ml.Members() to avoid data races. +func (sm *shardManagerImpl) getMembersSnapshot() []memberSnapshot { + sm.remoteNodeStatesMu.RLock() + defer sm.remoteNodeStatesMu.RUnlock() + + snapshots := make([]memberSnapshot, 0, len(sm.remoteNodeStates)) + for nodeName, state := range sm.remoteNodeStates { + // Marshal the state to get the meta bytes + metaBytes, err := json.Marshal(state) + if err != nil { + sm.logger.Warn("Failed to marshal node state for snapshot", + tag.NewStringTag("node", nodeName), + tag.Error(err)) + continue + } + snapshot := memberSnapshot{ + Name: nodeName, + Meta: metaBytes, + } + snapshots = append(snapshots, snapshot) + } + return snapshots +} + +// NewShardManager creates a new shard manager instance +func NewShardManager(memberlistConfig *config.MemberlistConfig, shardCountConfig config.ShardCountConfig, intraProxyTLSConfig encryption.TLSConfig, logger log.Logger) ShardManager { + delegate := &shardDelegate{ + logger: logger, + } + + sm := &shardManagerImpl{ + memberlistConfig: memberlistConfig, + logger: logger, + delegate: delegate, + localShards: make(map[string]ShardInfo), + intraMgr: nil, + intraProxyTLSConfig: intraProxyTLSConfig, + stopJoinRetry: make(chan struct{}), + activeReceivers: make(map[history.ClusterShardID]ActiveReceiver), + remoteSendChannels: make(map[history.ClusterShardID]chan RoutedMessage), + localAckChannels: make(map[history.ClusterShardID]chan RoutedAck), + localReceiverCancelFuncs: make(map[history.ClusterShardID]context.CancelFunc), + remoteNodeStates: make(map[string]NodeShardState), + } + + delegate.manager = sm + + if memberlistConfig != nil && shardCountConfig.Mode == config.ShardCountRouting { + sm.intraMgr = newIntraProxyManager(logger, sm) + } + + return sm +} + +// SetOnPeerJoin registers a callback invoked on new peer joins. +func (sm *shardManagerImpl) SetOnPeerJoin(handler func(nodeName string)) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + sm.onPeerJoin = handler +} + +// SetOnPeerLeave registers a callback invoked when a peer leaves. +func (sm *shardManagerImpl) SetOnPeerLeave(handler func(nodeName string)) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + sm.onPeerLeave = handler +} + +// SetOnLocalShardChange registers local shard change callback. +func (sm *shardManagerImpl) SetOnLocalShardChange(handler func(shard history.ClusterShardID, added bool)) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + sm.onLocalShardChange = handler +} + +// SetOnRemoteShardChange registers remote shard change callback. +func (sm *shardManagerImpl) SetOnRemoteShardChange(handler func(peer string, shard history.ClusterShardID, added bool)) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + sm.onRemoteShardChange = handler +} + +func (sm *shardManagerImpl) Start(lifetime context.Context) error { + sm.logger.Info("Starting shard manager") + + if sm.started { + sm.logger.Info("Shard manager already started") + return nil + } + + if sm.intraMgr != nil { + sm.intraMgr.Start() + } + + sm.SetupCallbacks() + + if err := sm.initializeMemberlist(); err != nil { + return err + } + + sm.mutex.Lock() + sm.started = true + sm.mutex.Unlock() + + sm.logger.Info("Shard manager started", + tag.NewStringTag("node", sm.GetNodeName()), + tag.NewStringTag("addr", sm.localAddr)) + + context.AfterFunc(lifetime, func() { + sm.Stop() + }) + return nil +} + +func (sm *shardManagerImpl) initializeMemberlist() error { + if sm.memberlistConfig == nil { + sm.logger.Info("Shard manager not configured, skipping") + return nil + } + + // Configure memberlist + var mlConfig *memberlist.Config + if sm.memberlistConfig.TCPOnly { + // Use LAN config as base for TCP-only mode + mlConfig = memberlist.DefaultLANConfig() + mlConfig.DisableTcpPings = sm.memberlistConfig.DisableTCPPings + // Set default timeouts for TCP-only if not specified + if sm.memberlistConfig.ProbeTimeoutMs == 0 { + mlConfig.ProbeTimeout = 1 * time.Second + } + if sm.memberlistConfig.ProbeIntervalMs == 0 { + mlConfig.ProbeInterval = 2 * time.Second + } + } else { + mlConfig = memberlist.DefaultLocalConfig() + } + mlConfig.Name = sm.memberlistConfig.NodeName + mlConfig.BindAddr = sm.memberlistConfig.BindAddr + mlConfig.BindPort = sm.memberlistConfig.BindPort + mlConfig.AdvertiseAddr = sm.memberlistConfig.BindAddr + mlConfig.AdvertisePort = sm.memberlistConfig.BindPort + + mlConfig.Delegate = sm.delegate + mlConfig.Events = &shardEventDelegate{manager: sm, logger: sm.logger} + + // Configure custom timeouts if specified + if sm.memberlistConfig.ProbeTimeoutMs > 0 { + mlConfig.ProbeTimeout = time.Duration(sm.memberlistConfig.ProbeTimeoutMs) * time.Millisecond + } + if sm.memberlistConfig.ProbeIntervalMs > 0 { + mlConfig.ProbeInterval = time.Duration(sm.memberlistConfig.ProbeIntervalMs) * time.Millisecond + } + + sm.logger.Info("Creating memberlist", + tag.NewStringTag("nodeName", mlConfig.Name), + tag.NewStringTag("bindAddr", mlConfig.BindAddr), + tag.NewStringTag("bindPort", fmt.Sprintf("%d", mlConfig.BindPort)), + tag.NewBoolTag("tcpOnly", sm.memberlistConfig.TCPOnly), + tag.NewBoolTag("disableTcpPings", mlConfig.DisableTcpPings), + tag.NewStringTag("probeTimeout", mlConfig.ProbeTimeout.String()), + tag.NewStringTag("probeInterval", mlConfig.ProbeInterval.String())) + + // Create memberlist with timeout protection + type result struct { + ml *memberlist.Memberlist + err error + } + resultCh := make(chan result, 1) + go func() { + ml, err := memberlist.Create(mlConfig) + resultCh <- result{ml: ml, err: err} + }() + + var ml *memberlist.Memberlist + select { + case res := <-resultCh: + ml = res.ml + if res.err != nil { + return fmt.Errorf("failed to create memberlist: %w", res.err) + } + sm.logger.Info("Memberlist created successfully") + case <-time.After(10 * time.Second): + return fmt.Errorf("memberlist.Create() timed out after 10s - check bind address/port availability") + } + + sm.mutex.Lock() + sm.ml = ml + sm.localAddr = fmt.Sprintf("%s:%d", sm.memberlistConfig.BindAddr, sm.memberlistConfig.BindPort) + sm.mutex.Unlock() + + sm.logger.Info("Shard manager base initialization complete", + tag.NewStringTag("node", sm.GetNodeName()), + tag.NewStringTag("addr", sm.localAddr)) + + // Join existing cluster if configured + if len(sm.memberlistConfig.JoinAddrs) > 0 { + sm.startJoinLoop() + } + + return nil +} + +func (sm *shardManagerImpl) Stop() { + sm.mutex.Lock() + + if !sm.started { + sm.mutex.Unlock() + return + } + sm.mutex.Unlock() + + sm.shutdownMemberlist() + + sm.mutex.Lock() + sm.started = false + sm.mutex.Unlock() + sm.logger.Info("Shard manager stopped") +} + +func (sm *shardManagerImpl) shutdownMemberlist() { + sm.mutex.RLock() + ml := sm.ml + sm.mutex.RUnlock() + + if ml == nil { + return + } + + // Stop any ongoing join retry + close(sm.stopJoinRetry) + sm.joinWg.Wait() + + // Leave the cluster gracefully + sm.mlMutex.Lock() + err := ml.Leave(5 * time.Second) + if err != nil { + sm.logger.Error("Error leaving memberlist cluster", tag.Error(err)) + } + + err = ml.Shutdown() + if err != nil { + sm.logger.Error("Error shutting down memberlist", tag.Error(err)) + } + sm.mlMutex.Unlock() + + // Clear pointer under main mutex + sm.mutex.Lock() + sm.ml = nil + sm.mutex.Unlock() +} + +// startJoinLoop starts the join retry loop if not already running +func (sm *shardManagerImpl) startJoinLoop() { + sm.mutex.Lock() + defer sm.mutex.Unlock() + + if sm.joinLoopRunning { + sm.logger.Info("Join loop already running, skipping") + return + } + + sm.logger.Info("Starting join loop") + sm.joinLoopRunning = true + sm.joinWg.Add(1) + go sm.retryJoinCluster() +} + +// retryJoinCluster attempts to join the cluster infinitely with exponential backoff +func (sm *shardManagerImpl) retryJoinCluster() { + defer func() { + sm.joinWg.Done() + sm.mutex.Lock() + sm.joinLoopRunning = false + sm.mutex.Unlock() + }() + + const ( + initialInterval = 2 * time.Second + maxInterval = 60 * time.Second + ) + + interval := initialInterval + attempt := 0 + + sm.logger.Info("Starting join retry loop", + tag.NewStringTag("joinAddrs", fmt.Sprintf("%v", sm.memberlistConfig.JoinAddrs))) + + for { + attempt++ + + sm.mutex.RLock() + ml := sm.ml + joinAddrs := sm.memberlistConfig.JoinAddrs + sm.mutex.RUnlock() + + if ml == nil { + sm.logger.Warn("Memberlist not initialized, stopping retry") + return + } + + sm.logger.Info("Attempting to join cluster", + tag.NewStringTag("attempt", strconv.Itoa(attempt)), + tag.NewStringTag("joinAddrs", fmt.Sprintf("%v", joinAddrs))) + + // Serialize Join with other memberlist operations + sm.mlMutex.Lock() + num, err := ml.Join(joinAddrs) + sm.mlMutex.Unlock() + if err != nil { + sm.logger.Warn("Failed to join cluster", tag.Error(err)) + + // Exponential backoff with cap + select { + case <-sm.stopJoinRetry: + sm.logger.Info("Join retry cancelled") + return + case <-time.After(interval): + interval *= 2 + if interval > maxInterval { + interval = maxInterval + } + } + } else { + sm.logger.Info("Successfully joined memberlist cluster", + tag.NewStringTag("members", strconv.Itoa(num)), + tag.NewStringTag("attempt", strconv.Itoa(attempt))) + return + } + } +} + +func (sm *shardManagerImpl) RegisterShard(clientShardID history.ClusterShardID) time.Time { + sm.logger.Info("RegisterShard", tag.NewStringTag("shard", ClusterShardIDtoString(clientShardID))) + registeredAt := sm.addLocalShard(clientShardID) + sm.broadcastShardChange("register", clientShardID) + + // Trigger memberlist metadata update to propagate NodeMeta to other nodes + // Run asynchronously to avoid blocking callers + sm.mutex.RLock() + ml := sm.ml + sm.mutex.RUnlock() + if ml != nil { + go func() { + // Use mlMutex to serialize with getMembersSnapshot and other memberlist operations + sm.mlMutex.Lock() + err := ml.UpdateNode(0) // 0 timeout means immediate update + sm.mlMutex.Unlock() + if err != nil { + sm.logger.Warn("Failed to update memberlist node metadata", tag.Error(err)) + } + }() + } + // Notify listeners + if sm.onLocalShardChange != nil { + sm.onLocalShardChange(clientShardID, true) + } + return registeredAt +} + +func (sm *shardManagerImpl) UnregisterShard(clientShardID history.ClusterShardID, expectedRegisteredAt time.Time) { + sm.logger.Info("UnregisterShard", tag.NewStringTag("shard", ClusterShardIDtoString(clientShardID))) + + // Only unregister if the registration timestamp matches (prevents old senders from removing new registrations) + sm.mutex.Lock() + key := ClusterShardIDtoShortString(clientShardID) + if shardInfo, exists := sm.localShards[key]; exists && shardInfo.Created.Equal(expectedRegisteredAt) { + delete(sm.localShards, key) + // Update metrics after local shards change + sm.mutex.Unlock() + + sm.removeLocalShard(clientShardID) + sm.broadcastShardChange("unregister", clientShardID) + + // Trigger memberlist metadata update to propagate NodeMeta to other nodes + // Run asynchronously to avoid blocking callers + sm.mutex.RLock() + ml := sm.ml + sm.mutex.RUnlock() + if ml != nil { + go func() { + // Use mlMutex to serialize with getMembersSnapshot and other memberlist operations + sm.mlMutex.Lock() + err := ml.UpdateNode(0) // 0 timeout means immediate update + sm.mlMutex.Unlock() + if err != nil { + sm.logger.Warn("Failed to update memberlist node metadata", tag.Error(err)) + } + }() + } + // Notify listeners + if sm.onLocalShardChange != nil { + sm.onLocalShardChange(clientShardID, false) + } + sm.logger.Info("UnregisterShard completed", tag.NewStringTag("shard", ClusterShardIDtoString(clientShardID))) + } else { + sm.mutex.Unlock() + sm.logger.Info("Skipped unregistering shard (timestamp mismatch or already unregistered)", tag.NewStringTag("shard", ClusterShardIDtoString(clientShardID))) + } +} + +func (sm *shardManagerImpl) IsLocalShard(clientShardID history.ClusterShardID) bool { + if !sm.started { + return true // If not using memberlist, handle locally + } + + sm.mutex.RLock() + defer sm.mutex.RUnlock() + + _, found := sm.localShards[ClusterShardIDtoShortString(clientShardID)] + return found +} + +func (sm *shardManagerImpl) GetProxyAddress(nodeName string) (string, bool) { + // TODO: get the proxy address from the memberlist metadata + if sm.memberlistConfig == nil || sm.memberlistConfig.ProxyAddresses == nil { + return "", false + } + addr, found := sm.memberlistConfig.ProxyAddresses[nodeName] + return addr, found +} + +func (sm *shardManagerImpl) GetNodeName() string { + if sm.memberlistConfig == nil { + return "" + } + return sm.memberlistConfig.NodeName +} + +func (sm *shardManagerImpl) GetMemberNodes() []string { + if !sm.started || sm.ml == nil { + return []string{sm.GetNodeName()} + } + + // Use a timeout to prevent deadlocks when memberlist is busy + membersChan := make(chan []memberSnapshot, 1) + go func() { + defer func() { + if r := recover(); r != nil { + sm.logger.Error("Panic in GetMemberNodes", tag.NewStringTag("error", fmt.Sprintf("%v", r))) + } + }() + membersChan <- sm.getMembersSnapshot() + }() + + select { + case members := <-membersChan: + nodes := make([]string, len(members)) + for i, member := range members { + nodes[i] = member.Name + } + return nodes + case <-time.After(100 * time.Millisecond): + // Timeout: return cached node name to prevent hanging + sm.logger.Warn("GetMemberNodes timeout, returning self node", + tag.NewStringTag("node", sm.GetNodeName())) + return []string{sm.GetNodeName()} + } +} + +func (sm *shardManagerImpl) GetLocalShards() map[string]history.ClusterShardID { + sm.mutex.RLock() + defer sm.mutex.RUnlock() + shards := make(map[string]history.ClusterShardID, len(sm.localShards)) + for k, v := range sm.localShards { + shards[k] = v.ID + } + return shards +} + +func (sm *shardManagerImpl) GetShardInfo() ShardDebugInfo { + localShardMap := sm.GetLocalShards() + remoteShards, err := sm.GetRemoteShardsForPeer("") + if err != nil { + sm.logger.Error("Failed to get remote shards", tag.Error(err)) + } + + remoteShardsMap := make(map[string]string) + remoteShardCounts := make(map[string]int) + + for nodeName, shards := range remoteShards { + for _, shard := range shards.Shards { + shardKey := ClusterShardIDtoShortString(shard.ID) + remoteShardsMap[shardKey] = nodeName + } + remoteShardCounts[nodeName] = len(shards.Shards) + } + + return ShardDebugInfo{ + Enabled: true, + NodeName: sm.GetNodeName(), + LocalShards: localShardMap, + LocalShardCount: len(localShardMap), + RemoteShards: remoteShardsMap, + RemoteShardCounts: remoteShardCounts, + } +} + +// GetShardInfos returns debug information about shard distribution as a slice +func (sm *shardManagerImpl) GetShardInfos() []ShardDebugInfo { + if sm.memberlistConfig == nil { + return []ShardDebugInfo{} + } + return []ShardDebugInfo{sm.GetShardInfo()} +} + +// GetChannelInfo returns debug information about active channels +func (sm *shardManagerImpl) GetChannelInfo() ChannelDebugInfo { + remoteSendChannels := make(map[string]int) + var totalSendChannels int + + // Collect remote send channel info first + allSendChans := sm.GetAllRemoteSendChans() + for shardID, ch := range allSendChans { + shardKey := ClusterShardIDtoString(shardID) + remoteSendChannels[shardKey] = len(ch) + } + totalSendChannels = len(allSendChans) + + localAckChannels := make(map[string]int) + var totalAckChannels int + + // Collect local ack channel info separately + allAckChans := sm.GetAllLocalAckChans() + for shardID, ch := range allAckChans { + shardKey := ClusterShardIDtoString(shardID) + localAckChannels[shardKey] = len(ch) + } + totalAckChannels = len(allAckChans) + + return ChannelDebugInfo{ + RemoteSendChannels: remoteSendChannels, + LocalAckChannels: localAckChannels, + TotalSendChannels: totalSendChannels, + TotalAckChannels: totalAckChannels, + } +} + +// TerminatePreviousLocalReceiver checks if there is a previous local receiver for this shard and terminates it if needed +func (sm *shardManagerImpl) TerminatePreviousLocalReceiver(shardID history.ClusterShardID, logger log.Logger) { + // Check if there's a previous cancel function for this shard + if prevCancelFunc, exists := sm.GetLocalReceiverCancelFunc(shardID); exists { + logger.Info("Terminating previous local receiver for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + + // Cancel the previous receiver's context + prevCancelFunc() + + // Force remove the cancel function and ack channel from tracking + sm.RemoveLocalReceiverCancelFunc(shardID) + sm.ForceRemoveLocalAckChan(shardID) + } +} + +func (sm *shardManagerImpl) GetShardOwner(shard history.ClusterShardID) (string, bool) { + remoteShards, err := sm.GetRemoteShardsForPeer("") + if err != nil { + sm.logger.Error("Failed to get remote shards", tag.Error(err)) + } + for nodeName, shards := range remoteShards { + for _, s := range shards.Shards { + if s.ID == shard { + return nodeName, true + } + } + } + return "", false +} + +// GetRemoteShardsForPeer returns all shards owned by the specified peer node. +// Uses the remoteNodeStates map instead of ml.Members() to avoid data races. +func (sm *shardManagerImpl) GetRemoteShardsForPeer(peerNodeName string) (map[string]NodeShardState, error) { + result := make(map[string]NodeShardState) + + sm.remoteNodeStatesMu.RLock() + defer sm.remoteNodeStatesMu.RUnlock() + + for nodeName, state := range sm.remoteNodeStates { + if nodeName == sm.GetNodeName() { + continue + } + if peerNodeName != "" && nodeName != peerNodeName { + continue + } + result[nodeName] = state + } + + return result, nil +} + +// DeliverAckToShardOwner routes an ACK to the local shard owner or records intent for remote forwarding. +func (sm *shardManagerImpl) DeliverAckToShardOwner( + sourceShard history.ClusterShardID, + routedAck *RoutedAck, + shutdownChan channel.ShutdownOnce, + logger log.Logger, + ack int64, + allowForward bool, +) bool { + logger = log.With(logger, tag.NewStringTag("sourceShard", ClusterShardIDtoString(sourceShard)), tag.NewInt64("ack", ack)) + if ackCh, ok := sm.GetLocalAckChan(sourceShard); ok { + delivered := false + func() { + defer func() { + if panicErr := recover(); panicErr != nil { + logger.Warn("Failed to deliver ACK to local shard owner (channel closed)") + } + }() + select { + case ackCh <- *routedAck: + logger.Info("Delivered ACK to local shard owner") + delivered = true + case <-shutdownChan.Channel(): + // Shutdown signal received + } + }() + if delivered { + return true + } + if shutdownChan.IsShutdown() { + return false + } + } + if !allowForward { + logger.Warn("No local ack channel for source shard, forwarding ACK to shard owner is not allowed") + return false + } + + // Attempt remote delivery via intra-proxy when enabled and shard is remote + if sm.memberlistConfig != nil { + if owner, ok := sm.GetShardOwner(sourceShard); ok && owner != sm.GetNodeName() { + if addr, found := sm.GetProxyAddress(owner); found { + clientShard := routedAck.TargetShard + serverShard := sourceShard + // Synchronous send to preserve ordering + if err := sm.intraMgr.sendAck(context.Background(), owner, clientShard, serverShard, routedAck.Req); err != nil { + logger.Error("Failed to forward ACK to shard owner via intra-proxy", tag.Error(err), tag.NewStringTag("owner", owner), tag.NewStringTag("addr", addr)) + return false + } + logger.Info("Forwarded ACK to shard owner via intra-proxy", tag.NewStringTag("owner", owner), tag.NewStringTag("addr", addr)) + return true + } + logger.Warn("Owner proxy address not found for shard") + return false + } + } + + logger.Warn("No remote shard owner found for source shard") + return false +} + +// DeliverMessagesToShardOwner routes replication messages to the local target shard owner +// or forwards to the remote owner via intra-proxy stream synchronously. +func (sm *shardManagerImpl) DeliverMessagesToShardOwner( + targetShard history.ClusterShardID, + routedMsg *RoutedMessage, + shutdownChan channel.ShutdownOnce, + logger log.Logger, +) bool { + logger = log.With(logger, tag.NewStringTag("task-target-shard", ClusterShardIDtoString(targetShard))) + + // Try local delivery first + if ch, ok := sm.GetRemoteSendChan(targetShard); ok { + delivered := false + func() { + defer func() { + if panicErr := recover(); panicErr != nil { + logger.Warn("Failed to deliver messages to local shard owner (channel closed)") + } + }() + select { + case ch <- *routedMsg: + logger.Info("Delivered messages to local shard owner") + delivered = true + case <-shutdownChan.Channel(): + // Shutdown signal received + } + }() + if delivered { + return true + } + if shutdownChan.IsShutdown() { + return false + } + } + + // Attempt remote delivery via intra-proxy when enabled and shard is remote + if sm.memberlistConfig != nil { + if owner, ok := sm.GetShardOwner(targetShard); ok && owner != sm.GetNodeName() { + if addr, found := sm.GetProxyAddress(owner); found { + if mgr := sm.GetIntraProxyManager(); mgr != nil { + resp := routedMsg.Resp + if err := mgr.sendReplicationMessages(context.Background(), owner, targetShard, routedMsg.SourceShard, resp); err != nil { + logger.Error("Failed to forward replication messages to shard owner via intra-proxy", tag.Error(err), tag.NewStringTag("owner", owner), tag.NewStringTag("addr", addr)) + return false + } + return true + } + } else { + logger.Warn("Owner proxy address not found for target shard", tag.NewStringTag("owner", owner), tag.NewStringTag("shard", ClusterShardIDtoString(targetShard))) + } + } + } + + logger.Warn("No local send channel for target shard", tag.NewStringTag("targetShard", ClusterShardIDtoString(targetShard))) + return false +} + +func (sm *shardManagerImpl) SetupCallbacks() { + // Wire memberlist peer-join callback to reconcile intra-proxy receivers for local/remote pairs + sm.SetOnPeerJoin(func(nodeName string) { + sm.logger.Info("OnPeerJoin", tag.NewStringTag("nodeName", nodeName)) + defer sm.logger.Info("OnPeerJoin done", tag.NewStringTag("nodeName", nodeName)) + if sm.intraMgr != nil { + sm.intraMgr.Notify() + } + }) + + // Wire peer-leave to cleanup intra-proxy resources for that peer + sm.SetOnPeerLeave(func(nodeName string) { + sm.logger.Info("OnPeerLeave", tag.NewStringTag("nodeName", nodeName)) + defer sm.logger.Info("OnPeerLeave done", tag.NewStringTag("nodeName", nodeName)) + if sm.intraMgr != nil { + sm.intraMgr.Notify() + } + }) + + // Wire local shard changes to reconcile intra-proxy receivers + sm.SetOnLocalShardChange(func(shard history.ClusterShardID, added bool) { + sm.logger.Info("OnLocalShardChange", tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) + defer sm.logger.Info("OnLocalShardChange done", tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) + if added { + sm.notifyReceiversOfNewShard(shard) + } + if sm.intraMgr != nil { + sm.intraMgr.Notify() + } + }) + + // Wire remote shard changes to reconcile intra-proxy receivers + sm.SetOnRemoteShardChange(func(peer string, shard history.ClusterShardID, added bool) { + sm.logger.Info("OnRemoteShardChange", tag.NewStringTag("peer", peer), tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) + defer sm.logger.Info("OnRemoteShardChange done", tag.NewStringTag("peer", peer), tag.NewStringTag("shard", ClusterShardIDtoString(shard)), tag.NewStringTag("added", strconv.FormatBool(added))) + if added { + sm.notifyReceiversOfNewShard(shard) + } + if sm.intraMgr != nil { + sm.intraMgr.Notify() + } + }) +} + +func (sm *shardManagerImpl) GetIntraProxyManager() *intraProxyManager { + return sm.intraMgr +} + +func (sm *shardManagerImpl) GetIntraProxyTLSConfig() encryption.TLSConfig { + return sm.intraProxyTLSConfig +} + +func (sm *shardManagerImpl) broadcastShardChange(msgType string, shard history.ClusterShardID) { + if !sm.started || sm.ml == nil || sm.memberlistConfig == nil { + return + } + + msg := ShardMessage{ + Type: msgType, + NodeName: sm.GetNodeName(), + ClientShard: shard, + Timestamp: time.Now(), + } + + data, err := json.Marshal(msg) + if err != nil { + sm.logger.Error("Failed to marshal shard message", tag.Error(err)) + return + } + + // Use remoteNodeStates map to get list of nodes to send to + sm.remoteNodeStatesMu.RLock() + nodeNames := make([]string, 0, len(sm.remoteNodeStates)) + for nodeName := range sm.remoteNodeStates { + // Skip sending to self node + if nodeName == sm.GetNodeName() { + continue + } + nodeNames = append(nodeNames, nodeName) + } + sm.remoteNodeStatesMu.RUnlock() + + for _, nodeName := range nodeNames { + // Send in goroutine to make it non-blocking + // Look up fresh node pointer when sending to avoid race with memberlist updates + go func(targetNodeName string) { + sm.mutex.RLock() + ml := sm.ml + sm.mutex.RUnlock() + + if ml == nil { + return + } + + // Find the node by name from current members list + // Serialize memberlist operations to prevent races with UpdateNode + sm.mlMutex.RLock() + var targetNode *memberlist.Node + for _, n := range ml.Members() { + if n != nil && n.Name == targetNodeName { + targetNode = n + break + } + } + if targetNode == nil { + sm.mlMutex.RUnlock() + sm.logger.Warn("Node not found for broadcast", + tag.NewStringTag("target_node", targetNodeName)) + return + } + // SendReliable reads node fields (via FullAddress()), so we must hold the lock + // to prevent races with memberlist's internal updates. + // Note: This is a blocking network call, but we need the lock to prevent + // memberlist from modifying the node while SendReliable reads its fields. + err := ml.SendReliable(targetNode, data) + sm.mlMutex.RUnlock() + if err != nil { + sm.logger.Error("Failed to broadcast shard change", + tag.Error(err), + tag.NewStringTag("target_node", targetNodeName)) + } + }(nodeName) + } +} + +// shardDelegate implements memberlist.Delegate +func (sd *shardDelegate) NodeMeta(limit int) []byte { + if sd.manager == nil || sd.manager.memberlistConfig == nil { + return nil + } + // Copy shard map under read lock to avoid concurrent map iteration/modification + sd.manager.mutex.RLock() + shardsCopy := make(map[string]ShardInfo, len(sd.manager.localShards)) + for k, v := range sd.manager.localShards { + shardsCopy[k] = v + } + nodeName := sd.manager.GetNodeName() + sd.manager.mutex.RUnlock() + + state := NodeShardState{ + NodeName: nodeName, + Shards: shardsCopy, + Updated: time.Now(), + } + + data, err := json.Marshal(state) + if err != nil { + sd.logger.Error("Failed to marshal node meta", tag.Error(err)) + return nil + } + + if len(data) > limit { + // If metadata is too large, just send node name + return []byte(sd.manager.GetNodeName()) + } + + return data +} + +func (sd *shardDelegate) NotifyMsg(data []byte) { + var msg ShardMessage + if err := json.Unmarshal(data, &msg); err != nil { + sd.logger.Error("Failed to unmarshal shard message", tag.Error(err)) + return + } + + sd.logger.Info("Received shard message", + tag.NewStringTag("type", msg.Type), + tag.NewStringTag("node", msg.NodeName), + tag.NewStringTag("shard", ClusterShardIDtoString(msg.ClientShard))) + + // Inform listeners about remote shard changes + if sd.manager != nil && sd.manager.onRemoteShardChange != nil { + added := msg.Type == "register" + + // if shard is previously registered as local shard, but now is registered as remote shard, + // check if the remote shard is newer than the local shard. If so, unregister the local shard. + if added { + // Lock when reading localShards to prevent race with concurrent writes + sd.manager.mutex.RLock() + localShard, ok := sd.manager.localShards[ClusterShardIDtoShortString(msg.ClientShard)] + sd.manager.mutex.RUnlock() + if ok { + if localShard.Created.Before(msg.Timestamp) { + // Force unregister the local shard by passing its own timestamp + sd.manager.UnregisterShard(msg.ClientShard, localShard.Created) + } + } + } + + sd.manager.onRemoteShardChange(msg.NodeName, msg.ClientShard, added) + } +} + +func (sd *shardDelegate) GetBroadcasts(overhead, limit int) [][]byte { + // Not implementing broadcasts for now + return nil +} + +func (sd *shardDelegate) LocalState(join bool) []byte { + return sd.NodeMeta(4096) // TODO: set this to a reasonable value +} + +func (sd *shardDelegate) MergeRemoteState(buf []byte, join bool) { + var state NodeShardState + if err := json.Unmarshal(buf, &state); err != nil { + sd.logger.Error("Failed to unmarshal remote state", tag.Error(err)) + return + } + + // Save the remote state to local map + if sd.manager != nil { + sd.manager.remoteNodeStatesMu.Lock() + sd.manager.remoteNodeStates[state.NodeName] = state + sd.manager.remoteNodeStatesMu.Unlock() + } + + sd.logger.Info("Merged remote shard state", + tag.NewStringTag("node", state.NodeName), + tag.NewStringTag("shards", strconv.Itoa(len(state.Shards))), + tag.NewStringTag("state", fmt.Sprintf("%+v", state))) +} + +func (sm *shardManagerImpl) addLocalShard(shard history.ClusterShardID) time.Time { + sm.mutex.Lock() + defer sm.mutex.Unlock() + + key := ClusterShardIDtoShortString(shard) + now := time.Now() + sm.localShards[key] = ShardInfo{ID: shard, Created: now} + + return now +} + +func (sm *shardManagerImpl) removeLocalShard(shard history.ClusterShardID) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + + key := ClusterShardIDtoShortString(shard) + delete(sm.localShards, key) +} + +// RegisterActiveReceiver registers an active receiver for watermark propagation +func (sm *shardManagerImpl) RegisterActiveReceiver(sourceShardID history.ClusterShardID, receiver ActiveReceiver) { + sm.activeReceiversMu.Lock() + defer sm.activeReceiversMu.Unlock() + sm.activeReceivers[sourceShardID] = receiver +} + +// UnregisterActiveReceiver removes an active receiver +func (sm *shardManagerImpl) UnregisterActiveReceiver(sourceShardID history.ClusterShardID) { + sm.activeReceiversMu.Lock() + defer sm.activeReceiversMu.Unlock() + delete(sm.activeReceivers, sourceShardID) +} + +// GetActiveReceiver returns the active receiver for the given source shard +func (sm *shardManagerImpl) GetActiveReceiver(sourceShardID history.ClusterShardID) (ActiveReceiver, bool) { + sm.activeReceiversMu.RLock() + defer sm.activeReceiversMu.RUnlock() + receiver, ok := sm.activeReceivers[sourceShardID] + return receiver, ok +} + +// notifyReceiversOfNewShard notifies all receivers about a newly registered target shard +// so they can send pending watermarks if available +func (sm *shardManagerImpl) notifyReceiversOfNewShard(targetShardID history.ClusterShardID) { + sm.activeReceiversMu.RLock() + receivers := make([]ActiveReceiver, 0, len(sm.activeReceivers)) + for _, receiver := range sm.activeReceivers { + receivers = append(receivers, receiver) + } + sm.activeReceiversMu.RUnlock() + + for _, receiver := range receivers { + // Only notify receivers that route to the same cluster as the newly registered shard + if receiver.GetTargetShardID().ClusterID == targetShardID.ClusterID { + receiver.NotifyNewTargetShard(targetShardID) + } + } +} + +// SetRemoteSendChan registers a send channel for a specific shard ID +func (sm *shardManagerImpl) SetRemoteSendChan(shardID history.ClusterShardID, sendChan chan RoutedMessage) { + sm.logger.Info("Register remote send channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + sm.remoteSendChannelsMu.Lock() + defer sm.remoteSendChannelsMu.Unlock() + sm.remoteSendChannels[shardID] = sendChan +} + +// GetRemoteSendChan retrieves the send channel for a specific shard ID +func (sm *shardManagerImpl) GetRemoteSendChan(shardID history.ClusterShardID) (chan RoutedMessage, bool) { + sm.remoteSendChannelsMu.RLock() + defer sm.remoteSendChannelsMu.RUnlock() + ch, exists := sm.remoteSendChannels[shardID] + return ch, exists +} + +// GetAllRemoteSendChans returns a map of all remote send channels +func (sm *shardManagerImpl) GetAllRemoteSendChans() map[history.ClusterShardID]chan RoutedMessage { + sm.remoteSendChannelsMu.RLock() + defer sm.remoteSendChannelsMu.RUnlock() + + // Create a copy of the map + result := make(map[history.ClusterShardID]chan RoutedMessage, len(sm.remoteSendChannels)) + for k, v := range sm.remoteSendChannels { + result[k] = v + } + return result +} + +// GetRemoteSendChansByCluster returns a copy of remote send channels filtered by clusterID +func (sm *shardManagerImpl) GetRemoteSendChansByCluster(clusterID int32) map[history.ClusterShardID]chan RoutedMessage { + sm.remoteSendChannelsMu.RLock() + defer sm.remoteSendChannelsMu.RUnlock() + + result := make(map[history.ClusterShardID]chan RoutedMessage) + for k, v := range sm.remoteSendChannels { + if k.ClusterID == clusterID { + result[k] = v + } + } + return result +} + +// RemoveRemoteSendChan removes the send channel for a specific shard ID only if it matches the provided channel +func (sm *shardManagerImpl) RemoveRemoteSendChan(shardID history.ClusterShardID, expectedChan chan RoutedMessage) { + sm.remoteSendChannelsMu.Lock() + defer sm.remoteSendChannelsMu.Unlock() + if currentChan, exists := sm.remoteSendChannels[shardID]; exists && currentChan == expectedChan { + delete(sm.remoteSendChannels, shardID) + sm.logger.Info("Removed remote send channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + } else { + sm.logger.Info("Skipped removing remote send channel for shard (channel mismatch or already removed)", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + } +} + +// SetLocalAckChan registers an ack channel for a specific shard ID +func (sm *shardManagerImpl) SetLocalAckChan(shardID history.ClusterShardID, ackChan chan RoutedAck) { + sm.logger.Info("Register local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + sm.localAckChannelsMu.Lock() + defer sm.localAckChannelsMu.Unlock() + sm.localAckChannels[shardID] = ackChan +} + +// GetLocalAckChan retrieves the ack channel for a specific shard ID +func (sm *shardManagerImpl) GetLocalAckChan(shardID history.ClusterShardID) (chan RoutedAck, bool) { + sm.localAckChannelsMu.RLock() + defer sm.localAckChannelsMu.RUnlock() + ch, exists := sm.localAckChannels[shardID] + return ch, exists +} + +// GetAllLocalAckChans returns a map of all local ack channels +func (sm *shardManagerImpl) GetAllLocalAckChans() map[history.ClusterShardID]chan RoutedAck { + sm.localAckChannelsMu.RLock() + defer sm.localAckChannelsMu.RUnlock() + + // Create a copy of the map + result := make(map[history.ClusterShardID]chan RoutedAck, len(sm.localAckChannels)) + for k, v := range sm.localAckChannels { + result[k] = v + } + return result +} + +// RemoveLocalAckChan removes the ack channel for a specific shard ID only if it matches the provided channel +func (sm *shardManagerImpl) RemoveLocalAckChan(shardID history.ClusterShardID, expectedChan chan RoutedAck) { + sm.logger.Info("Remove local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + sm.localAckChannelsMu.Lock() + defer sm.localAckChannelsMu.Unlock() + if currentChan, exists := sm.localAckChannels[shardID]; exists && currentChan == expectedChan { + delete(sm.localAckChannels, shardID) + } else { + sm.logger.Info("Skipped removing local ack channel for shard (channel mismatch or already removed)", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + } +} + +// ForceRemoveLocalAckChan unconditionally removes the ack channel for a specific shard ID +func (sm *shardManagerImpl) ForceRemoveLocalAckChan(shardID history.ClusterShardID) { + sm.logger.Info("Force remove local ack channel for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + sm.localAckChannelsMu.Lock() + defer sm.localAckChannelsMu.Unlock() + delete(sm.localAckChannels, shardID) +} + +// SetLocalReceiverCancelFunc registers a cancel function for a local receiver for a specific shard ID +func (sm *shardManagerImpl) SetLocalReceiverCancelFunc(shardID history.ClusterShardID, cancelFunc context.CancelFunc) { + sm.logger.Info("Register local receiver cancel function for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + sm.localReceiverCancelFuncsMu.Lock() + defer sm.localReceiverCancelFuncsMu.Unlock() + sm.localReceiverCancelFuncs[shardID] = cancelFunc +} + +// GetLocalReceiverCancelFunc retrieves the cancel function for a local receiver for a specific shard ID +func (sm *shardManagerImpl) GetLocalReceiverCancelFunc(shardID history.ClusterShardID) (context.CancelFunc, bool) { + sm.localReceiverCancelFuncsMu.RLock() + defer sm.localReceiverCancelFuncsMu.RUnlock() + cancelFunc, exists := sm.localReceiverCancelFuncs[shardID] + return cancelFunc, exists +} + +// RemoveLocalReceiverCancelFunc unconditionally removes the cancel function for a local receiver for a specific shard ID +func (sm *shardManagerImpl) RemoveLocalReceiverCancelFunc(shardID history.ClusterShardID) { + sm.logger.Info("Remove local receiver cancel function for shard", tag.NewStringTag("shardID", ClusterShardIDtoString(shardID))) + sm.localReceiverCancelFuncsMu.Lock() + defer sm.localReceiverCancelFuncsMu.Unlock() + delete(sm.localReceiverCancelFuncs, shardID) +} + +// shardEventDelegate handles memberlist cluster events +type shardEventDelegate struct { + manager *shardManagerImpl + logger log.Logger +} + +func (sed *shardEventDelegate) NotifyJoin(node *memberlist.Node) { + sed.logger.Info("Node joined cluster", + tag.NewStringTag("node", node.Name), + tag.NewStringTag("addr", node.Addr.String())) +} + +func (sed *shardEventDelegate) NotifyLeave(node *memberlist.Node) { + sed.logger.Info("Node left cluster", + tag.NewStringTag("node", node.Name), + tag.NewStringTag("addr", node.Addr.String())) + + // Remove the node from remoteNodeStates map + if sed.manager != nil { + sed.manager.remoteNodeStatesMu.Lock() + delete(sed.manager.remoteNodeStates, node.Name) + sed.manager.remoteNodeStatesMu.Unlock() + } + + // If we're now isolated and have join addresses configured, restart join loop + if sed.manager != nil && sed.manager.ml != nil && sed.manager.memberlistConfig != nil { + sed.manager.mlMutex.RLock() + numMembers := sed.manager.ml.NumMembers() + sed.manager.mlMutex.RUnlock() + if numMembers == 1 && len(sed.manager.memberlistConfig.JoinAddrs) > 0 { + sed.logger.Info("Node is now isolated, restarting join loop", + tag.NewStringTag("numMembers", strconv.Itoa(numMembers))) + sed.manager.startJoinLoop() + } + } +} + +func (sed *shardEventDelegate) NotifyUpdate(node *memberlist.Node) { + sed.logger.Info("Node updated", + tag.NewStringTag("node", node.Name), + tag.NewStringTag("addr", node.Addr.String())) +} diff --git a/proxy/stream_tracker.go b/proxy/stream_tracker.go new file mode 100644 index 00000000..8a09e179 --- /dev/null +++ b/proxy/stream_tracker.go @@ -0,0 +1,214 @@ +package proxy + +import ( + "fmt" + "sync" + "time" + + "go.temporal.io/server/client/history" +) + +const ( + StreamRoleSender = "Sender" + StreamRoleReceiver = "Receiver" + StreamRoleForwarder = "Forwarder" +) + +// StreamTracker tracks active gRPC streams for debugging +type StreamTracker struct { + mu sync.RWMutex + streams map[string]*StreamInfo +} + +// NewStreamTracker creates a new stream tracker +func NewStreamTracker() *StreamTracker { + return &StreamTracker{ + streams: make(map[string]*StreamInfo), + } +} + +// RegisterStream adds a new active stream +func (st *StreamTracker) RegisterStream(id, method, direction, sourceShard, targetShard, role string) { + st.mu.Lock() + defer st.mu.Unlock() + + now := time.Now() + st.streams[id] = &StreamInfo{ + ID: id, + Method: method, + Direction: direction, + ClientShard: targetShard, + ServerShard: sourceShard, + Role: role, + StartTime: now, + LastSeen: now, + SenderDebug: &SenderDebugInfo{}, + ReceiverDebug: &ReceiverDebugInfo{}, + } +} + +// UpdateStream updates the last seen time for a stream +func (st *StreamTracker) UpdateStream(id string) { + st.mu.Lock() + defer st.mu.Unlock() + + if stream, exists := st.streams[id]; exists { + stream.LastSeen = time.Now() + } +} + +// UpdateStreamSyncReplicationState updates the sync replication state information for a stream +func (st *StreamTracker) UpdateStreamSyncReplicationState(id string, inclusiveLowWatermark int64, watermarkTime *time.Time) { + st.mu.Lock() + defer st.mu.Unlock() + + if stream, exists := st.streams[id]; exists { + stream.LastSeen = time.Now() + stream.LastSyncWatermark = &inclusiveLowWatermark + stream.LastSyncWatermarkTime = watermarkTime + } +} + +// UpdateStreamReplicationMessages updates the replication messages information for a stream +func (st *StreamTracker) UpdateStreamReplicationMessages(id string, exclusiveHighWatermark int64) { + st.mu.Lock() + defer st.mu.Unlock() + + if stream, exists := st.streams[id]; exists { + stream.LastSeen = time.Now() + stream.LastExclusiveHighWatermark = &exclusiveHighWatermark + } +} + +// UpdateStreamLastTaskIDs updates the last seen task ids for a stream +func (st *StreamTracker) UpdateStreamLastTaskIDs(id string, taskIDs []int64) { + st.mu.Lock() + defer st.mu.Unlock() + + if stream, exists := st.streams[id]; exists { + stream.LastSeen = time.Now() + stream.LastTaskIDs = taskIDs + } +} + +// UnregisterStream removes a stream from tracking +func (st *StreamTracker) UnregisterStream(id string) { + st.mu.Lock() + defer st.mu.Unlock() + + delete(st.streams, id) +} + +// UpdateStreamSenderDebug sets the sender debug snapshot for a stream +func (st *StreamTracker) UpdateStreamSenderDebug(id string, info *SenderDebugInfo) { + st.mu.Lock() + defer st.mu.Unlock() + + if stream, exists := st.streams[id]; exists { + stream.SenderDebug = info + stream.LastSeen = time.Now() + } +} + +// UpdateStreamReceiverDebug sets the receiver debug snapshot for a stream +func (st *StreamTracker) UpdateStreamReceiverDebug(id string, info *ReceiverDebugInfo) { + st.mu.Lock() + defer st.mu.Unlock() + + if stream, exists := st.streams[id]; exists { + stream.ReceiverDebug = info + stream.LastSeen = time.Now() + } +} + +// GetActiveStreams returns a copy of all active streams +func (st *StreamTracker) GetActiveStreams() []StreamInfo { + st.mu.RLock() + defer st.mu.RUnlock() + + now := time.Now() + streams := make([]StreamInfo, 0, len(st.streams)) + for _, stream := range st.streams { + // Create a copy and calculate both durations in seconds + streamCopy := *stream + totalSeconds := int(now.Sub(stream.StartTime).Seconds()) + idleSeconds := int(now.Sub(stream.LastSeen).Seconds()) + streamCopy.TotalDuration = formatDurationSeconds(totalSeconds) + streamCopy.IdleDuration = formatDurationSeconds(idleSeconds) + streams = append(streams, streamCopy) + } + + return streams +} + +// GetStreamCount returns the number of active streams +func (st *StreamTracker) GetStreamCount() int { + st.mu.RLock() + defer st.mu.RUnlock() + + return len(st.streams) +} + +// Global stream tracker instance +var globalStreamTracker = NewStreamTracker() + +// GetGlobalStreamTracker returns the global stream tracker instance +func GetGlobalStreamTracker() *StreamTracker { + return globalStreamTracker +} + +// BuildSenderStreamID returns the canonical sender stream ID. +func BuildSenderStreamID(source, target history.ClusterShardID) string { + return fmt.Sprintf("snd-%s", ClusterShardIDtoShortString(source)) +} + +// BuildReceiverStreamID returns the canonical receiver stream ID. +func BuildReceiverStreamID(source, target history.ClusterShardID) string { + return fmt.Sprintf("rcv-%s", ClusterShardIDtoShortString(source)) +} + +// BuildForwarderStreamID returns the canonical forwarder stream ID. +// Note: forwarder uses server-first ordering in the ID. +func BuildForwarderStreamID(source, target history.ClusterShardID) string { + return fmt.Sprintf("fwd-snd-%s", ClusterShardIDtoShortString(source)) +} + +// BuildIntraProxySenderStreamID returns the server-side intra-proxy stream ID for a peer and shard pair. +func BuildIntraProxySenderStreamID(peer string, source, target history.ClusterShardID) string { + return fmt.Sprintf("ip-snd-%s-%s|%s", ClusterShardIDtoShortString(source), ClusterShardIDtoShortString(target), peer) +} + +// BuildIntraProxyReceiverStreamID returns the client-side intra-proxy stream ID for a peer and shard pair. +func BuildIntraProxyReceiverStreamID(peer string, source, target history.ClusterShardID) string { + return fmt.Sprintf("ip-rcv-%s-%s|%s", ClusterShardIDtoShortString(source), ClusterShardIDtoShortString(target), peer) +} + +// formatDurationSeconds formats a duration in seconds to a readable string +func formatDurationSeconds(totalSeconds int) string { + if totalSeconds < 60 { + return fmt.Sprintf("%ds", totalSeconds) + } + + minutes := totalSeconds / 60 + seconds := totalSeconds % 60 + + if minutes < 60 { + if seconds == 0 { + return fmt.Sprintf("%dm", minutes) + } + return fmt.Sprintf("%dm%ds", minutes, seconds) + } + + hours := minutes / 60 + minutes = minutes % 60 + + if minutes == 0 && seconds == 0 { + return fmt.Sprintf("%dh", hours) + } else if seconds == 0 { + return fmt.Sprintf("%dh%dm", hours, minutes) + } else if minutes == 0 { + return fmt.Sprintf("%dh%ds", hours, seconds) + } else { + return fmt.Sprintf("%dh%dm%ds", hours, minutes, seconds) + } +} diff --git a/proxy/test/bench_test.go b/proxy/test/bench_test.go index 60247c45..6f72943b 100644 --- a/proxy/test/bench_test.go +++ b/proxy/test/bench_test.go @@ -12,15 +12,86 @@ import ( "github.com/temporalio/s2s-proxy/endtoendtest" ) -func benchmarkStreamSendRecvWithoutProxy(b *testing.B, payloadSize int) { +func createEchoServerConfigWithPorts( + echoServerAddress string, + serverProxyInboundAddress string, + serverProxyOutboundAddress string, + opts ...cfgOption, +) *config.S2SProxyConfig { + return createS2SProxyConfig(&config.S2SProxyConfig{ + Inbound: &config.ProxyConfig{ + Name: "proxy1-inbound-server", + Server: config.ProxyServerConfig{ + TCPServerSetting: config.TCPServerSetting{ + ListenAddress: serverProxyInboundAddress, + }, + }, + Client: config.ProxyClientConfig{ + TCPClientSetting: config.TCPClientSetting{ + ServerAddress: echoServerAddress, + }, + }, + }, + Outbound: &config.ProxyConfig{ + Name: "proxy1-outbound-server", + Server: config.ProxyServerConfig{ + TCPServerSetting: config.TCPServerSetting{ + ListenAddress: serverProxyOutboundAddress, + }, + }, + Client: config.ProxyClientConfig{ + TCPClientSetting: config.TCPClientSetting{ + ServerAddress: "to-be-added", + }, + }, + }, + }, opts) +} + +func createEchoClientConfigWithPorts( + echoClientAddress string, + clientProxyInboundAddress string, + clientProxyOutboundAddress string, + opts ...cfgOption, +) *config.S2SProxyConfig { + return createS2SProxyConfig(&config.S2SProxyConfig{ + Inbound: &config.ProxyConfig{ + Name: "proxy2-inbound-server", + Server: config.ProxyServerConfig{ + TCPServerSetting: config.TCPServerSetting{ + ListenAddress: clientProxyInboundAddress, + }, + }, + Client: config.ProxyClientConfig{ + TCPClientSetting: config.TCPClientSetting{ + ServerAddress: echoClientAddress, + }, + }, + }, + Outbound: &config.ProxyConfig{ + Name: "proxy2-outbound-server", + Server: config.ProxyServerConfig{ + TCPServerSetting: config.TCPServerSetting{ + ListenAddress: clientProxyOutboundAddress, + }, + }, + Client: config.ProxyClientConfig{ + TCPClientSetting: config.TCPClientSetting{ + ServerAddress: "to-be-added", + }, + }, + }, + }, opts) +} +func benchmarkStreamSendRecvWithoutProxy(b *testing.B, payloadSize int) { echoServerInfo := endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: GetLocalhostAddress(), ClusterShardID: serverClusterShard, } echoClientInfo := endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: GetLocalhostAddress(), ClusterShardID: clientClusterShard, } @@ -31,7 +102,18 @@ func benchmarkStreamSendRecvWithMuxProxy(b *testing.B, payloadSize int) { b.Log("Start BenchmarkStreamSendRecv") muxTransportName := "muxed" - echoServerConfig := createEchoServerConfig( + // Allocate ports dynamically + echoServerAddress := GetLocalhostAddress() + serverProxyInboundAddress := GetLocalhostAddress() + serverProxyOutboundAddress := GetLocalhostAddress() + echoClientAddress := GetLocalhostAddress() + clientProxyInboundAddress := GetLocalhostAddress() + clientProxyOutboundAddress := GetLocalhostAddress() + + echoServerConfig := createEchoServerConfigWithPorts( + echoServerAddress, + serverProxyInboundAddress, + serverProxyOutboundAddress, withMux( config.MuxTransportConfig{ Name: muxTransportName, @@ -54,7 +136,10 @@ func benchmarkStreamSendRecvWithMuxProxy(b *testing.B, payloadSize int) { }, false), ) - echoClientConfig := createEchoClientConfig( + echoClientConfig := createEchoClientConfigWithPorts( + echoClientAddress, + clientProxyInboundAddress, + clientProxyOutboundAddress, withMux( config.MuxTransportConfig{ Name: muxTransportName, diff --git a/proxy/test/echo_proxy_test.go b/proxy/test/echo_proxy_test.go index c5ec33e8..0788feaa 100644 --- a/proxy/test/echo_proxy_test.go +++ b/proxy/test/echo_proxy_test.go @@ -25,16 +25,6 @@ func init() { mux.MuxManagerStartDelay = 0 } -const ( - echoServerAddress = "localhost:7266" - serverProxyInboundAddress = "localhost:7366" - serverProxyOutboundAddress = "localhost:7466" - echoClientAddress = "localhost:8266" - clientProxyInboundAddress = "localhost:8366" - clientProxyOutboundAddress = "localhost:8466" - invalidAddress = "" -) - var ( serverClusterShard = history.ClusterShardID{ ClusterID: 1, @@ -56,8 +46,14 @@ var ( type ( proxyTestSuite struct { suite.Suite - originalPath string - developPath string + originalPath string + developPath string + echoServerAddress string + serverProxyInboundAddress string + serverProxyOutboundAddress string + echoClientAddress string + clientProxyInboundAddress string + clientProxyOutboundAddress string } cfgOption func(c *config.S2SProxyConfig) @@ -156,18 +152,18 @@ func createS2SProxyConfig(cfg *config.S2SProxyConfig, opts []cfgOption) *config. return cfg } -func createEchoServerConfig(opts ...cfgOption) *config.S2SProxyConfig { +func (s *proxyTestSuite) createEchoServerConfig(opts ...cfgOption) *config.S2SProxyConfig { return createS2SProxyConfig(&config.S2SProxyConfig{ Inbound: &config.ProxyConfig{ Name: "proxy1-inbound-server", Server: config.ProxyServerConfig{ TCPServerSetting: config.TCPServerSetting{ - ListenAddress: serverProxyInboundAddress, + ListenAddress: s.serverProxyInboundAddress, }, }, Client: config.ProxyClientConfig{ TCPClientSetting: config.TCPClientSetting{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, }, }, }, @@ -175,7 +171,7 @@ func createEchoServerConfig(opts ...cfgOption) *config.S2SProxyConfig { Name: "proxy1-outbound-server", Server: config.ProxyServerConfig{ TCPServerSetting: config.TCPServerSetting{ - ListenAddress: serverProxyOutboundAddress, + ListenAddress: s.serverProxyOutboundAddress, }, }, Client: config.ProxyClientConfig{ @@ -210,18 +206,18 @@ func EchoClientTLSOptions() []cfgOption { } } -func createEchoClientConfig(opts ...cfgOption) *config.S2SProxyConfig { +func (s *proxyTestSuite) createEchoClientConfig(opts ...cfgOption) *config.S2SProxyConfig { return createS2SProxyConfig(&config.S2SProxyConfig{ Inbound: &config.ProxyConfig{ Name: "proxy2-inbound-server", Server: config.ProxyServerConfig{ TCPServerSetting: config.TCPServerSetting{ - ListenAddress: clientProxyInboundAddress, + ListenAddress: s.clientProxyInboundAddress, }, }, Client: config.ProxyClientConfig{ TCPClientSetting: config.TCPClientSetting{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, }, }, }, @@ -229,7 +225,7 @@ func createEchoClientConfig(opts ...cfgOption) *config.S2SProxyConfig { Name: "proxy2-outbound-server", Server: config.ProxyServerConfig{ TCPServerSetting: config.TCPServerSetting{ - ListenAddress: clientProxyOutboundAddress, + ListenAddress: s.clientProxyOutboundAddress, }, }, Client: config.ProxyClientConfig{ @@ -252,6 +248,14 @@ func (s *proxyTestSuite) SetupTest() { s.developPath = filepath.Join("..", "..", "develop") err = os.Chdir(s.developPath) s.NoError(err) + + // Allocate free ports for each test + s.echoServerAddress = GetLocalhostAddress() + s.serverProxyInboundAddress = GetLocalhostAddress() + s.serverProxyOutboundAddress = GetLocalhostAddress() + s.echoClientAddress = GetLocalhostAddress() + s.clientProxyInboundAddress = GetLocalhostAddress() + s.clientProxyOutboundAddress = GetLocalhostAddress() } func (s *proxyTestSuite) TearDownTest() { @@ -300,11 +304,11 @@ func (s *proxyTestSuite) Test_Echo_Success() { // echo_server <- - -> echo_client name: "no-proxy", echoServerInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, }, echoClientInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, }, }, @@ -312,12 +316,12 @@ func (s *proxyTestSuite) Test_Echo_Success() { // echo_server <-> proxy.inbound <- - -> echo_client name: "server-side-only-proxy", echoServerInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, - S2sProxyConfig: createEchoServerConfig(), + S2sProxyConfig: s.createEchoServerConfig(), }, echoClientInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, }, }, @@ -325,49 +329,49 @@ func (s *proxyTestSuite) Test_Echo_Success() { // echo_server <- - -> proxy.outbound <-> echo_client name: "client-side-only-proxy", echoServerInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, }, echoClientInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, - S2sProxyConfig: createEchoClientConfig(), + S2sProxyConfig: s.createEchoClientConfig(), }, }, { // echo_server <-> proxy.inbound <- - -> proxy.outbound <-> echo_client name: "server-and-client-side-proxy", echoServerInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, - S2sProxyConfig: createEchoServerConfig(), + S2sProxyConfig: s.createEchoServerConfig(), }, echoClientInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, - S2sProxyConfig: createEchoClientConfig(), + S2sProxyConfig: s.createEchoClientConfig(), }, }, { // echo_server <-> proxy.inbound <- mTLS -> proxy.outbound <-> echo_client name: "server-and-client-side-proxy-mTLS", echoServerInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, - S2sProxyConfig: createEchoServerConfig(EchoServerTLSOptions()...), + S2sProxyConfig: s.createEchoServerConfig(EchoServerTLSOptions()...), }, echoClientInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, - S2sProxyConfig: createEchoClientConfig(EchoClientTLSOptions()...), + S2sProxyConfig: s.createEchoClientConfig(EchoClientTLSOptions()...), }, }, { name: "server-and-client-side-proxy-ACL", echoServerInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, - S2sProxyConfig: createEchoServerConfig(withACLPolicy( + S2sProxyConfig: s.createEchoServerConfig(withACLPolicy( &config.ACLPolicy{ AllowedMethods: config.AllowedMethods{ AdminService: []string{ @@ -383,9 +387,9 @@ func (s *proxyTestSuite) Test_Echo_Success() { )), }, echoClientInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, - S2sProxyConfig: createEchoClientConfig(), + S2sProxyConfig: s.createEchoClientConfig(), }, }, } @@ -439,9 +443,9 @@ func (s *proxyTestSuite) Test_Echo_WithNamespaceTranslation() { { name: "server-and-client-side-proxy-namespacetrans", echoServerInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, - S2sProxyConfig: createEchoServerConfig(withNamespaceTranslation( + S2sProxyConfig: s.createEchoServerConfig(withNamespaceTranslation( []config.NameMappingConfig{ { LocalName: "local", @@ -452,9 +456,9 @@ func (s *proxyTestSuite) Test_Echo_WithNamespaceTranslation() { )), }, echoClientInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, - S2sProxyConfig: createEchoClientConfig(), + S2sProxyConfig: s.createEchoClientConfig(), }, serverNamespace: "local", clientNamespace: "remote", @@ -462,9 +466,9 @@ func (s *proxyTestSuite) Test_Echo_WithNamespaceTranslation() { { name: "server-and-client-side-proxy-namespacetrans-acl", echoServerInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, - S2sProxyConfig: createEchoServerConfig( + S2sProxyConfig: s.createEchoServerConfig( withNamespaceTranslation( []config.NameMappingConfig{ { @@ -489,9 +493,9 @@ func (s *proxyTestSuite) Test_Echo_WithNamespaceTranslation() { ), )}, echoClientInfo: endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, - S2sProxyConfig: createEchoClientConfig(), + S2sProxyConfig: s.createEchoClientConfig(), }, serverNamespace: "local", clientNamespace: "remote", @@ -533,13 +537,13 @@ func (s *proxyTestSuite) Test_Echo_WithMuxTransport() { // // echoServer proxy1.inbound.Server(muxClient) <- proxy2.outbound.Client(muxServer) echoClient // echoServer proxy1.outbound.Client(muxClient) -> proxy2.inbound.Server(muxServer) echoClient - echoServerConfig := createEchoServerConfig( + echoServerConfig := s.createEchoServerConfig( withMux( config.MuxTransportConfig{ Name: muxTransportName, Mode: config.ClientMode, Client: config.TCPClientSetting{ - ServerAddress: clientProxyInboundAddress, + ServerAddress: s.clientProxyInboundAddress, }, }), withServerConfig( @@ -556,13 +560,13 @@ func (s *proxyTestSuite) Test_Echo_WithMuxTransport() { }, false), ) - echoClientConfig := createEchoClientConfig( + echoClientConfig := s.createEchoClientConfig( withMux( config.MuxTransportConfig{ Name: muxTransportName, Mode: config.ServerMode, Server: config.TCPServerSetting{ - ListenAddress: clientProxyInboundAddress, + ListenAddress: s.clientProxyInboundAddress, }, }), withServerConfig( @@ -580,12 +584,12 @@ func (s *proxyTestSuite) Test_Echo_WithMuxTransport() { ) echoServerInfo := endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, S2sProxyConfig: echoServerConfig, } echoClientInfo := endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, S2sProxyConfig: echoClientConfig, } @@ -614,14 +618,14 @@ func (s *proxyTestSuite) Test_ForceStopSourceServer() { logger := log.NewTestLogger() echoServerInfo := endtoendtest.ClusterInfo{ - ServerAddress: echoServerAddress, + ServerAddress: s.echoServerAddress, ClusterShardID: serverClusterShard, } echoClientInfo := endtoendtest.ClusterInfo{ - ServerAddress: echoClientAddress, + ServerAddress: s.echoClientAddress, ClusterShardID: clientClusterShard, - S2sProxyConfig: createEchoClientConfig(), + S2sProxyConfig: s.createEchoClientConfig(), } echoServer := endtoendtest.NewEchoServer(echoServerInfo, echoClientInfo, "EchoServer", logger, nil) @@ -629,6 +633,10 @@ func (s *proxyTestSuite) Test_ForceStopSourceServer() { echoServer.Start() echoClient.Start() + defer func() { + echoClient.Stop() + echoServer.Stop() + }() stream, err := echoClient.CreateStreamClient() s.NoError(err) @@ -649,5 +657,4 @@ func (s *proxyTestSuite) Test_ForceStopSourceServer() { s.ErrorContains(err, "EOF") _ = stream.CloseSend() - echoClient.Stop() } diff --git a/proxy/test/intra_proxy_routing_test.go b/proxy/test/intra_proxy_routing_test.go new file mode 100644 index 00000000..538c5284 --- /dev/null +++ b/proxy/test/intra_proxy_routing_test.go @@ -0,0 +1,262 @@ +package proxy + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "go.temporal.io/server/api/historyservice/v1" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" + "go.temporal.io/server/tests/testcore" + + "github.com/temporalio/s2s-proxy/config" + s2sproxy "github.com/temporalio/s2s-proxy/proxy" +) + +type ( + IntraProxyRoutingTestSuite struct { + suite.Suite + *require.Assertions + + logger log.Logger + + clusterA *testcore.TestCluster + clusterB *testcore.TestCluster + + proxyA1 *s2sproxy.Proxy + proxyA2 *s2sproxy.Proxy + proxyB1 *s2sproxy.Proxy + proxyB2 *s2sproxy.Proxy + + proxyA1Outbound string + proxyA2Outbound string + proxyB1Outbound string + proxyB2Outbound string + + proxyB1Mux string + proxyB2Mux string + + proxyA1MemberlistPort int + proxyA2MemberlistPort int + proxyB1MemberlistPort int + proxyB2MemberlistPort int + + loadBalancerA *trackingTCPProxy + loadBalancerB *trackingTCPProxy + loadBalancerC *trackingTCPProxy + + loadBalancerAPort string + loadBalancerBPort string + loadBalancerCPort string + + connectionCountsA1 atomic.Int64 + connectionCountsA2 atomic.Int64 + connectionCountsB1 atomic.Int64 + connectionCountsB2 atomic.Int64 + connectionCountsPA1 atomic.Int64 + connectionCountsPA2 atomic.Int64 + } +) + +func TestIntraProxyRoutingTestSuite(t *testing.T) { + s := &IntraProxyRoutingTestSuite{} + suite.Run(t, s) +} + +func (s *IntraProxyRoutingTestSuite) SetupSuite() { + s.Assertions = require.New(s.T()) + s.logger = log.NewTestLogger() + + s.logger.Info("Setting up intra-proxy routing test suite") + + s.clusterA = createCluster(s.logger, s.T(), "cluster-a", 2, 1, 1) + s.clusterB = createCluster(s.logger, s.T(), "cluster-b", 2, 2, 1) + + s.proxyA1Outbound = GetLocalhostAddress() + s.proxyA2Outbound = GetLocalhostAddress() + s.proxyB1Outbound = GetLocalhostAddress() + s.proxyB2Outbound = GetLocalhostAddress() + + s.proxyB1Mux = GetLocalhostAddress() + s.proxyB2Mux = GetLocalhostAddress() + + loadBalancerAPort := fmt.Sprintf("%d", GetFreePort()) + loadBalancerBPort := fmt.Sprintf("%d", GetFreePort()) + loadBalancerCPort := fmt.Sprintf("%d", GetFreePort()) + + s.loadBalancerAPort = loadBalancerAPort + s.loadBalancerBPort = loadBalancerBPort + s.loadBalancerCPort = loadBalancerCPort + + proxyA1Address := GetLocalhostAddress() + proxyA2Address := GetLocalhostAddress() + proxyB1Address := GetLocalhostAddress() + proxyB2Address := GetLocalhostAddress() + + proxyAddressesA := map[string]string{ + "proxy-node-a-1": proxyA1Address, + "proxy-node-a-2": proxyA2Address, + } + proxyAddressesB := map[string]string{ + "proxy-node-b-1": proxyB1Address, + "proxy-node-b-2": proxyB2Address, + } + + s.proxyA1MemberlistPort = GetFreePort() + s.proxyA2MemberlistPort = GetFreePort() + s.proxyB1MemberlistPort = GetFreePort() + s.proxyB2MemberlistPort = GetFreePort() + + s.proxyB1 = createProxy(s.logger, s.T(), "proxy-b-1", proxyB1Address, s.proxyB1Outbound, s.proxyB1Mux, s.clusterB, config.ServerMode, config.ShardCountConfig{}, "proxy-node-b-1", "127.0.0.1", s.proxyB1MemberlistPort, nil, proxyAddressesB) + s.proxyB2 = createProxy(s.logger, s.T(), "proxy-b-2", proxyB2Address, s.proxyB2Outbound, s.proxyB2Mux, s.clusterB, config.ServerMode, config.ShardCountConfig{}, "proxy-node-b-2", "127.0.0.1", s.proxyB2MemberlistPort, []string{fmt.Sprintf("127.0.0.1:%d", s.proxyB1MemberlistPort)}, proxyAddressesB) + + s.logger.Info("Setting up load balancers") + + var err error + s.loadBalancerA, err = createLoadBalancer(s.logger, loadBalancerAPort, []string{s.proxyA1Outbound, s.proxyA2Outbound}, &s.connectionCountsA1, &s.connectionCountsA2) + s.NoError(err, "Failed to start load balancer A") + s.loadBalancerB, err = createLoadBalancer(s.logger, loadBalancerBPort, []string{s.proxyB1Mux, s.proxyB2Mux}, &s.connectionCountsPA1, &s.connectionCountsPA2) + s.NoError(err, "Failed to start load balancer B") + s.loadBalancerC, err = createLoadBalancer(s.logger, loadBalancerCPort, []string{s.proxyB1Outbound, s.proxyB2Outbound}, &s.connectionCountsB1, &s.connectionCountsB2) + s.NoError(err, "Failed to start load balancer C") + + muxLoadBalancerBAddress := fmt.Sprintf("localhost:%s", loadBalancerBPort) + s.proxyA1 = createProxy(s.logger, s.T(), "proxy-a-1", proxyA1Address, s.proxyA1Outbound, muxLoadBalancerBAddress, s.clusterA, config.ClientMode, config.ShardCountConfig{}, "proxy-node-a-1", "127.0.0.1", s.proxyA1MemberlistPort, nil, proxyAddressesA) + s.proxyA2 = createProxy(s.logger, s.T(), "proxy-a-2", proxyA2Address, s.proxyA2Outbound, muxLoadBalancerBAddress, s.clusterA, config.ClientMode, config.ShardCountConfig{}, "proxy-node-a-2", "127.0.0.1", s.proxyA2MemberlistPort, []string{fmt.Sprintf("127.0.0.1:%d", s.proxyA1MemberlistPort)}, proxyAddressesA) + + s.logger.Info("Waiting for proxies to start and connect") + time.Sleep(15 * time.Second) + + s.logger.Info("Configuring remote clusters") + configureRemoteCluster(s.logger, s.T(), s.clusterA, s.clusterB.ClusterName(), fmt.Sprintf("localhost:%s", loadBalancerAPort)) + configureRemoteCluster(s.logger, s.T(), s.clusterB, s.clusterA.ClusterName(), fmt.Sprintf("localhost:%s", loadBalancerCPort)) + waitForReplicationReady(s.logger, s.T(), s.clusterA, s.clusterB) +} + +func (s *IntraProxyRoutingTestSuite) TearDownSuite() { + s.logger.Info("Tearing down intra-proxy routing test suite") + if s.clusterA != nil && s.clusterB != nil { + s.logger.Info("Removing remote cluster A from cluster B") + removeRemoteCluster(s.logger, s.T(), s.clusterA, s.clusterB.ClusterName()) + s.logger.Info("Remote cluster A removed") + s.logger.Info("Removing remote cluster B from cluster A") + removeRemoteCluster(s.logger, s.T(), s.clusterB, s.clusterA.ClusterName()) + s.logger.Info("Remote cluster B removed") + } + if s.clusterA != nil { + s.NoError(s.clusterA.TearDownCluster()) + s.logger.Info("Cluster A torn down") + } + if s.clusterB != nil { + s.NoError(s.clusterB.TearDownCluster()) + s.logger.Info("Cluster B torn down") + } + if s.loadBalancerA != nil { + s.logger.Info("Stopping load balancer A") + s.loadBalancerA.Stop() + s.logger.Info("Load balancer A stopped") + } + if s.loadBalancerB != nil { + s.logger.Info("Stopping load balancer B") + s.loadBalancerB.Stop() + s.logger.Info("Load balancer B stopped") + } + if s.loadBalancerC != nil { + s.logger.Info("Stopping load balancer C") + s.loadBalancerC.Stop() + s.logger.Info("Load balancer C stopped") + } + if s.proxyA1 != nil { + s.logger.Info("Stopping proxy A1") + s.proxyA1.Stop() + s.logger.Info("Proxy A1 stopped") + } + if s.proxyA2 != nil { + s.logger.Info("Stopping proxy A2") + s.proxyA2.Stop() + s.logger.Info("Proxy A2 stopped") + } + if s.proxyB1 != nil { + s.logger.Info("Stopping proxy B1") + s.proxyB1.Stop() + s.logger.Info("Proxy B1 stopped") + } + if s.proxyB2 != nil { + s.logger.Info("Stopping proxy B2") + s.proxyB2.Stop() + s.logger.Info("Proxy B2 stopped") + } + s.logger.Info("Intra-proxy routing test suite torn down") +} + +func (s *IntraProxyRoutingTestSuite) TestIntraProxyRoutingDistribution() { + s.logger.Info("Testing intra-proxy routing distribution") + + ctx := context.Background() + + s.logger.Info("Triggering replication connections to verify distribution") + + var wg sync.WaitGroup + numConnections := 20 + + for i := 0; i < numConnections; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := s.clusterA.HistoryClient().GetReplicationStatus( + ctx, + &historyservice.GetReplicationStatusRequest{}, + ) + if err != nil { + s.logger.Warn("GetReplicationStatus failed", tag.Error(err)) + } + }() + } + + wg.Wait() + + time.Sleep(2 * time.Second) + + countA1 := s.connectionCountsA1.Load() + countA2 := s.connectionCountsA2.Load() + countB1 := s.connectionCountsB1.Load() + countB2 := s.connectionCountsB2.Load() + countPA1 := s.connectionCountsPA1.Load() + countPA2 := s.connectionCountsPA2.Load() + + s.logger.Info("Connection distribution results", + tag.NewInt64("loadBalancerA_pa1", countA1), + tag.NewInt64("loadBalancerA_pa2", countA2), + tag.NewInt64("loadBalancerB_pb1_from_pa", countPA1), + tag.NewInt64("loadBalancerB_pb2_from_pa", countPA2), + tag.NewInt64("loadBalancerC_pb1", countB1), + tag.NewInt64("loadBalancerC_pb2", countB2), + ) + + s.Greater(countA1, int64(0), "Load balancer A should route to pa1") + s.Greater(countA2, int64(0), "Load balancer A should route to pa2") + s.Greater(countB1, int64(0), "Load balancer C should route to pb1") + s.Greater(countB2, int64(0), "Load balancer C should route to pb2") + s.Greater(countPA1, int64(0), "Load balancer B should route to pb1 from pa") + s.Greater(countPA2, int64(0), "Load balancer B should route to pb2 from pa") + + totalA := countA1 + countA2 + totalB := countB1 + countB2 + totalPA := countPA1 + countPA2 + + s.logger.Info("Total connections", + tag.NewInt64("totalA", totalA), + tag.NewInt64("totalB", totalB), + tag.NewInt64("totalPA", totalPA), + ) + + s.Greater(totalA, int64(0), "Should have connections through load balancer A") + s.Greater(totalB, int64(0), "Should have connections through load balancer C") + s.Greater(totalPA, int64(0), "Should have connections through load balancer B") +} diff --git a/proxy/test/replication_failover_test.go b/proxy/test/replication_failover_test.go index c2918520..f8229b71 100644 --- a/proxy/test/replication_failover_test.go +++ b/proxy/test/replication_failover_test.go @@ -3,8 +3,8 @@ package proxy import ( "context" "fmt" - "math/rand" "sync" + "sync/atomic" "testing" "time" @@ -15,11 +15,8 @@ import ( replicationpb "go.temporal.io/api/replication/v1" taskqueuepb "go.temporal.io/api/taskqueue/v1" "go.temporal.io/api/workflowservice/v1" - "go.temporal.io/server/api/adminservice/v1" "go.temporal.io/server/api/historyservice/v1" "go.temporal.io/server/common" - "go.temporal.io/server/common/cluster" - "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/namespace" @@ -30,6 +27,13 @@ import ( s2sproxy "github.com/temporalio/s2s-proxy/proxy" ) +type SetupMode string + +const ( + SetupModeSimple SetupMode = "simple" // Case B: Two proxies, direct connection, no load balancer, no memberlist + SetupModeMultiProxy SetupMode = "multiproxy" // Case A: Multi-proxy with load balancers and memberlist +) + type ( // ReplicationTestSuite tests s2s-proxy replication and failover across multiple shard configurations ReplicationTestSuite struct { @@ -41,18 +45,48 @@ type ( clusterA *testcore.TestCluster clusterB *testcore.TestCluster + // Case B: Simple setup proxyA *s2sproxy.Proxy proxyB *s2sproxy.Proxy - proxyAAddress string - proxyBAddress string + proxyAOutbound string + proxyBOutbound string + + // Case A: Multi-proxy setup + proxyA1 *s2sproxy.Proxy + proxyA2 *s2sproxy.Proxy + proxyB1 *s2sproxy.Proxy + proxyB2 *s2sproxy.Proxy + + proxyA1Outbound string + proxyA2Outbound string + proxyB1Outbound string + proxyB2Outbound string + + proxyB1Mux string + proxyB2Mux string + + proxyA1MemberlistPort int + proxyA2MemberlistPort int + proxyB1MemberlistPort int + proxyB2MemberlistPort int + + loadBalancerA *trackingTCPProxy + loadBalancerB *trackingTCPProxy + loadBalancerC *trackingTCPProxy - shardCountA int - shardCountB int - shardCountConfig config.ShardCountConfig - namespace string - namespaceID string - startTime time.Time + loadBalancerAPort string + loadBalancerBPort string + loadBalancerCPort string + + setupMode SetupMode + + shardCountA int + shardCountB int + shardCountConfigB config.ShardCountConfig + namespace string + namespaceID string + startTime time.Time workflows []*WorkflowDistribution @@ -69,47 +103,113 @@ type ( } TestConfig struct { - Name string - ShardCountA int - ShardCountB int - WorkflowsPerPair int - ShardCountConfig config.ShardCountConfig + Name string + ShardCountA int + ShardCountB int + WorkflowsPerPair int + ShardCountConfigB config.ShardCountConfig + SetupMode SetupMode } ) var testConfigs = []TestConfig{ + // Case B: Simple setup tests + { + Name: "Simple_SingleShard", + ShardCountA: 1, + ShardCountB: 1, + WorkflowsPerPair: 1, + SetupMode: SetupModeSimple, + }, + { + Name: "Simple_FourShards", + ShardCountA: 4, + ShardCountB: 4, + WorkflowsPerPair: 1, + SetupMode: SetupModeSimple, + }, + { + Name: "Simple_AsymmetricShards_4to2", + ShardCountA: 4, + ShardCountB: 2, + WorkflowsPerPair: 1, + SetupMode: SetupModeSimple, + }, + { + Name: "Simple_AsymmetricShards_2to4", + ShardCountA: 2, + ShardCountB: 4, + WorkflowsPerPair: 1, + SetupMode: SetupModeSimple, + }, { - Name: "SingleShard", + Name: "Simple_ArbitraryShards_2to3_LCM", + ShardCountA: 2, + ShardCountB: 3, + WorkflowsPerPair: 1, + ShardCountConfigB: config.ShardCountConfig{ + Mode: config.ShardCountLCM, + }, + SetupMode: SetupModeSimple, + }, + { + Name: "Simple_ArbitraryShards_2to3_Routing", + ShardCountA: 2, + ShardCountB: 3, + WorkflowsPerPair: 1, + ShardCountConfigB: config.ShardCountConfig{ + Mode: config.ShardCountRouting, + }, + SetupMode: SetupModeSimple, + }, + // Case A: Multi-proxy setup tests + { + Name: "MultiProxy_SingleShard", ShardCountA: 1, ShardCountB: 1, WorkflowsPerPair: 1, + SetupMode: SetupModeMultiProxy, }, { - Name: "FourShards", + Name: "MultiProxy_FourShards", ShardCountA: 4, ShardCountB: 4, WorkflowsPerPair: 1, + SetupMode: SetupModeMultiProxy, }, { - Name: "AsymmetricShards_4to2", + Name: "MultiProxy_AsymmetricShards_4to2", ShardCountA: 4, ShardCountB: 2, WorkflowsPerPair: 1, + SetupMode: SetupModeMultiProxy, }, { - Name: "AsymmetricShards_2to4", + Name: "MultiProxy_AsymmetricShards_2to4", ShardCountA: 2, ShardCountB: 4, WorkflowsPerPair: 1, + SetupMode: SetupModeMultiProxy, }, { - Name: "ArbitraryShards_2to3_LCM", + Name: "MultiProxy_ArbitraryShards_2to3_LCM", ShardCountA: 2, ShardCountB: 3, WorkflowsPerPair: 1, - ShardCountConfig: config.ShardCountConfig{ + ShardCountConfigB: config.ShardCountConfig{ Mode: config.ShardCountLCM, }, + SetupMode: SetupModeMultiProxy, + }, + { + Name: "MultiProxy_ArbitraryShards_2to3_Routing", + ShardCountA: 2, + ShardCountB: 3, + WorkflowsPerPair: 1, + ShardCountConfigB: config.ShardCountConfig{ + Mode: config.ShardCountRouting, + }, + SetupMode: SetupModeMultiProxy, }, } @@ -117,10 +217,11 @@ func TestReplicationFailoverTestSuite(t *testing.T) { for _, tc := range testConfigs { t.Run(tc.Name, func(t *testing.T) { s := &ReplicationTestSuite{ - shardCountA: tc.ShardCountA, - shardCountB: tc.ShardCountB, - shardCountConfig: tc.ShardCountConfig, - workflowsPerPair: tc.WorkflowsPerPair, + shardCountA: tc.ShardCountA, + shardCountB: tc.ShardCountB, + shardCountConfigB: tc.ShardCountConfigB, + workflowsPerPair: tc.WorkflowsPerPair, + setupMode: tc.SetupMode, } suite.Run(t, s) }) @@ -135,264 +236,173 @@ func (s *ReplicationTestSuite) SetupSuite() { s.logger.Info("Setting up replication test suite", tag.NewInt("shardCountA", s.shardCountA), tag.NewInt("shardCountB", s.shardCountB), + tag.NewStringTag("setupMode", string(s.setupMode)), ) - s.clusterA = s.createCluster("cluster-a", s.shardCountA, 1) - s.clusterB = s.createCluster("cluster-b", s.shardCountB, 2) + s.clusterA = createCluster(s.logger, s.T(), "cluster-a", s.shardCountA, 1, 1) + s.clusterB = createCluster(s.logger, s.T(), "cluster-b", s.shardCountB, 2, 1) - basePort := 17000 + rand.Intn(10000) - s.proxyAAddress = fmt.Sprintf("localhost:%d", basePort) - proxyAOutbound := fmt.Sprintf("localhost:%d", basePort+1) - s.proxyBAddress = fmt.Sprintf("localhost:%d", basePort+100) - proxyBOutbound := fmt.Sprintf("localhost:%d", basePort+101) - muxServerAddress := fmt.Sprintf("localhost:%d", basePort+200) - - proxyBShardConfig := s.shardCountConfig - if proxyBShardConfig.Mode == config.ShardCountLCM { - proxyBShardConfig.LocalShardCount = int32(s.shardCountB) - proxyBShardConfig.RemoteShardCount = int32(s.shardCountA) + if s.setupMode == SetupModeSimple { + s.setupSimple() + } else { + s.setupMultiProxy() } - s.proxyA = s.createProxy("proxy-a", s.proxyAAddress, proxyAOutbound, muxServerAddress, s.clusterA, config.ClientMode, config.ShardCountConfig{}) - s.proxyB = s.createProxy("proxy-b", s.proxyBAddress, proxyBOutbound, muxServerAddress, s.clusterB, config.ServerMode, proxyBShardConfig) + s.logger.Info("Waiting for proxies to start and connect") + time.Sleep(10 * time.Second) - s.configureRemoteCluster(s.clusterA, s.clusterB.ClusterName(), proxyAOutbound) - s.configureRemoteCluster(s.clusterB, s.clusterA.ClusterName(), proxyBOutbound) - s.waitForReplicationReady() + s.logger.Info("Configuring remote clusters") + if s.setupMode == SetupModeSimple { + configureRemoteCluster(s.logger, s.T(), s.clusterA, s.clusterB.ClusterName(), s.proxyAOutbound) + configureRemoteCluster(s.logger, s.T(), s.clusterB, s.clusterA.ClusterName(), s.proxyBOutbound) + } else { + configureRemoteCluster(s.logger, s.T(), s.clusterA, s.clusterB.ClusterName(), fmt.Sprintf("localhost:%s", s.loadBalancerAPort)) + configureRemoteCluster(s.logger, s.T(), s.clusterB, s.clusterA.ClusterName(), fmt.Sprintf("localhost:%s", s.loadBalancerCPort)) + } + + waitForReplicationReady(s.logger, s.T(), s.clusterA, s.clusterB) s.namespace = s.createGlobalNamespace() s.waitForClusterSynced() } -func (s *ReplicationTestSuite) TearDownSuite() { - if s.namespace != "" && s.clusterA != nil { - s.deglobalizeNamespace(s.namespace) - } +func (s *ReplicationTestSuite) setupSimple() { + s.logger.Info("Setting up simple two-proxy configuration") - if s.clusterA != nil && s.clusterB != nil { - s.removeRemoteCluster(s.clusterA, s.clusterB.ClusterName()) - s.removeRemoteCluster(s.clusterB, s.clusterA.ClusterName()) - } - if s.clusterA != nil { - s.NoError(s.clusterA.TearDownCluster()) - } - if s.clusterB != nil { - s.NoError(s.clusterB.TearDownCluster()) - } - if s.proxyA != nil { - s.proxyA.Stop() - } - if s.proxyB != nil { - s.proxyB.Stop() + proxyAOutbound := GetLocalhostAddress() + proxyBOutbound := GetLocalhostAddress() + muxServerAddress := GetLocalhostAddress() + + s.proxyAOutbound = proxyAOutbound + s.proxyBOutbound = proxyBOutbound + + proxyBShardConfig := s.shardCountConfigB + if proxyBShardConfig.Mode == config.ShardCountLCM || proxyBShardConfig.Mode == config.ShardCountRouting { + proxyBShardConfig.LocalShardCount = int32(s.shardCountB) + proxyBShardConfig.RemoteShardCount = int32(s.shardCountA) } + s.proxyA = createProxy(s.logger, s.T(), "proxy-a", "", proxyAOutbound, muxServerAddress, s.clusterA, config.ClientMode, config.ShardCountConfig{}, "", "", 0, nil, nil) + s.proxyB = createProxy(s.logger, s.T(), "proxy-b", "", proxyBOutbound, muxServerAddress, s.clusterB, config.ServerMode, proxyBShardConfig, "", "", 0, nil, nil) } -func (s *ReplicationTestSuite) SetupTest() { - s.workflows = nil +func (s *ReplicationTestSuite) setupMultiProxy() { + s.logger.Info("Setting up multi-proxy configuration with load balancers") - if s.namespace != "" { - s.ensureNamespaceActive(s.clusterA.ClusterName()) - } -} + s.proxyA1Outbound = GetLocalhostAddress() + s.proxyA2Outbound = GetLocalhostAddress() + s.proxyB1Outbound = GetLocalhostAddress() + s.proxyB2Outbound = GetLocalhostAddress() -func (s *ReplicationTestSuite) createCluster( - clusterName string, - numShards int, - initialFailoverVersion int64, -) *testcore.TestCluster { - clusterSuffix := common.GenerateRandomString(8) - fullClusterName := fmt.Sprintf("%s-%s", clusterName, clusterSuffix) - - clusterConfig := &testcore.TestClusterConfig{ - ClusterMetadata: cluster.Config{ - EnableGlobalNamespace: true, - FailoverVersionIncrement: 10, - MasterClusterName: fullClusterName, - CurrentClusterName: fullClusterName, - ClusterInformation: map[string]cluster.ClusterInformation{ - fullClusterName: { - Enabled: true, - InitialFailoverVersion: initialFailoverVersion, - }, - }, - }, - HistoryConfig: testcore.HistoryConfig{ - NumHistoryShards: int32(numShards), - NumHistoryHosts: 1, - }, - DynamicConfigOverrides: map[dynamicconfig.Key]interface{}{ - dynamicconfig.NamespaceCacheRefreshInterval.Key(): time.Second, - dynamicconfig.EnableReplicationStream.Key(): true, - dynamicconfig.EnableReplicationTaskBatching.Key(): true, - }, - } + s.proxyB1Mux = GetLocalhostAddress() + s.proxyB2Mux = GetLocalhostAddress() - testClusterFactory := testcore.NewTestClusterFactory() - cluster, err := testClusterFactory.NewCluster(s.T(), clusterConfig, s.logger) - s.NoError(err, "Failed to create cluster %s", clusterName) - s.NotNil(cluster) + loadBalancerAPort := fmt.Sprintf("%d", GetFreePort()) + loadBalancerBPort := fmt.Sprintf("%d", GetFreePort()) + loadBalancerCPort := fmt.Sprintf("%d", GetFreePort()) - return cluster -} + s.loadBalancerAPort = loadBalancerAPort + s.loadBalancerBPort = loadBalancerBPort + s.loadBalancerCPort = loadBalancerCPort -func (s *ReplicationTestSuite) createProxy( - name string, - inboundAddress string, - outboundAddress string, - muxAddress string, - cluster *testcore.TestCluster, - muxMode config.MuxMode, - shardCountConfig config.ShardCountConfig, -) *s2sproxy.Proxy { - var muxConnectionType config.ConnectionType - var muxAddressInfo config.TCPTLSInfo - if muxMode == config.ServerMode { - muxConnectionType = config.ConnTypeMuxServer - muxAddressInfo = config.TCPTLSInfo{ - ConnectionString: muxAddress, - } - } else { - muxConnectionType = config.ConnTypeMuxClient - muxAddressInfo = config.TCPTLSInfo{ - ConnectionString: muxAddress, - } - } + proxyA1Address := GetLocalhostAddress() + proxyA2Address := GetLocalhostAddress() + proxyB1Address := GetLocalhostAddress() + proxyB2Address := GetLocalhostAddress() - cfg := &config.S2SProxyConfig{ - ClusterConnections: []config.ClusterConnConfig{ - { - Name: name, - LocalServer: config.ClusterDefinition{ - Connection: config.TransportInfo{ - ConnectionType: config.ConnTypeTCP, - TcpClient: config.TCPTLSInfo{ - ConnectionString: cluster.Host().FrontendGRPCAddress(), - }, - TcpServer: config.TCPTLSInfo{ - ConnectionString: outboundAddress, - }, - }, - }, - RemoteServer: config.ClusterDefinition{ - Connection: config.TransportInfo{ - ConnectionType: muxConnectionType, - MuxCount: 1, - MuxAddressInfo: muxAddressInfo, - }, - }, - ShardCountConfig: shardCountConfig, - }, - }, + // For intra-proxy communication, use outbound addresses where proxies listen + proxyAddressesA := map[string]string{ + "proxy-node-a-1": s.proxyA1Outbound, + "proxy-node-a-2": s.proxyA2Outbound, + } + proxyAddressesB := map[string]string{ + "proxy-node-b-1": s.proxyB1Outbound, + "proxy-node-b-2": s.proxyB2Outbound, } - configProvider := &simpleConfigProvider{cfg: *cfg} - proxy := s2sproxy.NewProxy(configProvider, s.logger) - s.NotNil(proxy) + s.proxyA1MemberlistPort = GetFreePort() + s.proxyA2MemberlistPort = GetFreePort() + s.proxyB1MemberlistPort = GetFreePort() + s.proxyB2MemberlistPort = GetFreePort() - err := proxy.Start() - s.NoError(err, "Failed to start proxy %s", name) + proxyBShardConfig := s.shardCountConfigB + if proxyBShardConfig.Mode == config.ShardCountLCM || proxyBShardConfig.Mode == config.ShardCountRouting { + proxyBShardConfig.LocalShardCount = int32(s.shardCountB) + proxyBShardConfig.RemoteShardCount = int32(s.shardCountA) + } - s.logger.Info("Started proxy", tag.NewStringTag("name", name), - tag.NewStringTag("inboundAddress", inboundAddress), - tag.NewStringTag("outboundAddress", outboundAddress), - tag.NewStringTag("muxAddress", muxAddress), - tag.NewStringTag("muxMode", string(muxMode)), - ) + s.proxyB1 = createProxy(s.logger, s.T(), "proxy-b-1", proxyB1Address, s.proxyB1Outbound, s.proxyB1Mux, s.clusterB, config.ServerMode, proxyBShardConfig, "proxy-node-b-1", "127.0.0.1", s.proxyB1MemberlistPort, nil, proxyAddressesB) + s.proxyB2 = createProxy(s.logger, s.T(), "proxy-b-2", proxyB2Address, s.proxyB2Outbound, s.proxyB2Mux, s.clusterB, config.ServerMode, proxyBShardConfig, "proxy-node-b-2", "127.0.0.1", s.proxyB2MemberlistPort, []string{fmt.Sprintf("127.0.0.1:%d", s.proxyB1MemberlistPort)}, proxyAddressesB) - return proxy -} + var countA1, countA2, countB1, countB2, countPA1, countPA2 atomic.Int64 -type simpleConfigProvider struct { - cfg config.S2SProxyConfig -} + var err error + s.loadBalancerA, err = createLoadBalancer(s.logger, loadBalancerAPort, []string{s.proxyA1Outbound, s.proxyA2Outbound}, &countA1, &countA2) + s.NoError(err, "Failed to start load balancer A") + s.loadBalancerB, err = createLoadBalancer(s.logger, loadBalancerBPort, []string{s.proxyB1Mux, s.proxyB2Mux}, &countPA1, &countPA2) + s.NoError(err, "Failed to start load balancer B") + s.loadBalancerC, err = createLoadBalancer(s.logger, loadBalancerCPort, []string{s.proxyB1Outbound, s.proxyB2Outbound}, &countB1, &countB2) + s.NoError(err, "Failed to start load balancer C") -func (p *simpleConfigProvider) GetS2SProxyConfig() config.S2SProxyConfig { - return p.cfg + muxLoadBalancerBAddress := fmt.Sprintf("localhost:%s", loadBalancerBPort) + s.proxyA1 = createProxy(s.logger, s.T(), "proxy-a-1", proxyA1Address, s.proxyA1Outbound, muxLoadBalancerBAddress, s.clusterA, config.ClientMode, config.ShardCountConfig{}, "proxy-node-a-1", "127.0.0.1", s.proxyA1MemberlistPort, nil, proxyAddressesA) + s.proxyA2 = createProxy(s.logger, s.T(), "proxy-a-2", proxyA2Address, s.proxyA2Outbound, muxLoadBalancerBAddress, s.clusterA, config.ClientMode, config.ShardCountConfig{}, "proxy-node-a-2", "127.0.0.1", s.proxyA2MemberlistPort, []string{fmt.Sprintf("127.0.0.1:%d", s.proxyA1MemberlistPort)}, proxyAddressesA) } -func (s *ReplicationTestSuite) configureRemoteCluster( - cluster *testcore.TestCluster, - remoteClusterName string, - proxyAddress string, -) { - _, err := cluster.AdminClient().AddOrUpdateRemoteCluster( - context.Background(), - &adminservice.AddOrUpdateRemoteClusterRequest{ - FrontendAddress: proxyAddress, - EnableRemoteClusterConnection: true, - }, - ) - s.NoError(err, "Failed to configure remote cluster %s", remoteClusterName) - s.logger.Info("Configured remote cluster", - tag.NewStringTag("remoteClusterName", remoteClusterName), - tag.NewStringTag("proxyAddress", proxyAddress), - tag.NewStringTag("clusterName", cluster.ClusterName()), - ) -} - -func (s *ReplicationTestSuite) deglobalizeNamespace(namespaceName string) { - if s.clusterA == nil { - return +func (s *ReplicationTestSuite) TearDownSuite() { + if s.namespace != "" && s.clusterA != nil { + s.deglobalizeNamespace(s.namespace) } - ctx := context.Background() - updateReq := &workflowservice.UpdateNamespaceRequest{ - Namespace: namespaceName, - ReplicationConfig: &replicationpb.NamespaceReplicationConfig{ - ActiveClusterName: s.clusterA.ClusterName(), - Clusters: []*replicationpb.ClusterReplicationConfig{ - {ClusterName: s.clusterA.ClusterName()}, - }, - }, + if s.clusterA != nil && s.clusterB != nil { + removeRemoteCluster(s.logger, s.T(), s.clusterA, s.clusterB.ClusterName()) + removeRemoteCluster(s.logger, s.T(), s.clusterB, s.clusterA.ClusterName()) } - - _, err := s.clusterA.FrontendClient().UpdateNamespace(ctx, updateReq) - if err != nil { - s.logger.Warn("Failed to deglobalize namespace", tag.NewStringTag("namespace", namespaceName), tag.Error(err)) - return + if s.clusterA != nil { + s.NoError(s.clusterA.TearDownCluster()) + } + if s.clusterB != nil { + s.NoError(s.clusterB.TearDownCluster()) } - s.Eventually(func() bool { - for _, c := range []*testcore.TestCluster{s.clusterA, s.clusterB} { - if c == nil { - continue - } - descResp, err := c.FrontendClient().DescribeNamespace(ctx, &workflowservice.DescribeNamespaceRequest{ - Namespace: namespaceName, - }) - if err != nil || descResp == nil { - return false - } - clusters := descResp.ReplicationConfig.GetClusters() - if len(clusters) != 1 { - return false - } - if clusters[0].GetClusterName() != s.clusterA.ClusterName() { - return false - } + if s.setupMode == SetupModeSimple { + if s.proxyA != nil { + s.proxyA.Stop() } - return true - }, 10*time.Second, 200*time.Millisecond, "Namespace deglobalization not propagated") - - s.logger.Info("Deglobalized namespace", tag.NewStringTag("namespace", namespaceName)) + if s.proxyB != nil { + s.proxyB.Stop() + } + } else { + if s.loadBalancerA != nil { + s.loadBalancerA.Stop() + } + if s.loadBalancerB != nil { + s.loadBalancerB.Stop() + } + if s.loadBalancerC != nil { + s.loadBalancerC.Stop() + } + if s.proxyA1 != nil { + s.proxyA1.Stop() + } + if s.proxyA2 != nil { + s.proxyA2.Stop() + } + if s.proxyB1 != nil { + s.proxyB1.Stop() + } + if s.proxyB2 != nil { + s.proxyB2.Stop() + } + } } -func (s *ReplicationTestSuite) removeRemoteCluster( - cluster *testcore.TestCluster, - remoteClusterName string, -) { - _, err := cluster.AdminClient().RemoveRemoteCluster( - context.Background(), - &adminservice.RemoveRemoteClusterRequest{ - ClusterName: remoteClusterName, - }, - ) - s.NoError(err, "Failed to remove remote cluster %s", remoteClusterName) - s.logger.Info("Removed remote cluster", - tag.NewStringTag("remoteClusterName", remoteClusterName), - tag.NewStringTag("clusterName", cluster.ClusterName()), - ) +func (s *ReplicationTestSuite) SetupTest() { + s.workflows = nil + + if s.namespace != "" { + s.ensureNamespaceActive(s.clusterA.ClusterName()) + } } func (s *ReplicationTestSuite) createGlobalNamespace() string { @@ -503,32 +513,30 @@ func (s *ReplicationTestSuite) generateWorkflowsWithLoad(workflowsPerPair int) [ // Vice versa: if sourceShardCount=2 and targetShardCount=4: // - sourceShard 1 (0-based: 0) can only map to targetShard 1 or 3 (0-based: 0 or 2) // - sourceShard 2 (0-based: 1) can only map to targetShard 2 or 4 (0-based: 1 or 3) +// +// When using routing mode, all pairs are valid because intra-proxy routing can handle arbitrary mappings. func (s *ReplicationTestSuite) isValidShardPair(sourceShard int32, targetShard int32) bool { - // If shard counts are equal, source and target shards must match - // (same hash function with same shard count produces identical shard assignment) + // In routing mode, all pairs are valid because intra-proxy routing can handle arbitrary mappings + if s.shardCountConfigB.Mode == config.ShardCountRouting { + return true + } + if s.shardCountA == s.shardCountB { return sourceShard == targetShard } - // Convert to 0-based for modulo arithmetic sourceShard0 := sourceShard - 1 targetShard0 := targetShard - 1 - // Case 1: targetShardCount divides sourceShardCount (e.g., 4 -> 2) - // Source shard x maps to target shard (x % targetShardCount) if s.shardCountA%s.shardCountB == 0 { expectedTarget := sourceShard0 % int32(s.shardCountB) return targetShard0 == expectedTarget } - // Case 2: sourceShardCount divides targetShardCount (e.g., 2 -> 4) - // Source shard x can map to target shards in set {x, x+sourceShardCount, x+2*sourceShardCount, ...} - // where all values are < targetShardCount if s.shardCountB%s.shardCountA == 0 { return targetShard0%int32(s.shardCountA) == sourceShard0 } - // No divisibility relationship, all pairs are possible (though may be hard to find) return true } @@ -559,25 +567,6 @@ func (s *ReplicationTestSuite) findWorkflowIDForShardPairWithIndex(sourceShard i return "" } -func (s *ReplicationTestSuite) waitForReplicationReady() { - time.Sleep(1 * time.Second) - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - for _, cluster := range []*testcore.TestCluster{s.clusterA, s.clusterB} { - s.Eventually(func() bool { - _, err := cluster.HistoryClient().GetReplicationStatus( - ctx, - &historyservice.GetReplicationStatusRequest{}, - ) - return err == nil - }, 5*time.Second, 200*time.Millisecond, "Replication infrastructure not ready") - } - - time.Sleep(1 * time.Second) -} - func (s *ReplicationTestSuite) waitForClusterSynced() { s.waitForClusterConnected(s.clusterA, s.clusterB.ClusterName()) s.waitForClusterConnected(s.clusterB, s.clusterA.ClusterName()) @@ -593,6 +582,11 @@ func (s *ReplicationTestSuite) waitForClusterConnected( ) s.Eventually(func() bool { + s.logger.Info("Checking replication status for clusters to sync", + tag.NewStringTag("source", sourceCluster.ClusterName()), + tag.NewStringTag("target", targetClusterName), + ) + resp, err := sourceCluster.HistoryClient().GetReplicationStatus( context.Background(), &historyservice.GetReplicationStatusRequest{}, @@ -601,27 +595,36 @@ func (s *ReplicationTestSuite) waitForClusterConnected( s.logger.Debug("GetReplicationStatus failed", tag.Error(err)) return false } + s.logger.Info("GetReplicationStatus succeeded", + tag.NewStringTag("source", sourceCluster.ClusterName()), + tag.NewStringTag("target", targetClusterName), + tag.NewStringTag("resp", fmt.Sprintf("%v", resp)), + ) if len(resp.Shards) == 0 { return false } for _, shard := range resp.Shards { + s.logger.Info("Checking shard", + tag.NewInt32("shardId", shard.ShardId), + tag.NewInt64("maxReplicationTaskId", shard.MaxReplicationTaskId), + tag.NewStringTag("shardLocalTime", fmt.Sprintf("%v", shard.ShardLocalTime.AsTime())), + tag.NewStringTag("remoteClusters", fmt.Sprintf("%v", shard.RemoteClusters)), + ) if shard.MaxReplicationTaskId <= 0 { continue } + s.NotNil(shard.ShardLocalTime) + s.WithinRange(shard.ShardLocalTime.AsTime(), s.startTime, time.Now()) + remoteInfo, ok := shard.RemoteClusters[targetClusterName] if !ok || remoteInfo == nil { return false } if remoteInfo.AckedTaskId < shard.MaxReplicationTaskId { - s.logger.Debug("Replication not synced", - tag.ShardID(shard.ShardId), - tag.NewInt64("maxTaskId", shard.MaxReplicationTaskId), - tag.NewInt64("ackedTaskId", remoteInfo.AckedTaskId), - ) return false } } @@ -673,8 +676,6 @@ func (s *ReplicationTestSuite) TestReplication() { } } - // TODO: make some progress on the workflows - s.waitForClusterSynced() clientB := s.clusterB.FrontendClient() @@ -685,7 +686,6 @@ func (s *ReplicationTestSuite) TestReplication() { s.failoverNamespace(ctx, s.namespace, s.clusterB.ClusterName()) for _, wf := range s.workflows { - // TODO: continue the workflows instead of just terminating them s.completeWorkflow(ctx, clientB, wf) } @@ -815,6 +815,53 @@ func (s *ReplicationTestSuite) failoverNamespace( s.logger.Info("Namespace failover completed", tag.NewStringTag("namespace", namespaceName), tag.NewStringTag("targetCluster", targetCluster)) } +func (s *ReplicationTestSuite) deglobalizeNamespace(namespaceName string) { + if s.clusterA == nil { + return + } + + ctx := context.Background() + updateReq := &workflowservice.UpdateNamespaceRequest{ + Namespace: namespaceName, + ReplicationConfig: &replicationpb.NamespaceReplicationConfig{ + ActiveClusterName: s.clusterA.ClusterName(), + Clusters: []*replicationpb.ClusterReplicationConfig{ + {ClusterName: s.clusterA.ClusterName()}, + }, + }, + } + + _, err := s.clusterA.FrontendClient().UpdateNamespace(ctx, updateReq) + if err != nil { + s.logger.Warn("Failed to deglobalize namespace", tag.NewStringTag("namespace", namespaceName), tag.Error(err)) + return + } + + s.Eventually(func() bool { + for _, c := range []*testcore.TestCluster{s.clusterA, s.clusterB} { + if c == nil { + continue + } + descResp, err := c.FrontendClient().DescribeNamespace(ctx, &workflowservice.DescribeNamespaceRequest{ + Namespace: namespaceName, + }) + if err != nil || descResp == nil { + return false + } + clusters := descResp.ReplicationConfig.GetClusters() + if len(clusters) != 1 { + return false + } + if clusters[0].GetClusterName() != s.clusterA.ClusterName() { + return false + } + } + return true + }, 10*time.Second, 200*time.Millisecond, "Namespace deglobalization not propagated") + + s.logger.Info("Deglobalized namespace", tag.NewStringTag("namespace", namespaceName)) +} + func (s *ReplicationTestSuite) ensureNamespaceActive(targetCluster string) { descResp, err := s.clusterA.FrontendClient().DescribeNamespace(context.Background(), &workflowservice.DescribeNamespaceRequest{ Namespace: s.namespace, diff --git a/proxy/test/tcp_proxy.go b/proxy/test/tcp_proxy.go new file mode 100644 index 00000000..bbdd811c --- /dev/null +++ b/proxy/test/tcp_proxy.go @@ -0,0 +1,176 @@ +package proxy + +import ( + "context" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" +) + +type ( + UpstreamServer struct { + Address string + conns atomic.Int64 + } + + Upstream struct { + Servers []*UpstreamServer + mu sync.RWMutex + } + + ProxyRule struct { + ListenPort string + Upstream *Upstream + } + + TCPProxy struct { + rules []*ProxyRule + logger log.Logger + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + servers []net.Listener + } +) + +func NewUpstream(servers []string) *Upstream { + upstreamServers := make([]*UpstreamServer, len(servers)) + for i, addr := range servers { + upstreamServers[i] = &UpstreamServer{Address: addr} + } + return &Upstream{Servers: upstreamServers} +} + +func (u *Upstream) selectLeastConn() *UpstreamServer { + u.mu.RLock() + defer u.mu.RUnlock() + + if len(u.Servers) == 0 { + return nil + } + + selected := u.Servers[0] + minConns := selected.conns.Load() + + for i := 1; i < len(u.Servers); i++ { + conns := u.Servers[i].conns.Load() + if conns < minConns { + minConns = conns + selected = u.Servers[i] + } + } + + return selected +} + +func (u *Upstream) incrementConn(server *UpstreamServer) { + server.conns.Add(1) +} + +func (u *Upstream) decrementConn(server *UpstreamServer) { + server.conns.Add(-1) +} + +func NewTCPProxy(logger log.Logger, rules []*ProxyRule) *TCPProxy { + ctx, cancel := context.WithCancel(context.Background()) + return &TCPProxy{ + rules: rules, + logger: logger, + ctx: ctx, + cancel: cancel, + } +} + +func (p *TCPProxy) Start() error { + for _, rule := range p.rules { + listener, err := net.Listen("tcp", ":"+rule.ListenPort) + if err != nil { + p.Stop() + return fmt.Errorf("failed to listen on port %s: %w", rule.ListenPort, err) + } + p.servers = append(p.servers, listener) + + p.wg.Add(1) + go p.handleListener(listener, rule) + } + + return nil +} + +func (p *TCPProxy) Stop() { + p.cancel() + for _, server := range p.servers { + _ = server.Close() + } + p.wg.Wait() +} + +func (p *TCPProxy) handleListener(listener net.Listener, rule *ProxyRule) { + defer p.wg.Done() + + for { + select { + case <-p.ctx.Done(): + return + default: + } + + clientConn, err := listener.Accept() + if err != nil { + select { + case <-p.ctx.Done(): + return + default: + p.logger.Warn("failed to accept connection", tag.Error(err)) + continue + } + } + + p.wg.Add(1) + go p.handleConnection(clientConn, rule) + } +} + +func (p *TCPProxy) handleConnection(clientConn net.Conn, rule *ProxyRule) { + defer p.wg.Done() + defer func() { _ = clientConn.Close() }() + + upstream := rule.Upstream.selectLeastConn() + if upstream == nil { + p.logger.Error("no upstream servers available") + return + } + + rule.Upstream.incrementConn(upstream) + defer rule.Upstream.decrementConn(upstream) + + serverConn, err := net.DialTimeout("tcp", upstream.Address, 5*time.Second) + if err != nil { + p.logger.Warn("failed to connect to upstream", tag.NewStringTag("upstream", upstream.Address), tag.Error(err)) + return + } + defer func() { _ = serverConn.Close() }() + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + _, _ = io.Copy(serverConn, clientConn) + _ = serverConn.Close() + }() + + go func() { + defer wg.Done() + _, _ = io.Copy(clientConn, serverConn) + _ = clientConn.Close() + }() + + wg.Wait() +} diff --git a/proxy/test/tcp_proxy_test.go b/proxy/test/tcp_proxy_test.go new file mode 100644 index 00000000..f0aaff84 --- /dev/null +++ b/proxy/test/tcp_proxy_test.go @@ -0,0 +1,171 @@ +package proxy + +import ( + "io" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.temporal.io/server/common/log" +) + +func TestTCPProxy(t *testing.T) { + logger := log.NewTestLogger() + + server1Addr := GetLocalhostAddress() + server2Addr := GetLocalhostAddress() + server3Addr := GetLocalhostAddress() + server4Addr := GetLocalhostAddress() + server5Addr := GetLocalhostAddress() + server6Addr := GetLocalhostAddress() + + echoServer1 := startEchoServer(t, server1Addr) + echoServer2 := startEchoServer(t, server2Addr) + echoServer3 := startEchoServer(t, server3Addr) + echoServer4 := startEchoServer(t, server4Addr) + echoServer5 := startEchoServer(t, server5Addr) + echoServer6 := startEchoServer(t, server6Addr) + + defer func() { _ = echoServer1.Close() }() + defer func() { _ = echoServer2.Close() }() + defer func() { _ = echoServer3.Close() }() + defer func() { _ = echoServer4.Close() }() + defer func() { _ = echoServer5.Close() }() + defer func() { _ = echoServer6.Close() }() + + proxyAddr1 := GetLocalhostAddress() + proxyAddr2 := GetLocalhostAddress() + proxyAddr3 := GetLocalhostAddress() + + _, proxyPort1, _ := net.SplitHostPort(proxyAddr1) + _, proxyPort2, _ := net.SplitHostPort(proxyAddr2) + _, proxyPort3, _ := net.SplitHostPort(proxyAddr3) + + rules := []*ProxyRule{ + { + ListenPort: proxyPort1, + Upstream: NewUpstream([]string{server1Addr, server2Addr}), + }, + { + ListenPort: proxyPort2, + Upstream: NewUpstream([]string{server3Addr, server4Addr}), + }, + { + ListenPort: proxyPort3, + Upstream: NewUpstream([]string{server5Addr, server6Addr}), + }, + } + + proxy := NewTCPProxy(logger, rules) + err := proxy.Start() + require.NoError(t, err) + defer proxy.Stop() + + // Test proxy on port 1 + testProxyConnection(t, proxyAddr1, "test message 1") + + // Test proxy on port 2 + testProxyConnection(t, proxyAddr2, "test message 2") + + // Test proxy on port 3 + testProxyConnection(t, proxyAddr3, "test message 3") +} + +func testProxyConnection(t *testing.T, proxyAddr, message string) { + conn, err := net.DialTimeout("tcp", proxyAddr, 5*time.Second) + require.NoError(t, err) + defer func() { _ = conn.Close() }() + + _, err = conn.Write([]byte(message)) + require.NoError(t, err) + + buf := make([]byte, len(message)) + _, err = io.ReadFull(conn, buf) + require.NoError(t, err) + require.Equal(t, message, string(buf)) +} + +func startEchoServer(t *testing.T, addr string) net.Listener { + listener, err := net.Listen("tcp", addr) + require.NoError(t, err) + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer func() { _ = c.Close() }() + _, _ = io.Copy(c, c) + }(conn) + } + }() + + return listener +} + +func TestTCPProxyLeastConn(t *testing.T) { + logger := log.NewTestLogger() + + // Create two echo servers + server1Addr := GetLocalhostAddress() + server2Addr := GetLocalhostAddress() + server1 := startEchoServer(t, server1Addr) + server2 := startEchoServer(t, server2Addr) + defer func() { _ = server1.Close() }() + defer func() { _ = server2.Close() }() + + // Create proxy with two upstreams + proxyAddr := GetLocalhostAddress() + _, proxyPort, _ := net.SplitHostPort(proxyAddr) + rules := []*ProxyRule{ + { + ListenPort: proxyPort, + Upstream: NewUpstream([]string{server1Addr, server2Addr}), + }, + } + + proxy := NewTCPProxy(logger, rules) + err := proxy.Start() + require.NoError(t, err) + defer proxy.Stop() + + // Make multiple connections to verify load balancing + for i := 0; i < 10; i++ { + testProxyConnection(t, proxyAddr, "test") + time.Sleep(10 * time.Millisecond) + } +} + +func TestTCPProxyContextCancellation(t *testing.T) { + logger := log.NewTestLogger() + + serverAddr := GetLocalhostAddress() + server := startEchoServer(t, serverAddr) + defer func() { _ = server.Close() }() + + proxyAddr := GetLocalhostAddress() + _, proxyPort, _ := net.SplitHostPort(proxyAddr) + rules := []*ProxyRule{ + { + ListenPort: proxyPort, + Upstream: NewUpstream([]string{serverAddr}), + }, + } + + proxy := NewTCPProxy(logger, rules) + err := proxy.Start() + require.NoError(t, err) + + // Verify it's working + testProxyConnection(t, proxyAddr, "test") + + // Stop the proxy + proxy.Stop() + + // Verify new connections fail + _, err = net.DialTimeout("tcp", proxyAddr, 100*time.Millisecond) + require.Error(t, err) +} diff --git a/proxy/test/test_common.go b/proxy/test/test_common.go new file mode 100644 index 00000000..30f8189a --- /dev/null +++ b/proxy/test/test_common.go @@ -0,0 +1,496 @@ +package proxy + +import ( + "context" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "go.temporal.io/server/api/adminservice/v1" + "go.temporal.io/server/api/historyservice/v1" + "go.temporal.io/server/common" + "go.temporal.io/server/common/cluster" + "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/log/tag" + "go.temporal.io/server/tests/testcore" + + "github.com/temporalio/s2s-proxy/config" + s2sproxy "github.com/temporalio/s2s-proxy/proxy" +) + +type simpleConfigProvider struct { + cfg config.S2SProxyConfig +} + +func (p *simpleConfigProvider) GetS2SProxyConfig() config.S2SProxyConfig { + return p.cfg +} + +type trackingUpstreamServer struct { + address string + conns atomic.Int64 + count1 *atomic.Int64 + count2 *atomic.Int64 +} + +type trackingUpstream struct { + servers []*trackingUpstreamServer + mu sync.RWMutex +} + +func (u *trackingUpstream) selectLeastConn() *trackingUpstreamServer { + u.mu.RLock() + defer u.mu.RUnlock() + + if len(u.servers) == 0 { + return nil + } + + selected := u.servers[0] + minConns := selected.conns.Load() + + for i := 1; i < len(u.servers); i++ { + conns := u.servers[i].conns.Load() + if conns < minConns { + minConns = conns + selected = u.servers[i] + } + } + + if selected != nil { + if selected == u.servers[0] { + selected.count1.Add(1) + } else if len(u.servers) > 1 && selected == u.servers[1] { + selected.count2.Add(1) + } + } + + return selected +} + +func (u *trackingUpstream) incrementConn(server *trackingUpstreamServer) { + server.conns.Add(1) +} + +func (u *trackingUpstream) decrementConn(server *trackingUpstreamServer) { + server.conns.Add(-1) +} + +type trackingTCPProxy struct { + rules []*trackingProxyRule + logger log.Logger + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + servers []net.Listener +} + +type trackingProxyRule struct { + ListenPort string + Upstream *trackingUpstream +} + +func (p *trackingTCPProxy) Start() error { + for _, rule := range p.rules { + listener, err := net.Listen("tcp", ":"+rule.ListenPort) + if err != nil { + p.Stop() + return fmt.Errorf("failed to listen on port %s: %w", rule.ListenPort, err) + } + p.servers = append(p.servers, listener) + + p.wg.Add(1) + go p.handleListener(listener, rule) + } + + return nil +} + +func (p *trackingTCPProxy) Stop() { + p.logger.Info("Stopping tracking TCP proxy") + p.cancel() + for _, server := range p.servers { + p.logger.Info("Closing server", tag.NewStringTag("server", server.Addr().String())) + _ = server.Close() + } + p.logger.Info("Waiting for goroutines to finish") + p.wg.Wait() + p.logger.Info("Tracking TCP proxy stopped") +} + +func (p *trackingTCPProxy) handleListener(listener net.Listener, rule *trackingProxyRule) { + defer p.wg.Done() + + for { + select { + case <-p.ctx.Done(): + return + default: + } + + clientConn, err := listener.Accept() + if err != nil { + select { + case <-p.ctx.Done(): + return + default: + p.logger.Warn("failed to accept connection", tag.Error(err)) + continue + } + } + + p.wg.Add(1) + go p.handleConnection(clientConn, rule) + } +} + +func (p *trackingTCPProxy) handleConnection(clientConn net.Conn, rule *trackingProxyRule) { + defer p.wg.Done() + defer func() { _ = clientConn.Close() }() + + select { + case <-p.ctx.Done(): + return + default: + } + + upstream := rule.Upstream.selectLeastConn() + if upstream == nil { + p.logger.Error("no upstream servers available") + return + } + + rule.Upstream.incrementConn(upstream) + defer rule.Upstream.decrementConn(upstream) + + serverConn, err := net.DialTimeout("tcp", upstream.address, 5*time.Second) + if err != nil { + p.logger.Warn("failed to connect to upstream", tag.NewStringTag("upstream", upstream.address), tag.Error(err)) + return + } + defer func() { _ = serverConn.Close() }() + + var wg sync.WaitGroup + wg.Add(3) + + go func() { + defer wg.Done() + <-p.ctx.Done() + _ = clientConn.Close() + _ = serverConn.Close() + }() + + go func() { + defer wg.Done() + _, _ = io.Copy(serverConn, clientConn) + _ = serverConn.Close() + }() + + go func() { + defer wg.Done() + _, _ = io.Copy(clientConn, serverConn) + _ = clientConn.Close() + }() + + wg.Wait() +} + +func createLoadBalancer( + logger log.Logger, + listenPort string, + upstreams []string, + count1 *atomic.Int64, + count2 *atomic.Int64, +) (*trackingTCPProxy, error) { + trackingServers := make([]*trackingUpstreamServer, len(upstreams)) + for i, addr := range upstreams { + trackingServers[i] = &trackingUpstreamServer{ + address: addr, + count1: count1, + count2: count2, + } + } + + trackingUpstream := &trackingUpstream{ + servers: trackingServers, + } + + rules := []*trackingProxyRule{ + { + ListenPort: listenPort, + Upstream: trackingUpstream, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + trackingProxy := &trackingTCPProxy{ + rules: rules, + logger: logger, + ctx: ctx, + cancel: cancel, + } + + err := trackingProxy.Start() + if err != nil { + return nil, err + } + + return trackingProxy, nil +} + +func createCluster( + logger log.Logger, + t testingT, + clusterName string, + numShards int, + initialFailoverVersion int64, + numHistoryHosts int, +) *testcore.TestCluster { + clusterSuffix := common.GenerateRandomString(8) + fullClusterName := fmt.Sprintf("%s-%s", clusterName, clusterSuffix) + + clusterConfig := &testcore.TestClusterConfig{ + ClusterMetadata: cluster.Config{ + EnableGlobalNamespace: true, + FailoverVersionIncrement: 10, + MasterClusterName: fullClusterName, + CurrentClusterName: fullClusterName, + ClusterInformation: map[string]cluster.ClusterInformation{ + fullClusterName: { + Enabled: true, + InitialFailoverVersion: initialFailoverVersion, + }, + }, + }, + HistoryConfig: testcore.HistoryConfig{ + NumHistoryShards: int32(numShards), + NumHistoryHosts: numHistoryHosts, + }, + DynamicConfigOverrides: map[dynamicconfig.Key]interface{}{ + dynamicconfig.NamespaceCacheRefreshInterval.Key(): time.Second, + dynamicconfig.EnableReplicationStream.Key(): true, + dynamicconfig.EnableReplicationTaskBatching.Key(): true, + }, + } + + testClusterFactory := testcore.NewTestClusterFactory() + logger = log.With(logger, tag.NewStringTag("clusterName", clusterName)) + + testT := getTestingT(t) + cluster, err := testClusterFactory.NewCluster(testT, clusterConfig, logger) + if err != nil { + t.Fatalf("Failed to create cluster %s: %v", clusterName, err) + } + + return cluster +} + +func createProxy( + logger log.Logger, + t testingT, + name string, + inboundAddress string, + outboundAddress string, + muxAddress string, + cluster *testcore.TestCluster, + muxMode config.MuxMode, + shardCountConfig config.ShardCountConfig, + nodeName string, + memberlistBindAddr string, + memberlistBindPort int, + memberlistJoinAddrs []string, + proxyAddresses map[string]string, +) *s2sproxy.Proxy { + var muxConnectionType config.ConnectionType + var muxAddressInfo config.TCPTLSInfo + if muxMode == config.ServerMode { + muxConnectionType = config.ConnTypeMuxServer + muxAddressInfo = config.TCPTLSInfo{ + ConnectionString: muxAddress, + } + } else { + muxConnectionType = config.ConnTypeMuxClient + muxAddressInfo = config.TCPTLSInfo{ + ConnectionString: muxAddress, + } + } + + cfg := &config.S2SProxyConfig{ + ClusterConnections: []config.ClusterConnConfig{ + { + Name: name, + LocalServer: config.ClusterDefinition{ + Connection: config.TransportInfo{ + ConnectionType: config.ConnTypeTCP, + TcpClient: config.TCPTLSInfo{ + ConnectionString: cluster.Host().FrontendGRPCAddress(), + }, + TcpServer: config.TCPTLSInfo{ + ConnectionString: outboundAddress, + }, + }, + }, + RemoteServer: config.ClusterDefinition{ + Connection: config.TransportInfo{ + ConnectionType: muxConnectionType, + MuxCount: 1, + MuxAddressInfo: muxAddressInfo, + }, + }, + ShardCountConfig: shardCountConfig, + }, + }, + } + + if nodeName != "" && memberlistBindAddr != "" { + cfg.ClusterConnections[0].MemberlistConfig = &config.MemberlistConfig{ + Enabled: true, + NodeName: nodeName, + BindAddr: memberlistBindAddr, + BindPort: memberlistBindPort, + JoinAddrs: memberlistJoinAddrs, + ProxyAddresses: proxyAddresses, + TCPOnly: true, + } + } + + configProvider := &simpleConfigProvider{cfg: *cfg} + proxy := s2sproxy.NewProxy(configProvider, logger) + if proxy == nil { + t.Fatalf("Failed to create proxy %s", name) + } + + err := proxy.Start() + if err != nil { + t.Fatalf("Failed to start proxy %s: %v", name, err) + } + + logger.Info("Started proxy", tag.NewStringTag("name", name), + tag.NewStringTag("inboundAddress", inboundAddress), + tag.NewStringTag("outboundAddress", outboundAddress), + tag.NewStringTag("muxAddress", muxAddress), + tag.NewStringTag("muxMode", string(muxMode)), + tag.NewStringTag("nodeName", nodeName), + ) + + return proxy +} + +func configureRemoteCluster( + logger log.Logger, + t testingT, + cluster *testcore.TestCluster, + remoteClusterName string, + proxyAddress string, +) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + _, err := cluster.AdminClient().AddOrUpdateRemoteCluster( + ctx, + &adminservice.AddOrUpdateRemoteClusterRequest{ + FrontendAddress: proxyAddress, + EnableRemoteClusterConnection: true, + }, + ) + if err != nil { + t.Fatalf("Failed to configure remote cluster %s: %v", remoteClusterName, err) + } + logger.Info("Configured remote cluster", + tag.NewStringTag("remoteClusterName", remoteClusterName), + tag.NewStringTag("proxyAddress", proxyAddress), + tag.NewStringTag("clusterName", cluster.ClusterName()), + ) +} + +func removeRemoteCluster( + logger log.Logger, + t testingT, + cluster *testcore.TestCluster, + remoteClusterName string, +) { + _, err := cluster.AdminClient().RemoveRemoteCluster( + context.Background(), + &adminservice.RemoveRemoteClusterRequest{ + ClusterName: remoteClusterName, + }, + ) + if err != nil { + t.Fatalf("Failed to remove remote cluster %s: %v", remoteClusterName, err) + } + logger.Info("Removed remote cluster", + tag.NewStringTag("remoteClusterName", remoteClusterName), + tag.NewStringTag("clusterName", cluster.ClusterName()), + ) +} + +func waitForReplicationReady( + logger log.Logger, + t testingT, + clusters ...*testcore.TestCluster, +) { + time.Sleep(1 * time.Second) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + for _, cluster := range clusters { + ready := false + for i := 0; i < 25; i++ { + _, err := cluster.HistoryClient().GetReplicationStatus( + ctx, + &historyservice.GetReplicationStatusRequest{}, + ) + if err == nil { + ready = true + break + } + time.Sleep(200 * time.Millisecond) + } + if !ready { + t.Fatalf("Replication infrastructure not ready for cluster %s", cluster.ClusterName()) + } + } + + time.Sleep(1 * time.Second) +} + +type testingT interface { + Helper() + Fatalf(format string, args ...interface{}) +} + +func getTestingT(t testingT) *testing.T { + if testT, ok := t.(*testing.T); ok { + return testT + } + if suiteT, ok := t.(interface{ T() *testing.T }); ok { + return suiteT.T() + } + panic("testingT must be *testing.T or have T() method") +} + +// GetFreePort returns an available TCP port by listening on localhost:0. +// This is useful for tests that need to allocate ports dynamically. +func GetFreePort() int { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + panic(fmt.Sprintf("failed to get free port: %v", err)) + } + defer func() { + if err := l.Close(); err != nil { + fmt.Printf("Failed to close listener: %v\n", err) + } + }() + return l.Addr().(*net.TCPAddr).Port +} + +// GetLocalhostAddress returns a localhost address with a free port +func GetLocalhostAddress() string { + return fmt.Sprintf("localhost:%d", GetFreePort()) +} diff --git a/proxy/test/wiring_test.go b/proxy/test/wiring_test.go index 768b3919..55e58fc5 100644 --- a/proxy/test/wiring_test.go +++ b/proxy/test/wiring_test.go @@ -41,31 +41,46 @@ type ( ) var ( - // Create some believable echo server configs - echoServerInfo = endtoendtest.ClusterInfo{ + logger log.Logger +) + +func getEchoServerInfo() endtoendtest.ClusterInfo { + echoServerAddress := GetLocalhostAddress() + serverProxyInboundAddress := GetLocalhostAddress() + serverProxyOutboundAddress := GetLocalhostAddress() + prometheusAddress := GetLocalhostAddress() + healthCheckAddress := GetLocalhostAddress() + return endtoendtest.ClusterInfo{ ServerAddress: echoServerAddress, ClusterShardID: serverClusterShard, S2sProxyConfig: makeS2SConfig(s2sAddresses{ - echoServer: "localhost:7266", - inbound: "localhost:7366", - outbound: "localhost:7466", - prometheus: "localhost:7468", - healthCheck: "localhost:7479", + echoServer: echoServerAddress, + inbound: serverProxyInboundAddress, + outbound: serverProxyOutboundAddress, + prometheus: prometheusAddress, + healthCheck: healthCheckAddress, }), } - echoClientInfo = endtoendtest.ClusterInfo{ +} + +func getEchoClientInfo() endtoendtest.ClusterInfo { + echoClientAddress := GetLocalhostAddress() + clientProxyInboundAddress := GetLocalhostAddress() + clientProxyOutboundAddress := GetLocalhostAddress() + prometheusAddress := GetLocalhostAddress() + healthCheckAddress := GetLocalhostAddress() + return endtoendtest.ClusterInfo{ ServerAddress: echoClientAddress, ClusterShardID: clientClusterShard, S2sProxyConfig: makeS2SConfig(s2sAddresses{ - echoServer: "localhost:8266", - inbound: "localhost:8366", - outbound: "localhost:8466", - prometheus: "localhost:7467", - healthCheck: "localhost:7478", + echoServer: echoClientAddress, + inbound: clientProxyInboundAddress, + outbound: clientProxyOutboundAddress, + prometheus: prometheusAddress, + healthCheck: healthCheckAddress, }), } - logger log.Logger -) +} type hangupAdminServer struct { adminservice.UnimplementedAdminServiceServer @@ -129,6 +144,9 @@ func TestEOFFromServer(t *testing.T) { } func TestWiringWithEchoService(t *testing.T) { + echoServerInfo := getEchoServerInfo() + echoClientInfo := getEchoClientInfo() + echoServer := endtoendtest.NewEchoServer(echoServerInfo, echoClientInfo, "EchoServer", logger, nil) echoClient := endtoendtest.NewEchoServer(echoClientInfo, echoServerInfo, "EchoClient", logger, nil) echoServer.Start() diff --git a/transport/mux/multi_mux_manager.go b/transport/mux/multi_mux_manager.go index 536ee34a..8c9312e1 100644 --- a/transport/mux/multi_mux_manager.go +++ b/transport/mux/multi_mux_manager.go @@ -53,6 +53,8 @@ type ( CanAcceptConnections() bool Describe() string Name() string + // GetMuxConnections returns a snapshot of active mux connections + GetMuxConnections() map[string]session.ManagedMuxSession } MuxProviderBuilder func(AddNewMux, context.Context) (MuxProvider, error) ) @@ -202,3 +204,14 @@ func (m *multiMuxManager) Describe() string { func (m *multiMuxManager) Name() string { return m.name } + +func (m *multiMuxManager) GetMuxConnections() map[string]session.ManagedMuxSession { + m.muxesLock.RLock() + defer m.muxesLock.RUnlock() + // Return a copy to avoid holding the lock + result := make(map[string]session.ManagedMuxSession, len(m.muxes)) + for k, v := range m.muxes { + result[k] = v + } + return result +} diff --git a/transport/mux/multi_mux_manager_test.go b/transport/mux/multi_mux_manager_test.go index 87105828..391ceac2 100644 --- a/transport/mux/multi_mux_manager_test.go +++ b/transport/mux/multi_mux_manager_test.go @@ -47,12 +47,14 @@ func TestMultiMuxManager(t *testing.T) { require.False(t, muxesOnPipes.clientMM.CanAcceptConnections(), "All connections should have been consumed") // Close connections. We should see both sides fire disconnectFn + require.Eventually(t, func() bool { return clientConns.Load() != nil }, 2*time.Second, 10*time.Millisecond, "clientConns should be set") for _, v := range *clientConns.Load() { v.Close() } clientEvent = proxyassert.RequireCh(t, muxesOnPipes.clientEvents, 2*time.Second, "Client connection failed to disconnect!\nclientMux:%s", muxesOnPipes.clientMM.Describe()) require.Equal(t, "closed", clientEvent.eventType) require.Same(t, clientSession, clientEvent.session) + require.Eventually(t, func() bool { return serverConns.Load() != nil }, 2*time.Second, 10*time.Millisecond, "serverConns should be set") for _, v := range *serverConns.Load() { v.Close() } diff --git a/transport/mux/session/managed_mux_session.go b/transport/mux/session/managed_mux_session.go index b9a582d9..2d87dbd2 100644 --- a/transport/mux/session/managed_mux_session.go +++ b/transport/mux/session/managed_mux_session.go @@ -38,6 +38,8 @@ type ( Open() (net.Conn, error) State() *MuxSessionInfo Describe() string + // GetConnectionInfo returns the local and remote addresses of the underlying connection + GetConnectionInfo() (localAddr, remoteAddr net.Addr) } ) @@ -142,3 +144,13 @@ func (s *muxSession) Addr() net.Addr { func (s *muxSession) Describe() string { return fmt.Sprintf("[muxSession %s, state=%v, address=%s]", s.id, s.state.Load(), s.conn.RemoteAddr().String()) } + +func (s *muxSession) GetConnectionInfo() (localAddr, remoteAddr net.Addr) { + if s.session != nil { + return s.session.LocalAddr(), s.session.RemoteAddr() + } + if s.conn != nil { + return s.conn.LocalAddr(), s.conn.RemoteAddr() + } + return nil, nil +}