diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 881bae87f..129e6c936 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -57,6 +57,7 @@ var ( EnabledToolsets: enabledToolsets, DynamicToolsets: viper.GetBool("dynamic_toolsets"), ReadOnly: viper.GetBool("read-only"), + WritePrivateOnly: viper.GetBool("write-private-only"), ExportTranslations: viper.GetBool("export-translations"), EnableCommandLogging: viper.GetBool("enable-command-logging"), LogFilePath: viper.GetString("log-file"), @@ -105,6 +106,7 @@ func init() { rootCmd.PersistentFlags().StringSlice("toolsets", github.DefaultTools, "An optional comma separated list of groups of tools to allow with optional modes (e.g., 'repos:rw,issues:ro,users'), defaults to enabling all") rootCmd.PersistentFlags().Bool("dynamic-toolsets", false, "Enable dynamic toolsets") rootCmd.PersistentFlags().Bool("read-only", false, "Restrict the server to read-only operations") + rootCmd.PersistentFlags().Bool("write-private-only", false, "Restrict all write operations to private repositories only") rootCmd.PersistentFlags().String("log-file", "", "Path to log file") rootCmd.PersistentFlags().Bool("enable-command-logging", false, "When enabled, the server will log all command requests and responses to the log file") rootCmd.PersistentFlags().Bool("export-translations", false, "Save translations to a JSON file") @@ -121,6 +123,8 @@ func init() { _ = viper.BindEnv("toolsets", "GITHUB_TOOLSETS") _ = viper.BindPFlag("dynamic_toolsets", rootCmd.PersistentFlags().Lookup("dynamic-toolsets")) _ = viper.BindPFlag("read-only", rootCmd.PersistentFlags().Lookup("read-only")) + _ = viper.BindPFlag("write-private-only", rootCmd.PersistentFlags().Lookup("write-private-only")) + _ = viper.BindEnv("write-private-only", "GITHUB_WRITE_PRIVATE_ONLY") _ = viper.BindPFlag("log-file", rootCmd.PersistentFlags().Lookup("log-file")) _ = viper.BindPFlag("enable-command-logging", rootCmd.PersistentFlags().Lookup("enable-command-logging")) _ = viper.BindPFlag("export-translations", rootCmd.PersistentFlags().Lookup("export-translations")) diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index a1bb1ff0f..87159a10f 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -216,6 +216,11 @@ type MCPServerConfig struct { // ReadOnly indicates if we should only offer read-only tools ReadOnly bool + // WritePrivateOnly restricts all write operations to private repositories only. + // When true, write tool handlers are wrapped with a visibility guard that blocks + // writes to public repositories. Has no effect when ReadOnly is also true. + WritePrivateOnly bool + // Installations maps organization names to GitHub App installation IDs Installations map[string]int64 @@ -288,6 +293,7 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { toolsets, err := github.InitToolsets( enabledToolsets, cfg.ReadOnly, + cfg.WritePrivateOnly, getClient, getGQLClient, cfg.Translator, @@ -296,6 +302,14 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { return nil, fmt.Errorf("failed to initialize toolsets: %w", err) } + if cfg.WritePrivateOnly { + if cfg.ReadOnly { + logrus.Warn("GITHUB_WRITE_PRIVATE_ONLY is set but has no effect because --read-only is also active. Write tools are not registered in read-only mode.") + } else { + logrus.Info("Write operations restricted to private repositories (GITHUB_WRITE_PRIVATE_ONLY=true)") + } + } + github.RegisterResources(ghServer, getClient, cfg.Translator) // Register the tools with the server @@ -330,6 +344,9 @@ type StdioServerConfig struct { // ReadOnly indicates if we should only register read-only tools ReadOnly bool + // WritePrivateOnly restricts all write operations to private repositories only. + WritePrivateOnly bool + // ExportTranslations indicates if we should export translations // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#i18n--overriding-descriptions ExportTranslations bool @@ -353,14 +370,15 @@ func RunStdioServer(cfg StdioServerConfig) error { t, dumpTranslations := translations.TranslationHelper() ghServer, err := NewMCPServer(MCPServerConfig{ - Version: cfg.Version, - Host: cfg.Host, - Token: cfg.Token, - EnabledToolsets: cfg.EnabledToolsets, - DynamicToolsets: cfg.DynamicToolsets, - ReadOnly: cfg.ReadOnly, - Installations: cfg.Installations, - Translator: t, + Version: cfg.Version, + Host: cfg.Host, + Token: cfg.Token, + EnabledToolsets: cfg.EnabledToolsets, + DynamicToolsets: cfg.DynamicToolsets, + ReadOnly: cfg.ReadOnly, + WritePrivateOnly: cfg.WritePrivateOnly, + Installations: cfg.Installations, + Translator: t, }) if err != nil { return fmt.Errorf("failed to create MCP server: %w", err) diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 5cd5c7916..32b21efe7 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -7,6 +7,7 @@ import ( "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/shurcooL/githubv4" ) @@ -16,20 +17,41 @@ type GetGQLClientFn func(ctx context.Context, owner string) (*githubv4.Client, e var DefaultTools = []string{"all"} -func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (*toolsets.ToolsetGroup, error) { +func InitToolsets(passedToolsets []string, readOnly bool, writePrivateOnly bool, getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (*toolsets.ToolsetGroup, error) { // Parse toolset configurations from the passed toolsets configs, err := toolsets.ParseToolsetConfigFromSlice(passedToolsets) if err != nil { return nil, fmt.Errorf("failed to parse toolset configuration: %w", err) } - return InitToolsetsWithConfig(configs, readOnly, getClient, getGQLClient, t) + return InitToolsetsWithConfig(configs, readOnly, writePrivateOnly, getClient, getGQLClient, t) } -func InitToolsetsWithConfig(configs []toolsets.ToolsetConfig, readOnly bool, getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (*toolsets.ToolsetGroup, error) { +func InitToolsetsWithConfig(configs []toolsets.ToolsetConfig, readOnly bool, writePrivateOnly bool, getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (*toolsets.ToolsetGroup, error) { // Create a new toolset group tsg := toolsets.NewToolsetGroup(readOnly) + // Helper functions to conditionally wrap write tool handlers with guards. + // When writePrivateOnly=false, these are no-ops that return the tool and handler unchanged. + guardWrite := func(tool mcp.Tool, handler server.ToolHandlerFunc) (mcp.Tool, server.ToolHandlerFunc) { + if writePrivateOnly { + return WritePrivateOnlyGuard(getClient, tool, handler) + } + return tool, handler + } + guardCreate := func(tool mcp.Tool, handler server.ToolHandlerFunc) (mcp.Tool, server.ToolHandlerFunc) { + if writePrivateOnly { + return CreateRepositoryPrivateOnlyGuard(tool, handler) + } + return tool, handler + } + guardFork := func(tool mcp.Tool, handler server.ToolHandlerFunc) (mcp.Tool, server.ToolHandlerFunc) { + if writePrivateOnly { + return ForkRepositoryPrivateOnlyGuard(tool, handler) + } + return tool, handler + } + // Define all available features with their default state (disabled) // Create toolsets repos := toolsets.NewToolset("repos", "GitHub Repository related tools"). @@ -44,12 +66,12 @@ func InitToolsetsWithConfig(configs []toolsets.ToolsetConfig, readOnly bool, get toolsets.NewServerTool(GetTag(getClient, t)), ). AddWriteTools( - toolsets.NewServerTool(CreateOrUpdateFile(getClient, t)), - toolsets.NewServerTool(CreateRepository(getClient, t)), - toolsets.NewServerTool(ForkRepository(getClient, t)), - toolsets.NewServerTool(CreateBranch(getClient, t)), - toolsets.NewServerTool(PushFiles(getClient, t)), - toolsets.NewServerTool(DeleteFile(getClient, t)), + toolsets.NewServerTool(guardWrite(CreateOrUpdateFile(getClient, t))), + toolsets.NewServerTool(guardCreate(CreateRepository(getClient, t))), + toolsets.NewServerTool(guardFork(ForkRepository(getClient, t))), + toolsets.NewServerTool(guardWrite(CreateBranch(getClient, t))), + toolsets.NewServerTool(guardWrite(PushFiles(getClient, t))), + toolsets.NewServerTool(guardWrite(DeleteFile(getClient, t))), ) issues := toolsets.NewToolset("issues", "GitHub Issues related tools"). AddReadTools( @@ -59,9 +81,9 @@ func InitToolsetsWithConfig(configs []toolsets.ToolsetConfig, readOnly bool, get toolsets.NewServerTool(GetIssueComments(getClient, t)), ). AddWriteTools( - toolsets.NewServerTool(CreateIssue(getClient, t)), - toolsets.NewServerTool(AddIssueComment(getClient, t)), - toolsets.NewServerTool(UpdateIssue(getClient, t)), + toolsets.NewServerTool(guardWrite(CreateIssue(getClient, t))), + toolsets.NewServerTool(guardWrite(AddIssueComment(getClient, t))), + toolsets.NewServerTool(guardWrite(UpdateIssue(getClient, t))), ) users := toolsets.NewToolset("users", "GitHub User related tools"). AddReadTools( @@ -78,18 +100,18 @@ func InitToolsetsWithConfig(configs []toolsets.ToolsetConfig, readOnly bool, get toolsets.NewServerTool(GetPullRequestDiff(getClient, t)), ). AddWriteTools( - toolsets.NewServerTool(MergePullRequest(getClient, t)), - toolsets.NewServerTool(UpdatePullRequestBranch(getClient, t)), - toolsets.NewServerTool(CreatePullRequest(getClient, t)), - toolsets.NewServerTool(UpdatePullRequest(getClient, t)), - toolsets.NewServerTool(RequestCopilotReview(getClient, t)), + toolsets.NewServerTool(guardWrite(MergePullRequest(getClient, t))), + toolsets.NewServerTool(guardWrite(UpdatePullRequestBranch(getClient, t))), + toolsets.NewServerTool(guardWrite(CreatePullRequest(getClient, t))), + toolsets.NewServerTool(guardWrite(UpdatePullRequest(getClient, t))), + toolsets.NewServerTool(guardWrite(RequestCopilotReview(getClient, t))), // Reviews - toolsets.NewServerTool(CreateAndSubmitPullRequestReview(getGQLClient, t)), - toolsets.NewServerTool(CreatePendingPullRequestReview(getGQLClient, t)), - toolsets.NewServerTool(AddPullRequestReviewCommentToPendingReview(getGQLClient, t)), - toolsets.NewServerTool(SubmitPendingPullRequestReview(getGQLClient, t)), - toolsets.NewServerTool(DeletePendingPullRequestReview(getGQLClient, t)), + toolsets.NewServerTool(guardWrite(CreateAndSubmitPullRequestReview(getGQLClient, t))), + toolsets.NewServerTool(guardWrite(CreatePendingPullRequestReview(getGQLClient, t))), + toolsets.NewServerTool(guardWrite(AddPullRequestReviewCommentToPendingReview(getGQLClient, t))), + toolsets.NewServerTool(guardWrite(SubmitPendingPullRequestReview(getGQLClient, t))), + toolsets.NewServerTool(guardWrite(DeletePendingPullRequestReview(getGQLClient, t))), ) codeSecurity := toolsets.NewToolset("code_security", "Code security related tools, such as GitHub Code Scanning"). AddReadTools( diff --git a/pkg/github/tools_test.go b/pkg/github/tools_test.go index b8e33001a..fd70198b4 100644 --- a/pkg/github/tools_test.go +++ b/pkg/github/tools_test.go @@ -73,7 +73,7 @@ func TestInitToolsetsWithConfig(t *testing.T) { return nil, nil } - tsg, err := InitToolsetsWithConfig(tt.configs, tt.readOnly, getClient, getGQLClient, mockTranslator) + tsg, err := InitToolsetsWithConfig(tt.configs, tt.readOnly, false, getClient, getGQLClient, mockTranslator) if tt.wantErr { if err == nil { @@ -165,7 +165,7 @@ func TestInitToolsets_BackwardCompatibility(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tsg, err := InitToolsets(tt.passedToolsets, tt.readOnly, getClient, getGQLClient, mockTranslator) + tsg, err := InitToolsets(tt.passedToolsets, tt.readOnly, false, getClient, getGQLClient, mockTranslator) if tt.wantErr { if err == nil { @@ -255,7 +255,7 @@ func TestToolsetModeFiltering(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tsg, err := InitToolsetsWithConfig(tt.configs, false, getClient, getGQLClient, mockTranslator) + tsg, err := InitToolsetsWithConfig(tt.configs, false, false, getClient, getGQLClient, mockTranslator) if err != nil { t.Fatalf("InitToolsetsWithConfig() error: %v", err) } @@ -339,7 +339,7 @@ func TestContextToolsetIntegration(t *testing.T) { {Name: "repos", Mode: toolsets.ReadOnly}, } - tsg, err := InitToolsetsWithConfig(configs, false, getClient, getGQLClient, mockTranslator) + tsg, err := InitToolsetsWithConfig(configs, false, false, getClient, getGQLClient, mockTranslator) if err != nil { t.Fatalf("InitToolsetsWithConfig() error: %v", err) } @@ -379,3 +379,77 @@ func TestContextToolsetIntegration(t *testing.T) { t.Logf("Context toolset has %d active tools", len(activeTools)) } + +func TestWritePrivateOnlyGuardWiring(t *testing.T) { + // Verify that when writePrivateOnly=true, write tools are wrapped with guards + // and that when writePrivateOnly=false, they are not. + mockTranslator := func(key, fallback string) string { + return fallback + } + getClient := func(ctx context.Context, _ string) (*github.Client, error) { + return github.NewClient(nil), nil + } + getGQLClient := func(ctx context.Context, _ string) (*githubv4.Client, error) { + return nil, nil + } + + configs := []toolsets.ToolsetConfig{ + {Name: "all", Mode: toolsets.ReadWrite}, + } + + // Initialize with writePrivateOnly=false + tsgOff, err := InitToolsetsWithConfig(configs, false, false, getClient, getGQLClient, mockTranslator) + if err != nil { + t.Fatalf("InitToolsetsWithConfig(writePrivateOnly=false) error: %v", err) + } + + // Initialize with writePrivateOnly=true + tsgOn, err := InitToolsetsWithConfig(configs, false, true, getClient, getGQLClient, mockTranslator) + if err != nil { + t.Fatalf("InitToolsetsWithConfig(writePrivateOnly=true) error: %v", err) + } + + // Both should have the same number of toolsets + if len(tsgOff.Toolsets) != len(tsgOn.Toolsets) { + t.Errorf("Expected same number of toolsets, got %d vs %d", len(tsgOff.Toolsets), len(tsgOn.Toolsets)) + } + + // Both should have the same tools (write tools are still registered, just wrapped) + for name, tsOff := range tsgOff.Toolsets { + tsOn, exists := tsgOn.Toolsets[name] + if !exists { + t.Errorf("Toolset %s missing from writePrivateOnly=true", name) + continue + } + offTools := tsOff.GetActiveTools() + onTools := tsOn.GetActiveTools() + if len(offTools) != len(onTools) { + t.Errorf("Toolset %s: expected %d tools, got %d with writePrivateOnly=true", + name, len(offTools), len(onTools)) + } + } + + // Verify that fork_repository is blocked when writePrivateOnly=true + // by calling the handler directly + repoToolset := tsgOn.Toolsets["repos"] + if repoToolset == nil { + t.Fatal("repos toolset not found") + } + for _, tool := range repoToolset.GetActiveTools() { + if tool.Tool.Name == "fork_repository" { + result, err := tool.Handler(context.Background(), createMCPRequest(map[string]interface{}{ + "owner": "testowner", + "repo": "testrepo", + })) + if err != nil { + t.Fatalf("fork_repository handler returned error: %v", err) + } + // Should be blocked + if result == nil || !result.IsError { + t.Error("Expected fork_repository to be blocked when writePrivateOnly=true") + } + return + } + } + t.Error("fork_repository tool not found in repos toolset") +} diff --git a/pkg/github/write_guard.go b/pkg/github/write_guard.go new file mode 100644 index 000000000..e3422ebd0 --- /dev/null +++ b/pkg/github/write_guard.go @@ -0,0 +1,117 @@ +package github + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/sirupsen/logrus" +) + +// checkRepoVisibility calls Repositories.Get() to determine if a repository is private. +// Returns (true, nil) if private, (false, nil) if public, (false, err) on any API error. +// Callers must treat any error as fail-closed (block the write). +// +// Do NOT use determineRepoAccessType() or RepoAccessCache for this purpose. +// Those functions check authenticated access success, not actual repo visibility. +// A PAT with access to a public repo returns RepoAccessPrivate from that function, +// which would incorrectly allow writes to public repos. +func checkRepoVisibility(ctx context.Context, getClient GetClientFn, owner, repo string) (bool, error) { + client, err := getClient(ctx, owner) + if err != nil { + return false, fmt.Errorf("failed to get GitHub client: %w", err) + } + repoData, _, err := client.Repositories.Get(ctx, owner, repo) + if err != nil { + return false, fmt.Errorf("failed to get repository: %w", err) + } + return repoData.GetPrivate(), nil +} + +// WritePrivateOnlyGuard wraps a write tool handler to enforce the GITHUB_WRITE_PRIVATE_ONLY +// policy. It extracts owner and repo from the MCP request, checks repository visibility +// via Repositories.Get(), and blocks the write if the repository is public or if +// visibility cannot be confirmed. +// +// On any error from the visibility check, it fails closed (blocks the write). +// If the repo is confirmed private, it delegates to the next handler unchanged. +// +// This guard is for standard write tools with owner+repo params. +// Use CreateRepositoryPrivateOnlyGuard for create_repository. +// Use ForkRepositoryPrivateOnlyGuard for fork_repository. +func WritePrivateOnlyGuard(getClient GetClientFn, tool mcp.Tool, handler server.ToolHandlerFunc) (mcp.Tool, server.ToolHandlerFunc) { + return tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](req, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](req, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + isPrivate, err := checkRepoVisibility(ctx, getClient, owner, repo) + if err != nil { + logrus.Warnf("Write blocked: %s on %s/%s — visibility check failed: %v (GITHUB_WRITE_PRIVATE_ONLY)", + tool.Name, owner, repo, err) + return mcp.NewToolResultError( + "Write blocked: unable to verify repository visibility. " + + "Ensure your token has repo read access and try again.", + ), nil + } + if !isPrivate { + logrus.Warnf("Write blocked: %s on public repo %s/%s (GITHUB_WRITE_PRIVATE_ONLY)", + tool.Name, owner, repo) + return mcp.NewToolResultError(fmt.Sprintf( + "Write blocked: %s/%s is a public repository. "+ + "The server is configured with GITHUB_WRITE_PRIVATE_ONLY=true, which restricts "+ + "all write operations to private repositories only. "+ + "To proceed: use a private repository, or ask the administrator to unset GITHUB_WRITE_PRIVATE_ONLY.", + owner, repo, + )), nil + } + + return handler(ctx, req) + } +} + +// CreateRepositoryPrivateOnlyGuard wraps the create_repository handler to enforce +// GITHUB_WRITE_PRIVATE_ONLY. Performs a pre-flight parameter check (no API call) +// because there is no existing repository to query for visibility. +// +// If the private parameter is false or absent (defaults to false via OptionalParam[bool]), +// the call is blocked immediately. If private=true, the call proceeds to the wrapped handler. +// +// Never silently overrides private=false to private=true. +func CreateRepositoryPrivateOnlyGuard(tool mcp.Tool, handler server.ToolHandlerFunc) (mcp.Tool, server.ToolHandlerFunc) { + return tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + private, err := OptionalParam[bool](req, "private") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if !private { + return mcp.NewToolResultError( + "Write blocked: create_repository requires private=true when GITHUB_WRITE_PRIVATE_ONLY is set. " + + "Set private=true to create a private repository.", + ), nil + } + return handler(ctx, req) + } +} + +// ForkRepositoryPrivateOnlyGuard replaces the fork_repository handler entirely when +// GITHUB_WRITE_PRIVATE_ONLY is set. GitHub's CreateFork API has no visibility parameter — +// fork visibility is determined by source repo visibility and the user's GitHub plan. +// Since we cannot guarantee the fork will be private, we block entirely. +// +// The original handler is accepted as a parameter for signature consistency but is never called. +func ForkRepositoryPrivateOnlyGuard(tool mcp.Tool, _ server.ToolHandlerFunc) (mcp.Tool, server.ToolHandlerFunc) { + return tool, func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultError( + "Write blocked: fork_repository cannot guarantee the fork will be private when " + + "GITHUB_WRITE_PRIVATE_ONLY is set. GitHub's fork API does not expose a visibility parameter. " + + "To create a private fork, use the GitHub web UI or API directly.", + ), nil + } +} diff --git a/pkg/github/write_guard_test.go b/pkg/github/write_guard_test.go new file mode 100644 index 000000000..1e883e8b1 --- /dev/null +++ b/pkg/github/write_guard_test.go @@ -0,0 +1,310 @@ +package github + +import ( + "context" + "fmt" + "net/http" + "testing" + + "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// dummyWriteHandler is a no-op write handler used to verify the guard passes through correctly. +func dummyWriteHandler(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("write succeeded"), nil +} + +// dummyTool creates a minimal mcp.Tool for testing. +// Note: mcp.NewTool returns a single mcp.Tool value, NOT a tuple. +func dummyTool(name string) mcp.Tool { + return mcp.NewTool(name, mcp.WithDescription("test tool")) +} + +func Test_WritePrivateOnlyGuard(t *testing.T) { + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectBlocked bool + expectedErrMsg string + expectPassThru bool + }{ + { + name: "passes through when repo is private", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposByOwnerByRepo, + &github.Repository{Private: github.Ptr(true)}, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "myorg", + "repo": "private-repo", + }, + expectPassThru: true, + }, + { + name: "blocks when repo is public", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposByOwnerByRepo, + &github.Repository{Private: github.Ptr(false)}, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "myorg", + "repo": "public-repo", + }, + expectBlocked: true, + expectedErrMsg: "Write blocked: myorg/public-repo is a public repository", + }, + { + name: "blocks (fail-closed) when visibility check returns 404", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposByOwnerByRepo, + mockResponse(t, http.StatusNotFound, `{"message": "Not Found"}`), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "myorg", + "repo": "any-repo", + }, + expectBlocked: true, + expectedErrMsg: "Write blocked: unable to verify repository visibility", + }, + { + name: "blocks (fail-closed) when visibility check returns 403", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposByOwnerByRepo, + mockResponse(t, http.StatusForbidden, `{"message": "Forbidden"}`), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "myorg", + "repo": "any-repo", + }, + expectBlocked: true, + expectedErrMsg: "Write blocked: unable to verify repository visibility", + }, + { + name: "returns error when owner param is missing", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "repo": "some-repo", + // owner intentionally missing + }, + expectBlocked: true, + expectedErrMsg: "missing required parameter: owner", + }, + { + name: "returns error when repo param is missing", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]interface{}{ + "owner": "myorg", + // repo intentionally missing + }, + expectBlocked: true, + expectedErrMsg: "missing required parameter: repo", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + getClient := stubGetClientFn(client) + + tool := dummyTool("test_write_tool") + _, guardedHandler := WritePrivateOnlyGuard(getClient, tool, dummyWriteHandler) + + request := createMCPRequest(tc.requestArgs) + result, err := guardedHandler(context.Background(), request) + + require.NoError(t, err) // guard never returns a Go error; errors are tool results + require.NotNil(t, result) + + textContent := getTextResult(t, result) + + if tc.expectPassThru { + assert.Equal(t, "write succeeded", textContent.Text) + assert.False(t, result.IsError, "expected pass-through, got error result") + } else { + assert.True(t, result.IsError, "expected blocked result") + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + } + }) + } +} + +func Test_WritePrivateOnlyGuard_GetClientFailure(t *testing.T) { + // Test that the guard fails closed when getClient returns an error + getClient := func(_ context.Context, _ string) (*github.Client, error) { + return nil, fmt.Errorf("auth failure: no valid credentials") + } + + tool := dummyTool("test_write_tool") + _, guardedHandler := WritePrivateOnlyGuard(getClient, tool, dummyWriteHandler) + + request := createMCPRequest(map[string]interface{}{ + "owner": "myorg", + "repo": "some-repo", + }) + result, err := guardedHandler(context.Background(), request) + require.NoError(t, err) + assert.True(t, result.IsError, "expected blocked result when getClient fails") + + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, "unable to verify repository visibility") +} + +func Test_CreateRepositoryPrivateOnlyGuard(t *testing.T) { + tests := []struct { + name string + requestArgs map[string]interface{} + expectBlocked bool + expectedErrMsg string + }{ + { + name: "passes through when private=true", + requestArgs: map[string]interface{}{ + "name": "my-repo", + "private": true, + }, + expectBlocked: false, + }, + { + name: "blocks when private=false", + requestArgs: map[string]interface{}{ + "name": "my-repo", + "private": false, + }, + expectBlocked: true, + expectedErrMsg: "Write blocked: create_repository requires private=true", + }, + { + name: "blocks when private param is absent (defaults to false)", + requestArgs: map[string]interface{}{ + "name": "my-repo", + // private intentionally absent + }, + expectBlocked: true, + expectedErrMsg: "Write blocked: create_repository requires private=true", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tool := dummyTool("create_repository") + _, guardedHandler := CreateRepositoryPrivateOnlyGuard(tool, dummyWriteHandler) + + request := createMCPRequest(tc.requestArgs) + result, err := guardedHandler(context.Background(), request) + + require.NoError(t, err) + require.NotNil(t, result) + + textContent := getTextResult(t, result) + + if tc.expectBlocked { + assert.True(t, result.IsError) + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + } else { + assert.False(t, result.IsError) + assert.Equal(t, "write succeeded", textContent.Text) + } + }) + } +} + +func Test_ForkRepositoryPrivateOnlyGuard(t *testing.T) { + t.Run("always blocks regardless of params", func(t *testing.T) { + tool := dummyTool("fork_repository") + _, guardedHandler := ForkRepositoryPrivateOnlyGuard(tool, dummyWriteHandler) + + // Try with various param combinations — should always block + for _, args := range []map[string]interface{}{ + {"owner": "myorg", "repo": "some-repo"}, + {}, + {"owner": "myorg", "repo": "private-source"}, + } { + request := createMCPRequest(args) + result, err := guardedHandler(context.Background(), request) + + require.NoError(t, err) + require.NotNil(t, result) + assert.True(t, result.IsError) + + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, "Write blocked: fork_repository cannot guarantee") + } + }) +} + +func Test_checkRepoVisibility(t *testing.T) { + tests := []struct { + name string + mockedClient *http.Client + owner string + repo string + expectPrivate bool + expectError bool + }{ + { + name: "returns true for private repo", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposByOwnerByRepo, + &github.Repository{Private: github.Ptr(true)}, + ), + ), + owner: "myorg", + repo: "private-repo", + expectPrivate: true, + }, + { + name: "returns false for public repo", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposByOwnerByRepo, + &github.Repository{Private: github.Ptr(false)}, + ), + ), + owner: "myorg", + repo: "public-repo", + expectPrivate: false, + }, + { + name: "returns error on API failure", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposByOwnerByRepo, + mockResponse(t, http.StatusInternalServerError, `{"message": "Internal Server Error"}`), + ), + ), + owner: "myorg", + repo: "any-repo", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + getClient := stubGetClientFn(client) + + isPrivate, err := checkRepoVisibility(context.Background(), getClient, tc.owner, tc.repo) + + if tc.expectError { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tc.expectPrivate, isPrivate) + } + }) + } +}