Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,30 @@ import (
func ServiceTag(sv string) tag.ZapTag {
return tag.NewStringTag("service", sv)
}

// GCD calulates the greatest common divisor using the Euclidean algorithm.
//
// https://en.wikipedia.org/wiki/Euclidean_algorithm#Implementations
func GCD(a, b int32) int32 {
if a == 0 || b == 0 {
return 0
}
if a > b {
a, b = b, a
}

for b != 0 {
a, b = b, a%b
}
return a
}

// LCM calcuates the least common multiple of a and b.
//
// https://en.wikipedia.org/wiki/Least_common_multiple#Using_the_greatest_common_divisor
func LCM(a, b int32) int32 {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are surprisingly few small libraries for doing this, and I wasn't satisfied with their unit testing (just few cases). I figured these functions are so small that I'd implement them myself to avoid the dependency.

if a == 0 || b == 0 {
return 0
}
return a * b / GCD(a, b)
}
144 changes: 144 additions & 0 deletions common/common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package common

import (
"fmt"
"testing"

"github.com/stretchr/testify/require"
)

func TestGCD(t *testing.T) {
cases := []struct{ a, b, exp int32 }{
{0, 0, 0},
{1, 0, 0},
{0, 1, 0},

{1, 1, 1},
{2, 1, 1},
{3, 1, 1},
{4, 1, 1},
{5, 1, 1},
{6, 1, 1},
{7, 1, 1},
{8, 1, 1},
{9, 1, 1},

{2, 2, 2},
{3, 2, 1},
{4, 2, 2},
{5, 2, 1},
{6, 2, 2},
{7, 2, 1},
{8, 2, 2},
{9, 2, 1},

{3, 3, 3},
{4, 3, 1},
{5, 3, 1},
{6, 3, 3},
{7, 3, 1},
{8, 3, 1},
{9, 3, 3},

{4, 4, 4},
{5, 4, 1},
{6, 4, 2},
{7, 4, 1},
{8, 4, 4},
{9, 4, 1},

{5, 5, 5},
{6, 5, 1},
{7, 5, 1},
{8, 5, 1},
{9, 5, 1},

{36, 10, 2},
{36, 11, 1},
{36, 12, 12},
{36, 14, 2},
{36, 15, 3},
{36, 16, 4},

{4000, 1024, 32},
{4000, 512, 32},
}

for _, tc := range cases {
name := fmt.Sprintf("a=%d, b=%d", tc.a, tc.b)
t.Run(name, func(t *testing.T) {
require.Equal(t, tc.exp, GCD(tc.a, tc.b))
})

// GCD(a, b) == GCD(b, a)
name = fmt.Sprintf("b=%d, a=%d", tc.b, tc.a)
t.Run(name, func(t *testing.T) {
require.Equal(t, tc.exp, GCD(tc.b, tc.a))
})
}

}

func TestLCM(t *testing.T) {
cases := []struct{ a, b, exp int32 }{
{0, 0, 0},
{1, 0, 0},
{0, 1, 0},

{1, 1, 1},
{2, 1, 2},
{3, 1, 3},
{4, 1, 4},
{5, 1, 5},
{6, 1, 6},
{7, 1, 7},
{8, 1, 8},
{9, 1, 9},

{2, 2, 2},
{3, 2, 6},
{4, 2, 4},
{5, 2, 10},
{6, 2, 6},
{7, 2, 14},
{8, 2, 8},
{9, 2, 18},

{3, 3, 3},
{4, 3, 12},
{5, 3, 15},
{6, 3, 6},
{7, 3, 21},
{8, 3, 24},
{9, 3, 9},

{4, 4, 4},
{5, 4, 20},
{6, 4, 12},
{7, 4, 28},
{8, 4, 8},
{9, 4, 36},

{36, 10, 180},
{36, 11, 396},
{36, 12, 36},

{4000, 1024, 128000},
{4000, 512, 64000},
{4000, 256, 32000},
}

for _, tc := range cases {
name := fmt.Sprintf("a=%d, b=%d", tc.a, tc.b)
t.Run(name, func(t *testing.T) {
require.Equal(t, tc.exp, LCM(tc.a, tc.b))
})

// LCM(a, b) == LCM(b, a)
name = fmt.Sprintf("b=%d, a=%d", tc.b, tc.a)
t.Run(name, func(t *testing.T) {
require.Equal(t, tc.exp, LCM(tc.b, tc.a))
})
}

}
14 changes: 14 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ const (
ServerMode MuxMode = "server" // server of underly tcp connection in mux mode.
)

type ShardCountMode string

const (
ShardCountDefault ShardCountMode = ""
ShardCountLCM ShardCountMode = "lcm"
)

type HealthCheckProtocol string

const (
Expand Down Expand Up @@ -92,13 +99,20 @@ type (
MuxTransports []MuxTransportConfig `yaml:"mux"`
HealthCheck *HealthCheckConfig `yaml:"healthCheck"`
NamespaceNameTranslation NamespaceNameTranslationConfig `yaml:"namespaceNameTranslation"`
ShardCountConfig ShardCountConfig `yaml:"shardCount"`
Metrics *MetricsConfig `yaml:"metrics"`
}

NamespaceNameTranslationConfig struct {
Mappings []NameMappingConfig `yaml:"mappings"`
}

ShardCountConfig struct {
Mode ShardCountMode `yaml:"mode"`
LocalShardCount int32 `yaml:"localShardCount"`
RemoteShardCount int32 `yaml:"remoteShardCount"`
}

NameMappingConfig struct {
LocalName string `yaml:"localName"`
RemoteName string `yaml:"remoteName"`
Expand Down
66 changes: 61 additions & 5 deletions proxy/adminservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"strconv"
"sync"

"github.com/temporalio/s2s-proxy/client"
Expand All @@ -14,6 +15,7 @@ import (
"go.temporal.io/api/serviceerror"
"go.temporal.io/server/api/adminservice/v1"
"go.temporal.io/server/client/history"
servercommon "go.temporal.io/server/common"
"go.temporal.io/server/common/channel"
"go.temporal.io/server/common/headers"
"go.temporal.io/server/common/log"
Expand Down Expand Up @@ -77,7 +79,17 @@ func (s *adminServiceProxyServer) DeleteWorkflowExecution(ctx context.Context, i
}

func (s *adminServiceProxyServer) DescribeCluster(ctx context.Context, in0 *adminservice.DescribeClusterRequest) (*adminservice.DescribeClusterResponse, error) {
return s.adminClient.DescribeCluster(ctx, in0)
resp, err := s.adminClient.DescribeCluster(ctx, in0)
if err != nil {
return resp, err
}

if cfg := s.Config.ShardCountConfig; cfg.Mode == config.ShardCountLCM {
// Present a fake number of shards. In LCM mode, we present the least
// common multiple of both cluster shard counts.
resp.HistoryShardCount = common.LCM(cfg.RemoteShardCount, cfg.LocalShardCount)
}
return resp, err
}

func (s *adminServiceProxyServer) DescribeDLQJob(ctx context.Context, in0 *adminservice.DescribeDLQJobRequest) (*adminservice.DescribeDLQJobResponse, error) {
Expand Down Expand Up @@ -213,21 +225,56 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages(
if !ok {
return serviceerror.NewInvalidArgument("missing cluster & shard ID metadata")
}
targetClusterShardID, sourceClusterShardID, err := history.DecodeClusterShardMD(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this was my fault when I bumped the temporal version, but the target/source ids here were flipped.

I changed them to match "client/server" terminology of the history.DecodeClusterShardMD function?

I'm fine either way.

clientShardID, serverShardID, err := history.DecodeClusterShardMD(
headers.NewGRPCHeaderGetter(targetStreamServer.Context()),
)
if err != nil {
return err
}

logger := log.With(s.logger,
tag.NewStringTag("source", ClusterShardIDtoString(sourceClusterShardID)),
tag.NewStringTag("target", ClusterShardIDtoString(targetClusterShardID)))
tag.NewStringTag("client", ClusterShardIDtoString(clientShardID)),
tag.NewStringTag("server", ClusterShardIDtoString(serverShardID)),
)

logger.Info("AdminStreamReplicationMessages started.")
defer logger.Info("AdminStreamReplicationMessages stopped.")

// simply forwarding target metadata
if cfg := s.Config.ShardCountConfig; cfg.Mode == config.ShardCountLCM {
// Abitrary shard count support.
//
// Temporal only supports shard counts where one shard count is an even multiple of the other.
// The trick in this mode is the proxy will present the Least Common Multiple of both cluster shard counts.
// Temporal establishes outbound replication streams to the proxy for all unqiue shard id pairs between
// itself and the proxy's shard count. Then the proxy directly forwards those streams along to the target
// cluster, remapping proxy stream shard ids to the target cluster shard ids.
newClientShardID := history.ClusterShardID{
ClusterID: clientShardID.ClusterID,
ShardID: serverShardID.ShardID, // proxy fake shard id
}
newServerShardID := history.ClusterShardID{
ClusterID: serverShardID.ClusterID,
}
LCM := common.LCM(cfg.LocalShardCount, cfg.RemoteShardCount)
if s.IsInbound {
// Stream is going to local server. Remap shard id by local server shard count.
newServerShardID.ShardID = mapShardIDUnique(LCM, cfg.LocalShardCount, serverShardID.ShardID)
} else {
// Stream is going to remote server. Remap shard id by remote server shard count.
newServerShardID.ShardID = mapShardIDUnique(LCM, cfg.RemoteShardCount, serverShardID.ShardID)
}

logger = log.With(logger,
tag.NewStringTag("newClient", ClusterShardIDtoString(newClientShardID)),
tag.NewStringTag("newServer", ClusterShardIDtoString(newServerShardID)))

// Maybe there's a cleaner way. Trying to preserve any other metadata.
targetMetadata.Set(history.MetadataKeyClientClusterID, strconv.Itoa(int(newClientShardID.ClusterID)))
targetMetadata.Set(history.MetadataKeyClientShardID, strconv.Itoa(int(newClientShardID.ShardID)))
targetMetadata.Set(history.MetadataKeyServerClusterID, strconv.Itoa(int(newServerShardID.ClusterID)))
targetMetadata.Set(history.MetadataKeyServerShardID, strconv.Itoa(int(newServerShardID.ShardID)))
}

outgoingContext := metadata.NewOutgoingContext(targetStreamServer.Context(), targetMetadata)
outgoingContext, cancel := context.WithCancel(outgoingContext)
defer cancel()
Expand Down Expand Up @@ -327,3 +374,12 @@ func (s *adminServiceProxyServer) StreamWorkflowReplicationMessages(
wg.Wait()
return nil
}

func mapShardIDUnique(sourceShardCount, targetShardCount, sourceShardID int32) int32 {
targetShardID := servercommon.MapShardID(sourceShardCount, targetShardCount, sourceShardID)
if len(targetShardID) != 1 {
panic(fmt.Sprintf("remapping shard count error: sourceShardCount=%d targetShardCount=%d sourceShardID=%d targetShardID=%v\n",
sourceShardCount, targetShardCount, sourceShardID, targetShardID))
}
return targetShardID[0]
}