From 3fb32e884e88d138765295e1210113dfeb8d882d Mon Sep 17 00:00:00 2001 From: Bolek Kulbabinski <1416262+bolekk@users.noreply.github.com> Date: Fri, 27 Feb 2026 14:32:56 -0800 Subject: [PATCH] Gateway multi handler conversions --- .../workflow_metadata_aggregator.go | 50 +++-- core/services/gateway/handler_factory.go | 2 +- .../handlers/capabilities/v2/http_handler.go | 23 +- .../capabilities/v2/http_handler_test.go | 81 ++++--- .../capabilities/v2/http_trigger_handler.go | 49 +++-- .../v2/http_trigger_handler_test.go | 51 +++-- .../capabilities/v2/metrics/metrics.go | 11 +- .../capabilities/v2/metrics/metrics_test.go | 10 +- .../v2/workflow_metadata_handler.go | 83 +++++-- .../v2/workflow_metadata_handler_test.go | 18 +- .../services/gateway/handlers/shard_router.go | 104 +++++++++ .../gateway/handlers/shard_router_test.go | 203 ++++++++++++++++++ 12 files changed, 555 insertions(+), 130 deletions(-) create mode 100644 core/services/gateway/handlers/shard_router.go create mode 100644 core/services/gateway/handlers/shard_router_test.go diff --git a/core/services/gateway/common/aggregation/workflow_metadata_aggregator.go b/core/services/gateway/common/aggregation/workflow_metadata_aggregator.go index 58894d458d3..ca26496bfce 100644 --- a/core/services/gateway/common/aggregation/workflow_metadata_aggregator.go +++ b/core/services/gateway/common/aggregation/workflow_metadata_aggregator.go @@ -140,40 +140,64 @@ func (agg *WorkflowMetadataAggregator) Collect(obs *gateway_common.WorkflowMetad return nil } +// AggregatedWorkflow pairs a workflow's metadata with the set of node addresses that reported it. +type AggregatedWorkflow struct { + Metadata gateway_common.WorkflowMetadata + Reporters StringSet +} + // Aggregate returns the aggregated workflow metadata for workflows that have reached the threshold. // Results are sorted chronologically by sequence number (newest first, oldest last). func (agg *WorkflowMetadataAggregator) Aggregate() ([]gateway_common.WorkflowMetadata, error) { + results, err := agg.AggregateWithReporters() + if err != nil { + return nil, err + } + aggregated := make([]gateway_common.WorkflowMetadata, len(results)) + for i, r := range results { + aggregated[i] = r.Metadata + } + return aggregated, nil +} + +// AggregateWithReporters returns aggregated workflow metadata together with the set of +// node addresses that reported each workflow. This allows callers to determine which +// shard owns a workflow based on the reporting nodes. +func (agg *WorkflowMetadataAggregator) AggregateWithReporters() ([]AggregatedWorkflow, error) { agg.mu.RLock() defer agg.mu.RUnlock() - type aggregatedObs struct { - metadata gateway_common.WorkflowMetadata + type sortable struct { + result AggregatedWorkflow sequence uint64 } - var toSort []aggregatedObs + var toSort []sortable for _, nodeObs := range agg.observations { if len(nodeObs.nodes) >= agg.threshold { - toSort = append(toSort, aggregatedObs{ - metadata: *nodeObs.observation, + reportersCopy := make(StringSet, len(nodeObs.nodes)) + for addr := range nodeObs.nodes { + reportersCopy.Add(addr) + } + toSort = append(toSort, sortable{ + result: AggregatedWorkflow{ + Metadata: *nodeObs.observation, + Reporters: reportersCopy, + }, sequence: nodeObs.sequence, }) } } - // Sort chronologically (newest first) so that workflows that were registered most recently - // takes precedence sort.Slice(toSort, func(i, j int) bool { return toSort[i].sequence > toSort[j].sequence }) - // Extract just the metadata - aggregated := make([]gateway_common.WorkflowMetadata, len(toSort)) - for i, obs := range toSort { - aggregated[i] = obs.metadata + results := make([]AggregatedWorkflow, len(toSort)) + for i, s := range toSort { + results[i] = s.result } - - return aggregated, nil + return results, nil } type NodeObservations struct { diff --git a/core/services/gateway/handler_factory.go b/core/services/gateway/handler_factory.go index 76172b3dc9b..cb32c297a39 100644 --- a/core/services/gateway/handler_factory.go +++ b/core/services/gateway/handler_factory.go @@ -83,7 +83,7 @@ func (hf *handlerFactory) NewHandler( case WebAPICapabilitiesType: return capabilities.NewHandler(handlerConfig, donConfig, don, hf.httpClient, hf.lggr) case HTTPCapabilityType: - return v2.NewGatewayHandler(handlerConfig, donConfig, don, hf.httpClient, hf.lggr, hf.lf) + return v2.NewGatewayHandler(handlerConfig, shardedDONs, shardsConnMgrs, hf.httpClient, hf.lggr, hf.lf) case VaultHandlerType: requestAuthorizer := vaultcap.NewRequestAuthorizer(hf.lggr, hf.workflowRegistrySyncer) return vault.NewHandler(handlerConfig, donConfig, don, hf.capabilitiesRegistry, requestAuthorizer, hf.lggr, clockwork.NewRealClock(), hf.lf) diff --git a/core/services/gateway/handlers/capabilities/v2/http_handler.go b/core/services/gateway/handlers/capabilities/v2/http_handler.go index 17820849e9a..5d4daa8f6e1 100644 --- a/core/services/gateway/handlers/capabilities/v2/http_handler.go +++ b/core/services/gateway/handlers/capabilities/v2/http_handler.go @@ -110,9 +110,14 @@ type RetryConfig struct { Multiplier float64 `json:"multiplier"` } -func NewGatewayHandler(handlerConfig json.RawMessage, donConfig *config.DONConfig, don handlers.DON, httpClient network.HTTPClient, lggr logger.Logger, lf limits.Factory) (*gatewayHandler, error) { +func NewGatewayHandler(handlerConfig json.RawMessage, shardedDONs []config.ShardedDONConfig, shardsConnMgrs [][]handlers.DON, httpClient network.HTTPClient, lggr logger.Logger, lf limits.Factory) (*gatewayHandler, error) { + shardRouter, err := handlers.NewShardRouter(shardedDONs, shardsConnMgrs) + if err != nil { + return nil, fmt.Errorf("failed to create shard router: %w", err) + } + var cfg ServiceConfig - err := json.Unmarshal(handlerConfig, &cfg) + err = json.Unmarshal(handlerConfig, &cfg) if err != nil { return nil, err } @@ -126,24 +131,24 @@ func NewGatewayHandler(handlerConfig json.RawMessage, donConfig *config.DONConfi return nil, fmt.Errorf("failed to create user rate limiter: %w", err) } - metrics, err := metrics.NewMetrics(donConfig) + m, err := metrics.NewMetrics(shardRouter.AllMembers()) if err != nil { return nil, fmt.Errorf("failed to initialize metrics: %w", err) } - metadataHandler := NewWorkflowMetadataHandler(lggr, cfg, don, donConfig, metrics) - triggerHandler := NewHTTPTriggerHandler(lggr, cfg, donConfig, don, metadataHandler, userRateLimiter, metrics) + metadataHandler := NewWorkflowMetadataHandler(lggr, cfg, shardRouter, m) + triggerHandler := NewHTTPTriggerHandler(lggr, cfg, shardRouter, metadataHandler, userRateLimiter, m) return &gatewayHandler{ config: cfg, - don: don, - lggr: logger.With(logger.Named(lggr, handlerName), "donId", donConfig.DonId), + don: shardRouter, + lggr: logger.With(logger.Named(lggr, handlerName), "donId", shardRouter.DonID()), httpClient: httpClient, nodeRateLimiter: nodeRateLimiter, stopCh: make(services.StopChan), - responseCache: newResponseCache(lggr, cfg.OutboundRequestCacheTTLMs, metrics), + responseCache: newResponseCache(lggr, cfg.OutboundRequestCacheTTLMs, m), triggerHandler: triggerHandler, metadataHandler: metadataHandler, - metrics: metrics, + metrics: m, }, nil } diff --git a/core/services/gateway/handlers/capabilities/v2/http_handler_test.go b/core/services/gateway/handlers/capabilities/v2/http_handler_test.go index 202d09932d3..abe5f0f96bf 100644 --- a/core/services/gateway/handlers/capabilities/v2/http_handler_test.go +++ b/core/services/gateway/handlers/capabilities/v2/http_handler_test.go @@ -20,12 +20,24 @@ import ( gateway_common "github.com/smartcontractkit/chainlink-common/pkg/types/gateway" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers" triggermocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities/v2/mocks" handlermocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/mocks" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/network" httpmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/network/mocks" ) +// toShardedArgs converts legacy DONConfig + DON into the sharded format expected by NewGatewayHandler. +func toShardedArgs(donCfg *config.DONConfig, don handlers.DON) ([]config.ShardedDONConfig, [][]handlers.DON) { + shardedDONs := []config.ShardedDONConfig{{ + DonName: donCfg.DonId, + F: donCfg.F, + Shards: []config.Shard{{Nodes: donCfg.Members}}, + }} + connMgrs := [][]handlers.DON{{don}} + return shardedDONs, connMgrs +} + func TestNewGatewayHandler(t *testing.T) { t.Run("successful creation", func(t *testing.T) { cfg := serviceCfg() @@ -39,7 +51,8 @@ func TestNewGatewayHandler(t *testing.T) { mockHTTPClient := httpmocks.NewHTTPClient(t) lggr := logger.Test(t) - handler, err := NewGatewayHandler(configBytes, donConfig, mockDon, mockHTTPClient, lggr, limits.Factory{Logger: lggr}) + shardedDONs, connMgrs := toShardedArgs(donConfig, mockDon) + handler, err := NewGatewayHandler(configBytes, shardedDONs, connMgrs, mockHTTPClient, lggr, limits.Factory{Logger: lggr}) require.NoError(t, err) require.NotNil(t, handler) require.NotNil(t, handler.responseCache) @@ -54,7 +67,8 @@ func TestNewGatewayHandler(t *testing.T) { mockHTTPClient := httpmocks.NewHTTPClient(t) lggr := logger.Test(t) - handler, err := NewGatewayHandler(invalidConfig, donConfig, mockDon, mockHTTPClient, lggr, limits.Factory{Logger: lggr}) + shardedDONs, connMgrs := toShardedArgs(donConfig, mockDon) + handler, err := NewGatewayHandler(invalidConfig, shardedDONs, connMgrs, mockHTTPClient, lggr, limits.Factory{Logger: lggr}) require.Error(t, err) require.Nil(t, handler) }) @@ -74,7 +88,8 @@ func TestNewGatewayHandler(t *testing.T) { mockHTTPClient := httpmocks.NewHTTPClient(t) lggr := logger.Test(t) - handler, err := NewGatewayHandler(configBytes, donConfig, mockDon, mockHTTPClient, lggr, limits.Factory{Logger: lggr}) + shardedDONs, connMgrs := toShardedArgs(donConfig, mockDon) + handler, err := NewGatewayHandler(configBytes, shardedDONs, connMgrs, mockHTTPClient, lggr, limits.Factory{Logger: lggr}) require.Error(t, err) require.Nil(t, handler) }) @@ -97,7 +112,8 @@ func TestNewGatewayHandler(t *testing.T) { mockHTTPClient := httpmocks.NewHTTPClient(t) lggr := logger.Test(t) - handler, err := NewGatewayHandler(configBytes, donConfig, mockDon, mockHTTPClient, lggr, limits.Factory{Logger: lggr}) + shardedDONs, connMgrs := toShardedArgs(donConfig, mockDon) + handler, err := NewGatewayHandler(configBytes, shardedDONs, connMgrs, mockHTTPClient, lggr, limits.Factory{Logger: lggr}) require.NoError(t, err) require.NotNil(t, handler) require.Equal(t, defaultCleanUpPeriodMs, handler.config.CleanUpPeriodMs) // Default value @@ -105,10 +121,9 @@ func TestNewGatewayHandler(t *testing.T) { } func TestHandleNodeMessage(t *testing.T) { - handler := createTestHandler(t) + handler, mockDon := createTestHandler(t) t.Run("successful node message handling", func(t *testing.T) { - mockDon := handler.don.(*handlermocks.DON) mockHTTPClient := handler.httpClient.(*httpmocks.HTTPClient) // Prepare outbound request @@ -149,7 +164,6 @@ func TestHandleNodeMessage(t *testing.T) { }) t.Run("successful node message handling with MultiHeaders", func(t *testing.T) { - mockDon := handler.don.(*handlermocks.DON) mockHTTPClient := handler.httpClient.(*httpmocks.HTTPClient) // Prepare outbound request @@ -247,7 +261,6 @@ func TestHandleNodeMessage(t *testing.T) { Result: &rawRequest, } - mockDon := handler.don.(*handlermocks.DON) // First call: should fetch from HTTP client and cache the response httpResp := &network.HTTPResponse{ StatusCode: 200, @@ -293,7 +306,6 @@ func TestHandleNodeMessage(t *testing.T) { Result: &rawRequest, } - mockDon := handler.don.(*handlermocks.DON) mockHTTPClient := handler.httpClient.(*httpmocks.HTTPClient) httpResp := &network.HTTPResponse{ StatusCode: 500, @@ -345,7 +357,7 @@ func TestHandleNodeMessage(t *testing.T) { } func TestServiceLifecycle(t *testing.T) { - handler := createTestHandler(t) + handler, _ := createTestHandler(t) t.Run("start and stop", func(t *testing.T) { ctx := testutils.Context(t) @@ -367,7 +379,7 @@ func TestHandleNodeMessage_RoutesToTriggerHandler(t *testing.T) { // This test covers the case where the response ID does not contain a "/" // and should be routed to the triggerHandler.HandleNodeTriggerResponse. mockTriggerHandler := triggermocks.NewHTTPTriggerHandler(t) - handler := createTestHandler(t) + handler, _ := createTestHandler(t) handler.triggerHandler = mockTriggerHandler rawRes := json.RawMessage([]byte(`{}`)) @@ -388,7 +400,7 @@ func TestHandleNodeMessage_RoutesToTriggerHandler(t *testing.T) { } func TestHandleNodeMessage_UnsupportedMethod(t *testing.T) { - handler := createTestHandler(t) + handler, _ := createTestHandler(t) rawRes := json.RawMessage([]byte(`{}`)) resp := &jsonrpc.Response[json.RawMessage]{ ID: "unsupportedMethod/123", @@ -402,7 +414,7 @@ func TestHandleNodeMessage_UnsupportedMethod(t *testing.T) { } func TestHandleNodeMessage_EmptyID(t *testing.T) { - handler := createTestHandler(t) + handler, _ := createTestHandler(t) rawRes := json.RawMessage([]byte(`{}`)) resp := &jsonrpc.Response[json.RawMessage]{ ID: "", @@ -458,7 +470,8 @@ func TestGatewayHandler_Start_CallsDeleteExpired(t *testing.T) { mockHTTPClient := httpmocks.NewHTTPClient(t) lggr := logger.Test(t) - handler, err := NewGatewayHandler(configBytes, donConfig, mockDon, mockHTTPClient, lggr, limits.Factory{Logger: lggr}) + shardedDONs, connMgrs := toShardedArgs(donConfig, mockDon) + handler, err := NewGatewayHandler(configBytes, shardedDONs, connMgrs, mockHTTPClient, lggr, limits.Factory{Logger: lggr}) require.NoError(t, err) require.NotNil(t, handler) mockCache := newMockResponseCache() @@ -491,7 +504,7 @@ func serviceCfg() ServiceConfig { return WithDefaults(cfg) } -func createTestHandler(t *testing.T) *gatewayHandler { +func createTestHandler(t *testing.T) (*gatewayHandler, *handlermocks.DON) { cfg := serviceCfg() return createTestHandlerWithConfig(t, cfg) } @@ -504,22 +517,27 @@ func verifyBackwardCompatibility(t *testing.T, headers map[string]string, multiH } } -func createTestHandlerWithConfig(t *testing.T, cfg ServiceConfig) *gatewayHandler { +func createTestHandlerWithConfig(t *testing.T, cfg ServiceConfig) (*gatewayHandler, *handlermocks.DON) { configBytes, err := json.Marshal(cfg) require.NoError(t, err) donConfig := &config.DONConfig{ DonId: "test-don", + Members: []config.NodeConfig{ + {Name: "node1", Address: "node1"}, + {Name: "node2", Address: "node2"}, + }, } mockDon := handlermocks.NewDON(t) mockHTTPClient := httpmocks.NewHTTPClient(t) lggr := logger.Test(t) - handler, err := NewGatewayHandler(configBytes, donConfig, mockDon, mockHTTPClient, lggr, limits.Factory{Logger: lggr}) + shardedDONs, connMgrs := toShardedArgs(donConfig, mockDon) + handler, err := NewGatewayHandler(configBytes, shardedDONs, connMgrs, mockHTTPClient, lggr, limits.Factory{Logger: lggr}) require.NoError(t, err) require.NotNil(t, handler) - return handler + return handler, mockDon } func TestCreateHTTPRequestCallback(t *testing.T) { @@ -542,7 +560,7 @@ func TestCreateHTTPRequestCallback(t *testing.T) { } t.Run("successful HTTP request with latency measurement", func(t *testing.T) { - handler := createTestHandler(t) + handler, _ := createTestHandler(t) mockHTTPClient := handler.httpClient.(*httpmocks.HTTPClient) expectedResp := &network.HTTPResponse{ @@ -565,7 +583,7 @@ func TestCreateHTTPRequestCallback(t *testing.T) { }) t.Run("HTTP send error sets IsExternalEndpointError to true", func(t *testing.T) { - handler := createTestHandler(t) + handler, _ := createTestHandler(t) mockHTTPClient := handler.httpClient.(*httpmocks.HTTPClient) mockHTTPClient.EXPECT().Send(mock.Anything, mock.Anything).Return(nil, network.ErrHTTPSend) @@ -584,7 +602,7 @@ func TestCreateHTTPRequestCallback(t *testing.T) { }) t.Run("response with MultiHeaders is passed through correctly", func(t *testing.T) { - handler := createTestHandler(t) + handler, _ := createTestHandler(t) mockHTTPClient := handler.httpClient.(*httpmocks.HTTPClient) expectedResp := &network.HTTPResponse{ @@ -634,7 +652,7 @@ func TestCreateHTTPRequestCallback(t *testing.T) { }) t.Run("response with empty MultiHeaders still sets Headers", func(t *testing.T) { - handler := createTestHandler(t) + handler, _ := createTestHandler(t) mockHTTPClient := handler.httpClient.(*httpmocks.HTTPClient) expectedResp := &network.HTTPResponse{ @@ -661,7 +679,7 @@ func TestCreateHTTPRequestCallback(t *testing.T) { }) t.Run("HTTP read error sets IsExternalEndpointError to true", func(t *testing.T) { - handler := createTestHandler(t) + handler, _ := createTestHandler(t) mockHTTPClient := handler.httpClient.(*httpmocks.HTTPClient) mockHTTPClient.EXPECT().Send(mock.Anything, mock.Anything).Return(nil, network.ErrHTTPRead) @@ -680,7 +698,7 @@ func TestCreateHTTPRequestCallback(t *testing.T) { }) t.Run("other errors set IsExternalEndpointError to false", func(t *testing.T) { - handler := createTestHandler(t) + handler, _ := createTestHandler(t) mockHTTPClient := handler.httpClient.(*httpmocks.HTTPClient) genericError := errors.New("some other network error") @@ -701,8 +719,7 @@ func TestCreateHTTPRequestCallback(t *testing.T) { } func TestMakeOutgoingRequest_SendResponseUsesIndependentContext(t *testing.T) { - handler := createTestHandler(t) - mockDon := handler.don.(*handlermocks.DON) + handler, mockDon := createTestHandler(t) mockHTTPClient := handler.httpClient.(*httpmocks.HTTPClient) outboundReq := gateway_common.OutboundHTTPRequest{ @@ -744,7 +761,7 @@ func TestMakeOutgoingRequest_SendResponseUsesIndependentContext(t *testing.T) { // TestMakeOutgoingRequestCachingBehavior tests the specific caching logic in makeOutgoingRequest func TestMakeOutgoingRequestCachingBehavior(t *testing.T) { t.Run("MaxAgeMs=0 and Store=true calls Set", func(t *testing.T) { - handler := createTestHandler(t) + handler, mockDon := createTestHandler(t) mockCache := newMockResponseCache() handler.responseCache = mockCache @@ -766,7 +783,6 @@ func TestMakeOutgoingRequestCachingBehavior(t *testing.T) { Result: &rawRequest, } - mockDon := handler.don.(*handlermocks.DON) mockHTTPClient := handler.httpClient.(*httpmocks.HTTPClient) httpResp := &network.HTTPResponse{ StatusCode: 200, @@ -788,7 +804,7 @@ func TestMakeOutgoingRequestCachingBehavior(t *testing.T) { }) t.Run("MaxAgeMs=0 and Store=false does not call Set", func(t *testing.T) { - handler := createTestHandler(t) + handler, mockDon := createTestHandler(t) mockCache := newMockResponseCache() handler.responseCache = mockCache @@ -810,7 +826,6 @@ func TestMakeOutgoingRequestCachingBehavior(t *testing.T) { Result: &rawRequest, } - mockDon := handler.don.(*handlermocks.DON) mockHTTPClient := handler.httpClient.(*httpmocks.HTTPClient) httpResp := &network.HTTPResponse{ StatusCode: 200, @@ -832,7 +847,7 @@ func TestMakeOutgoingRequestCachingBehavior(t *testing.T) { }) t.Run("MaxAgeMs>0 calls CachedFetch", func(t *testing.T) { - handler := createTestHandler(t) + handler, mockDon := createTestHandler(t) mockCache := newMockResponseCache() handler.responseCache = mockCache @@ -854,7 +869,6 @@ func TestMakeOutgoingRequestCachingBehavior(t *testing.T) { Result: &rawRequest, } - mockDon := handler.don.(*handlermocks.DON) mockHTTPClient := handler.httpClient.(*httpmocks.HTTPClient) httpResp := &network.HTTPResponse{ StatusCode: 200, @@ -878,7 +892,7 @@ func TestMakeOutgoingRequestCachingBehavior(t *testing.T) { // setupRateLimitingTest creates common test setup for rate limiting tests func setupRateLimitingTest(t *testing.T, cfg ServiceConfig) (*gatewayHandler, *jsonrpc.Response[json.RawMessage], *httpmocks.HTTPClient, *handlermocks.DON) { - handler := createTestHandlerWithConfig(t, cfg) + handler, mockDon := createTestHandlerWithConfig(t, cfg) outboundReq := gateway_common.OutboundHTTPRequest{ Method: "GET", @@ -896,7 +910,6 @@ func setupRateLimitingTest(t *testing.T, cfg ServiceConfig) (*gatewayHandler, *j } mockHTTPClient := handler.httpClient.(*httpmocks.HTTPClient) - mockDon := handler.don.(*handlermocks.DON) return handler, resp, mockHTTPClient, mockDon } diff --git a/core/services/gateway/handlers/capabilities/v2/http_trigger_handler.go b/core/services/gateway/handlers/capabilities/v2/http_trigger_handler.go index 3bc88ffddba..b764f1602c5 100644 --- a/core/services/gateway/handlers/capabilities/v2/http_trigger_handler.go +++ b/core/services/gateway/handlers/capabilities/v2/http_trigger_handler.go @@ -23,7 +23,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/platform" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/common/aggregation" - "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities/v2/metrics" "github.com/smartcontractkit/chainlink/v2/core/services/job" @@ -51,8 +50,7 @@ type savedCallback struct { type httpTriggerHandler struct { services.StateMachine config ServiceConfig - don handlers.DON - donConfig *config.DONConfig + shards *handlers.ShardRouter lggr logger.Logger callbacksMu sync.Mutex callbacks map[string]savedCallback // requestID -> savedCallback @@ -69,17 +67,16 @@ type HTTPTriggerHandler interface { HandleNodeTriggerResponse(ctx context.Context, resp *jsonrpc.Response[json.RawMessage], nodeAddr string) error } -func NewHTTPTriggerHandler(lggr logger.Logger, cfg ServiceConfig, donConfig *config.DONConfig, don handlers.DON, workflowMetadataHandler *WorkflowMetadataHandler, userRateLimiter limits.RateLimiter, metrics *metrics.Metrics) *httpTriggerHandler { +func NewHTTPTriggerHandler(lggr logger.Logger, cfg ServiceConfig, shards *handlers.ShardRouter, workflowMetadataHandler *WorkflowMetadataHandler, userRateLimiter limits.RateLimiter, m *metrics.Metrics) *httpTriggerHandler { return &httpTriggerHandler{ lggr: logger.Named(lggr, "RequestCallbacks"), callbacks: make(map[string]savedCallback), config: cfg, - don: don, - donConfig: donConfig, + shards: shards, stopCh: make(services.StopChan), workflowMetadataHandler: workflowMetadataHandler, userRateLimiter: userRateLimiter, - metrics: metrics, + metrics: m, } } @@ -94,6 +91,12 @@ func (h *httpTriggerHandler) HandleUserTriggerRequest(ctx context.Context, req * return err } + shard, found := h.workflowMetadataHandler.GetWorkflowShard(workflowID) + if !found { + h.handleUserError(ctx, req.ID, jsonrpc.ErrInternal, "could not determine shard for workflow", callback) + return fmt.Errorf("shard not found for workflow %s", workflowID) + } + key, err := h.authorizeRequest(ctx, workflowID, req, callback) if err != nil { return err @@ -116,12 +119,12 @@ func (h *httpTriggerHandler) HandleUserTriggerRequest(ctx context.Context, req * return errors.New("error marshaling trigger request: " + err.Error()) } - doneCh, err := h.setupCallback(ctx, req.ID, callback, requestStartTime) + doneCh, err := h.setupCallback(ctx, req.ID, callback, requestStartTime, shard) if err != nil { return err } - return h.sendWithRetries(ctx, executionID, reqWithKey, doneCh) + return h.sendWithRetries(ctx, executionID, reqWithKey, doneCh, shard) } func (h *httpTriggerHandler) validatedTriggerRequest(ctx context.Context, req *jsonrpc.Request[json.RawMessage], callback handlers.Callback) (*jsonrpc.Request[gateway_common.HTTPTriggerRequest], error) { @@ -378,7 +381,7 @@ func (h *httpTriggerHandler) checkRateLimit(ctx context.Context, workflowID, req return nil } -func (h *httpTriggerHandler) setupCallback(ctx context.Context, requestID string, callback handlers.Callback, requestStartTime time.Time) (<-chan struct{}, error) { +func (h *httpTriggerHandler) setupCallback(ctx context.Context, requestID string, callback handlers.Callback, requestStartTime time.Time, shard handlers.ShardInfo) (<-chan struct{}, error) { h.callbacksMu.Lock() defer h.callbacksMu.Unlock() @@ -387,8 +390,8 @@ func (h *httpTriggerHandler) setupCallback(ctx context.Context, requestID string return nil, fmt.Errorf("in-flight request ID: %s", requestID) } - // (N+F)//2 + 1 threshold where N = number of nodes, F = number of faulty nodes - threshold := (len(h.donConfig.Members)+h.donConfig.F)/2 + 1 + // (N+F)//2 + 1 threshold where N = shard member count, F = shard's faulty node count + threshold := (len(shard.Members)+shard.F)/2 + 1 agg, err := aggregation.NewIdenticalNodeResponseAggregator(threshold) if err != nil { return nil, errors.New("failed to create response aggregator: " + err.Error()) @@ -543,15 +546,14 @@ func (h *httpTriggerHandler) handleUserError(ctx context.Context, requestID stri } } -// sendWithRetries attempts to send the request to all DON members, +// sendWithRetries attempts to send the request to all members of the owning shard, // retrying failed nodes until either all succeed or the max trigger request duration is reached. // doneCh is closed when the callback has been responded to (quorum reached), allowing immediate termination. -func (h *httpTriggerHandler) sendWithRetries(ctx context.Context, executionID string, req *jsonrpc.Request[json.RawMessage], doneCh <-chan struct{}) error { +func (h *httpTriggerHandler) sendWithRetries(ctx context.Context, executionID string, req *jsonrpc.Request[json.RawMessage], doneCh <-chan struct{}, shard handlers.ShardInfo) error { if doneCh == nil { return errors.New("doneCh cannot be nil") } - // Create a context that will be cancelled when the max request duration is reached maxDuration := time.Duration(h.config.MaxTriggerRequestDurationMs) * time.Millisecond ctxWithTimeout, cancel := context.WithTimeout(ctx, maxDuration) defer cancel() @@ -565,16 +567,15 @@ func (h *httpTriggerHandler) sendWithRetries(ctx context.Context, executionID st } for { - // Retry sending to nodes that haven't received the message allNodesSucceeded := true var combinedErr error - for _, member := range h.donConfig.Members { + for _, member := range shard.Members { if successfulNodes[member.Address] { continue } h.metrics.IncrementTriggerCapabilityRequestCount(ctx, member.Address, gateway_common.MethodWorkflowExecute, h.lggr) - err := h.don.SendToNode(ctxWithTimeout, member.Address, req) + err := shard.DON.SendToNode(ctxWithTimeout, member.Address, req) if err != nil { allNodesSucceeded = false h.metrics.IncrementTriggerCapabilityRequestFailures(ctx, member.Address, gateway_common.MethodWorkflowExecute, h.lggr) @@ -584,22 +585,20 @@ func (h *httpTriggerHandler) sendWithRetries(ctx context.Context, executionID st "executionID", executionID, "error", err) } else { - // Mark this node as successful successfulNodes[member.Address] = true } } if allNodesSucceeded { - h.lggr.Infow("Successfully sent trigger request to all nodes", + h.lggr.Infow("Successfully sent trigger request to all shard nodes", "executionID", executionID, - "nodeCount", len(h.donConfig.Members)) + "nodeCount", len(shard.Members)) return nil } - // Not all nodes succeeded, wait and retry h.lggr.Debugw("Retrying failed nodes for trigger request", "executionID", executionID, - "failedCount", len(h.donConfig.Members)-len(successfulNodes), + "failedCount", len(shard.Members)-len(successfulNodes), "errors", combinedErr) select { @@ -608,13 +607,13 @@ func (h *httpTriggerHandler) sendWithRetries(ctx context.Context, executionID st "executionID", executionID, "requestID", req.ID, "successNodes", len(successfulNodes), - "totalNodes", len(h.donConfig.Members)) + "totalNodes", len(shard.Members)) return nil case <-time.After(b.Duration()): continue case <-ctxWithTimeout.Done(): return fmt.Errorf("request retry time exceeded, some nodes may not have received the request: executionID=%s, successNodes=%d, totalNodes=%d", - executionID, len(successfulNodes), len(h.donConfig.Members)) + executionID, len(successfulNodes), len(shard.Members)) } } } diff --git a/core/services/gateway/handlers/capabilities/v2/http_trigger_handler_test.go b/core/services/gateway/handlers/capabilities/v2/http_trigger_handler_test.go index 42e5e8dde74..14c587e5a52 100644 --- a/core/services/gateway/handlers/capabilities/v2/http_trigger_handler_test.go +++ b/core/services/gateway/handlers/capabilities/v2/http_trigger_handler_test.go @@ -35,11 +35,23 @@ const ( ) func createTestMetrics(t *testing.T, donConfig *config.DONConfig) *metrics.Metrics { - m, err := metrics.NewMetrics(donConfig) + m, err := metrics.NewMetrics(donConfig.Members) require.NoError(t, err) return m } +func testShardRouterFromConfig(t *testing.T, donConfig *config.DONConfig, don handlers.DON) *handlers.ShardRouter { + t.Helper() + shardedDONs := []config.ShardedDONConfig{{ + DonName: donConfig.DonId, + F: donConfig.F, + Shards: []config.Shard{{Nodes: donConfig.Members}}, + }} + router, err := handlers.NewShardRouter(shardedDONs, [][]handlers.DON{{don}}) + require.NoError(t, err) + return router +} + func requireUserErrorSent(t *testing.T, payload handlers.UserCallbackPayload, errorCode int64) { require.NotEmpty(t, payload.RawResponse) require.Equal(t, api.FromJSONRPCErrorCode(errorCode), payload.ErrorCode) @@ -526,6 +538,7 @@ func registerWorkflow(_ *testing.T, handler *httpTriggerHandler, workflowID stri workflowName: "test-workflow", workflowTag: "v1.0", } + handler.workflowMetadataHandler.workflowToShard[workflowID] = 0 } func TestHttpTriggerHandler_ReapExpiredCallbacks(t *testing.T) { @@ -670,10 +683,11 @@ func TestHttpTriggerHandler_HandleUserTriggerRequest_Retries(t *testing.T) { } mockDon := handlermocks.NewDON(t) - metadataHandler := createTestMetadataHandler(t) + shardRouter := testShardRouterFromConfig(t, donConfig, mockDon) + metadataHandler := createTestMetadataHandlerWithDON(t, shardRouter) userRateLimiter := createTestUserRateLimiter() testMetrics := createTestMetrics(t, donConfig) - handler := NewHTTPTriggerHandler(lggr, cfg, donConfig, mockDon, metadataHandler, userRateLimiter, testMetrics) + handler := NewHTTPTriggerHandler(lggr, cfg, shardRouter, metadataHandler, userRateLimiter, testMetrics) privateKey := createTestPrivateKey(t) registerWorkflow(t, handler, workflowID, privateKey) @@ -737,6 +751,7 @@ func TestHttpTriggerHandler_HandleUserTriggerRequest_JWTAuthorization(t *testing workflowName: "test-workflow", workflowTag: "v1.0", } + handler.workflowMetadataHandler.workflowToShard[workflowID] = 0 t.Run("successful JWT authorization", func(t *testing.T) { callback := hc.NewCallback() @@ -885,6 +900,7 @@ func TestHttpTriggerHandler_HandleUserTriggerRequest_WorkflowLookup(t *testing.T } handler.workflowMetadataHandler.workflowIDToRef[workflowID] = workflowRef handler.workflowMetadataHandler.workflowRefToID[workflowRef] = workflowID + handler.workflowMetadataHandler.workflowToShard[workflowID] = 0 t.Run("successful workflow lookup by name", func(t *testing.T) { callback := hc.NewCallback() @@ -1528,8 +1544,11 @@ func createTestJWTToken(t *testing.T, req *jsonrpc.Request[json.RawMessage], pri } func createTestMetadataHandler(t *testing.T) *WorkflowMetadataHandler { + return createTestMetadataHandlerWithDON(t, nil) +} + +func createTestMetadataHandlerWithDON(t *testing.T, shardRouter *handlers.ShardRouter) *WorkflowMetadataHandler { lggr := logger.Test(t) - mockDon := handlermocks.NewDON(t) donConfig := &config.DONConfig{ F: 1, Members: []config.NodeConfig{ @@ -1538,9 +1557,13 @@ func createTestMetadataHandler(t *testing.T) *WorkflowMetadataHandler { {Address: "node3"}, }, } + if shardRouter == nil { + mockDon := handlermocks.NewDON(t) + shardRouter = testShardRouterFromConfig(t, donConfig, mockDon) + } cfg := WithDefaults(ServiceConfig{}) testMetrics := createTestMetrics(t, donConfig) - return NewWorkflowMetadataHandler(lggr, cfg, mockDon, donConfig, testMetrics) + return NewWorkflowMetadataHandler(lggr, cfg, shardRouter, testMetrics) } func createTestUserRateLimiter() limits.RateLimiter { @@ -1566,12 +1589,13 @@ func createTestTriggerHandlerWithConfig(t *testing.T, cfg ServiceConfig) (*httpT }, } mockDon := handlermocks.NewDON(t) + shardRouter := testShardRouterFromConfig(t, donConfig, mockDon) lggr := logger.Test(t) - metadataHandler := createTestMetadataHandler(t) + metadataHandler := createTestMetadataHandlerWithDON(t, shardRouter) userRateLimiter := createTestUserRateLimiter() testMetrics := createTestMetrics(t, donConfig) - handler := NewHTTPTriggerHandler(lggr, cfg, donConfig, mockDon, metadataHandler, userRateLimiter, testMetrics) + handler := NewHTTPTriggerHandler(lggr, cfg, shardRouter, metadataHandler, userRateLimiter, testMetrics) return handler, mockDon } @@ -1592,13 +1616,14 @@ func TestHttpTriggerHandler_HandleUserTriggerRequest_RateLimiting(t *testing.T) } mockDon := handlermocks.NewDON(t) + shardRouter := testShardRouterFromConfig(t, donConfig, mockDon) lggr := logger.Test(t) - metadataHandler := createTestMetadataHandler(t) + metadataHandler := createTestMetadataHandlerWithDON(t, shardRouter) testMetrics := createTestMetrics(t, donConfig) t.Run("successful rate limit check with CRE context", func(t *testing.T) { userRateLimiter := createTestUserRateLimiter() // Unlimited - handler := NewHTTPTriggerHandler(lggr, cfg, donConfig, mockDon, metadataHandler, userRateLimiter, testMetrics) + handler := NewHTTPTriggerHandler(lggr, cfg, shardRouter, metadataHandler, userRateLimiter, testMetrics) privateKey := createTestPrivateKey(t) workflowID := "0x1234567890abcdef1234567890abcdef12345678901234567890abcdef123456" @@ -1642,9 +1667,8 @@ func TestHttpTriggerHandler_HandleUserTriggerRequest_RateLimiting(t *testing.T) }) t.Run("rate limit exceeded returns proper error", func(t *testing.T) { - // Create a rate limiter with very restrictive limits restrictiveRateLimiter := limits.WorkflowRateLimiter(1, 0) - handler := NewHTTPTriggerHandler(lggr, cfg, donConfig, mockDon, metadataHandler, restrictiveRateLimiter, testMetrics) + handler := NewHTTPTriggerHandler(lggr, cfg, shardRouter, metadataHandler, restrictiveRateLimiter, testMetrics) privateKey := createTestPrivateKey(t) workflowID := "0x1234567890abcdef1234567890abcdef12345678901234567890abcdef123456" @@ -1705,10 +1729,11 @@ func TestHttpTriggerHandler_HandleUserTriggerRequest_StopsRetriesOnQuorum(t *tes } mockDon := handlermocks.NewDON(t) - metadataHandler := createTestMetadataHandler(t) + shardRouter := testShardRouterFromConfig(t, donConfig, mockDon) + metadataHandler := createTestMetadataHandlerWithDON(t, shardRouter) userRateLimiter := createTestUserRateLimiter() testMetrics := createTestMetrics(t, donConfig) - handler := NewHTTPTriggerHandler(lggr, cfg, donConfig, mockDon, metadataHandler, userRateLimiter, testMetrics) + handler := NewHTTPTriggerHandler(lggr, cfg, shardRouter, metadataHandler, userRateLimiter, testMetrics) privateKey := createTestPrivateKey(t) registerWorkflow(t, handler, workflowID, privateKey) diff --git a/core/services/gateway/handlers/capabilities/v2/metrics/metrics.go b/core/services/gateway/handlers/capabilities/v2/metrics/metrics.go index 1a608365312..194cdff310c 100644 --- a/core/services/gateway/handlers/capabilities/v2/metrics/metrics.go +++ b/core/services/gateway/handlers/capabilities/v2/metrics/metrics.go @@ -77,8 +77,9 @@ type Metrics struct { nodeAddressToNodeName map[string]string } -// NewMetrics creates a new instance of Metrics with all metrics initialized -func NewMetrics(donConfig *config.DONConfig) (*Metrics, error) { +// NewMetrics creates a new instance of Metrics with all metrics initialized. +// allMembers contains all node configs across all shards for address-to-name resolution. +func NewMetrics(allMembers []config.NodeConfig) (*Metrics, error) { meter := beholder.GetMeter() common, err := newCommonMetrics(meter) @@ -97,10 +98,8 @@ func NewMetrics(donConfig *config.DONConfig) (*Metrics, error) { } nodeAddressToNodeName := make(map[string]string) - if donConfig != nil { - for _, member := range donConfig.Members { - nodeAddressToNodeName[member.Address] = member.Name - } + for _, member := range allMembers { + nodeAddressToNodeName[member.Address] = member.Name } return &Metrics{ diff --git a/core/services/gateway/handlers/capabilities/v2/metrics/metrics_test.go b/core/services/gateway/handlers/capabilities/v2/metrics/metrics_test.go index 6e19b92674c..7c00a8a598b 100644 --- a/core/services/gateway/handlers/capabilities/v2/metrics/metrics_test.go +++ b/core/services/gateway/handlers/capabilities/v2/metrics/metrics_test.go @@ -11,13 +11,11 @@ import ( func TestNewMetrics(t *testing.T) { t.Parallel() - donConfig := &config.DONConfig{ - Members: []config.NodeConfig{ - {Address: "0xnode1", Name: "node1"}, - {Address: "0xnode2", Name: "node2"}, - }, + members := []config.NodeConfig{ + {Address: "0xnode1", Name: "node1"}, + {Address: "0xnode2", Name: "node2"}, } - metrics, err := NewMetrics(donConfig) + metrics, err := NewMetrics(members) require.NoError(t, err) require.NotNil(t, metrics) require.NotNil(t, metrics.action) diff --git a/core/services/gateway/handlers/capabilities/v2/workflow_metadata_handler.go b/core/services/gateway/handlers/capabilities/v2/workflow_metadata_handler.go index f02a4da7f4c..af14dff432b 100644 --- a/core/services/gateway/handlers/capabilities/v2/workflow_metadata_handler.go +++ b/core/services/gateway/handlers/capabilities/v2/workflow_metadata_handler.go @@ -14,7 +14,6 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/types/gateway" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/common/aggregation" - "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities/v2/metrics" "github.com/smartcontractkit/chainlink/v2/core/utils" @@ -42,10 +41,10 @@ type WorkflowMetadataHandler struct { authorizedKeys map[string]map[gateway.AuthorizedKey]struct{} // map of workflow ID to authorized keys workflowRefToID map[workflowReference]string // map of workflow reference to workflow ID workflowIDToRef map[string]workflowReference // map of workflow ID to workflow reference + workflowToShard map[string]int // map of workflow ID to shard index agg *aggregation.WorkflowMetadataAggregator config ServiceConfig - don handlers.DON - donConfig *config.DONConfig + shards *handlers.ShardRouter stopCh services.StopChan metrics *metrics.Metrics jwtCache *jwtReplayCache // JWT replay protection cache @@ -54,20 +53,21 @@ type WorkflowMetadataHandler struct { } // NewWorkflowMetadataHandler creates a new WorkflowMetadataHandler. -func NewWorkflowMetadataHandler(lggr logger.Logger, cfg ServiceConfig, don handlers.DON, donConfig *config.DONConfig, metrics *metrics.Metrics) *WorkflowMetadataHandler { - // f+1 identical responses from workflow are needed for workflow metadata to be registered - threshold := donConfig.F + 1 +func NewWorkflowMetadataHandler(lggr logger.Logger, cfg ServiceConfig, shards *handlers.ShardRouter, m *metrics.Metrics) *WorkflowMetadataHandler { + // f+1 identical responses from workflow are needed for workflow metadata to be registered. + // Use the F value from the first shard (all shards of the same DON share the same F). + threshold := shards.Shard(0).F + 1 return &WorkflowMetadataHandler{ lggr: logger.Named(lggr, "HTTPTriggerWorkflowMetadataHandler"), authorizedKeys: make(map[string]map[gateway.AuthorizedKey]struct{}), workflowRefToID: make(map[workflowReference]string), workflowIDToRef: make(map[string]workflowReference), - agg: aggregation.NewWorkflowMetadataAggregator(lggr, threshold, time.Duration(cfg.CleanUpPeriodMs)*time.Millisecond, metrics), - don: don, - donConfig: donConfig, + workflowToShard: make(map[string]int), + agg: aggregation.NewWorkflowMetadataAggregator(lggr, threshold, time.Duration(cfg.CleanUpPeriodMs)*time.Millisecond, m), + shards: shards, config: cfg, stopCh: make(services.StopChan), - metrics: metrics, + metrics: m, jwtCache: newJWTReplayCache(time.Duration(cfg.JWTReplayPeriodMs) * time.Millisecond), } } @@ -105,7 +105,7 @@ func (h *WorkflowMetadataHandler) Authorize(workflowID string, token string, req // syncMetadata aggregates the authorized keys and workflow selectors from the WorkflowMetadataAggregator and updates the local cache. // Should be called periodically to keep the authorized keys up to date. func (h *WorkflowMetadataHandler) syncMetadata() { - metadata, err := h.agg.Aggregate() + results, err := h.agg.AggregateWithReporters() if err != nil { h.lggr.Errorw("Failed to aggregate auth data", "error", err) return @@ -113,7 +113,9 @@ func (h *WorkflowMetadataHandler) syncMetadata() { authorizedKeys := make(map[string]map[gateway.AuthorizedKey]struct{}) workflowRefToID := make(map[workflowReference]string) workflowIDToRef := make(map[string]workflowReference) - for _, data := range metadata { + workflowToShard := make(map[string]int) + for _, result := range results { + data := result.Metadata workflowRef := workflowReference{ workflowOwner: data.WorkflowSelector.WorkflowOwner, workflowName: data.WorkflowSelector.WorkflowName, @@ -137,6 +139,11 @@ func (h *WorkflowMetadataHandler) syncMetadata() { for _, key := range data.AuthorizedKeys { authorizedKeys[data.WorkflowSelector.WorkflowID][key] = struct{}{} } + + shardIdx := h.resolveShardFromReporters(result.Reporters) + if shardIdx >= 0 { + workflowToShard[data.WorkflowSelector.WorkflowID] = shardIdx + } } h.mu.Lock() defer h.mu.Unlock() @@ -145,7 +152,6 @@ func (h *WorkflowMetadataHandler) syncMetadata() { latencyMs := time.Since(h.startTime).Milliseconds() h.metrics.RecordMetadataSyncStartupLatency(context.Background(), latencyMs, h.lggr) } - // Log all registered workflow IDs workflowIDs := make([]string, 0, len(workflowIDToRef)) for workflowID := range workflowIDToRef { workflowIDs = append(workflowIDs, workflowID) @@ -155,11 +161,32 @@ func (h *WorkflowMetadataHandler) syncMetadata() { h.authorizedKeys = authorizedKeys h.workflowRefToID = workflowRefToID h.workflowIDToRef = workflowIDToRef + h.workflowToShard = workflowToShard h.metrics.RecordLoadedMetadataSize(context.Background(), int64(len(h.workflowIDToRef)), h.lggr) } -// sendMetadataPullRequest sends a request to all nodes in the DON to pull the latest metadata. -// no retries are performed, as the caller is expected to poll periodically. +// resolveShardFromReporters determines which shard a workflow belongs to by finding the +// shard that the majority of reporting nodes belong to. +func (h *WorkflowMetadataHandler) resolveShardFromReporters(reporters aggregation.StringSet) int { + counts := make(map[int]int) + for addr := range reporters { + if idx, ok := h.shards.ShardIndexForNode(addr); ok { + counts[idx]++ + } + } + bestIdx := -1 + bestCount := 0 + for idx, count := range counts { + if count > bestCount { + bestCount = count + bestIdx = idx + } + } + return bestIdx +} + +// sendMetadataPullRequest sends a request to all nodes across all shards to pull the latest metadata. +// No retries are performed, as the caller is expected to poll periodically. func (h *WorkflowMetadataHandler) sendMetadataPullRequest() error { timeout := time.Duration(h.config.MetadataPullRequestTimeoutMs) * time.Millisecond ctx, cancel := h.stopCh.CtxWithTimeout(timeout) @@ -171,12 +198,15 @@ func (h *WorkflowMetadataHandler) sendMetadataPullRequest() error { Method: gateway.MethodPullWorkflowMetadata, } var combinedErr error - for _, member := range h.donConfig.Members { - h.metrics.IncrementTriggerCapabilityRequestCount(ctx, member.Address, gateway.MethodPullWorkflowMetadata, h.lggr) - err := h.don.SendToNode(ctx, member.Address, req) - if err != nil { - h.metrics.IncrementTriggerCapabilityRequestFailures(ctx, member.Address, gateway.MethodPullWorkflowMetadata, h.lggr) - combinedErr = errors.Join(combinedErr, fmt.Errorf("failed to send pull request to node %s: %w", member.Address, err)) + for i := 0; i < h.shards.NumShards(); i++ { + shard := h.shards.Shard(i) + for _, member := range shard.Members { + h.metrics.IncrementTriggerCapabilityRequestCount(ctx, member.Address, gateway.MethodPullWorkflowMetadata, h.lggr) + err := shard.DON.SendToNode(ctx, member.Address, req) + if err != nil { + h.metrics.IncrementTriggerCapabilityRequestFailures(ctx, member.Address, gateway.MethodPullWorkflowMetadata, h.lggr) + combinedErr = errors.Join(combinedErr, fmt.Errorf("failed to send pull request to node %s: %w", member.Address, err)) + } } } return combinedErr @@ -319,6 +349,17 @@ func (h *WorkflowMetadataHandler) GetWorkflowReference(workflowID string) (workf return workflowRef, exists } +// GetWorkflowShard returns the ShardInfo for the shard that owns the given workflow. +func (h *WorkflowMetadataHandler) GetWorkflowShard(workflowID string) (handlers.ShardInfo, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + shardIdx, exists := h.workflowToShard[workflowID] + if !exists { + return handlers.ShardInfo{}, false + } + return h.shards.Shard(shardIdx), true +} + func (h *WorkflowMetadataHandler) Close() error { return h.StopOnce("WorkflowMetadataHandler", func() error { h.lggr.Info("Stopping HTTP Trigger Metadata Handler") diff --git a/core/services/gateway/handlers/capabilities/v2/workflow_metadata_handler_test.go b/core/services/gateway/handlers/capabilities/v2/workflow_metadata_handler_test.go index 19b9f085d1c..889ff813ecf 100644 --- a/core/services/gateway/handlers/capabilities/v2/workflow_metadata_handler_test.go +++ b/core/services/gateway/handlers/capabilities/v2/workflow_metadata_handler_test.go @@ -17,6 +17,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/workflows" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities/v2/metrics" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/mocks" "github.com/smartcontractkit/chainlink/v2/core/utils" @@ -49,13 +50,26 @@ func createTestWorkflowMetadataHandler(t *testing.T) (*WorkflowMetadataHandler, }, } + shardRouter := testShardRouter(t, donConfig, mockDon) cfg := WithDefaults(ServiceConfig{}) - testMetrics, err := metrics.NewMetrics(donConfig) + testMetrics, err := metrics.NewMetrics(donConfig.Members) require.NoError(t, err) - handler := NewWorkflowMetadataHandler(lggr, cfg, mockDon, donConfig, testMetrics) + handler := NewWorkflowMetadataHandler(lggr, cfg, shardRouter, testMetrics) return handler, mockDon, donConfig } +func testShardRouter(t *testing.T, donConfig *config.DONConfig, don handlers.DON) *handlers.ShardRouter { + t.Helper() + shardedDONs := []config.ShardedDONConfig{{ + DonName: "test", + F: donConfig.F, + Shards: []config.Shard{{Nodes: donConfig.Members}}, + }} + router, err := handlers.NewShardRouter(shardedDONs, [][]handlers.DON{{don}}) + require.NoError(t, err) + return router +} + func TestSyncMetadata(t *testing.T) { handler, _, _ := createTestWorkflowMetadataHandler(t) diff --git a/core/services/gateway/handlers/shard_router.go b/core/services/gateway/handlers/shard_router.go new file mode 100644 index 00000000000..5ca62288726 --- /dev/null +++ b/core/services/gateway/handlers/shard_router.go @@ -0,0 +1,104 @@ +package handlers + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + jsonrpc "github.com/smartcontractkit/chainlink-common/pkg/jsonrpc2" + + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" +) + +// ShardInfo holds the topology and connection manager for a single shard. +type ShardInfo struct { + DON DON + Members []config.NodeConfig + F int +} + +// ShardRouter routes SendToNode calls to the correct shard's connection manager +// and exposes per-shard topology for shard-aware fan-out and aggregation. +type ShardRouter struct { + shards []ShardInfo + nodeToConnMgr map[string]DON + nodeToShard map[string]int // lowercase node address -> shard index +} + +// NewShardRouter creates a ShardRouter from sharded DON configs and their connection managers. +// All shards across all DONs are flattened into a single ordered list. +// shardsConnMgrs[donIdx][shardIdx] is the connection manager for DON donIdx, shard shardIdx. +func NewShardRouter(shardedDONs []config.ShardedDONConfig, shardsConnMgrs [][]DON) (*ShardRouter, error) { + nodeToConnMgr := make(map[string]DON) + nodeToShard := make(map[string]int) + var shards []ShardInfo + shardIndex := 0 + + for donIdx, donCfg := range shardedDONs { + if donIdx >= len(shardsConnMgrs) { + return nil, fmt.Errorf("missing connection managers for DON %s", donCfg.DonName) + } + for shardIdx, shard := range donCfg.Shards { + if shardIdx >= len(shardsConnMgrs[donIdx]) { + return nil, fmt.Errorf("missing connection manager for DON %s shard %d", donCfg.DonName, shardIdx) + } + connMgr := shardsConnMgrs[donIdx][shardIdx] + for _, node := range shard.Nodes { + addr := strings.ToLower(node.Address) + if _, exists := nodeToConnMgr[addr]; exists { + return nil, fmt.Errorf("duplicate node address %s across shards", addr) + } + nodeToConnMgr[addr] = connMgr + nodeToShard[addr] = shardIndex + } + shards = append(shards, ShardInfo{ + DON: connMgr, + Members: shard.Nodes, + F: donCfg.F, + }) + shardIndex++ + } + } + return &ShardRouter{ + shards: shards, + nodeToConnMgr: nodeToConnMgr, + nodeToShard: nodeToShard, + }, nil +} + +func (r *ShardRouter) SendToNode(ctx context.Context, nodeAddress string, req *jsonrpc.Request[json.RawMessage]) error { + connMgr, ok := r.nodeToConnMgr[strings.ToLower(nodeAddress)] + if !ok { + return fmt.Errorf("node %s not found in any shard", nodeAddress) + } + return connMgr.SendToNode(ctx, nodeAddress, req) +} + +func (r *ShardRouter) Shard(idx int) ShardInfo { + return r.shards[idx] +} + +func (r *ShardRouter) NumShards() int { + return len(r.shards) +} + +// ShardIndexForNode returns the shard index that the given node belongs to. +func (r *ShardRouter) ShardIndexForNode(nodeAddr string) (int, bool) { + idx, ok := r.nodeToShard[strings.ToLower(nodeAddr)] + return idx, ok +} + +// AllMembers returns all node configs across all shards (for metrics initialization). +func (r *ShardRouter) AllMembers() []config.NodeConfig { + var members []config.NodeConfig + for _, s := range r.shards { + members = append(members, s.Members...) + } + return members +} + +// DonID returns a summary ID for logging purposes. +func (r *ShardRouter) DonID() string { + return fmt.Sprintf("sharded(%d)", len(r.shards)) +} diff --git a/core/services/gateway/handlers/shard_router_test.go b/core/services/gateway/handlers/shard_router_test.go new file mode 100644 index 00000000000..dc2bfed7591 --- /dev/null +++ b/core/services/gateway/handlers/shard_router_test.go @@ -0,0 +1,203 @@ +package handlers_test + +import ( + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers" + handlermocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/mocks" +) + +func TestShardRouter_SendToNode_RoutesToCorrectShard(t *testing.T) { + t.Parallel() + + shard0DON := handlermocks.NewDON(t) + shard1DON := handlermocks.NewDON(t) + + shardedDONs := []config.ShardedDONConfig{{ + DonName: "myDON", + F: 1, + Shards: []config.Shard{ + {Nodes: []config.NodeConfig{ + {Name: "s0n0", Address: "0xaaaa"}, + {Name: "s0n1", Address: "0xbbbb"}, + }}, + {Nodes: []config.NodeConfig{ + {Name: "s1n0", Address: "0xcccc"}, + {Name: "s1n1", Address: "0xdddd"}, + }}, + }, + }} + connMgrs := [][]handlers.DON{{shard0DON, shard1DON}} + + router, err := handlers.NewShardRouter(shardedDONs, connMgrs) + require.NoError(t, err) + + ctx := testutils.Context(t) + + shard0DON.EXPECT().SendToNode(mock.Anything, "0xaaaa", mock.Anything).Return(nil).Once() + require.NoError(t, router.SendToNode(ctx, "0xaaaa", nil)) + + shard1DON.EXPECT().SendToNode(mock.Anything, "0xcccc", mock.Anything).Return(nil).Once() + require.NoError(t, router.SendToNode(ctx, "0xcccc", nil)) +} + +func TestShardRouter_SendToNode_UnknownNodeReturnsError(t *testing.T) { + t.Parallel() + + shardedDONs := []config.ShardedDONConfig{{ + DonName: "myDON", + F: 0, + Shards: []config.Shard{ + {Nodes: []config.NodeConfig{{Name: "n0", Address: "0xaaaa"}}}, + }, + }} + connMgrs := [][]handlers.DON{{handlermocks.NewDON(t)}} + + router, err := handlers.NewShardRouter(shardedDONs, connMgrs) + require.NoError(t, err) + + err = router.SendToNode(testutils.Context(t), "0xunknown", nil) + require.ErrorContains(t, err, "not found in any shard") +} + +func TestShardRouter_DuplicateNodeAddressAcrossShardsErrors(t *testing.T) { + t.Parallel() + + shardedDONs := []config.ShardedDONConfig{{ + DonName: "myDON", + F: 0, + Shards: []config.Shard{ + {Nodes: []config.NodeConfig{{Name: "n0", Address: "0xaaaa"}}}, + {Nodes: []config.NodeConfig{{Name: "n1", Address: "0xaaaa"}}}, + }, + }} + connMgrs := [][]handlers.DON{{handlermocks.NewDON(t), handlermocks.NewDON(t)}} + + _, err := handlers.NewShardRouter(shardedDONs, connMgrs) + require.ErrorContains(t, err, "duplicate node address") +} + +func TestShardRouter_MultipleDONs(t *testing.T) { + t.Parallel() + + don1Shard0 := handlermocks.NewDON(t) + don2Shard0 := handlermocks.NewDON(t) + + shardedDONs := []config.ShardedDONConfig{ + { + DonName: "don1", + F: 0, + Shards: []config.Shard{{Nodes: []config.NodeConfig{{Name: "d1n0", Address: "0x1111"}}}}, + }, + { + DonName: "don2", + F: 0, + Shards: []config.Shard{{Nodes: []config.NodeConfig{{Name: "d2n0", Address: "0x2222"}}}}, + }, + } + connMgrs := [][]handlers.DON{{don1Shard0}, {don2Shard0}} + + router, err := handlers.NewShardRouter(shardedDONs, connMgrs) + require.NoError(t, err) + + ctx := testutils.Context(t) + + don1Shard0.EXPECT().SendToNode(mock.Anything, "0x1111", mock.Anything).Return(nil).Once() + require.NoError(t, router.SendToNode(ctx, "0x1111", nil)) + + don2Shard0.EXPECT().SendToNode(mock.Anything, "0x2222", mock.Anything).Return(nil).Once() + require.NoError(t, router.SendToNode(ctx, "0x2222", nil)) +} + +func TestShardRouter_CaseInsensitiveAddressLookup(t *testing.T) { + t.Parallel() + + mockDON := handlermocks.NewDON(t) + + shardedDONs := []config.ShardedDONConfig{{ + DonName: "myDON", + F: 0, + Shards: []config.Shard{{Nodes: []config.NodeConfig{{Name: "n0", Address: "0xAaBb"}}}}, + }} + connMgrs := [][]handlers.DON{{mockDON}} + + router, err := handlers.NewShardRouter(shardedDONs, connMgrs) + require.NoError(t, err) + + mockDON.EXPECT().SendToNode(mock.Anything, "0xaabb", mock.Anything).Return(nil).Once() + require.NoError(t, router.SendToNode(testutils.Context(t), "0xaabb", nil)) +} + +func TestShardRouter_ShardTopology(t *testing.T) { + t.Parallel() + + shard0DON := handlermocks.NewDON(t) + shard1DON := handlermocks.NewDON(t) + shard2DON := handlermocks.NewDON(t) + + shardedDONs := []config.ShardedDONConfig{ + { + DonName: "don1", + F: 1, + Shards: []config.Shard{ + {Nodes: []config.NodeConfig{ + {Name: "d1s0n0", Address: "0x1"}, + {Name: "d1s0n1", Address: "0x2"}, + }}, + {Nodes: []config.NodeConfig{ + {Name: "d1s1n0", Address: "0x3"}, + }}, + }, + }, + { + DonName: "don2", + F: 2, + Shards: []config.Shard{ + {Nodes: []config.NodeConfig{ + {Name: "d2s0n0", Address: "0x4"}, + }}, + }, + }, + } + connMgrs := [][]handlers.DON{{shard0DON, shard1DON}, {shard2DON}} + + router, err := handlers.NewShardRouter(shardedDONs, connMgrs) + require.NoError(t, err) + + require.Equal(t, 3, router.NumShards()) + + shard0 := router.Shard(0) + require.Len(t, shard0.Members, 2) + require.Equal(t, 1, shard0.F) + + shard1 := router.Shard(1) + require.Len(t, shard1.Members, 1) + require.Equal(t, 1, shard1.F) + + shard2 := router.Shard(2) + require.Len(t, shard2.Members, 1) + require.Equal(t, 2, shard2.F) + + allMembers := router.AllMembers() + require.Len(t, allMembers, 4) + + idx, ok := router.ShardIndexForNode("0x1") + require.True(t, ok) + require.Equal(t, 0, idx) + + idx, ok = router.ShardIndexForNode("0x3") + require.True(t, ok) + require.Equal(t, 1, idx) + + idx, ok = router.ShardIndexForNode("0x4") + require.True(t, ok) + require.Equal(t, 2, idx) + + _, ok = router.ShardIndexForNode("0xunknown") + require.False(t, ok) +}