diff --git a/cmd/root.go b/cmd/root.go index 19008e4..dfc116f 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -38,6 +38,7 @@ var ( streamFlag bool // Enable streaming output compactMode bool // Enable compact output mode scriptMCPConfig *config.Config // Used to override config in script mode + approveToolRun bool // Session management saveSessionPath string @@ -199,7 +200,6 @@ func InitConfig() { viper.Set("hooks", hooksConfig) } } - } // LoadConfigWithEnvSubstitution loads a config file with environment variable substitution @@ -284,6 +284,8 @@ func init() { BoolVar(&compactMode, "compact", false, "enable compact output mode without fancy styling") rootCmd.PersistentFlags(). BoolVar(&noHooks, "no-hooks", false, "disable all hooks execution") + rootCmd.PersistentFlags(). + BoolVar(&approveToolRun, "approve-tool-run", false, "enable requiring user approval for every tool call") // Session management flags rootCmd.PersistentFlags(). @@ -329,6 +331,7 @@ func init() { viper.BindPFlag("num-gpu-layers", rootCmd.PersistentFlags().Lookup("num-gpu-layers")) viper.BindPFlag("main-gpu", rootCmd.PersistentFlags().Lookup("main-gpu")) viper.BindPFlag("tls-skip-verify", rootCmd.PersistentFlags().Lookup("tls-skip-verify")) + viper.BindPFlag("approve-tool-run", rootCmd.PersistentFlags().Lookup("approve-tool-run")) // Defaults are already set in flag definitions, no need to duplicate in viper @@ -427,7 +430,8 @@ func runNormalMode(ctx context.Context) error { debugLogger = bufferedLogger } - mcpAgent, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{ModelConfig: modelConfig, + mcpAgent, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{ + ModelConfig: modelConfig, MCPConfig: mcpConfig, SystemPrompt: systemPrompt, MaxSteps: viper.GetInt("max-steps"), @@ -725,7 +729,8 @@ func runNormalMode(ctx context.Context) error { return fmt.Errorf("--quiet flag can only be used with --prompt/-p") } - return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager, hookExecutor) + approveToolRun := viper.GetBool("approve-tool-run") + return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager, hookExecutor, approveToolRun) } // AgenticLoopConfig configures the behavior of the unified agentic loop @@ -734,6 +739,7 @@ type AgenticLoopConfig struct { IsInteractive bool // true for interactive mode, false for non-interactive InitialPrompt string // initial prompt for non-interactive mode ContinueAfterRun bool // true to continue to interactive mode after initial run (--no-exit) + ApproveToolRun bool // only used in interactive mode // UI configuration Quiet bool // suppress all output except final response @@ -1083,7 +1089,27 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes currentSpinner.Start() } }, - streamingCallback, // Add streaming callback as the last parameter + // Add streaming callback handler + streamingCallback, + // Tool call approval handler - called before tool execution to get user approval + func(toolName, toolArgs string) (bool, error) { + if !config.IsInteractive || !config.ApproveToolRun { + return true, nil + } + if currentSpinner != nil { + currentSpinner.Stop() + currentSpinner = nil + } + allow, err := cli.GetToolApproval(toolName, toolArgs) + if err != nil { + return false, err + } + // Start spinner again for tool calls + currentSpinner = ui.NewSpinner("Thinking...") + currentSpinner.Start() + + return allow, nil + }, ) // Make sure spinner is stopped if still running @@ -1286,6 +1312,7 @@ func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.C IsInteractive: false, InitialPrompt: prompt, ContinueAfterRun: noExit, + ApproveToolRun: false, Quiet: quiet, ServerNames: serverNames, ToolNames: toolNames, @@ -1298,12 +1325,13 @@ func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.C } // runInteractiveMode handles the interactive mode execution -func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message, sessionManager *session.Manager, hookExecutor *hooks.Executor) error { +func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message, sessionManager *session.Manager, hookExecutor *hooks.Executor, approveToolRun bool) error { // Configure and run unified agentic loop config := AgenticLoopConfig{ IsInteractive: true, InitialPrompt: "", ContinueAfterRun: false, + ApproveToolRun: approveToolRun, Quiet: false, ServerNames: serverNames, ToolNames: toolNames, diff --git a/internal/agent/agent.go b/internal/agent/agent.go index e0d7672..8bf6c55 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -4,6 +4,9 @@ import ( "context" "encoding/json" "fmt" + "strings" + "time" + tea "github.com/charmbracelet/bubbletea" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" @@ -12,8 +15,6 @@ import ( "github.com/mark3labs/mcphost/internal/config" "github.com/mark3labs/mcphost/internal/models" "github.com/mark3labs/mcphost/internal/tools" - "strings" - "time" ) // AgentConfig is the config for agent. @@ -44,6 +45,9 @@ type StreamingResponseHandler func(content string) // ToolCallContentHandler is a function type for handling content that accompanies tool calls type ToolCallContentHandler func(content string) +// ToolApprovalHandler is a function type for handling user approval of tool calls +type ToolApprovalHandler func(toolName, toolArgs string) (bool, error) + // Agent is the agent with real-time tool call display. type Agent struct { toolManager *tools.MCPToolManager @@ -106,15 +110,15 @@ type GenerateWithLoopResult struct { // GenerateWithLoop processes messages with a custom loop that displays tool calls in real-time func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message, - onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler) (*GenerateWithLoopResult, error) { - - return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult, onResponse, onToolCallContent, nil) + onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onToolApproval ToolApprovalHandler, +) (*GenerateWithLoopResult, error) { + return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult, onResponse, onToolCallContent, nil, onToolApproval) } // GenerateWithLoopAndStreaming processes messages with a custom loop that displays tool calls in real-time and supports streaming callbacks func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*schema.Message, - onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onStreamingResponse StreamingResponseHandler) (*GenerateWithLoopResult, error) { - + onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onStreamingResponse StreamingResponseHandler, onToolApproval ToolApprovalHandler, +) (*GenerateWithLoopResult, error) { // Create a copy of messages to avoid modifying the original workingMessages := make([]*schema.Message, len(messages)) copy(workingMessages, messages) @@ -176,6 +180,19 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*sc // Handle tool calls for _, toolCall := range response.ToolCalls { + if onToolApproval != nil { + approved, err := onToolApproval(toolCall.Function.Name, toolCall.Function.Arguments) + if err != nil { + return nil, err + } + if !approved { + rejectedMsg := fmt.Sprintf("The user did not allow tool call %s. Reason: User cancelled.", toolCall.Function.Name) + toolMessage := schema.ToolMessage(rejectedMsg, toolCall.ID) + workingMessages = append(workingMessages, toolMessage) + continue + } + } + // Notify about tool call if onToolCall != nil { onToolCall(toolCall.Function.Name, toolCall.Function.Arguments) diff --git a/internal/config/config.go b/internal/config/config.go index b5f5062..4eb2241 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -154,6 +154,7 @@ type Config struct { Stream *bool `json:"stream,omitempty" yaml:"stream,omitempty"` Theme any `json:"theme" yaml:"theme"` MarkdownTheme any `json:"markdown-theme" yaml:"markdown-theme"` + ApproveToolRun bool `json:"approve-tool-run" yaml:"approve-tool-run"` // Model generation parameters MaxTokens int `json:"max-tokens,omitempty" yaml:"max-tokens,omitempty"` diff --git a/internal/ui/cli.go b/internal/ui/cli.go index 4eacd6b..e6309cd 100644 --- a/internal/ui/cli.go +++ b/internal/ui/cli.go @@ -13,9 +13,7 @@ import ( "golang.org/x/term" ) -var ( - promptStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12")) -) +var promptStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12")) // CLI handles the command line interface with improved message rendering type CLI struct { @@ -83,7 +81,6 @@ func (c *CLI) GetPrompt() (string, error) { // Run as a tea program p := tea.NewProgram(input) finalModel, err := p.Run() - if err != nil { return "", err } @@ -151,7 +148,6 @@ func (c *CLI) DisplayAssistantMessageWithModel(message, modelName string) error // DisplayToolCallMessage displays a tool call in progress func (c *CLI) DisplayToolCallMessage(toolName, toolArgs string) { - c.messageContainer.messages = nil // clear previous messages (they should have been printed already) c.lastStreamHeight = 0 // Reset last stream height for new prompt @@ -331,6 +327,20 @@ func (c *CLI) IsSlashCommand(input string) bool { return strings.HasPrefix(input, "/") } +func (c *CLI) GetToolApproval(toolName, toolArgs string) (bool, error) { + input := NewToolApprovalInput(toolName, toolArgs, c.width) + p := tea.NewProgram(input) + finalModel, err := p.Run() + if err != nil { + return false, err + } + + if finalInput, ok := finalModel.(*ToolApprovalInput); ok { + return finalInput.approved, nil + } + return false, fmt.Errorf("GetToolApproval: unexpected error type") +} + // SlashCommandResult represents the result of handling a slash command type SlashCommandResult struct { Handled bool @@ -377,7 +387,6 @@ func (c *CLI) ClearMessages() { // displayContainer renders and displays the message container func (c *CLI) displayContainer() { - // Add left padding to the entire container content := c.messageContainer.Render() diff --git a/internal/ui/tool_approval_input.go b/internal/ui/tool_approval_input.go new file mode 100644 index 0000000..01970cc --- /dev/null +++ b/internal/ui/tool_approval_input.go @@ -0,0 +1,135 @@ +package ui + +import ( + "fmt" + "strings" + + "github.com/charmbracelet/bubbles/textarea" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +type ToolApprovalInput struct { + textarea textarea.Model + toolName string + toolArgs string + width int + selected bool // true when "yes" is highlighted and false when "no" is + approved bool + done bool +} + +func NewToolApprovalInput(toolName, toolArgs string, width int) *ToolApprovalInput { + ta := textarea.New() + ta.Placeholder = "" + ta.ShowLineNumbers = false + ta.CharLimit = 1000 + ta.SetWidth(width - 8) // Account for container padding, border and internal padding + ta.SetHeight(4) // Default to 3 lines like huh + ta.Focus() + + // Style the textarea to match huh theme + ta.FocusedStyle.Base = lipgloss.NewStyle() + ta.FocusedStyle.Placeholder = lipgloss.NewStyle().Foreground(lipgloss.Color("240")) + ta.FocusedStyle.Text = lipgloss.NewStyle().Foreground(lipgloss.Color("252")) + ta.FocusedStyle.Prompt = lipgloss.NewStyle() + ta.FocusedStyle.CursorLine = lipgloss.NewStyle() + ta.Cursor.Style = lipgloss.NewStyle().Foreground(lipgloss.Color("39")) + + return &ToolApprovalInput{ + textarea: ta, + toolName: toolName, + toolArgs: toolArgs, + width: width, + selected: true, + } +} + +func (t *ToolApprovalInput) Init() tea.Cmd { + return textarea.Blink +} + +func (t *ToolApprovalInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyMsg: + switch msg.String() { + case "y", "Y": + t.approved = true + t.done = true + return t, tea.Quit + case "n", "N": + t.approved = false + t.done = true + return t, tea.Quit + case "left": + t.selected = true + return t, nil + case "right": + t.selected = false + return t, nil + case "enter": + t.approved = t.selected + t.done = true + return t, tea.Quit + case "esc", "ctrl+c": + t.approved = false + t.done = true + return t, tea.Quit + } + } + return t, nil +} + +func (t *ToolApprovalInput) View() string { + if t.done { + return "we are done" + } + // Add left padding to entire component (2 spaces like other UI elements) + containerStyle := lipgloss.NewStyle().PaddingLeft(2) + + // Title + titleStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("252")). + MarginBottom(1) + + // Input box with huh-like styling + inputBoxStyle := lipgloss.NewStyle(). + Border(lipgloss.ThickBorder()). + BorderLeft(true). + BorderRight(false). + BorderTop(false). + BorderBottom(false). + BorderForeground(lipgloss.Color("39")). + PaddingLeft(1). + Width(t.width - 2) // Account for container padding + + // Style for the currently selected/highlighted option + selectedStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("42")). // Bright green + Bold(true). + Underline(true) + + // Style for the unselected/unhighlighted option + unselectedStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("240")) // Dark gray + + // Build the view + var view strings.Builder + view.WriteString(titleStyle.Render("Allow tool execution")) + view.WriteString("\n") + details := fmt.Sprintf("Tool: %s\nArguments: %s\n\n", t.toolName, t.toolArgs) + view.WriteString(details) + view.WriteString("Allow tool execution: ") + + var yesText, noText string + if t.selected { + yesText = selectedStyle.Render("[y]es") + noText = unselectedStyle.Render("[n]o") + } else { + yesText = unselectedStyle.Render("[y]es") + noText = selectedStyle.Render("[n]o") + } + view.WriteString(yesText + "/" + noText + "\n") + + return containerStyle.Render(inputBoxStyle.Render(view.String())) +} diff --git a/sdk/mcphost.go b/sdk/mcphost.go index a984555..4344537 100644 --- a/sdk/mcphost.go +++ b/sdk/mcphost.go @@ -134,6 +134,7 @@ func (m *MCPHost) Prompt(ctx context.Context, message string) (string, error) { nil, // onToolResult nil, // onResponse nil, // onToolCallContent + nil, // onToolApproval ) if err != nil { return "", err @@ -171,6 +172,7 @@ func (m *MCPHost) PromptWithCallbacks( nil, // onResponse nil, // onToolCallContent onStreaming, + nil, // onToolApproval ) if err != nil { return "", err