diff --git a/pkg/vmcp/aggregator/aggregator.go b/pkg/vmcp/aggregator/aggregator.go index 80c931eb3..b919e291b 100644 --- a/pkg/vmcp/aggregator/aggregator.go +++ b/pkg/vmcp/aggregator/aggregator.go @@ -29,17 +29,34 @@ type BackendDiscoverer interface { // 1. Query: Fetch capabilities from each backend // 2. Conflict Resolution: Handle duplicate tool/resource/prompt names // 3. Merging: Create final unified capability view and routing table +// +//go:generate mockgen -destination=mocks/mock_interfaces.go -package=mocks -source=aggregator.go BackendDiscoverer Aggregator ConflictResolver ToolFilter ToolOverride type Aggregator interface { // QueryCapabilities queries a backend for its MCP capabilities. // Returns the raw capabilities (tools, resources, prompts) from the backend. QueryCapabilities(ctx context.Context, backend vmcp.Backend) (*BackendCapabilities, error) + // QueryAllCapabilities queries all backends for their capabilities in parallel. + // Handles backend failures gracefully (logs and continues with remaining backends). + QueryAllCapabilities(ctx context.Context, backends []vmcp.Backend) (map[string]*BackendCapabilities, error) + // ResolveConflicts applies conflict resolution strategy to handle // duplicate capability names across backends. ResolveConflicts(ctx context.Context, capabilities map[string]*BackendCapabilities) (*ResolvedCapabilities, error) // MergeCapabilities creates the final unified capability view and routing table. - MergeCapabilities(ctx context.Context, resolved *ResolvedCapabilities) (*AggregatedCapabilities, error) + // Uses the backend registry to populate full BackendTarget information for routing. + MergeCapabilities( + ctx context.Context, + resolved *ResolvedCapabilities, + registry vmcp.BackendRegistry, + ) (*AggregatedCapabilities, error) + + // AggregateCapabilities is a convenience method that performs the full aggregation pipeline: + // 1. Query all backends + // 2. Resolve conflicts + // 3. Merge into final view + AggregateCapabilities(ctx context.Context, backends []vmcp.Backend) (*AggregatedCapabilities, error) } // BackendCapabilities contains the raw capabilities from a single backend. diff --git a/pkg/vmcp/aggregator/cli_discoverer.go b/pkg/vmcp/aggregator/cli_discoverer.go new file mode 100644 index 000000000..069456b7c --- /dev/null +++ b/pkg/vmcp/aggregator/cli_discoverer.go @@ -0,0 +1,124 @@ +package aggregator + +import ( + "context" + "fmt" + + rt "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/groups" + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/workloads" +) + +// cliBackendDiscoverer discovers backend MCP servers from Docker/Podman workloads in a group. +// This is the CLI version of BackendDiscoverer that uses the workloads.Manager. +type cliBackendDiscoverer struct { + workloadsManager workloads.Manager + groupsManager groups.Manager +} + +// NewCLIBackendDiscoverer creates a new CLI-based backend discoverer. +// It discovers workloads from Docker/Podman containers managed by ToolHive. +func NewCLIBackendDiscoverer(workloadsManager workloads.Manager, groupsManager groups.Manager) BackendDiscoverer { + return &cliBackendDiscoverer{ + workloadsManager: workloadsManager, + groupsManager: groupsManager, + } +} + +// Discover finds all backend workloads in the specified group. +// Returns all accessible backends with their health status marked based on workload status. +// The groupRef is the group name (e.g., "engineering-team"). +func (d *cliBackendDiscoverer) Discover(ctx context.Context, groupRef string) ([]vmcp.Backend, error) { + logger.Infof("Discovering backends in group %s", groupRef) + + // Verify that the group exists + exists, err := d.groupsManager.Exists(ctx, groupRef) + if err != nil { + return nil, fmt.Errorf("failed to check if group exists: %w", err) + } + if !exists { + return nil, fmt.Errorf("group %s not found", groupRef) + } + + // Get all workload names in the group + workloadNames, err := d.workloadsManager.ListWorkloadsInGroup(ctx, groupRef) + if err != nil { + return nil, fmt.Errorf("failed to list workloads in group: %w", err) + } + + if len(workloadNames) == 0 { + logger.Infof("No workloads found in group %s", groupRef) + return []vmcp.Backend{}, nil + } + + logger.Debugf("Found %d workloads in group %s, discovering backends", len(workloadNames), groupRef) + + // Query each workload and convert to backend + var backends []vmcp.Backend + for _, name := range workloadNames { + workload, err := d.workloadsManager.GetWorkload(ctx, name) + if err != nil { + logger.Warnf("Failed to get workload %s: %v, skipping", name, err) + continue + } + + // Skip workloads without a URL (not accessible) + if workload.URL == "" { + logger.Debugf("Skipping workload %s without URL", name) + continue + } + + // Map workload status to backend health status + healthStatus := mapWorkloadStatusToHealth(workload.Status) + + // Convert core.Workload to vmcp.Backend + backend := vmcp.Backend{ + ID: name, + Name: name, + BaseURL: workload.URL, + TransportType: workload.TransportType.String(), + HealthStatus: healthStatus, + Metadata: make(map[string]string), + } + + // Copy user labels to metadata first + for k, v := range workload.Labels { + backend.Metadata[k] = v + } + + // Set system metadata (these override user labels to prevent conflicts) + backend.Metadata["group"] = groupRef + backend.Metadata["tool_type"] = workload.ToolType + backend.Metadata["workload_status"] = string(workload.Status) + + backends = append(backends, backend) + logger.Debugf("Discovered backend %s: %s (%s) with health status %s", + backend.ID, backend.BaseURL, backend.TransportType, backend.HealthStatus) + } + + if len(backends) == 0 { + logger.Infof("No accessible backends found in group %s (all workloads lack URLs)", groupRef) + return []vmcp.Backend{}, nil + } + + logger.Infof("Discovered %d backends in group %s", len(backends), groupRef) + return backends, nil +} + +// mapWorkloadStatusToHealth converts a workload status to a backend health status. +func mapWorkloadStatusToHealth(status rt.WorkloadStatus) vmcp.BackendHealthStatus { + switch status { + case rt.WorkloadStatusRunning: + return vmcp.BackendHealthy + case rt.WorkloadStatusUnhealthy: + return vmcp.BackendUnhealthy + case rt.WorkloadStatusStopped, rt.WorkloadStatusError, rt.WorkloadStatusStopping, rt.WorkloadStatusRemoving: + return vmcp.BackendUnhealthy + case rt.WorkloadStatusStarting, rt.WorkloadStatusUnknown: + return vmcp.BackendUnknown + default: + return vmcp.BackendUnknown + } +} diff --git a/pkg/vmcp/aggregator/cli_discoverer_test.go b/pkg/vmcp/aggregator/cli_discoverer_test.go new file mode 100644 index 000000000..19e1de944 --- /dev/null +++ b/pkg/vmcp/aggregator/cli_discoverer_test.go @@ -0,0 +1,250 @@ +package aggregator + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/core" + "github.com/stacklok/toolhive/pkg/groups/mocks" + "github.com/stacklok/toolhive/pkg/transport/types" + "github.com/stacklok/toolhive/pkg/vmcp" + workloadmocks "github.com/stacklok/toolhive/pkg/workloads/mocks" +) + +const testGroupName = "test-group" + +func TestCLIBackendDiscoverer_Discover(t *testing.T) { + t.Parallel() + + t.Run("successful discovery with multiple backends", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + workload1 := newTestWorkload("workload1", + withToolType("github"), + withLabels(map[string]string{"env": "prod"})) + + workload2 := newTestWorkload("workload2", + withURL("http://localhost:8081/mcp"), + withTransport(types.TransportTypeSSE), + withToolType("jira")) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"workload1", "workload2"}, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) + + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 2) + assert.Equal(t, "workload1", backends[0].ID) + assert.Equal(t, "http://localhost:8080/mcp", backends[0].BaseURL) + assert.Equal(t, vmcp.BackendHealthy, backends[0].HealthStatus) + assert.Equal(t, "github", backends[0].Metadata["tool_type"]) + assert.Equal(t, "prod", backends[0].Metadata["env"]) + assert.Equal(t, "workload2", backends[1].ID) + assert.Equal(t, "sse", backends[1].TransportType) + }) + + t.Run("discovers workloads with different statuses", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + runningWorkload := newTestWorkload("running-workload") + stoppedWorkload := newTestWorkload("stopped-workload", + withStatus(runtime.WorkloadStatusStopped), + withURL("http://localhost:8081/mcp"), + withTransport(types.TransportTypeSSE)) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"running-workload", "stopped-workload"}, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "running-workload").Return(runningWorkload, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped-workload").Return(stoppedWorkload, nil) + + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 2) + assert.Equal(t, "running-workload", backends[0].ID) + assert.Equal(t, vmcp.BackendHealthy, backends[0].HealthStatus) + assert.Equal(t, "stopped-workload", backends[1].ID) + assert.Equal(t, vmcp.BackendUnhealthy, backends[1].HealthStatus) + assert.Equal(t, "stopped", backends[1].Metadata["workload_status"]) + }) + + t.Run("filters out workloads without URL", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + workloadWithURL := newTestWorkload("workload1") + workloadWithoutURL := newTestWorkload("workload2", withURL("")) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"workload1", "workload2"}, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workloadWithURL, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workloadWithoutURL, nil) + + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 1) + assert.Equal(t, "workload1", backends[0].ID) + }) + + t.Run("returns empty list when all workloads lack URLs", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + workload1 := newTestWorkload("workload1", withURL("")) + workload2 := newTestWorkload("workload2", withStatus(runtime.WorkloadStatusStopped), withURL("")) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"workload1", "workload2"}, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) + + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + assert.Empty(t, backends) + }) + + t.Run("returns error when group does not exist", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + mockGroups.EXPECT().Exists(gomock.Any(), "nonexistent-group").Return(false, nil) + + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + backends, err := discoverer.Discover(context.Background(), "nonexistent-group") + + require.Error(t, err) + assert.Nil(t, backends) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("returns error when group check fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(false, errors.New("database error")) + + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.Error(t, err) + assert.Nil(t, backends) + assert.Contains(t, err.Error(), "failed to check if group exists") + }) + + t.Run("returns empty list when group is empty", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + mockGroups.EXPECT().Exists(gomock.Any(), "empty-group").Return(true, nil) + mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), "empty-group").Return([]string{}, nil) + + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + backends, err := discoverer.Discover(context.Background(), "empty-group") + + require.NoError(t, err) + assert.Empty(t, backends) + }) + + t.Run("discovers all workloads regardless of health status", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + stoppedWorkload := newTestWorkload("stopped1", withStatus(runtime.WorkloadStatusStopped)) + errorWorkload := newTestWorkload("error1", + withStatus(runtime.WorkloadStatusError), + withURL("http://localhost:8081/mcp"), + withTransport(types.TransportTypeSSE)) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"stopped1", "error1"}, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped1").Return(stoppedWorkload, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "error1").Return(errorWorkload, nil) + + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 2) + assert.Equal(t, vmcp.BackendUnhealthy, backends[0].HealthStatus) + assert.Equal(t, vmcp.BackendUnhealthy, backends[1].HealthStatus) + }) + + t.Run("gracefully handles workload get failures", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + goodWorkload := newTestWorkload("good-workload") + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"good-workload", "failing-workload"}, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "good-workload").Return(goodWorkload, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "failing-workload"). + Return(core.Workload{}, errors.New("workload query failed")) + + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 1) + assert.Equal(t, "good-workload", backends[0].ID) + }) +} diff --git a/pkg/vmcp/aggregator/default_aggregator.go b/pkg/vmcp/aggregator/default_aggregator.go new file mode 100644 index 000000000..c00c8e885 --- /dev/null +++ b/pkg/vmcp/aggregator/default_aggregator.go @@ -0,0 +1,303 @@ +package aggregator + +import ( + "context" + "fmt" + "sync" + + "golang.org/x/sync/errgroup" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// defaultAggregator implements the Aggregator interface for capability aggregation. +// It queries backends in parallel, handles failures gracefully, and merges capabilities. +type defaultAggregator struct { + backendClient vmcp.BackendClient + // TODO: Add conflict resolver, tool filter, tool override +} + +// NewDefaultAggregator creates a new default aggregator implementation. +func NewDefaultAggregator(backendClient vmcp.BackendClient) Aggregator { + return &defaultAggregator{ + backendClient: backendClient, + } +} + +// QueryCapabilities queries a single backend for its MCP capabilities. +// Returns the raw capabilities (tools, resources, prompts) from the backend. +func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp.Backend) (*BackendCapabilities, error) { + logger.Debugf("Querying capabilities from backend %s", backend.ID) + + // Create a BackendTarget from the Backend + target := &vmcp.BackendTarget{ + WorkloadID: backend.ID, + WorkloadName: backend.Name, + BaseURL: backend.BaseURL, + TransportType: backend.TransportType, + HealthStatus: backend.HealthStatus, + Metadata: backend.Metadata, + } + + // Query capabilities using the backend client + capabilities, err := a.backendClient.ListCapabilities(ctx, target) + if err != nil { + return nil, fmt.Errorf("%w: %s: %v", ErrBackendQueryFailed, backend.ID, err) + } + + // Convert to BackendCapabilities + result := &BackendCapabilities{ + BackendID: backend.ID, + Tools: capabilities.Tools, + Resources: capabilities.Resources, + Prompts: capabilities.Prompts, + SupportsLogging: capabilities.SupportsLogging, + SupportsSampling: capabilities.SupportsSampling, + } + + logger.Debugf("Backend %s: %d tools, %d resources, %d prompts", + backend.ID, len(result.Tools), len(result.Resources), len(result.Prompts)) + + return result, nil +} + +// QueryAllCapabilities queries all backends for their capabilities in parallel. +// Handles backend failures gracefully (logs and continues with remaining backends). +func (a *defaultAggregator) QueryAllCapabilities( + ctx context.Context, + backends []vmcp.Backend, +) (map[string]*BackendCapabilities, error) { + logger.Infof("Querying capabilities from %d backends", len(backends)) + + // Use errgroup for parallel queries with context cancellation + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(10) // Limit concurrent queries to avoid overwhelming backends + + // Thread-safe map for results + var mu sync.Mutex + capabilities := make(map[string]*BackendCapabilities) + + // Query each backend in parallel + for _, backend := range backends { + backend := backend // Capture loop variable + g.Go(func() error { + caps, err := a.QueryCapabilities(ctx, backend) + if err != nil { + // Log the error but continue with other backends + logger.Warnf("Failed to query backend %s: %v", backend.ID, err) + return nil // Don't fail the entire operation + } + + // Store result safely + mu.Lock() + capabilities[backend.ID] = caps + mu.Unlock() + + return nil + }) + } + + // Wait for all queries to complete + if err := g.Wait(); err != nil { + return nil, fmt.Errorf("capability queries failed: %w", err) + } + + if len(capabilities) == 0 { + return nil, fmt.Errorf("no backends returned capabilities") + } + + logger.Infof("Successfully queried %d/%d backends", len(capabilities), len(backends)) + return capabilities, nil +} + +// ResolveConflicts applies conflict resolution strategy to handle +// duplicate capability names across backends. +func (*defaultAggregator) ResolveConflicts( + _ context.Context, + capabilities map[string]*BackendCapabilities, +) (*ResolvedCapabilities, error) { + logger.Debugf("Resolving conflicts across %d backends", len(capabilities)) + + // For Phase 1 (Issue #148), we'll implement basic conflict resolution + // Just collect all capabilities without resolving conflicts yet + // Conflict resolution will be implemented in a future phase + + resolved := &ResolvedCapabilities{ + Tools: make(map[string]*ResolvedTool), + Resources: []vmcp.Resource{}, + Prompts: []vmcp.Prompt{}, + } + + // Collect all tools (for now, without conflict resolution) + // Later, we'll add prefix/priority/manual strategies + for backendID, caps := range capabilities { + for _, tool := range caps.Tools { + // For now, just use the tool name as-is + // In future phases, we'll apply prefixing or priority rules + resolvedName := tool.Name + + // If there's a conflict, log a warning (but don't fail) + if existing, exists := resolved.Tools[resolvedName]; exists { + logger.Warnf("Tool name conflict: %s exists in both %s and %s (keeping first)", + resolvedName, existing.BackendID, backendID) + continue + } + + resolved.Tools[resolvedName] = &ResolvedTool{ + ResolvedName: resolvedName, + OriginalName: tool.Name, + Description: tool.Description, + InputSchema: tool.InputSchema, + BackendID: tool.BackendID, + // ConflictResolutionApplied will be set in future phases + } + } + + // Collect resources (URIs should be globally unique) + resolved.Resources = append(resolved.Resources, caps.Resources...) + + // Collect prompts + resolved.Prompts = append(resolved.Prompts, caps.Prompts...) + + // Aggregate logging/sampling support (OR logic - enabled if any backend supports) + resolved.SupportsLogging = resolved.SupportsLogging || caps.SupportsLogging + resolved.SupportsSampling = resolved.SupportsSampling || caps.SupportsSampling + } + + logger.Debugf("Resolved %d unique tools, %d resources, %d prompts", + len(resolved.Tools), len(resolved.Resources), len(resolved.Prompts)) + + return resolved, nil +} + +// MergeCapabilities creates the final unified capability view and routing table. +// Uses the backend registry to populate full BackendTarget information for routing. +func (*defaultAggregator) MergeCapabilities( + ctx context.Context, + resolved *ResolvedCapabilities, + registry vmcp.BackendRegistry, +) (*AggregatedCapabilities, error) { + logger.Debugf("Merging capabilities into final view") + + // Create routing table + routingTable := &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: make(map[string]*vmcp.BackendTarget), + Prompts: make(map[string]*vmcp.BackendTarget), + } + + // Convert resolved tools to final vmcp.Tool format + tools := make([]vmcp.Tool, 0, len(resolved.Tools)) + for _, resolvedTool := range resolved.Tools { + tools = append(tools, vmcp.Tool{ + Name: resolvedTool.ResolvedName, + Description: resolvedTool.Description, + InputSchema: resolvedTool.InputSchema, + BackendID: resolvedTool.BackendID, + }) + + // Look up full backend information from registry + backend := registry.Get(ctx, resolvedTool.BackendID) + if backend == nil { + logger.Warnf("Backend %s not found in registry for tool %s, creating minimal target", + resolvedTool.BackendID, resolvedTool.ResolvedName) + routingTable.Tools[resolvedTool.ResolvedName] = &vmcp.BackendTarget{ + WorkloadID: resolvedTool.BackendID, + } + } else { + // Use the backendToTarget helper from registry package + routingTable.Tools[resolvedTool.ResolvedName] = vmcp.BackendToTarget(backend) + } + } + + // Add resources to routing table + for _, resource := range resolved.Resources { + backend := registry.Get(ctx, resource.BackendID) + if backend == nil { + logger.Warnf("Backend %s not found in registry for resource %s, creating minimal target", + resource.BackendID, resource.URI) + routingTable.Resources[resource.URI] = &vmcp.BackendTarget{ + WorkloadID: resource.BackendID, + } + } else { + routingTable.Resources[resource.URI] = vmcp.BackendToTarget(backend) + } + } + + // Add prompts to routing table + for _, prompt := range resolved.Prompts { + backend := registry.Get(ctx, prompt.BackendID) + if backend == nil { + logger.Warnf("Backend %s not found in registry for prompt %s, creating minimal target", + prompt.BackendID, prompt.Name) + routingTable.Prompts[prompt.Name] = &vmcp.BackendTarget{ + WorkloadID: prompt.BackendID, + } + } else { + routingTable.Prompts[prompt.Name] = vmcp.BackendToTarget(backend) + } + } + + // Create final aggregated view + aggregated := &AggregatedCapabilities{ + Tools: tools, + Resources: resolved.Resources, + Prompts: resolved.Prompts, + SupportsLogging: resolved.SupportsLogging, + SupportsSampling: resolved.SupportsSampling, + RoutingTable: routingTable, + Metadata: &AggregationMetadata{ + BackendCount: 0, // Will be set by caller + ToolCount: len(tools), + ResourceCount: len(resolved.Resources), + PromptCount: len(resolved.Prompts), + ConflictsResolved: 0, // Will be tracked in future phases + }, + } + + logger.Infof("Merged capabilities: %d tools, %d resources, %d prompts", + aggregated.Metadata.ToolCount, aggregated.Metadata.ResourceCount, aggregated.Metadata.PromptCount) + + return aggregated, nil +} + +// AggregateCapabilities is a convenience method that performs the full aggregation pipeline: +// 1. Create backend registry +// 2. Query all backends +// 3. Resolve conflicts +// 4. Merge into final view with full backend information +func (a *defaultAggregator) AggregateCapabilities(ctx context.Context, backends []vmcp.Backend) (*AggregatedCapabilities, error) { + logger.Infof("Starting capability aggregation for %d backends", len(backends)) + + // Step 1: Create registry from discovered backends + registry := vmcp.NewImmutableRegistry(backends) + logger.Debugf("Created backend registry with %d backends", registry.Count()) + + // Step 2: Query all backends + capabilities, err := a.QueryAllCapabilities(ctx, backends) + if err != nil { + return nil, fmt.Errorf("failed to query backends: %w", err) + } + + // Step 3: Resolve conflicts + resolved, err := a.ResolveConflicts(ctx, capabilities) + if err != nil { + return nil, fmt.Errorf("failed to resolve conflicts: %w", err) + } + + // Step 4: Merge into final view with full backend information + aggregated, err := a.MergeCapabilities(ctx, resolved, registry) + if err != nil { + return nil, fmt.Errorf("failed to merge capabilities: %w", err) + } + + // Update metadata with backend count + aggregated.Metadata.BackendCount = len(backends) + + logger.Infof("Capability aggregation complete: %d backends, %d tools, %d resources, %d prompts", + aggregated.Metadata.BackendCount, aggregated.Metadata.ToolCount, + aggregated.Metadata.ResourceCount, aggregated.Metadata.PromptCount) + + return aggregated, nil +} diff --git a/pkg/vmcp/aggregator/default_aggregator_test.go b/pkg/vmcp/aggregator/default_aggregator_test.go new file mode 100644 index 000000000..94045df1e --- /dev/null +++ b/pkg/vmcp/aggregator/default_aggregator_test.go @@ -0,0 +1,344 @@ +package aggregator + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/mocks" +) + +func TestDefaultAggregator_QueryCapabilities(t *testing.T) { + t.Parallel() + + t.Run("successful query", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + backend := newTestBackend("backend1", withBackendName("Backend 1")) + + expectedCaps := newTestCapabilityList( + withTools(newTestTool("test_tool", "backend1")), + withResources(newTestResource("test://resource", "backend1")), + withPrompts(newTestPrompt("test_prompt", "backend1")), + withLogging(true)) + + mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(expectedCaps, nil) + + agg := NewDefaultAggregator(mockClient) + result, err := agg.QueryCapabilities(context.Background(), backend) + + require.NoError(t, err) + assert.Equal(t, "backend1", result.BackendID) + require.Len(t, result.Tools, 1) + assert.Equal(t, "test_tool", result.Tools[0].Name) + assert.Len(t, result.Resources, 1) + assert.Len(t, result.Prompts, 1) + assert.True(t, result.SupportsLogging) + assert.False(t, result.SupportsSampling) + }) + + t.Run("backend query failure", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + backend := newTestBackend("backend1", withBackendName("Backend 1")) + + mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()). + Return(nil, errors.New("connection failed")) + + agg := NewDefaultAggregator(mockClient) + result, err := agg.QueryCapabilities(context.Background(), backend) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "backend1") + }) +} + +func TestDefaultAggregator_QueryAllCapabilities(t *testing.T) { + t.Parallel() + + t.Run("query multiple backends successfully", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + backends := []vmcp.Backend{ + newTestBackend("backend1", withBackendName("Backend 1")), + newTestBackend("backend2", withBackendName("Backend 2"), + withBackendURL("http://localhost:8081"), + withBackendTransport("sse")), + } + + caps1 := newTestCapabilityList(withTools(newTestTool("tool1", "backend1"))) + caps2 := newTestCapabilityList(withTools(newTestTool("tool2", "backend2"))) + + mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(caps1, nil) + mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(caps2, nil) + + agg := NewDefaultAggregator(mockClient) + result, err := agg.QueryAllCapabilities(context.Background(), backends) + + require.NoError(t, err) + require.Len(t, result, 2) + assert.Contains(t, result, "backend1") + assert.Contains(t, result, "backend2") + }) + + t.Run("graceful handling of partial failures", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + backends := []vmcp.Backend{ + newTestBackend("backend1"), + newTestBackend("backend2", withBackendURL("http://localhost:8081")), + } + + caps1 := newTestCapabilityList(withTools(newTestTool("tool1", "backend1"))) + + mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, target *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + if target.WorkloadID == "backend1" { + return caps1, nil + } + return nil, errors.New("connection timeout") + }).Times(2) + + agg := NewDefaultAggregator(mockClient) + result, err := agg.QueryAllCapabilities(context.Background(), backends) + + require.NoError(t, err) + require.Len(t, result, 1) + assert.Contains(t, result, "backend1") + assert.NotContains(t, result, "backend2") + }) + + t.Run("all backends fail", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + backends := []vmcp.Backend{newTestBackend("backend1")} + + mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()). + Return(nil, errors.New("connection failed")) + + agg := NewDefaultAggregator(mockClient) + result, err := agg.QueryAllCapabilities(context.Background(), backends) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "no backends returned capabilities") + }) +} + +func TestDefaultAggregator_ResolveConflicts(t *testing.T) { + t.Parallel() + + t.Run("basic conflict detection", func(t *testing.T) { + t.Parallel() + capabilities := map[string]*BackendCapabilities{ + "backend1": { + BackendID: "backend1", + Tools: []vmcp.Tool{ + {Name: "tool1", Description: "Tool 1 from backend1", BackendID: "backend1"}, + {Name: "shared_tool", Description: "Shared from backend1", BackendID: "backend1"}, + }, + }, + "backend2": { + BackendID: "backend2", + Tools: []vmcp.Tool{ + {Name: "tool2", Description: "Tool 2 from backend2", BackendID: "backend2"}, + {Name: "shared_tool", Description: "Shared from backend2", BackendID: "backend2"}, + }, + }, + } + + agg := NewDefaultAggregator(nil) + resolved, err := agg.ResolveConflicts(context.Background(), capabilities) + + require.NoError(t, err) + assert.NotNil(t, resolved) + // In Phase 1, we just collect tools - conflict is detected but first one wins + assert.Contains(t, resolved.Tools, "tool1") + assert.Contains(t, resolved.Tools, "tool2") + assert.Contains(t, resolved.Tools, "shared_tool") + // Shared tool should have one backend (whichever was encountered first in map iteration) + // Map iteration order is non-deterministic, so accept either backend + sharedToolBackend := resolved.Tools["shared_tool"].BackendID + assert.True(t, sharedToolBackend == "backend1" || sharedToolBackend == "backend2", + "shared_tool should belong to either backend1 or backend2, got: %s", sharedToolBackend) + }) + + t.Run("no conflicts", func(t *testing.T) { + t.Parallel() + capabilities := map[string]*BackendCapabilities{ + "backend1": { + BackendID: "backend1", + Tools: []vmcp.Tool{ + {Name: "unique1", BackendID: "backend1"}, + }, + }, + "backend2": { + BackendID: "backend2", + Tools: []vmcp.Tool{ + {Name: "unique2", BackendID: "backend2"}, + }, + }, + } + + agg := NewDefaultAggregator(nil) + resolved, err := agg.ResolveConflicts(context.Background(), capabilities) + + require.NoError(t, err) + assert.Len(t, resolved.Tools, 2) + assert.Contains(t, resolved.Tools, "unique1") + assert.Contains(t, resolved.Tools, "unique2") + }) +} + +func TestDefaultAggregator_MergeCapabilities(t *testing.T) { + t.Parallel() + + t.Run("merge resolved capabilities", func(t *testing.T) { + t.Parallel() + resolved := &ResolvedCapabilities{ + Tools: map[string]*ResolvedTool{ + "tool1": { + ResolvedName: "tool1", + OriginalName: "tool1", + Description: "Tool 1", + BackendID: "backend1", + }, + "tool2": { + ResolvedName: "tool2", + OriginalName: "tool2", + Description: "Tool 2", + BackendID: "backend2", + }, + }, + Resources: []vmcp.Resource{ + {URI: "test://resource1", BackendID: "backend1"}, + }, + Prompts: []vmcp.Prompt{ + {Name: "prompt1", BackendID: "backend1"}, + }, + SupportsLogging: true, + SupportsSampling: false, + } + + // Create registry with test backends + backends := []vmcp.Backend{ + { + ID: "backend1", + Name: "Backend 1", + BaseURL: "http://backend1:8080", + TransportType: "streamable-http", + HealthStatus: vmcp.BackendHealthy, + }, + { + ID: "backend2", + Name: "Backend 2", + BaseURL: "http://backend2:8080", + TransportType: "sse", + HealthStatus: vmcp.BackendHealthy, + }, + } + registry := vmcp.NewImmutableRegistry(backends) + + agg := NewDefaultAggregator(nil) + aggregated, err := agg.MergeCapabilities(context.Background(), resolved, registry) + + require.NoError(t, err) + assert.Len(t, aggregated.Tools, 2) + assert.Len(t, aggregated.Resources, 1) + assert.Len(t, aggregated.Prompts, 1) + assert.True(t, aggregated.SupportsLogging) + assert.False(t, aggregated.SupportsSampling) + + // Check routing table + assert.NotNil(t, aggregated.RoutingTable) + assert.Contains(t, aggregated.RoutingTable.Tools, "tool1") + assert.Contains(t, aggregated.RoutingTable.Tools, "tool2") + assert.Contains(t, aggregated.RoutingTable.Resources, "test://resource1") + assert.Contains(t, aggregated.RoutingTable.Prompts, "prompt1") + + // Verify routing table has full backend information + tool1Target := aggregated.RoutingTable.Tools["tool1"] + assert.NotNil(t, tool1Target) + assert.Equal(t, "backend1", tool1Target.WorkloadID) + assert.Equal(t, "Backend 1", tool1Target.WorkloadName) + assert.Equal(t, "http://backend1:8080", tool1Target.BaseURL) + assert.Equal(t, "streamable-http", tool1Target.TransportType) + assert.Equal(t, vmcp.BackendHealthy, tool1Target.HealthStatus) + + tool2Target := aggregated.RoutingTable.Tools["tool2"] + assert.NotNil(t, tool2Target) + assert.Equal(t, "backend2", tool2Target.WorkloadID) + assert.Equal(t, "Backend 2", tool2Target.WorkloadName) + assert.Equal(t, "http://backend2:8080", tool2Target.BaseURL) + assert.Equal(t, "sse", tool2Target.TransportType) + + // Check metadata + assert.Equal(t, 2, aggregated.Metadata.ToolCount) + assert.Equal(t, 1, aggregated.Metadata.ResourceCount) + assert.Equal(t, 1, aggregated.Metadata.PromptCount) + }) +} + +func TestDefaultAggregator_AggregateCapabilities(t *testing.T) { + t.Parallel() + + t.Run("full aggregation pipeline", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + backends := []vmcp.Backend{ + newTestBackend("backend1", withBackendName("Backend 1")), + newTestBackend("backend2", withBackendName("Backend 2"), + withBackendURL("http://localhost:8081"), + withBackendTransport("sse")), + } + + caps1 := newTestCapabilityList( + withTools(newTestTool("tool1", "backend1")), + withResources(newTestResource("test://resource1", "backend1")), + withLogging(true)) + + caps2 := newTestCapabilityList( + withTools(newTestTool("tool2", "backend2")), + withSampling(true)) + + mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(caps1, nil) + mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(caps2, nil) + + agg := NewDefaultAggregator(mockClient) + result, err := agg.AggregateCapabilities(context.Background(), backends) + + require.NoError(t, err) + assert.NotNil(t, result) + assert.Len(t, result.Tools, 2) + assert.Len(t, result.Resources, 1) + assert.True(t, result.SupportsLogging) + assert.True(t, result.SupportsSampling) + assert.Equal(t, 2, result.Metadata.BackendCount) + assert.Equal(t, 2, result.Metadata.ToolCount) + assert.Equal(t, 1, result.Metadata.ResourceCount) + }) +} diff --git a/pkg/vmcp/aggregator/discoverer.go b/pkg/vmcp/aggregator/discoverer.go new file mode 100644 index 000000000..b4cff44d4 --- /dev/null +++ b/pkg/vmcp/aggregator/discoverer.go @@ -0,0 +1,8 @@ +// Package aggregator provides platform-specific backend discovery implementations. +// +// This file serves as a navigation reference for backend discovery implementations: +// - CLI (Docker/Podman): see cli_discoverer.go +// - Kubernetes: see k8s_discoverer.go +// +// The BackendDiscoverer interface is defined in aggregator.go. +package aggregator diff --git a/pkg/vmcp/aggregator/k8s_discoverer.go b/pkg/vmcp/aggregator/k8s_discoverer.go new file mode 100644 index 000000000..b9f61fbc0 --- /dev/null +++ b/pkg/vmcp/aggregator/k8s_discoverer.go @@ -0,0 +1,33 @@ +package aggregator + +import ( + "context" + "fmt" + + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// k8sBackendDiscoverer discovers backend MCP servers from Kubernetes pods/services in a group. +// This is the Kubernetes version of BackendDiscoverer (not implemented yet). +type k8sBackendDiscoverer struct { + // TODO: Add Kubernetes client and group CRD interfaces +} + +// NewK8sBackendDiscoverer creates a new Kubernetes-based backend discoverer. +// It discovers workloads from Kubernetes MCPServer resources managed by the operator. +func NewK8sBackendDiscoverer() BackendDiscoverer { + return &k8sBackendDiscoverer{} +} + +// Discover finds all backend workloads in the specified Kubernetes group. +// The groupRef is the MCPGroup name. +func (*k8sBackendDiscoverer) Discover(_ context.Context, _ string) ([]vmcp.Backend, error) { + // TODO: Implement Kubernetes backend discovery + // 1. Query MCPGroup CRD by name + // 2. List MCPServer resources with matching group label + // 3. Filter for ready/running MCPServers + // 4. Build service URLs (http://service-name.namespace.svc.cluster.local:port) + // 5. Extract transport type from MCPServer spec + // 6. Return vmcp.Backend list + return nil, fmt.Errorf("kubernetes backend discovery not yet implemented") +} diff --git a/pkg/vmcp/aggregator/mocks/mock_interfaces.go b/pkg/vmcp/aggregator/mocks/mock_interfaces.go new file mode 100644 index 000000000..c90fe21f8 --- /dev/null +++ b/pkg/vmcp/aggregator/mocks/mock_interfaces.go @@ -0,0 +1,274 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: aggregator.go +// +// Generated by this command: +// +// mockgen -destination=mocks/mock_interfaces.go -package=mocks -source=aggregator.go BackendDiscoverer Aggregator ConflictResolver ToolFilter ToolOverride +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + vmcp "github.com/stacklok/toolhive/pkg/vmcp" + aggregator "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + gomock "go.uber.org/mock/gomock" +) + +// MockBackendDiscoverer is a mock of BackendDiscoverer interface. +type MockBackendDiscoverer struct { + ctrl *gomock.Controller + recorder *MockBackendDiscovererMockRecorder + isgomock struct{} +} + +// MockBackendDiscovererMockRecorder is the mock recorder for MockBackendDiscoverer. +type MockBackendDiscovererMockRecorder struct { + mock *MockBackendDiscoverer +} + +// NewMockBackendDiscoverer creates a new mock instance. +func NewMockBackendDiscoverer(ctrl *gomock.Controller) *MockBackendDiscoverer { + mock := &MockBackendDiscoverer{ctrl: ctrl} + mock.recorder = &MockBackendDiscovererMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBackendDiscoverer) EXPECT() *MockBackendDiscovererMockRecorder { + return m.recorder +} + +// Discover mocks base method. +func (m *MockBackendDiscoverer) Discover(ctx context.Context, groupRef string) ([]vmcp.Backend, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Discover", ctx, groupRef) + ret0, _ := ret[0].([]vmcp.Backend) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Discover indicates an expected call of Discover. +func (mr *MockBackendDiscovererMockRecorder) Discover(ctx, groupRef any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Discover", reflect.TypeOf((*MockBackendDiscoverer)(nil).Discover), ctx, groupRef) +} + +// MockAggregator is a mock of Aggregator interface. +type MockAggregator struct { + ctrl *gomock.Controller + recorder *MockAggregatorMockRecorder + isgomock struct{} +} + +// MockAggregatorMockRecorder is the mock recorder for MockAggregator. +type MockAggregatorMockRecorder struct { + mock *MockAggregator +} + +// NewMockAggregator creates a new mock instance. +func NewMockAggregator(ctrl *gomock.Controller) *MockAggregator { + mock := &MockAggregator{ctrl: ctrl} + mock.recorder = &MockAggregatorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAggregator) EXPECT() *MockAggregatorMockRecorder { + return m.recorder +} + +// AggregateCapabilities mocks base method. +func (m *MockAggregator) AggregateCapabilities(ctx context.Context, backends []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AggregateCapabilities", ctx, backends) + ret0, _ := ret[0].(*aggregator.AggregatedCapabilities) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AggregateCapabilities indicates an expected call of AggregateCapabilities. +func (mr *MockAggregatorMockRecorder) AggregateCapabilities(ctx, backends any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AggregateCapabilities", reflect.TypeOf((*MockAggregator)(nil).AggregateCapabilities), ctx, backends) +} + +// MergeCapabilities mocks base method. +func (m *MockAggregator) MergeCapabilities(ctx context.Context, resolved *aggregator.ResolvedCapabilities, registry vmcp.BackendRegistry) (*aggregator.AggregatedCapabilities, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MergeCapabilities", ctx, resolved, registry) + ret0, _ := ret[0].(*aggregator.AggregatedCapabilities) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MergeCapabilities indicates an expected call of MergeCapabilities. +func (mr *MockAggregatorMockRecorder) MergeCapabilities(ctx, resolved, registry any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MergeCapabilities", reflect.TypeOf((*MockAggregator)(nil).MergeCapabilities), ctx, resolved, registry) +} + +// QueryAllCapabilities mocks base method. +func (m *MockAggregator) QueryAllCapabilities(ctx context.Context, backends []vmcp.Backend) (map[string]*aggregator.BackendCapabilities, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueryAllCapabilities", ctx, backends) + ret0, _ := ret[0].(map[string]*aggregator.BackendCapabilities) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryAllCapabilities indicates an expected call of QueryAllCapabilities. +func (mr *MockAggregatorMockRecorder) QueryAllCapabilities(ctx, backends any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryAllCapabilities", reflect.TypeOf((*MockAggregator)(nil).QueryAllCapabilities), ctx, backends) +} + +// QueryCapabilities mocks base method. +func (m *MockAggregator) QueryCapabilities(ctx context.Context, backend vmcp.Backend) (*aggregator.BackendCapabilities, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueryCapabilities", ctx, backend) + ret0, _ := ret[0].(*aggregator.BackendCapabilities) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryCapabilities indicates an expected call of QueryCapabilities. +func (mr *MockAggregatorMockRecorder) QueryCapabilities(ctx, backend any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryCapabilities", reflect.TypeOf((*MockAggregator)(nil).QueryCapabilities), ctx, backend) +} + +// ResolveConflicts mocks base method. +func (m *MockAggregator) ResolveConflicts(ctx context.Context, capabilities map[string]*aggregator.BackendCapabilities) (*aggregator.ResolvedCapabilities, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResolveConflicts", ctx, capabilities) + ret0, _ := ret[0].(*aggregator.ResolvedCapabilities) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ResolveConflicts indicates an expected call of ResolveConflicts. +func (mr *MockAggregatorMockRecorder) ResolveConflicts(ctx, capabilities any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResolveConflicts", reflect.TypeOf((*MockAggregator)(nil).ResolveConflicts), ctx, capabilities) +} + +// MockConflictResolver is a mock of ConflictResolver interface. +type MockConflictResolver struct { + ctrl *gomock.Controller + recorder *MockConflictResolverMockRecorder + isgomock struct{} +} + +// MockConflictResolverMockRecorder is the mock recorder for MockConflictResolver. +type MockConflictResolverMockRecorder struct { + mock *MockConflictResolver +} + +// NewMockConflictResolver creates a new mock instance. +func NewMockConflictResolver(ctrl *gomock.Controller) *MockConflictResolver { + mock := &MockConflictResolver{ctrl: ctrl} + mock.recorder = &MockConflictResolverMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConflictResolver) EXPECT() *MockConflictResolverMockRecorder { + return m.recorder +} + +// ResolveToolConflicts mocks base method. +func (m *MockConflictResolver) ResolveToolConflicts(ctx context.Context, tools map[string][]vmcp.Tool) (map[string]*aggregator.ResolvedTool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResolveToolConflicts", ctx, tools) + ret0, _ := ret[0].(map[string]*aggregator.ResolvedTool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ResolveToolConflicts indicates an expected call of ResolveToolConflicts. +func (mr *MockConflictResolverMockRecorder) ResolveToolConflicts(ctx, tools any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResolveToolConflicts", reflect.TypeOf((*MockConflictResolver)(nil).ResolveToolConflicts), ctx, tools) +} + +// MockToolFilter is a mock of ToolFilter interface. +type MockToolFilter struct { + ctrl *gomock.Controller + recorder *MockToolFilterMockRecorder + isgomock struct{} +} + +// MockToolFilterMockRecorder is the mock recorder for MockToolFilter. +type MockToolFilterMockRecorder struct { + mock *MockToolFilter +} + +// NewMockToolFilter creates a new mock instance. +func NewMockToolFilter(ctrl *gomock.Controller) *MockToolFilter { + mock := &MockToolFilter{ctrl: ctrl} + mock.recorder = &MockToolFilterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockToolFilter) EXPECT() *MockToolFilterMockRecorder { + return m.recorder +} + +// FilterTools mocks base method. +func (m *MockToolFilter) FilterTools(ctx context.Context, tools []vmcp.Tool) ([]vmcp.Tool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FilterTools", ctx, tools) + ret0, _ := ret[0].([]vmcp.Tool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FilterTools indicates an expected call of FilterTools. +func (mr *MockToolFilterMockRecorder) FilterTools(ctx, tools any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterTools", reflect.TypeOf((*MockToolFilter)(nil).FilterTools), ctx, tools) +} + +// MockToolOverride is a mock of ToolOverride interface. +type MockToolOverride struct { + ctrl *gomock.Controller + recorder *MockToolOverrideMockRecorder + isgomock struct{} +} + +// MockToolOverrideMockRecorder is the mock recorder for MockToolOverride. +type MockToolOverrideMockRecorder struct { + mock *MockToolOverride +} + +// NewMockToolOverride creates a new mock instance. +func NewMockToolOverride(ctrl *gomock.Controller) *MockToolOverride { + mock := &MockToolOverride{ctrl: ctrl} + mock.recorder = &MockToolOverrideMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockToolOverride) EXPECT() *MockToolOverrideMockRecorder { + return m.recorder +} + +// ApplyOverrides mocks base method. +func (m *MockToolOverride) ApplyOverrides(ctx context.Context, tools []vmcp.Tool) ([]vmcp.Tool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ApplyOverrides", ctx, tools) + ret0, _ := ret[0].([]vmcp.Tool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ApplyOverrides indicates an expected call of ApplyOverrides. +func (mr *MockToolOverrideMockRecorder) ApplyOverrides(ctx, tools any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplyOverrides", reflect.TypeOf((*MockToolOverride)(nil).ApplyOverrides), ctx, tools) +} diff --git a/pkg/vmcp/aggregator/testhelpers_test.go b/pkg/vmcp/aggregator/testhelpers_test.go new file mode 100644 index 000000000..0b766c508 --- /dev/null +++ b/pkg/vmcp/aggregator/testhelpers_test.go @@ -0,0 +1,154 @@ +package aggregator + +import ( + "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/core" + "github.com/stacklok/toolhive/pkg/transport/types" + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// Test fixture builders to reduce verbosity in tests + +func newTestWorkload(name string, opts ...func(*core.Workload)) core.Workload { + w := core.Workload{ + Name: name, + Status: runtime.WorkloadStatusRunning, + URL: "http://localhost:8080/mcp", + TransportType: types.TransportTypeStreamableHTTP, + Group: testGroupName, + } + for _, opt := range opts { + opt(&w) + } + return w +} + +func withStatus(status runtime.WorkloadStatus) func(*core.Workload) { + return func(w *core.Workload) { + w.Status = status + } +} + +func withURL(url string) func(*core.Workload) { + return func(w *core.Workload) { + w.URL = url + } +} + +func withTransport(transport types.TransportType) func(*core.Workload) { + return func(w *core.Workload) { + w.TransportType = transport + } +} + +func withToolType(toolType string) func(*core.Workload) { + return func(w *core.Workload) { + w.ToolType = toolType + } +} + +func withLabels(labels map[string]string) func(*core.Workload) { + return func(w *core.Workload) { + w.Labels = labels + } +} + +func newTestBackend(id string, opts ...func(*vmcp.Backend)) vmcp.Backend { + b := vmcp.Backend{ + ID: id, + Name: id, + BaseURL: "http://localhost:8080", + TransportType: "streamable-http", + HealthStatus: vmcp.BackendHealthy, + } + for _, opt := range opts { + opt(&b) + } + return b +} + +func withBackendURL(url string) func(*vmcp.Backend) { + return func(b *vmcp.Backend) { + b.BaseURL = url + } +} + +func withBackendTransport(transport string) func(*vmcp.Backend) { + return func(b *vmcp.Backend) { + b.TransportType = transport + } +} + +func withBackendName(name string) func(*vmcp.Backend) { + return func(b *vmcp.Backend) { + b.Name = name + } +} + +func newTestCapabilityList(opts ...func(*vmcp.CapabilityList)) *vmcp.CapabilityList { + caps := &vmcp.CapabilityList{ + Tools: []vmcp.Tool{}, + Resources: []vmcp.Resource{}, + Prompts: []vmcp.Prompt{}, + SupportsLogging: false, + SupportsSampling: false, + } + for _, opt := range opts { + opt(caps) + } + return caps +} + +func withTools(tools ...vmcp.Tool) func(*vmcp.CapabilityList) { + return func(c *vmcp.CapabilityList) { + c.Tools = tools + } +} + +func withResources(resources ...vmcp.Resource) func(*vmcp.CapabilityList) { + return func(c *vmcp.CapabilityList) { + c.Resources = resources + } +} + +func withPrompts(prompts ...vmcp.Prompt) func(*vmcp.CapabilityList) { + return func(c *vmcp.CapabilityList) { + c.Prompts = prompts + } +} + +func withLogging(enabled bool) func(*vmcp.CapabilityList) { + return func(c *vmcp.CapabilityList) { + c.SupportsLogging = enabled + } +} + +func withSampling(enabled bool) func(*vmcp.CapabilityList) { + return func(c *vmcp.CapabilityList) { + c.SupportsSampling = enabled + } +} + +func newTestTool(name, backendID string) vmcp.Tool { + return vmcp.Tool{ + Name: name, + Description: name + " description", + InputSchema: map[string]any{"type": "object"}, + BackendID: backendID, + } +} + +func newTestResource(uri, backendID string) vmcp.Resource { + return vmcp.Resource{ + URI: uri, + Name: uri, + BackendID: backendID, + } +} + +func newTestPrompt(name, backendID string) vmcp.Prompt { + return vmcp.Prompt{ + Name: name, + BackendID: backendID, + } +} diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go new file mode 100644 index 000000000..aaaf9cc59 --- /dev/null +++ b/pkg/vmcp/client/client.go @@ -0,0 +1,422 @@ +// Package client provides MCP protocol client implementation for communicating with backend servers. +// +// This package implements the BackendClient interface defined in the vmcp package, +// using the mark3labs/mcp-go SDK for protocol communication. +package client + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "net/http" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp" +) + +const ( + // maxResponseSize is the maximum size in bytes for HTTP responses from backend MCP servers. + // This protects against DoS attacks via memory exhaustion from malicious or compromised backends. + // + // The MCP specification does not define size limits, so we enforce a reasonable limit + // to prevent unbounded memory allocation during JSON deserialization. + // + // Value: 100 MB + // Rationale: + // - Allows large tool outputs, resources, and capability lists + // - Prevents memory exhaustion (a single large response could OOM the process) + // - Applied at HTTP transport layer before JSON deserialization + // - Backends needing larger responses should use pagination or streaming + // + // Note: This limit is enforced per HTTP response, not per MCP request. + // A tools/list response with 1000 tools would be limited to 100MB total. + maxResponseSize = 100 * 1024 * 1024 // 100 MB +) + +// httpBackendClient implements vmcp.BackendClient using mark3labs/mcp-go HTTP client. +// It supports streamable-HTTP and SSE transports for backend MCP servers. +type httpBackendClient struct { + // clientFactory creates MCP clients for backends. + // Abstracted as a function to enable testing with mock clients. + clientFactory func(ctx context.Context, target *vmcp.BackendTarget) (*client.Client, error) +} + +// NewHTTPBackendClient creates a new HTTP-based backend client. +// This client supports streamable-HTTP and SSE transports. +func NewHTTPBackendClient() vmcp.BackendClient { + return &httpBackendClient{ + clientFactory: defaultClientFactory, + } +} + +// roundTripperFunc is a function adapter for http.RoundTripper. +type roundTripperFunc func(*http.Request) (*http.Response, error) + +// RoundTrip implements http.RoundTripper interface. +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +// defaultClientFactory creates mark3labs MCP clients for different transport types. +func defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*client.Client, error) { + // Create HTTP client with response size limits for DoS protection + httpClient := &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + resp, err := http.DefaultTransport.RoundTrip(req) + if err != nil { + return nil, err + } + // Wrap response body with size limit + resp.Body = struct { + io.Reader + io.Closer + }{ + Reader: io.LimitReader(resp.Body, maxResponseSize), + Closer: resp.Body, + } + return resp, nil + }), + } + + var c *client.Client + var err error + + switch target.TransportType { + case "streamable-http", "streamable": + c, err = client.NewStreamableHttpClient( + target.BaseURL, + transport.WithHTTPTimeout(0), + transport.WithContinuousListening(), + transport.WithHTTPBasicClient(httpClient), + // TODO: Add authentication header injection via WithHTTPHeaderFunc + // This will be implemented when we add OutgoingAuthenticator support + ) + if err != nil { + return nil, fmt.Errorf("failed to create streamable-http client: %w", err) + } + + case "sse": + c, err = client.NewSSEMCPClient( + target.BaseURL, + transport.WithHTTPClient(httpClient), + ) + if err != nil { + return nil, fmt.Errorf("failed to create SSE client: %w", err) + } + + default: + return nil, fmt.Errorf("%w: %s (supported: streamable-http, sse)", vmcp.ErrUnsupportedTransport, target.TransportType) + } + + // Start the client connection + if err := c.Start(ctx); err != nil { + return nil, fmt.Errorf("failed to start client connection: %w", err) + } + + // Initialize the MCP connection + if err := initializeClient(ctx, c); err != nil { + _ = c.Close() + return nil, fmt.Errorf("failed to initialize MCP connection: %w", err) + } + + return c, nil +} + +// initializeClient performs MCP protocol initialization handshake. +func initializeClient(ctx context.Context, c *client.Client) error { + _, err := c.Initialize(ctx, mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "toolhive-vmcp", + Version: "0.1.0", + }, + Capabilities: mcp.ClientCapabilities{ + // Virtual MCP acts as a client to backends + Roots: &struct { + ListChanged bool `json:"listChanged,omitempty"` + }{ + ListChanged: false, + }, + }, + }, + }) + return err +} + +// ListCapabilities queries a backend for its MCP capabilities. +// Returns tools, resources, and prompts exposed by the backend. +func (h *httpBackendClient) ListCapabilities(ctx context.Context, target *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + logger.Debugf("Querying capabilities from backend %s (%s)", target.WorkloadName, target.BaseURL) + + // Create a client for this backend + c, err := h.clientFactory(ctx, target) + if err != nil { + return nil, fmt.Errorf("failed to create client for backend %s: %w", target.WorkloadID, err) + } + defer c.Close() + + // Query tools + toolsResp, err := c.ListTools(ctx, mcp.ListToolsRequest{}) + if err != nil { + return nil, fmt.Errorf("failed to list tools from backend %s: %w", target.WorkloadID, err) + } + + // Query resources + resourcesResp, err := c.ListResources(ctx, mcp.ListResourcesRequest{}) + if err != nil { + return nil, fmt.Errorf("failed to list resources from backend %s: %w", target.WorkloadID, err) + } + + // Query prompts + promptsResp, err := c.ListPrompts(ctx, mcp.ListPromptsRequest{}) + if err != nil { + return nil, fmt.Errorf("failed to list prompts from backend %s: %w", target.WorkloadID, err) + } + + // Convert MCP types to vmcp types + capabilities := &vmcp.CapabilityList{ + Tools: make([]vmcp.Tool, len(toolsResp.Tools)), + Resources: make([]vmcp.Resource, len(resourcesResp.Resources)), + Prompts: make([]vmcp.Prompt, len(promptsResp.Prompts)), + } + + // Convert tools + for i, tool := range toolsResp.Tools { + // Convert ToolInputSchema to map[string]any + // The ToolInputSchema is a struct with Type, Properties, Required fields + inputSchema := map[string]any{ + "type": tool.InputSchema.Type, + } + if tool.InputSchema.Properties != nil { + inputSchema["properties"] = tool.InputSchema.Properties + } + if len(tool.InputSchema.Required) > 0 { + inputSchema["required"] = tool.InputSchema.Required + } + if tool.InputSchema.Defs != nil { + inputSchema["$defs"] = tool.InputSchema.Defs + } + + capabilities.Tools[i] = vmcp.Tool{ + Name: tool.Name, + Description: tool.Description, + InputSchema: inputSchema, + BackendID: target.WorkloadID, + } + } + + // Convert resources + for i, resource := range resourcesResp.Resources { + capabilities.Resources[i] = vmcp.Resource{ + URI: resource.URI, + Name: resource.Name, + Description: resource.Description, + MimeType: resource.MIMEType, + BackendID: target.WorkloadID, + } + } + + // Convert prompts + for i, prompt := range promptsResp.Prompts { + args := make([]vmcp.PromptArgument, len(prompt.Arguments)) + for j, arg := range prompt.Arguments { + args[j] = vmcp.PromptArgument{ + Name: arg.Name, + Description: arg.Description, + Required: arg.Required, + } + } + + capabilities.Prompts[i] = vmcp.Prompt{ + Name: prompt.Name, + Description: prompt.Description, + Arguments: args, + BackendID: target.WorkloadID, + } + } + + // TODO: Query server capabilities to detect logging/sampling support + // This requires additional MCP protocol support for capabilities introspection + + logger.Debugf("Backend %s capabilities: %d tools, %d resources, %d prompts", + target.WorkloadName, len(capabilities.Tools), len(capabilities.Resources), len(capabilities.Prompts)) + + return capabilities, nil +} + +// CallTool invokes a tool on the backend MCP server. +func (h *httpBackendClient) CallTool( + ctx context.Context, + target *vmcp.BackendTarget, + toolName string, + arguments map[string]any, +) (map[string]any, error) { + logger.Debugf("Calling tool %s on backend %s", toolName, target.WorkloadName) + + // Create a client for this backend + c, err := h.clientFactory(ctx, target) + if err != nil { + return nil, fmt.Errorf("failed to create client for backend %s: %w", target.WorkloadID, err) + } + defer c.Close() + + // Call the tool + result, err := c.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: toolName, + Arguments: arguments, + }, + }) + if err != nil { + // Network/connection errors are operational errors + return nil, fmt.Errorf("%w: tool call failed on backend %s: %v", vmcp.ErrBackendUnavailable, target.WorkloadID, err) + } + + // Check if the tool call returned an error (MCP domain error) + if result.IsError { + // Extract error message from content for logging and forwarding + var errorMsg string + if len(result.Content) > 0 { + if textContent, ok := mcp.AsTextContent(result.Content[0]); ok { + errorMsg = textContent.Text + } + } + if errorMsg == "" { + errorMsg = "unknown error" + } + logger.Warnf("Tool %s on backend %s returned error: %s", toolName, target.WorkloadID, errorMsg) + // Wrap with ErrToolExecutionFailed so router can forward transparently to client + return nil, fmt.Errorf("%w: %s on backend %s: %s", vmcp.ErrToolExecutionFailed, toolName, target.WorkloadID, errorMsg) + } + + // Convert result contents to a map + // MCP tools return an array of Content interface (TextContent, ImageContent, etc.) + resultMap := make(map[string]any) + if len(result.Content) > 0 { + textIndex := 0 + imageIndex := 0 + for i, content := range result.Content { + // Try to convert to TextContent + if textContent, ok := mcp.AsTextContent(content); ok { + key := "text" + if textIndex > 0 { + key = fmt.Sprintf("text_%d", textIndex) + } + resultMap[key] = textContent.Text + textIndex++ + } else if imageContent, ok := mcp.AsImageContent(content); ok { + // Convert to ImageContent + key := fmt.Sprintf("image_%d", imageIndex) + resultMap[key] = imageContent.Data + imageIndex++ + } else { + // Log unsupported content types for tracking + logger.Debugf("Unsupported content type at index %d from tool %s on backend %s: %T", + i, toolName, target.WorkloadID, content) + } + } + } + + return resultMap, nil +} + +// ReadResource retrieves a resource from the backend MCP server. +func (h *httpBackendClient) ReadResource(ctx context.Context, target *vmcp.BackendTarget, uri string) ([]byte, error) { + logger.Debugf("Reading resource %s from backend %s", uri, target.WorkloadName) + + // Create a client for this backend + c, err := h.clientFactory(ctx, target) + if err != nil { + return nil, fmt.Errorf("failed to create client for backend %s: %w", target.WorkloadID, err) + } + defer c.Close() + + // Read the resource + result, err := c.ReadResource(ctx, mcp.ReadResourceRequest{ + Params: mcp.ReadResourceParams{ + URI: uri, + }, + }) + if err != nil { + return nil, fmt.Errorf("resource read failed on backend %s: %w", target.WorkloadID, err) + } + + // Concatenate all resource contents + // MCP resources can have multiple contents (text or blob) + var data []byte + for _, content := range result.Contents { + // Try to convert to TextResourceContents + if textContent, ok := mcp.AsTextResourceContents(content); ok { + data = append(data, []byte(textContent.Text)...) + } else if blobContent, ok := mcp.AsBlobResourceContents(content); ok { + // Blob is base64-encoded per MCP spec, decode it to bytes + decoded, err := base64.StdEncoding.DecodeString(blobContent.Blob) + if err != nil { + logger.Warnf("Failed to decode base64 blob from resource %s on backend %s: %v", + uri, target.WorkloadID, err) + // Append raw blob as fallback + data = append(data, []byte(blobContent.Blob)...) + } else { + data = append(data, decoded...) + } + } + } + + return data, nil +} + +// GetPrompt retrieves a prompt from the backend MCP server. +func (h *httpBackendClient) GetPrompt( + ctx context.Context, + target *vmcp.BackendTarget, + name string, + arguments map[string]any, +) (string, error) { + logger.Debugf("Getting prompt %s from backend %s", name, target.WorkloadName) + + // Create a client for this backend + c, err := h.clientFactory(ctx, target) + if err != nil { + return "", fmt.Errorf("failed to create client for backend %s: %w", target.WorkloadID, err) + } + defer c.Close() + + // Get the prompt + // Convert map[string]any to map[string]string + stringArgs := make(map[string]string) + for k, v := range arguments { + stringArgs[k] = fmt.Sprintf("%v", v) + } + + result, err := c.GetPrompt(ctx, mcp.GetPromptRequest{ + Params: mcp.GetPromptParams{ + Name: name, + Arguments: stringArgs, + }, + }) + if err != nil { + return "", fmt.Errorf("prompt get failed on backend %s: %w", target.WorkloadID, err) + } + + // Concatenate all prompt messages into a single string + // MCP prompts return messages with role and content (Content interface) + var prompt string + for _, msg := range result.Messages { + if msg.Role != "" { + prompt += fmt.Sprintf("[%s] ", msg.Role) + } + // Try to convert content to TextContent + if textContent, ok := mcp.AsTextContent(msg.Content); ok { + prompt += textContent.Text + "\n" + } + // TODO: Handle other content types (image, audio, resource) + } + + return prompt, nil +} diff --git a/pkg/vmcp/client/client_test.go b/pkg/vmcp/client/client_test.go new file mode 100644 index 000000000..2a7619cb0 --- /dev/null +++ b/pkg/vmcp/client/client_test.go @@ -0,0 +1,191 @@ +package client + +import ( + "context" + "errors" + "testing" + + "github.com/mark3labs/mcp-go/client" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp" +) + +func TestHTTPBackendClient_ListCapabilities_WithMockFactory(t *testing.T) { + t.Parallel() + + t.Run("handles client factory error", func(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("factory error") + mockFactory := func(_ context.Context, _ *vmcp.BackendTarget) (*client.Client, error) { + return nil, expectedErr + } + + backendClient := &httpBackendClient{ + clientFactory: mockFactory, + } + + target := &vmcp.BackendTarget{ + WorkloadID: "test-backend", + WorkloadName: "Test Backend", + BaseURL: "http://localhost:8080", + TransportType: "streamable-http", + } + + capabilities, err := backendClient.ListCapabilities(context.Background(), target) + + require.Error(t, err) + assert.Nil(t, capabilities) + assert.Contains(t, err.Error(), "failed to create client") + assert.Contains(t, err.Error(), "test-backend") + }) +} + +func TestDefaultClientFactory_UnsupportedTransport(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + transportType string + }{ + { + name: "stdio transport", + transportType: "stdio", + }, + { + name: "unknown transport", + transportType: "unknown-protocol", + }, + { + name: "empty transport", + transportType: "", + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + target := &vmcp.BackendTarget{ + WorkloadID: "test-backend", + WorkloadName: "Test Backend", + BaseURL: "http://localhost:8080", + TransportType: tc.transportType, + } + + _, err := defaultClientFactory(context.Background(), target) + + require.Error(t, err) + assert.ErrorIs(t, err, vmcp.ErrUnsupportedTransport) + assert.Contains(t, err.Error(), tc.transportType) + }) + } +} + +func TestHTTPBackendClient_CallTool_WithMockFactory(t *testing.T) { + t.Parallel() + + t.Run("handles client factory error", func(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("connection failed") + mockFactory := func(_ context.Context, _ *vmcp.BackendTarget) (*client.Client, error) { + return nil, expectedErr + } + + backendClient := &httpBackendClient{ + clientFactory: mockFactory, + } + + target := &vmcp.BackendTarget{ + WorkloadID: "test-backend", + WorkloadName: "Test Backend", + BaseURL: "http://localhost:8080", + TransportType: "streamable-http", + } + + result, err := backendClient.CallTool(context.Background(), target, "test_tool", map[string]any{}) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "failed to create client") + }) +} + +func TestHTTPBackendClient_ReadResource_WithMockFactory(t *testing.T) { + t.Parallel() + + t.Run("handles client factory error", func(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("connection failed") + mockFactory := func(_ context.Context, _ *vmcp.BackendTarget) (*client.Client, error) { + return nil, expectedErr + } + + backendClient := &httpBackendClient{ + clientFactory: mockFactory, + } + + target := &vmcp.BackendTarget{ + WorkloadID: "test-backend", + WorkloadName: "Test Backend", + BaseURL: "http://localhost:8080", + TransportType: "streamable-http", + } + + data, err := backendClient.ReadResource(context.Background(), target, "test://resource") + + require.Error(t, err) + assert.Nil(t, data) + assert.Contains(t, err.Error(), "failed to create client") + }) +} + +func TestHTTPBackendClient_GetPrompt_WithMockFactory(t *testing.T) { + t.Parallel() + + t.Run("handles client factory error", func(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("connection failed") + mockFactory := func(_ context.Context, _ *vmcp.BackendTarget) (*client.Client, error) { + return nil, expectedErr + } + + backendClient := &httpBackendClient{ + clientFactory: mockFactory, + } + + target := &vmcp.BackendTarget{ + WorkloadID: "test-backend", + WorkloadName: "Test Backend", + BaseURL: "http://localhost:8080", + TransportType: "streamable-http", + } + + prompt, err := backendClient.GetPrompt(context.Background(), target, "test_prompt", map[string]any{"arg": "value"}) + + require.Error(t, err) + assert.Empty(t, prompt) + assert.Contains(t, err.Error(), "failed to create client") + }) +} + +func TestInitializeClient_ErrorHandling(t *testing.T) { + t.Parallel() + + // This test verifies that initializeClient properly propagates errors + // We can't easily test the success case without a real MCP server + // Integration tests will cover the success path + t.Run("error handling structure", func(t *testing.T) { + t.Parallel() + + // Verify that initializeClient exists and has the right signature + // The actual error handling is tested via integration tests + assert.NotNil(t, initializeClient) + }) +} diff --git a/pkg/vmcp/client/conversions_test.go b/pkg/vmcp/client/conversions_test.go new file mode 100644 index 000000000..97165ec55 --- /dev/null +++ b/pkg/vmcp/client/conversions_test.go @@ -0,0 +1,413 @@ +package client + +import ( + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// These tests verify the critical type conversion logic in the backend client. +// Since we can't easily mock the mark3labs client, we test the conversion patterns +// that our code uses to transform MCP SDK types to vmcp domain types. + +func TestToolInputSchemaConversion(t *testing.T) { + t.Parallel() + + t.Run("converts basic tool schema", func(t *testing.T) { + t.Parallel() + + sdkTool := mcp.Tool{ + Name: "create_issue", + Description: "Create a GitHub issue", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "title": map[string]any{"type": "string", "description": "Issue title"}, + "body": map[string]any{"type": "string", "description": "Issue body"}, + }, + Required: []string{"title"}, + }, + } + + inputSchema := convertToolInputSchema(sdkTool.InputSchema) + + assert.Equal(t, "object", inputSchema["type"]) + assert.NotNil(t, inputSchema["properties"]) + assert.Equal(t, []string{"title"}, inputSchema["required"]) + + props := inputSchema["properties"].(map[string]any) + assert.Contains(t, props, "title") + assert.Contains(t, props, "body") + titleProp := props["title"].(map[string]any) + assert.Equal(t, "string", titleProp["type"]) + assert.Equal(t, "Issue title", titleProp["description"]) + }) + + t.Run("converts schema with $defs", func(t *testing.T) { + t.Parallel() + + sdkTool := mcp.Tool{ + Name: "complex_tool", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "config": map[string]any{"$ref": "#/$defs/Config"}, + }, + Defs: map[string]any{ + "Config": map[string]any{ + "type": "object", + "properties": map[string]any{"enabled": map[string]any{"type": "boolean"}}, + }, + }, + }, + } + + inputSchema := convertToolInputSchema(sdkTool.InputSchema) + + assert.Contains(t, inputSchema, "$defs") + defs := inputSchema["$defs"].(map[string]any) + assert.Contains(t, defs, "Config") + }) + + t.Run("handles empty required array", func(t *testing.T) { + t.Parallel() + + sdkTool := mcp.Tool{ + Name: "optional_tool", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{"optional_param": map[string]any{"type": "string"}}, + Required: []string{}, + }, + } + + inputSchema := convertToolInputSchema(sdkTool.InputSchema) + + assert.NotContains(t, inputSchema, "required") + }) +} + +func TestContentInterfaceHandling(t *testing.T) { + t.Parallel() + + t.Run("extracts text content correctly", func(t *testing.T) { + t.Parallel() + + toolResult := &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent("First text result"), + mcp.NewTextContent("Second text result"), + }, + IsError: false, + } + + resultMap := convertContentToMap(toolResult.Content) + + assert.Equal(t, "First text result", resultMap["text"]) + assert.Equal(t, "Second text result", resultMap["text_1"]) + }) + + t.Run("extracts mixed content types", func(t *testing.T) { + t.Parallel() + + toolResult := &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent("Text content"), + mcp.NewImageContent("base64data", "image/png"), + mcp.NewTextContent("More text"), + }, + IsError: false, + } + + resultMap := convertContentToMap(toolResult.Content) + + assert.Equal(t, "Text content", resultMap["text"]) + assert.Equal(t, "More text", resultMap["text_1"]) + assert.Equal(t, "base64data", resultMap["image_0"]) + }) + + t.Run("handles error result correctly", func(t *testing.T) { + t.Parallel() + + toolResult := &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent("Error: something went wrong"), + }, + IsError: true, + } + + // Verify IsError is a boolean (not pointer) - from client.go:223 + assert.True(t, toolResult.IsError) + // Our code should check: if result.IsError { return error } + }) +} + +func TestResourceContentsHandling(t *testing.T) { + t.Parallel() + + t.Run("extracts text resource content", func(t *testing.T) { + t.Parallel() + + resourceResult := &mcp.ReadResourceResult{ + Contents: []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: "test://resource", + MIMEType: "text/plain", + Text: "Resource text content", + }, + }, + } + + data := convertResourceContents(resourceResult.Contents) + assert.Equal(t, []byte("Resource text content"), data) + }) + + t.Run("extracts blob resource content", func(t *testing.T) { + t.Parallel() + + resourceResult := &mcp.ReadResourceResult{ + Contents: []mcp.ResourceContents{ + mcp.BlobResourceContents{ + URI: "test://binary", + MIMEType: "application/octet-stream", + Blob: "YmFzZTY0ZGF0YQ==", + }, + }, + } + + data := convertResourceContents(resourceResult.Contents) + assert.Equal(t, []byte("YmFzZTY0ZGF0YQ=="), data) + }) + + t.Run("concatenates multiple resource contents", func(t *testing.T) { + t.Parallel() + + resourceResult := &mcp.ReadResourceResult{ + Contents: []mcp.ResourceContents{ + mcp.TextResourceContents{URI: "test://multi", Text: "Part 1"}, + mcp.TextResourceContents{URI: "test://multi", Text: "Part 2"}, + }, + } + + data := convertResourceContents(resourceResult.Contents) + assert.Equal(t, []byte("Part 1Part 2"), data) + }) +} + +func TestPromptMessageHandling(t *testing.T) { + t.Parallel() + + t.Run("extracts prompt with single message", func(t *testing.T) { + t.Parallel() + + promptResult := &mcp.GetPromptResult{ + Description: "Test prompt", + Messages: []mcp.PromptMessage{ + {Role: "user", Content: mcp.NewTextContent("What is the weather?")}, + }, + } + + prompt := convertPromptMessages(promptResult.Messages) + assert.Equal(t, "[user] What is the weather?\n", prompt) + }) + + t.Run("concatenates multiple prompt messages", func(t *testing.T) { + t.Parallel() + + promptResult := &mcp.GetPromptResult{ + Messages: []mcp.PromptMessage{ + {Role: "system", Content: mcp.NewTextContent("You are a helpful assistant")}, + {Role: "user", Content: mcp.NewTextContent("Hello")}, + {Role: "assistant", Content: mcp.NewTextContent("Hi there!")}, + }, + } + + prompt := convertPromptMessages(promptResult.Messages) + expected := "[system] You are a helpful assistant\n[user] Hello\n[assistant] Hi there!\n" + assert.Equal(t, expected, prompt) + }) + + t.Run("handles prompt message without role", func(t *testing.T) { + t.Parallel() + + promptResult := &mcp.GetPromptResult{ + Messages: []mcp.PromptMessage{ + {Role: "", Content: mcp.NewTextContent("Message content")}, + }, + } + + prompt := convertPromptMessages(promptResult.Messages) + assert.Equal(t, "Message content\n", prompt) + }) +} + +func TestGetPromptArgumentsConversion(t *testing.T) { + t.Parallel() + + t.Run("converts map[string]any to map[string]string", func(t *testing.T) { + t.Parallel() + + arguments := map[string]any{ + "string_arg": "value", + "int_arg": 42, + "bool_arg": true, + "float_arg": 3.14, + } + + stringArgs := convertPromptArguments(arguments) + + assert.Equal(t, "value", stringArgs["string_arg"]) + assert.Equal(t, "42", stringArgs["int_arg"]) + assert.Equal(t, "true", stringArgs["bool_arg"]) + assert.Equal(t, "3.14", stringArgs["float_arg"]) + }) + + t.Run("handles nil and empty values", func(t *testing.T) { + t.Parallel() + + arguments := map[string]any{ + "nil_arg": nil, + "empty_arg": "", + } + + stringArgs := convertPromptArguments(arguments) + + assert.Equal(t, "", stringArgs["nil_arg"]) + assert.Equal(t, "", stringArgs["empty_arg"]) + }) +} + +func TestResourceMIMETypeField(t *testing.T) { + t.Parallel() + + t.Run("uses MIMEType not MimeType", func(t *testing.T) { + t.Parallel() + + // This verifies we're using the correct field name (from client.go:167) + sdkResource := mcp.Resource{ + URI: "test://resource", + Name: "Test Resource", + Description: "A test resource", + MIMEType: "application/json", // Note: MIMEType, not MimeType + } + + vmcpResource := vmcp.Resource{ + URI: sdkResource.URI, + Name: sdkResource.Name, + Description: sdkResource.Description, + MimeType: sdkResource.MIMEType, // Our conversion uses MIMEType + BackendID: "test-backend", + } + + assert.Equal(t, "application/json", vmcpResource.MimeType) + }) +} + +func TestMultipleContentItemsHandling(t *testing.T) { + t.Parallel() + + t.Run("handles tool result with many text items", func(t *testing.T) { + t.Parallel() + + toolResult := &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent("Result 1"), + mcp.NewTextContent("Result 2"), + mcp.NewTextContent("Result 3"), + mcp.NewTextContent("Result 4"), + mcp.NewTextContent("Result 5"), + }, + IsError: false, + } + + resultMap := convertContentToMap(toolResult.Content) + + assert.Equal(t, "Result 1", resultMap["text"]) + assert.Equal(t, "Result 2", resultMap["text_1"]) + assert.Equal(t, "Result 3", resultMap["text_2"]) + assert.Equal(t, "Result 4", resultMap["text_3"]) + assert.Equal(t, "Result 5", resultMap["text_4"]) + }) + + t.Run("handles tool result with many images", func(t *testing.T) { + t.Parallel() + + toolResult := &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewImageContent("data1", "image/png"), + mcp.NewImageContent("data2", "image/jpeg"), + mcp.NewImageContent("data3", "image/gif"), + }, + IsError: false, + } + + resultMap := convertContentToMap(toolResult.Content) + + assert.Equal(t, "data1", resultMap["image_0"]) + assert.Equal(t, "data2", resultMap["image_1"]) + assert.Equal(t, "data3", resultMap["image_2"]) + }) + + t.Run("handles empty content array", func(t *testing.T) { + t.Parallel() + + emptyContent := []mcp.Content{} + resultMap := convertContentToMap(emptyContent) + + assert.Empty(t, resultMap) + }) +} + +func TestPromptArgumentConversion(t *testing.T) { + t.Parallel() + + t.Run("converts prompt arguments correctly", func(t *testing.T) { + t.Parallel() + + // From client.go:174-183 + sdkPrompt := mcp.Prompt{ + Name: "test_prompt", + Description: "A test prompt", + Arguments: []mcp.PromptArgument{ + { + Name: "required_arg", + Description: "A required argument", + Required: true, + }, + { + Name: "optional_arg", + Description: "An optional argument", + Required: false, + }, + }, + } + + // Apply our conversion + args := make([]vmcp.PromptArgument, len(sdkPrompt.Arguments)) + for j, arg := range sdkPrompt.Arguments { + args[j] = vmcp.PromptArgument{ + Name: arg.Name, + Description: arg.Description, + Required: arg.Required, + } + } + + vmcpPrompt := vmcp.Prompt{ + Name: sdkPrompt.Name, + Description: sdkPrompt.Description, + Arguments: args, + BackendID: "test-backend", + } + + // Verify conversion + require.Len(t, vmcpPrompt.Arguments, 2) + assert.Equal(t, "required_arg", vmcpPrompt.Arguments[0].Name) + assert.True(t, vmcpPrompt.Arguments[0].Required) + assert.Equal(t, "optional_arg", vmcpPrompt.Arguments[1].Name) + assert.False(t, vmcpPrompt.Arguments[1].Required) + }) +} diff --git a/pkg/vmcp/client/testhelpers_test.go b/pkg/vmcp/client/testhelpers_test.go new file mode 100644 index 000000000..55c12e4a2 --- /dev/null +++ b/pkg/vmcp/client/testhelpers_test.go @@ -0,0 +1,84 @@ +package client + +import ( + "fmt" + + "github.com/mark3labs/mcp-go/mcp" +) + +// Helper functions to encapsulate conversion logic patterns + +// convertToolInputSchema simulates the conversion logic from client.go:138-151 +func convertToolInputSchema(schema mcp.ToolInputSchema) map[string]any { + inputSchema := map[string]any{ + "type": schema.Type, + } + if schema.Properties != nil { + inputSchema["properties"] = schema.Properties + } + if len(schema.Required) > 0 { + inputSchema["required"] = schema.Required + } + if schema.Defs != nil { + inputSchema["$defs"] = schema.Defs + } + return inputSchema +} + +// convertContentToMap simulates the conversion logic from client.go:228-250 +func convertContentToMap(contents []mcp.Content) map[string]any { + resultMap := make(map[string]any) + textIndex := 0 + imageIndex := 0 + for _, content := range contents { + if textContent, ok := mcp.AsTextContent(content); ok { + key := "text" + if textIndex > 0 { + key = fmt.Sprintf("text_%d", textIndex) + } + resultMap[key] = textContent.Text + textIndex++ + } else if imageContent, ok := mcp.AsImageContent(content); ok { + key := fmt.Sprintf("image_%d", imageIndex) + resultMap[key] = imageContent.Data + imageIndex++ + } + } + return resultMap +} + +// convertResourceContents simulates the conversion logic from client.go:276-289 +func convertResourceContents(contents []mcp.ResourceContents) []byte { + var data []byte + for _, content := range contents { + if textContent, ok := mcp.AsTextResourceContents(content); ok { + data = append(data, []byte(textContent.Text)...) + } else if blobContent, ok := mcp.AsBlobResourceContents(content); ok { + data = append(data, []byte(blobContent.Blob)...) + } + } + return data +} + +// convertPromptMessages simulates the conversion logic from client.go:315-327 +func convertPromptMessages(messages []mcp.PromptMessage) string { + var prompt string + for _, msg := range messages { + if msg.Role != "" { + prompt += "[" + string(msg.Role) + "] " + } + if textContent, ok := mcp.AsTextContent(msg.Content); ok { + prompt += textContent.Text + "\n" + } + } + return prompt +} + +// convertPromptArguments simulates the conversion logic from client.go:306-309 +func convertPromptArguments(arguments map[string]any) map[string]string { + stringArgs := make(map[string]string) + for k, v := range arguments { + stringArgs[k] = fmt.Sprintf("%v", v) + } + return stringArgs +} diff --git a/pkg/vmcp/errors.go b/pkg/vmcp/errors.go index 43f018457..fbad0cbd0 100644 --- a/pkg/vmcp/errors.go +++ b/pkg/vmcp/errors.go @@ -38,4 +38,20 @@ var ( // ErrInvalidInput indicates invalid input parameters. // Wrapping errors should specify which parameter is invalid and why. ErrInvalidInput = errors.New("invalid input") + + // ErrUnsupportedTransport indicates an unsupported MCP transport type. + // Wrapping errors should specify which transport type is not supported. + ErrUnsupportedTransport = errors.New("unsupported transport type") + + // ErrToolExecutionFailed indicates an MCP tool execution failed (domain error). + // This represents the tool running but returning an error result (IsError=true in MCP). + // These errors should be forwarded to the client transparently as the LLM needs to see them. + // Wrapping errors should include the tool name and error message from MCP. + ErrToolExecutionFailed = errors.New("tool execution failed") + + // ErrBackendUnavailable indicates a backend MCP server is unreachable (operational error). + // This represents infrastructure issues (network down, server not responding, etc.). + // These errors may be retried, circuit-broken, or handled differently from domain errors. + // Wrapping errors should include the backend ID and underlying cause. + ErrBackendUnavailable = errors.New("backend unavailable") ) diff --git a/pkg/vmcp/mocks/mock_backend_client.go b/pkg/vmcp/mocks/mock_backend_client.go new file mode 100644 index 000000000..37e882c5c --- /dev/null +++ b/pkg/vmcp/mocks/mock_backend_client.go @@ -0,0 +1,141 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: types.go +// +// Generated by this command: +// +// mockgen -destination=mocks/mock_backend_client.go -package=mocks -source=types.go BackendClient HealthChecker +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + vmcp "github.com/stacklok/toolhive/pkg/vmcp" + gomock "go.uber.org/mock/gomock" +) + +// MockHealthChecker is a mock of HealthChecker interface. +type MockHealthChecker struct { + ctrl *gomock.Controller + recorder *MockHealthCheckerMockRecorder + isgomock struct{} +} + +// MockHealthCheckerMockRecorder is the mock recorder for MockHealthChecker. +type MockHealthCheckerMockRecorder struct { + mock *MockHealthChecker +} + +// NewMockHealthChecker creates a new mock instance. +func NewMockHealthChecker(ctrl *gomock.Controller) *MockHealthChecker { + mock := &MockHealthChecker{ctrl: ctrl} + mock.recorder = &MockHealthCheckerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockHealthChecker) EXPECT() *MockHealthCheckerMockRecorder { + return m.recorder +} + +// CheckHealth mocks base method. +func (m *MockHealthChecker) CheckHealth(ctx context.Context, target *vmcp.BackendTarget) (vmcp.BackendHealthStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CheckHealth", ctx, target) + ret0, _ := ret[0].(vmcp.BackendHealthStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CheckHealth indicates an expected call of CheckHealth. +func (mr *MockHealthCheckerMockRecorder) CheckHealth(ctx, target any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckHealth", reflect.TypeOf((*MockHealthChecker)(nil).CheckHealth), ctx, target) +} + +// MockBackendClient is a mock of BackendClient interface. +type MockBackendClient struct { + ctrl *gomock.Controller + recorder *MockBackendClientMockRecorder + isgomock struct{} +} + +// MockBackendClientMockRecorder is the mock recorder for MockBackendClient. +type MockBackendClientMockRecorder struct { + mock *MockBackendClient +} + +// NewMockBackendClient creates a new mock instance. +func NewMockBackendClient(ctrl *gomock.Controller) *MockBackendClient { + mock := &MockBackendClient{ctrl: ctrl} + mock.recorder = &MockBackendClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBackendClient) EXPECT() *MockBackendClientMockRecorder { + return m.recorder +} + +// CallTool mocks base method. +func (m *MockBackendClient) CallTool(ctx context.Context, target *vmcp.BackendTarget, toolName string, arguments map[string]any) (map[string]any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CallTool", ctx, target, toolName, arguments) + ret0, _ := ret[0].(map[string]any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CallTool indicates an expected call of CallTool. +func (mr *MockBackendClientMockRecorder) CallTool(ctx, target, toolName, arguments any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CallTool", reflect.TypeOf((*MockBackendClient)(nil).CallTool), ctx, target, toolName, arguments) +} + +// GetPrompt mocks base method. +func (m *MockBackendClient) GetPrompt(ctx context.Context, target *vmcp.BackendTarget, name string, arguments map[string]any) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPrompt", ctx, target, name, arguments) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPrompt indicates an expected call of GetPrompt. +func (mr *MockBackendClientMockRecorder) GetPrompt(ctx, target, name, arguments any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrompt", reflect.TypeOf((*MockBackendClient)(nil).GetPrompt), ctx, target, name, arguments) +} + +// ListCapabilities mocks base method. +func (m *MockBackendClient) ListCapabilities(ctx context.Context, target *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListCapabilities", ctx, target) + ret0, _ := ret[0].(*vmcp.CapabilityList) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListCapabilities indicates an expected call of ListCapabilities. +func (mr *MockBackendClientMockRecorder) ListCapabilities(ctx, target any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListCapabilities", reflect.TypeOf((*MockBackendClient)(nil).ListCapabilities), ctx, target) +} + +// ReadResource mocks base method. +func (m *MockBackendClient) ReadResource(ctx context.Context, target *vmcp.BackendTarget, uri string) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadResource", ctx, target, uri) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadResource indicates an expected call of ReadResource. +func (mr *MockBackendClientMockRecorder) ReadResource(ctx, target, uri any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadResource", reflect.TypeOf((*MockBackendClient)(nil).ReadResource), ctx, target, uri) +} diff --git a/pkg/vmcp/mocks/mock_backend_registry.go b/pkg/vmcp/mocks/mock_backend_registry.go new file mode 100644 index 000000000..0bbe8b647 --- /dev/null +++ b/pkg/vmcp/mocks/mock_backend_registry.go @@ -0,0 +1,84 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: registry.go +// +// Generated by this command: +// +// mockgen -destination=mocks/mock_backend_registry.go -package=mocks -source=registry.go BackendRegistry +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + vmcp "github.com/stacklok/toolhive/pkg/vmcp" + gomock "go.uber.org/mock/gomock" +) + +// MockBackendRegistry is a mock of BackendRegistry interface. +type MockBackendRegistry struct { + ctrl *gomock.Controller + recorder *MockBackendRegistryMockRecorder + isgomock struct{} +} + +// MockBackendRegistryMockRecorder is the mock recorder for MockBackendRegistry. +type MockBackendRegistryMockRecorder struct { + mock *MockBackendRegistry +} + +// NewMockBackendRegistry creates a new mock instance. +func NewMockBackendRegistry(ctrl *gomock.Controller) *MockBackendRegistry { + mock := &MockBackendRegistry{ctrl: ctrl} + mock.recorder = &MockBackendRegistryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBackendRegistry) EXPECT() *MockBackendRegistryMockRecorder { + return m.recorder +} + +// Count mocks base method. +func (m *MockBackendRegistry) Count() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Count") + ret0, _ := ret[0].(int) + return ret0 +} + +// Count indicates an expected call of Count. +func (mr *MockBackendRegistryMockRecorder) Count() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Count", reflect.TypeOf((*MockBackendRegistry)(nil).Count)) +} + +// Get mocks base method. +func (m *MockBackendRegistry) Get(ctx context.Context, backendID string) *vmcp.Backend { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, backendID) + ret0, _ := ret[0].(*vmcp.Backend) + return ret0 +} + +// Get indicates an expected call of Get. +func (mr *MockBackendRegistryMockRecorder) Get(ctx, backendID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockBackendRegistry)(nil).Get), ctx, backendID) +} + +// List mocks base method. +func (m *MockBackendRegistry) List(ctx context.Context) []vmcp.Backend { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", ctx) + ret0, _ := ret[0].([]vmcp.Backend) + return ret0 +} + +// List indicates an expected call of List. +func (mr *MockBackendRegistryMockRecorder) List(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockBackendRegistry)(nil).List), ctx) +} diff --git a/pkg/vmcp/registry.go b/pkg/vmcp/registry.go new file mode 100644 index 000000000..d83c94e94 --- /dev/null +++ b/pkg/vmcp/registry.go @@ -0,0 +1,134 @@ +package vmcp + +import ( + "context" +) + +// BackendRegistry provides thread-safe access to discovered backends. +// This is a shared kernel interface used across vmcp bounded contexts +// (aggregator, router, health monitoring). +// +// The registry serves as the single source of truth for backend information +// during the lifecycle of a virtual MCP server instance. It supports both +// immutable (Phase 1) and mutable (future phases) implementations. +// +// Design Philosophy: +// - Phase 1: Immutable registry (backends discovered once, never change) +// - Future: Mutable registry with health monitoring and dynamic updates +// - Thread-safe for concurrent reads across all implementations +// - Implementations may support concurrent writes with appropriate locking +// +//go:generate mockgen -destination=mocks/mock_backend_registry.go -package=mocks -source=registry.go BackendRegistry +type BackendRegistry interface { + // Get retrieves a backend by ID. + // Returns nil if the backend is not found. + // This method is safe for concurrent reads. + // + // Example: + // backend := registry.Get(ctx, "github-mcp") + // if backend == nil { + // return fmt.Errorf("backend not found") + // } + Get(ctx context.Context, backendID string) *Backend + + // List returns all registered backends. + // The returned slice is a snapshot and safe to iterate without additional locking. + // Order is not guaranteed unless specified by the implementation. + // + // Example: + // backends := registry.List(ctx) + // for _, backend := range backends { + // fmt.Printf("Backend: %s\n", backend.Name) + // } + List(ctx context.Context) []Backend + + // Count returns the number of registered backends. + // This is more efficient than len(List()) for large registries. + Count() int +} + +// immutableRegistry is a Phase 1 implementation that stores a static list +// of backends discovered at startup. It's thread-safe for concurrent reads +// and never changes after construction. +// +// Use NewImmutableRegistry() to create instances. +type immutableRegistry struct { + // backends maps backend ID to backend information. + // This map is built once at construction and never modified. + backends map[string]Backend +} + +// NewImmutableRegistry creates a registry from a static list of backends. +// +// This implementation is used in Phase 1 where backends are discovered once +// at startup and don't change during the virtual MCP server's lifetime. +// The registry is thread-safe for concurrent reads. +// +// Parameters: +// - backends: List of discovered backends to register +// +// Returns: +// - BackendRegistry: An immutable registry instance +// +// Example: +// +// backends := discoverer.Discover(ctx, "engineering-team") +// registry := vmcp.NewImmutableRegistry(backends) +// backend := registry.Get(ctx, "github-mcp") +func NewImmutableRegistry(backends []Backend) BackendRegistry { + reg := &immutableRegistry{ + backends: make(map[string]Backend, len(backends)), + } + for _, b := range backends { + reg.backends[b.ID] = b + } + return reg +} + +// Get retrieves a backend by ID from the immutable registry. +// Returns nil if the backend is not found. +func (r *immutableRegistry) Get(_ context.Context, backendID string) *Backend { + if b, exists := r.backends[backendID]; exists { + // Return a copy to prevent external modifications + return &b + } + return nil +} + +// List returns all registered backends as a slice. +// The order is not guaranteed. The returned slice is a copy and safe to modify. +func (r *immutableRegistry) List(_ context.Context) []Backend { + backends := make([]Backend, 0, len(r.backends)) + for _, b := range r.backends { + backends = append(backends, b) + } + return backends +} + +// Count returns the number of registered backends. +func (r *immutableRegistry) Count() int { + return len(r.backends) +} + +// BackendToTarget converts a Backend to a BackendTarget for routing. +// This helper is used when populating routing tables during capability aggregation. +// +// The BackendTarget contains all information needed to forward requests to +// a specific backend workload, including authentication strategy and metadata. +func BackendToTarget(backend *Backend) *BackendTarget { + if backend == nil { + return nil + } + + return &BackendTarget{ + WorkloadID: backend.ID, + WorkloadName: backend.Name, + BaseURL: backend.BaseURL, + TransportType: backend.TransportType, + AuthStrategy: backend.AuthStrategy, + AuthMetadata: backend.AuthMetadata, + SessionAffinity: false, // TODO: Add session affinity support in future phases + HealthStatus: backend.HealthStatus, + Metadata: backend.Metadata, + } +} diff --git a/pkg/vmcp/types.go b/pkg/vmcp/types.go index a55c79cf3..118e2082a 100644 --- a/pkg/vmcp/types.go +++ b/pkg/vmcp/types.go @@ -191,6 +191,8 @@ type HealthChecker interface { // BackendClient abstracts MCP protocol communication with backend servers. // This interface handles the protocol-level details of calling backend MCP servers, // supporting multiple transport types (HTTP, SSE, stdio, streamable-http). +// +//go:generate mockgen -destination=mocks/mock_backend_client.go -package=mocks -source=types.go BackendClient HealthChecker type BackendClient interface { // CallTool invokes a tool on the backend MCP server. // Returns the tool output or an error.