diff --git a/core/services/gateway/config/config.go b/core/services/gateway/config/config.go index bc717413011..59e3f3c14b3 100644 --- a/core/services/gateway/config/config.go +++ b/core/services/gateway/config/config.go @@ -63,6 +63,16 @@ type ShardedDONConfig struct { F int Shards []Shard } + +// ShardDONID returns the donID for a given shard +func ShardDONID(donName string, shardIdx int) string { + if shardIdx == 0 { + // NOTE: special case for backward compatibility - shard 0 doesn't have an index suffix + return donName + } + return fmt.Sprintf("%s_%d", donName, shardIdx) +} + type Shard struct { Nodes []NodeConfig } diff --git a/core/services/gateway/connectionmanager.go b/core/services/gateway/connectionmanager.go index 30a1e823358..d05578fd628 100644 --- a/core/services/gateway/connectionmanager.go +++ b/core/services/gateway/connectionmanager.go @@ -72,7 +72,7 @@ func (m *connectionManager) Name() string { return m.lggr.Name() } type donConnectionManager struct { donConfig *config.DONConfig nodes map[string]*nodeState - handler handlers.Handler + handlers map[string]handlers.Handler // service name -> handler closeWait sync.WaitGroup shutdownCh services.StopChan gMetrics *monitoring.GatewayMetrics @@ -102,30 +102,43 @@ func NewConnectionManager(gwConfig *config.GatewayConfig, clock clockwork.Clock, if ok { return nil, fmt.Errorf("duplicate DON ID %s", donConfig.DonId) } - nodes := make(map[string]*nodeState) - for _, nodeConfig := range donConfig.Members { - nodeAddress := strings.ToLower(nodeConfig.Address) - _, ok := nodes[nodeAddress] - if ok { - return nil, fmt.Errorf("duplicate node address %s in DON %s", nodeAddress, donConfig.DonId) - } - connWrapper := network.NewWSConnectionWrapper(lggr) - if connWrapper == nil { - return nil, fmt.Errorf("error creating WSConnectionWrapper for node %s", nodeAddress) - } - nodes[nodeAddress] = &nodeState{ - name: nodeConfig.Name, - conn: connWrapper, - } + nodes, err := buildNodeStates(donConfig.Members, donConfig.DonId, lggr) + if err != nil { + return nil, err } dons[donConfig.DonId] = &donConnectionManager{ donConfig: &donConfig, nodes: nodes, + handlers: make(map[string]handlers.Handler), shutdownCh: make(chan struct{}), gMetrics: gMetrics, lggr: logger.Named(lggr, "DONConnectionManager."+donConfig.DonId), } } + for _, shardedDON := range gwConfig.ShardedDONs { + for shardIdx, shard := range shardedDON.Shards { + donID := config.ShardDONID(shardedDON.DonName, shardIdx) + if _, ok := dons[donID]; ok { + return nil, fmt.Errorf("duplicate DON ID %s", donID) + } + nodes, err := buildNodeStates(shard.Nodes, donID, lggr) + if err != nil { + return nil, err + } + dons[donID] = &donConnectionManager{ + donConfig: &config.DONConfig{ + DonId: donID, + F: shardedDON.F, + Members: shard.Nodes, + }, + nodes: nodes, + handlers: make(map[string]handlers.Handler), + shutdownCh: make(chan struct{}), + gMetrics: gMetrics, + lggr: logger.Named(lggr, "DONConnectionManager."+donID), + } + } + } connMgr := &connectionManager{ config: &gwConfig.ConnectionManagerConfig, dons: dons, @@ -142,6 +155,25 @@ func NewConnectionManager(gwConfig *config.GatewayConfig, clock clockwork.Clock, return connMgr, nil } +func buildNodeStates(members []config.NodeConfig, donID string, lggr logger.Logger) (map[string]*nodeState, error) { + nodes := make(map[string]*nodeState) + for _, nodeConfig := range members { + nodeAddress := strings.ToLower(nodeConfig.Address) + if _, ok := nodes[nodeAddress]; ok { + return nil, fmt.Errorf("duplicate node address %s in DON %s", nodeAddress, donID) + } + connWrapper := network.NewWSConnectionWrapper(lggr) + if connWrapper == nil { + return nil, fmt.Errorf("error creating WSConnectionWrapper for node %s", nodeAddress) + } + nodes[nodeAddress] = &nodeState{ + name: nodeConfig.Name, + conn: connWrapper, + } + } + return nodes, nil +} + func (m *connectionManager) DONConnectionManager(donId string) *donConnectionManager { return m.dons[donId] } @@ -263,8 +295,22 @@ func (m *connectionManager) GetPort() int { return m.wsServer.GetPort() } -func (m *donConnectionManager) SetHandler(handler handlers.Handler) { - m.handler = handler +func (m *donConnectionManager) SetHandler(serviceName string, handler handlers.Handler) { + m.handlers[serviceName] = handler +} + +func (m *donConnectionManager) getHandler(method string) (handlers.Handler, error) { + if len(m.handlers) == 1 { + for _, h := range m.handlers { + return h, nil // supports legacy single-handler case + } + } + serviceName := strings.Split(method, ".")[0] + handler, ok := m.handlers[serviceName] + if !ok { + return nil, fmt.Errorf("no handler for service %q (method %q)", serviceName, method) + } + return handler, nil } func (m *donConnectionManager) SendToNode(ctx context.Context, nodeAddress string, req *jsonrpc.Request[json.RawMessage]) error { @@ -297,8 +343,13 @@ func (m *donConnectionManager) readLoop(nodeAddress string, nodeState *nodeState m.lggr.Errorw("parse error when reading from node", "nodeAddress", nodeAddress, "err", err) break } + handler, err := m.getHandler(resp.Method) + if err != nil { + m.lggr.Errorw("no handler for node message", "nodeAddress", nodeAddress, "method", resp.Method, "err", err) + break + } startTime := time.Now() - err = m.handler.HandleNodeMessage(ctx, &resp, nodeAddress) + err = handler.HandleNodeMessage(ctx, &resp, nodeAddress) m.gMetrics.RecordNodeMsgHandlerInvocation(ctx, nodeAddress, nodeState.name, err == nil) m.gMetrics.RecordNodeMsgHandlerDuration(ctx, nodeAddress, nodeState.name, time.Since(startTime), err == nil) if err != nil { diff --git a/core/services/gateway/connectionmanager_test.go b/core/services/gateway/connectionmanager_test.go index 05d1186d4ce..befd592fadc 100644 --- a/core/services/gateway/connectionmanager_test.go +++ b/core/services/gateway/connectionmanager_test.go @@ -244,6 +244,176 @@ func TestConnectionManager_CleanStartClose(t *testing.T) { require.NoError(t, err) } +func TestConnectionManager_ShardedDONs_CreatesPerShardManagers(t *testing.T) { + t.Parallel() + + tomlConfig := ` +[nodeServerConfig] +Path = "/node" + +[[shardedDONs]] +DonName = "myDON" +F = 1 + +[[shardedDONs.Shards]] +[[shardedDONs.Shards.Nodes]] +Name = "s0_n0" +Address = "0x0001020304050607080900010203040506070809" +[[shardedDONs.Shards.Nodes]] +Name = "s0_n1" +Address = "0x0002020304050607080900010203040506070809" +[[shardedDONs.Shards.Nodes]] +Name = "s0_n2" +Address = "0x0003020304050607080900010203040506070809" +[[shardedDONs.Shards.Nodes]] +Name = "s0_n3" +Address = "0x0004020304050607080900010203040506070809" + +[[shardedDONs.Shards]] +[[shardedDONs.Shards.Nodes]] +Name = "s1_n0" +Address = "0x0005020304050607080900010203040506070809" +[[shardedDONs.Shards.Nodes]] +Name = "s1_n1" +Address = "0x0006020304050607080900010203040506070809" +[[shardedDONs.Shards.Nodes]] +Name = "s1_n2" +Address = "0x0007020304050607080900010203040506070809" +[[shardedDONs.Shards.Nodes]] +Name = "s1_n3" +Address = "0x0008020304050607080900010203040506070809" +` + + cfg := parseTOMLConfig(t, tomlConfig) + mgr := newConnectionManager(t, cfg, clockwork.NewFakeClock()) + + require.NotNil(t, mgr.DONConnectionManager(config.ShardDONID("myDON", 0)), "shard 0 connection manager should exist") + require.NotNil(t, mgr.DONConnectionManager(config.ShardDONID("myDON", 1)), "shard 1 connection manager should exist") + require.Nil(t, mgr.DONConnectionManager("myDON_2"), "shard 2 should not exist") +} + +func TestConnectionManager_ShardedDONs_MultipleDONs(t *testing.T) { + t.Parallel() + + tomlConfig := ` +[nodeServerConfig] +Path = "/node" + +[[shardedDONs]] +DonName = "donA" +F = 0 + +[[shardedDONs.Shards]] +[[shardedDONs.Shards.Nodes]] +Name = "a_n0" +Address = "0x0001020304050607080900010203040506070809" + +[[shardedDONs]] +DonName = "donB" +F = 0 + +[[shardedDONs.Shards]] +[[shardedDONs.Shards.Nodes]] +Name = "b_n0" +Address = "0x0002020304050607080900010203040506070809" +` + + cfg := parseTOMLConfig(t, tomlConfig) + mgr := newConnectionManager(t, cfg, clockwork.NewFakeClock()) + + require.NotNil(t, mgr.DONConnectionManager(config.ShardDONID("donA", 0))) + require.NotNil(t, mgr.DONConnectionManager(config.ShardDONID("donB", 0))) +} + +func TestConnectionManager_ShardedDONs_DuplicateNodeAddress(t *testing.T) { + t.Parallel() + + tomlConfig := ` +[nodeServerConfig] +Path = "/node" + +[[shardedDONs]] +DonName = "myDON" +F = 0 + +[[shardedDONs.Shards]] +[[shardedDONs.Shards.Nodes]] +Name = "n0" +Address = "0x0001020304050607080900010203040506070809" +[[shardedDONs.Shards.Nodes]] +Name = "n1" +Address = "0x0001020304050607080900010203040506070809" +` + + cfg := parseTOMLConfig(t, tomlConfig) + lggr := logger.Test(t) + gMetrics, err := monitoring.NewGatewayMetrics() + require.NoError(t, err) + _, err = gateway.NewConnectionManager(cfg, clockwork.NewFakeClock(), gMetrics, lggr, limits.Factory{Logger: lggr}) + require.Error(t, err) + require.Contains(t, err.Error(), "duplicate node address") +} + +func TestConnectionManager_ShardedDONs_SendToNode(t *testing.T) { + t.Parallel() + + tomlConfig := ` +[nodeServerConfig] +Path = "/node" + +[[shardedDONs]] +DonName = "myDON" +F = 0 + +[[shardedDONs.Shards]] +[[shardedDONs.Shards.Nodes]] +Name = "n0" +Address = "0x0001020304050607080900010203040506070809" +` + + cfg := parseTOMLConfig(t, tomlConfig) + mgr := newConnectionManager(t, cfg, clockwork.NewFakeClock()) + + donMgr := mgr.DONConnectionManager(config.ShardDONID("myDON", 0)) + require.NotNil(t, donMgr) + + err := donMgr.SendToNode(testutils.Context(t), "0x0001020304050607080900010203040506070809", nil) + require.Error(t, err, "nil request should fail") + + message := &jsonrpc.Request[json.RawMessage]{} + err = donMgr.SendToNode(testutils.Context(t), "0xdeadbeef", message) + require.Error(t, err, "unknown node should fail") +} + +func TestConnectionManager_ShardedDONs_StartClose(t *testing.T) { + t.Parallel() + + tomlConfig := ` +[nodeServerConfig] +Path = "/node" +[connectionManagerConfig] +HeartbeatIntervalSec = 1 + +[[shardedDONs]] +DonName = "myDON" +F = 0 + +[[shardedDONs.Shards]] +[[shardedDONs.Shards.Nodes]] +Name = "n0" +Address = "0x0001020304050607080900010203040506070809" +` + + cfg := parseTOMLConfig(t, tomlConfig) + mgr := newConnectionManager(t, cfg, clockwork.NewFakeClock()) + + err := mgr.Start(testutils.Context(t)) + require.NoError(t, err) + + err = mgr.Close() + require.NoError(t, err) +} + func newConnectionManager(t *testing.T, gwConfig *config.GatewayConfig, clock clockwork.Clock) gateway.ConnectionManager { lggr := logger.Test(t) gMetrics, err := monitoring.NewGatewayMetrics() diff --git a/core/services/gateway/gateway.go b/core/services/gateway/gateway.go index 641698d5779..d39ab616d58 100644 --- a/core/services/gateway/gateway.go +++ b/core/services/gateway/gateway.go @@ -141,7 +141,7 @@ func setupFromNewConfig( var shardConnMgrs []handlers.DON for shardIdx := range donCfg.Shards { - donID := fmt.Sprintf("%s_%d", donName, shardIdx) + donID := config.ShardDONID(donName, shardIdx) donConnMgr := connMgr.DONConnectionManager(donID) if donConnMgr == nil { return nil, fmt.Errorf("connection manager for DON %s shard %d not found", donName, shardIdx) @@ -158,15 +158,15 @@ func setupFromNewConfig( serviceToMultiHandler[svc.ServiceName] = handler - // Set (multi)handler on all associated DON connection managers + // Set (multi)handler on all associated DON connection managers, keyed by service name for i, donName := range svc.DONs { for shardIdx := range shardsConnMgrs[i] { - donID := fmt.Sprintf("%s_%d", donName, shardIdx) + donID := config.ShardDONID(donName, shardIdx) donConnMgr := connMgr.DONConnectionManager(donID) if donConnMgr == nil { return nil, fmt.Errorf("connection manager for DON %s shard %d not found", donName, shardIdx) } - donConnMgr.SetHandler(handler) + donConnMgr.SetHandler(svc.ServiceName, handler) } } @@ -235,7 +235,7 @@ func setupFromLegacyConfig( } } - donConnMgr.SetHandler(handler) + donConnMgr.SetHandler("", handler) } return handlerMap, serviceNameToDonID, nil diff --git a/core/services/gateway/gateway_test.go b/core/services/gateway/gateway_test.go index aaf2ece4bbc..a919096bd35 100644 --- a/core/services/gateway/gateway_test.go +++ b/core/services/gateway/gateway_test.go @@ -255,6 +255,158 @@ Name = "dummy" require.NoError(t, cfg.Validate()) } +func TestGateway_NewGatewayFromConfig_NewStyleConfig(t *testing.T) { + t.Parallel() + + tomlConfig := buildConfig(` +[[shardedDONs]] +DonName = "donA" +F = 1 + +[[shardedDONs.Shards]] +[[shardedDONs.Shards.Nodes]] +Name = "donA_s0_n0" +Address = "0x0001020304050607080900010203040506070809" +[[shardedDONs.Shards.Nodes]] +Name = "donA_s0_n1" +Address = "0x0002020304050607080900010203040506070809" +[[shardedDONs.Shards.Nodes]] +Name = "donA_s0_n2" +Address = "0x0003020304050607080900010203040506070809" +[[shardedDONs.Shards.Nodes]] +Name = "donA_s0_n3" +Address = "0x0004020304050607080900010203040506070809" + +[[shardedDONs.Shards]] +[[shardedDONs.Shards.Nodes]] +Name = "donA_s1_n0" +Address = "0x0005020304050607080900010203040506070809" +[[shardedDONs.Shards.Nodes]] +Name = "donA_s1_n1" +Address = "0x0006020304050607080900010203040506070809" +[[shardedDONs.Shards.Nodes]] +Name = "donA_s1_n2" +Address = "0x0007020304050607080900010203040506070809" +[[shardedDONs.Shards.Nodes]] +Name = "donA_s1_n3" +Address = "0x0008020304050607080900010203040506070809" + +[[shardedDONs]] +DonName = "donB" +F = 1 + +[[shardedDONs.Shards]] +[[shardedDONs.Shards.Nodes]] +Name = "donB_s0_n0" +Address = "0x0011020304050607080900010203040506070809" +[[shardedDONs.Shards.Nodes]] +Name = "donB_s0_n1" +Address = "0x0012020304050607080900010203040506070809" +[[shardedDONs.Shards.Nodes]] +Name = "donB_s0_n2" +Address = "0x0013020304050607080900010203040506070809" +[[shardedDONs.Shards.Nodes]] +Name = "donB_s0_n3" +Address = "0x0014020304050607080900010203040506070809" + +[[services]] +ServiceName = "workflows" +DONs = ["donA"] + +[[services.Handlers]] +Name = "dummy" + +[[services]] +ServiceName = "vault" +DONs = ["donB"] + +[[services.Handlers]] +Name = "dummy" +`) + + lggr := logger.Test(t) + cfg := parseTOMLConfig(t, tomlConfig) + require.NoError(t, cfg.Validate()) + + gatewayObj, err := gateway.NewGatewayFromConfig(cfg, newGatewayHandler(t), lggr, limits.Factory{Logger: lggr}) + require.NoError(t, err) + require.NotNil(t, gatewayObj) +} + +func TestGateway_NewGatewayFromConfig_NewStyleConfig_UserRouting(t *testing.T) { + t.Parallel() + + tomlConfig := buildConfig(` +[[shardedDONs]] +DonName = "donA" +F = 0 + +[[shardedDONs.Shards]] +[[shardedDONs.Shards.Nodes]] +Name = "donA_s0_n0" +Address = "0x0001020304050607080900010203040506070809" + +[[services]] +ServiceName = "svcA" +DONs = ["donA"] + +[[services.Handlers]] +Name = "dummy" +ServiceName = "svcA" + +[[services.Handlers]] +Name = "dummy2" +ServiceName = "svcA" +`) + + newServiceHandler := func(method string) *handlermocks.Handler { + h := handlermocks.NewHandler(t) + h.On("Methods").Return([]string{method}) + h.On("HandleJSONRPCUserMessage", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) { + req := args.Get(1).(jsonrpc.Request[json.RawMessage]) + cb := args.Get(2).(handlers.Callback) + rm := json.RawMessage(`{"result":"OK"}`) + resp, err := json.Marshal(&jsonrpc.Response[json.RawMessage]{ + Version: jsonrpc.JsonRpcVersion, ID: req.ID, Method: req.Method, Result: &rm, + }) + require.NoError(t, err) + require.NoError(t, cb.SendResponse(handlers.UserCallbackPayload{RawResponse: resp, ErrorCode: api.NoError})) + }).Maybe() + return h + } + + handler1 := newServiceHandler("svcA.action1") + handler2 := newServiceHandler("svcA.action2") + factory := &handlerFactory{handlers: map[string]handlers.Handler{ + "dummy": handler1, + "dummy2": handler2, + }} + + lggr := logger.Test(t) + cfg := parseTOMLConfig(t, tomlConfig) + require.NoError(t, cfg.Validate()) + + gatewayObj, err := gateway.NewGatewayFromConfig(cfg, factory, lggr, limits.Factory{Logger: lggr}) + require.NoError(t, err) + + ctx := testutils.Context(t) + + req := newJSONRpcRequest(t, "r1", "svcA.action1", []byte(`{}`)) + response, statusCode := gatewayObj.ProcessRequest(ctx, req, "") + require.Equal(t, 200, statusCode, string(response)) + requireJSONRPCResult(t, "svcA.action1", response, "r1", `{"result":"OK"}`) + + req = newJSONRpcRequest(t, "r2", "svcA.action2", []byte(`{}`)) + response, statusCode = gatewayObj.ProcessRequest(ctx, req, "") + require.Equal(t, 200, statusCode, string(response)) + requireJSONRPCResult(t, "svcA.action2", response, "r2", `{"result":"OK"}`) + + req = newJSONRpcRequest(t, "r3", "unknown.method", []byte(`{}`)) + response, statusCode = gatewayObj.ProcessRequest(ctx, req, "") + require.Equal(t, 400, statusCode) + requireJSONRPCError(t, response, "r3", jsonrpc.ErrInvalidRequest, "Service name not found: unknown") +} + func TestGateway_CleanStartAndClose(t *testing.T) { t.Parallel() diff --git a/core/services/gateway/handlers/handler.dummy_test.go b/core/services/gateway/handlers/handler.dummy_test.go index beb39ddcf56..b59ef00d035 100644 --- a/core/services/gateway/handlers/handler.dummy_test.go +++ b/core/services/gateway/handlers/handler.dummy_test.go @@ -27,7 +27,7 @@ type testConnManager struct { sendCounter int } -func (m *testConnManager) SetHandler(handler handlers.Handler) { +func (m *testConnManager) SetHandler(_ string, handler handlers.Handler) { m.handler = handler } @@ -49,7 +49,7 @@ func TestDummyHandler_BasicFlow(t *testing.T) { connMgr := testConnManager{} handler, err := handlers.NewDummyHandler(&config, &connMgr, logger.Test(t)) require.NoError(t, err) - connMgr.SetHandler(handler) + connMgr.SetHandler("", handler) ctx := testutils.Context(t)