diff --git a/common/common.go b/common/common.go index 00a0f0f9..97a8bb7a 100644 --- a/common/common.go +++ b/common/common.go @@ -7,3 +7,30 @@ import ( func ServiceTag(sv string) tag.ZapTag { return tag.NewStringTag("service", sv) } + +// GCD calulates the greatest common divisor using the Euclidean algorithm. +// +// https://en.wikipedia.org/wiki/Euclidean_algorithm#Implementations +func GCD(a, b int32) int32 { + if a == 0 || b == 0 { + return 0 + } + if a > b { + a, b = b, a + } + + for b != 0 { + a, b = b, a%b + } + return a +} + +// LCM calcuates the least common multiple of a and b. +// +// https://en.wikipedia.org/wiki/Least_common_multiple#Using_the_greatest_common_divisor +func LCM(a, b int32) int32 { + if a == 0 || b == 0 { + return 0 + } + return a * b / GCD(a, b) +} diff --git a/common/common_test.go b/common/common_test.go new file mode 100644 index 00000000..e6fc3c2f --- /dev/null +++ b/common/common_test.go @@ -0,0 +1,144 @@ +package common + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGCD(t *testing.T) { + cases := []struct{ a, b, exp int32 }{ + {0, 0, 0}, + {1, 0, 0}, + {0, 1, 0}, + + {1, 1, 1}, + {2, 1, 1}, + {3, 1, 1}, + {4, 1, 1}, + {5, 1, 1}, + {6, 1, 1}, + {7, 1, 1}, + {8, 1, 1}, + {9, 1, 1}, + + {2, 2, 2}, + {3, 2, 1}, + {4, 2, 2}, + {5, 2, 1}, + {6, 2, 2}, + {7, 2, 1}, + {8, 2, 2}, + {9, 2, 1}, + + {3, 3, 3}, + {4, 3, 1}, + {5, 3, 1}, + {6, 3, 3}, + {7, 3, 1}, + {8, 3, 1}, + {9, 3, 3}, + + {4, 4, 4}, + {5, 4, 1}, + {6, 4, 2}, + {7, 4, 1}, + {8, 4, 4}, + {9, 4, 1}, + + {5, 5, 5}, + {6, 5, 1}, + {7, 5, 1}, + {8, 5, 1}, + {9, 5, 1}, + + {36, 10, 2}, + {36, 11, 1}, + {36, 12, 12}, + {36, 14, 2}, + {36, 15, 3}, + {36, 16, 4}, + + {4000, 1024, 32}, + {4000, 512, 32}, + } + + for _, tc := range cases { + name := fmt.Sprintf("a=%d, b=%d", tc.a, tc.b) + t.Run(name, func(t *testing.T) { + require.Equal(t, tc.exp, GCD(tc.a, tc.b)) + }) + + // GCD(a, b) == GCD(b, a) + name = fmt.Sprintf("b=%d, a=%d", tc.b, tc.a) + t.Run(name, func(t *testing.T) { + require.Equal(t, tc.exp, GCD(tc.b, tc.a)) + }) + } + +} + +func TestLCM(t *testing.T) { + cases := []struct{ a, b, exp int32 }{ + {0, 0, 0}, + {1, 0, 0}, + {0, 1, 0}, + + {1, 1, 1}, + {2, 1, 2}, + {3, 1, 3}, + {4, 1, 4}, + {5, 1, 5}, + {6, 1, 6}, + {7, 1, 7}, + {8, 1, 8}, + {9, 1, 9}, + + {2, 2, 2}, + {3, 2, 6}, + {4, 2, 4}, + {5, 2, 10}, + {6, 2, 6}, + {7, 2, 14}, + {8, 2, 8}, + {9, 2, 18}, + + {3, 3, 3}, + {4, 3, 12}, + {5, 3, 15}, + {6, 3, 6}, + {7, 3, 21}, + {8, 3, 24}, + {9, 3, 9}, + + {4, 4, 4}, + {5, 4, 20}, + {6, 4, 12}, + {7, 4, 28}, + {8, 4, 8}, + {9, 4, 36}, + + {36, 10, 180}, + {36, 11, 396}, + {36, 12, 36}, + + {4000, 1024, 128000}, + {4000, 512, 64000}, + {4000, 256, 32000}, + } + + for _, tc := range cases { + name := fmt.Sprintf("a=%d, b=%d", tc.a, tc.b) + t.Run(name, func(t *testing.T) { + require.Equal(t, tc.exp, LCM(tc.a, tc.b)) + }) + + // LCM(a, b) == LCM(b, a) + name = fmt.Sprintf("b=%d, a=%d", tc.b, tc.a) + t.Run(name, func(t *testing.T) { + require.Equal(t, tc.exp, LCM(tc.b, tc.a)) + }) + } + +} diff --git a/config/config.go b/config/config.go index 17ac69a0..22ba887f 100644 --- a/config/config.go +++ b/config/config.go @@ -30,6 +30,13 @@ const ( ServerMode MuxMode = "server" // server of underly tcp connection in mux mode. ) +type ShardCountMode string + +const ( + ShardCountDefault ShardCountMode = "" + ShardCountLCM ShardCountMode = "lcm" +) + type HealthCheckProtocol string const ( @@ -92,6 +99,7 @@ type ( MuxTransports []MuxTransportConfig `yaml:"mux"` HealthCheck *HealthCheckConfig `yaml:"healthCheck"` NamespaceNameTranslation NamespaceNameTranslationConfig `yaml:"namespaceNameTranslation"` + ShardCountConfig ShardCountConfig `yaml:"shardCount"` Metrics *MetricsConfig `yaml:"metrics"` } @@ -99,6 +107,12 @@ type ( Mappings []NameMappingConfig `yaml:"mappings"` } + ShardCountConfig struct { + Mode ShardCountMode `yaml:"mode"` + LocalShardCount int32 `yaml:"localShardCount"` + RemoteShardCount int32 `yaml:"remoteShardCount"` + } + NameMappingConfig struct { LocalName string `yaml:"localName"` RemoteName string `yaml:"remoteName"` diff --git a/proxy/adminservice.go b/proxy/adminservice.go index 132ea646..5b92921b 100644 --- a/proxy/adminservice.go +++ b/proxy/adminservice.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "strconv" "sync" "github.com/temporalio/s2s-proxy/client" @@ -14,6 +15,7 @@ import ( "go.temporal.io/api/serviceerror" "go.temporal.io/server/api/adminservice/v1" "go.temporal.io/server/client/history" + servercommon "go.temporal.io/server/common" "go.temporal.io/server/common/channel" "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" @@ -77,7 +79,17 @@ func (s *adminServiceProxyServer) DeleteWorkflowExecution(ctx context.Context, i } func (s *adminServiceProxyServer) DescribeCluster(ctx context.Context, in0 *adminservice.DescribeClusterRequest) (*adminservice.DescribeClusterResponse, error) { - return s.adminClient.DescribeCluster(ctx, in0) + resp, err := s.adminClient.DescribeCluster(ctx, in0) + if err != nil { + return resp, err + } + + if cfg := s.Config.ShardCountConfig; cfg.Mode == config.ShardCountLCM { + // Present a fake number of shards. In LCM mode, we present the least + // common multiple of both cluster shard counts. + resp.HistoryShardCount = common.LCM(cfg.RemoteShardCount, cfg.LocalShardCount) + } + return resp, err } func (s *adminServiceProxyServer) DescribeDLQJob(ctx context.Context, in0 *adminservice.DescribeDLQJobRequest) (*adminservice.DescribeDLQJobResponse, error) { @@ -213,7 +225,7 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( if !ok { return serviceerror.NewInvalidArgument("missing cluster & shard ID metadata") } - targetClusterShardID, sourceClusterShardID, err := history.DecodeClusterShardMD( + clientShardID, serverShardID, err := history.DecodeClusterShardMD( headers.NewGRPCHeaderGetter(targetStreamServer.Context()), ) if err != nil { @@ -221,13 +233,48 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( } logger := log.With(s.logger, - tag.NewStringTag("source", ClusterShardIDtoString(sourceClusterShardID)), - tag.NewStringTag("target", ClusterShardIDtoString(targetClusterShardID))) + tag.NewStringTag("client", ClusterShardIDtoString(clientShardID)), + tag.NewStringTag("server", ClusterShardIDtoString(serverShardID)), + ) logger.Info("AdminStreamReplicationMessages started.") defer logger.Info("AdminStreamReplicationMessages stopped.") - // simply forwarding target metadata + if cfg := s.Config.ShardCountConfig; cfg.Mode == config.ShardCountLCM { + // Abitrary shard count support. + // + // Temporal only supports shard counts where one shard count is an even multiple of the other. + // The trick in this mode is the proxy will present the Least Common Multiple of both cluster shard counts. + // Temporal establishes outbound replication streams to the proxy for all unqiue shard id pairs between + // itself and the proxy's shard count. Then the proxy directly forwards those streams along to the target + // cluster, remapping proxy stream shard ids to the target cluster shard ids. + newClientShardID := history.ClusterShardID{ + ClusterID: clientShardID.ClusterID, + ShardID: serverShardID.ShardID, // proxy fake shard id + } + newServerShardID := history.ClusterShardID{ + ClusterID: serverShardID.ClusterID, + } + LCM := common.LCM(cfg.LocalShardCount, cfg.RemoteShardCount) + if s.IsInbound { + // Stream is going to local server. Remap shard id by local server shard count. + newServerShardID.ShardID = mapShardIDUnique(LCM, cfg.LocalShardCount, serverShardID.ShardID) + } else { + // Stream is going to remote server. Remap shard id by remote server shard count. + newServerShardID.ShardID = mapShardIDUnique(LCM, cfg.RemoteShardCount, serverShardID.ShardID) + } + + logger = log.With(logger, + tag.NewStringTag("newClient", ClusterShardIDtoString(newClientShardID)), + tag.NewStringTag("newServer", ClusterShardIDtoString(newServerShardID))) + + // Maybe there's a cleaner way. Trying to preserve any other metadata. + targetMetadata.Set(history.MetadataKeyClientClusterID, strconv.Itoa(int(newClientShardID.ClusterID))) + targetMetadata.Set(history.MetadataKeyClientShardID, strconv.Itoa(int(newClientShardID.ShardID))) + targetMetadata.Set(history.MetadataKeyServerClusterID, strconv.Itoa(int(newServerShardID.ClusterID))) + targetMetadata.Set(history.MetadataKeyServerShardID, strconv.Itoa(int(newServerShardID.ShardID))) + } + outgoingContext := metadata.NewOutgoingContext(targetStreamServer.Context(), targetMetadata) outgoingContext, cancel := context.WithCancel(outgoingContext) defer cancel() @@ -327,3 +374,12 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages( wg.Wait() return nil } + +func mapShardIDUnique(sourceShardCount, targetShardCount, sourceShardID int32) int32 { + targetShardID := servercommon.MapShardID(sourceShardCount, targetShardCount, sourceShardID) + if len(targetShardID) != 1 { + panic(fmt.Sprintf("remapping shard count error: sourceShardCount=%d targetShardCount=%d sourceShardID=%d targetShardID=%v\n", + sourceShardCount, targetShardCount, sourceShardID, targetShardID)) + } + return targetShardID[0] +}