diff --git a/.infer/config.yaml b/.infer/config.yaml index 76ba3611..719c9dd6 100644 --- a/.infer/config.yaml +++ b/.infer/config.yaml @@ -40,25 +40,22 @@ tools: commands: - ls - pwd - - echo + - tree - wc - sort - uniq + - head + - tail - task + - make + - find patterns: - - ^git branch( --show-current)?$ - - ^git checkout -b [a-zA-Z0-9/_-]+( [a-zA-Z0-9/_-]+)?$ - - ^git checkout [a-zA-Z0-9/_-]+ - - ^git add [a-zA-Z0-9/_.-]+ - - ^git diff+ - - ^git remote -v$ - ^git status$ - - ^git log --oneline -n [0-9]+$ - - ^git commit - - ^git push( --set-upstream)?( origin)? (feature|fix|bugfix|hotfix|chore|docs|test|refactor|build|ci|perf|style)/[a-zA-Z0-9/_.-]+$ - - ^git push( --set-upstream)?( origin)? develop$ - - ^git push( --set-upstream)?( origin)? staging$ - - ^git push( --set-upstream)?( origin)? release/[a-zA-Z0-9._-]+$ + - ^git branch( --show-current)?( -[alrvd])?$ + - ^git log + - ^git diff + - ^git remote( -v)?$ + - ^git show read: enabled: true require_approval: false @@ -240,7 +237,6 @@ chat: theme: tokyo-night a2a: enabled: true - agents: [] cache: enabled: true ttl: 300 diff --git a/AGENTS.md b/AGENTS.md index e7e1cd27..914ea87d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,183 +1,99 @@ + # AGENTS.md ## Project Overview -The Inference Gateway CLI is a powerful command-line interface for managing and interacting with the Inference Gateway. -It provides tools for configuration, monitoring, and management of inference services. The project is built in Go -and features an interactive chat interface, autonomous agent capabilities, and extensive tool integration for -AI-assisted development. +The Inference Gateway CLI is a command-line interface for interacting with the Inference Gateway. It is built in Go +and provides tools for managing and configuring inference services, monitoring agent status, and facilitating +development workflows. Key features include an interactive chat interface, extensible shortcut system, and support +for AI agents. ## Architecture & Structure -**Key Directories:** - -- `cmd/`: CLI command implementations -- `internal/`: Core application logic - - `app/`: Application layer - - `domain/`: Domain models and interfaces - - `handlers/`: Command and event handlers - - `services/`: Business logic services - - `shortcuts/`: Extensible shortcut system - - `ui/`: Terminal UI components -- `config/`: Configuration management -- `docs/`: Documentation -- `examples/`: Usage examples - -**Architectural Patterns:** - -- Clean Architecture with domain-driven design -- Command pattern for CLI operations -- Repository pattern for data access -- Service layer for business logic -- Dependency injection via container +- **`cmd/`**: Contains the implementation of CLI commands, such as `agent`, `chat`, `config`, and `status`. +- **`internal/`**: Houses the core application logic, including: + - **`app/`**: Application layer components. + - **`domain/`**: Core domain models, interfaces, and business logic. + - **`handlers/`**: Command and event handlers for processing user input and system events. + - **`services/`**: Business logic services that orchestrate operations. + - **`shortcuts/`**: Implements the extensible shortcut system for custom commands. + - **`ui/`**: Contains components for the terminal user interface. +- **`config/`**: Manages configuration loading and manipulation for agents and the gateway. +- **`docs/`**: Project documentation, including guides on agent configuration, A2A connections, and conversation management. +- **`examples/`**: Provides example usage of the CLI, including A2A agent setups and basic configurations. + +**Architectural Patterns**: + +- **Clean Architecture**: Emphasizes separation of concerns with layers like domain, application, and infrastructure. +- **Domain-Driven Design (DDD)**: Core business logic is modeled around the domain. +- **Command Pattern**: Used for structuring CLI commands. +- **Repository Pattern**: Abstracting data access for domain entities. +- **Dependency Injection**: Utilized for managing dependencies, likely via a container. ## Development Environment -**Setup Instructions:** - -- Go 1.25.0+ required -- Install dependencies: `task setup` -- Install pre-commit hooks: `task precommit:install` - -**Required Tools:** - -- Go 1.25.0+ -- golangci-lint -- Task (taskfile.dev) -- pre-commit -- Docker (for container builds) - -**Environment Variables:** - -- `GITHUB_TOKEN`: For GitHub API access -- `GOOGLE_SEARCH_API_KEY`: Optional Google search API -- `GOOGLE_SEARCH_ENGINE_ID`: Optional Google search engine ID -- `DUCKDUCKGO_SEARCH_API_KEY`: Optional DuckDuckGo API +- **Setup Instructions**: + - **Go Version**: Requires Go 1.25.0 or later. + - **Dependencies**: Install project dependencies using `task setup`. + - **Pre-commit Hooks**: Install pre-commit hooks with `task precommit:install`. +- **Required Tools**: + - Go (1.25.0+) + - `golangci-lint` (for linting) + - `Task` (taskfile.dev) + - `pre-commit` + - Docker (for building container images) ## Development Workflow -**Build Commands:** - -- `task build`: Build binary with version info -- `task install`: Install to GOPATH/bin -- `task release:build`: Build multi-platform release binaries - -**Testing Procedures:** - -- `task test`: Run all tests -- `task test:verbose`: Run tests with verbose output -- `task test:coverage`: Run tests with coverage -- `task vet`: Run go vet - -**Code Quality Tools:** - -- `task fmt`: Format Go code -- `task lint`: Run golangci-lint and markdownlint -- `task check`: Run all quality checks (fmt, vet, test) - -**Git Workflow:** - -- Main branch development -- Pre-commit hooks for code quality -- Conventional commits recommended -- GitHub Actions for CI/CD +- **Build Process**: Typically managed via `Task` tasks defined in `Taskfile.yml` (or similar). +- **Testing**: Unit tests are present (e.g., `*_test.go` files). Running tests is likely done via `task test` or + `go test ./...`. +- **Code Quality**: `golangci-lint` is used for static analysis and linting. Pre-commit hooks ensure code quality + before commits. +- **Git Workflow**: Standard Git workflows are expected. Pre-commit hooks likely enforce commit message formats + and code style. ## Key Commands -**Build:** `task build` -**Test:** `task test` -**Lint:** `task lint` -**Run:** `task run -- ` -**Format:** `task fmt` -**Clean:** `task clean` +- **Setup**: `task setup` +- **Install Pre-commit**: `task precommit:install` +- **Run Tests**: `task test` or `go test ./...` +- **Lint Code**: `task lint` or `golangci-lint run` +- **Build**: `task build` (specific target may vary) +- **Run CLI**: `go run ./cmd/cli` or executable from `bin/` after building. ## Testing Instructions -**How to Run Tests:** - -- `go test ./...`: Run all tests -- `go test -v ./...`: Verbose output -- `go test -cover ./...`: With coverage - -**Test Organization:** - -- Tests co-located with source files (`*_test.go`) -- Mock generation using counterfeiter -- Integration tests in separate packages - -**Coverage Requirements:** - -- No specific coverage threshold enforced -- Tests required for all new features -- Integration tests for critical paths +- **Running Tests**: Use `task test` or `go test ./...` to execute all unit tests. +- **Organization**: Tests are co-located with their respective source files (`*_test.go`). +- **Coverage**: Specific coverage requirements are not detailed here but are typically checked during CI. Use + `go test -cover ./...` for local coverage checks. ## Deployment & Release -**Deployment Processes:** - -- Multi-platform binary builds -- Docker container images -- GitHub Releases with signed artifacts - -**Release Procedures:** - -- Automated via GitHub Actions -- Version tagging with semantic versioning -- Cosign signatures for security - -**CI/CD Pipeline:** - -- GitHub Actions workflows for CI -- Automated testing and linting -- Multi-platform build validation +- **Builds**: Dockerfiles are present (e.g., in `examples/a2a/demo-site/`), indicating containerized deployments. +- **CI/CD**: Likely configured via GitHub Actions or similar, triggered by commits to the main branch and tag + creation for releases. +- **Release Procedures**: Involves building binaries, creating Docker images, and tagging releases. Specific commands + are likely defined in the `Taskfile.yml`. ## Project Conventions -**Coding Standards:** - -- Go standard formatting (`gofmt`) -- golangci-lint configuration -- Maximum cyclomatic complexity: 25 -- Maximum function length: 150 lines - -**Naming Conventions:** - -- Go idiomatic naming (camelCase for variables, PascalCase for exports) -- Clear, descriptive names -- Interface names end with "er" (e.g., `ToolService`) - -**File Organization:** - -- Domain-driven structure -- One type per file (with exceptions for small related types) -- Test files co-located with source - -**Commit Message Formats:** - -- Conventional commits preferred -- Descriptive commit messages -- Reference issues when applicable +- **Coding Standards**: Follows Go best practices. `golangci-lint` enforces specific rules. +- **Naming Conventions**: Standard Go naming conventions (CamelCase for exported, snake_case for unexported where appropriate). +- **File Organization**: Structured into `cmd`, `internal`, `config`, `docs`, and `examples` directories, with + subdirectories reflecting logical components. +- **Commit Messages**: Likely follow a convention enforced by pre-commit hooks (e.g., Conventional Commits). ## Important Files & Configurations -**Key Configuration Files:** - -- `go.mod`: Go module dependencies -- `Taskfile.yml`: Build and development tasks -- `.golangci.yml`: Linter configuration -- `.pre-commit-config.yaml`: Pre-commit hooks -- `.github/workflows/ci.yml`: CI pipeline - -**Critical Source Files:** - -- `cmd/root.go`: Main CLI entry point -- `internal/domain/interfaces.go`: Core domain interfaces -- `internal/services/agent.go`: Autonomous agent logic -- `internal/handlers/chat_handler.go`: Chat interface handling - -**Security Considerations:** - -- Path exclusions: `.infer/`, `.git/`, `*.env` -- Tool execution requires approval by default -- Sandboxed directory access -- Command whitelisting for Bash tool +- **`Taskfile.yml`**: Defines automation tasks for building, testing, linting, and setup. +- **`go.mod` / `go.sum`**: Go module dependency management files. +- **`.pre-commit-config.yaml`**: Configuration for pre-commit hooks, specifying linters and formatters. +- **`golangci.yml` (or similar)**: Configuration for `golangci-lint`. +- **`*.go` files**: Source code files. Key entry points are in `cmd/`. +- **`docs/`**: Markdown files containing project documentation. +- **`config/`**: Go files related to application configuration. +- **`internal/domain`**: Core business logic and interfaces. +- **`internal/shortcuts`**: Implementation of the shortcut system. +- **`Dockerfile` (in examples/)**: Example Dockerfiles for building container images. diff --git a/CLAUDE.md b/CLAUDE.md index c8082733..6d5c7d87 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -110,6 +110,7 @@ The chat interface uses an **event-driven pattern** with: - `ChatEvent`: Chat operations (ChatStart, ChatChunk, ChatComplete, ChatError, etc.) - `UIEvent`: UI updates (UpdateHistory, SetStatus, ShowError, etc.) - Tool execution events (ToolCallPreview, ToolExecutionStarted, ToolExecutionCompleted) + - Tool approval events (ToolApprovalRequestedEvent, ShowToolApprovalEvent, ToolApprovalResponseEvent) - A2A events (A2ATaskSubmitted, A2ATaskStatusUpdate, A2ATaskCompleted) ### Tool System Architecture @@ -132,6 +133,39 @@ Tools are implemented using the **Factory Pattern** and **Strategy Pattern**: - `github.go`: GitHub API integration - `a2a_*.go`: Agent-to-agent communication tools +### Tool Approval System + +The CLI implements a **user approval workflow** for sensitive tool operations to ensure safety: + +- **Configuration-Driven**: Each tool has a `require_approval` flag in its config + - Dangerous tools (Write, Edit, Delete, MultiEdit) require approval by default + - Safe tools (Read, Grep) do not require approval by default + - Can be overridden per-tool in `config.yaml` + +- **Approval Components**: + - `ApprovalComponent` (`internal/ui/components/approval_component.go`): Renders approval modal + - `DiffRenderer` (`internal/ui/components/diff_renderer.go`): Visualizes code changes for Edit/Write tools + - `ToolFormatterService`: Formats tool arguments for human-readable display + +- **Approval Events** (defined in `internal/domain/`): + - `ToolApprovalRequestedEvent`: Triggered when a tool requiring approval is called + - `ShowToolApprovalEvent`: UI event to display approval modal + - `ToolApprovalResponseEvent`: Captures user's approval/rejection decision + - `ToolApprovedEvent`/`ToolRejectedEvent`: Final approval outcome + +- **Approval Flow**: + 1. LLM requests a tool execution that requires approval + 2. System emits `ToolApprovalRequestedEvent` + 3. UI displays modal with tool details and diff visualization (for code changes) + 4. User approves (Enter/y) or rejects (Esc/n) the operation + 5. System proceeds with execution or cancels based on user decision + +- **UI Controls**: + - Navigation: ←/→ to select Approve/Reject + - Approve: Enter or 'y' key + - Reject: Esc or 'n' key + - Real-time diff preview for file modification tools + ### State Management The `StateManager` (`internal/services/state_manager.go`) centralizes application state using concurrent-safe patterns: diff --git a/CONFIG.md b/CONFIG.md index 0bf48704..471119ac 100644 --- a/CONFIG.md +++ b/CONFIG.md @@ -110,7 +110,7 @@ agent: verbose_tools: false max_turns: 50 max_tokens: 4096 - max_concurrent_tools: 5 + max_concurrent_tools: 5 optimization: enabled: false max_history: 10 @@ -150,12 +150,12 @@ tools: a2a: enabled: false # Enable/disable A2A functionality agents: [] # List of A2A agent endpoints - + # Agent card caching settings cache: enabled: true ttl: 300 # Cache TTL in seconds - + # Task monitoring configuration task: status_poll_seconds: 5 @@ -164,21 +164,21 @@ a2a: max_poll_interval_sec: 60 backoff_multiplier: 2.0 background_monitoring: true - + # Individual A2A tool settings tools: query_agent: enabled: true require_approval: false - + query_task: enabled: true require_approval: false - + submit_task: enabled: true require_approval: false - + download_artifacts: enabled: true download_dir: "/tmp/downloads" @@ -213,6 +213,104 @@ storage: db: 0 ``` +## Tool Approval System + +The CLI implements a **user approval workflow** for sensitive tool operations to enhance security and user control. + +### How It Works + +When the LLM attempts to execute a tool that requires approval: + +1. **Approval Request**: The system pauses execution and displays an approval modal +2. **Visual Preview**: For file modification tools (Write, Edit, MultiEdit), a diff + visualization shows exactly what will change +3. **User Decision**: You can approve (✓) or reject (✗) the operation +4. **Execution**: Only approved operations proceed; rejected operations are canceled + +### Configuration + +Each tool has a `require_approval` flag that can be configured: + +```yaml +tools: + # Dangerous operations require approval by default + write: + enabled: true + require_approval: true # User must approve before writing files + + edit: + enabled: true + require_approval: true # User must approve before editing files + + delete: + enabled: true + require_approval: true # User must approve before deleting files + + # Safe operations don't require approval by default + read: + enabled: true + require_approval: false # No approval needed to read files + + grep: + enabled: true + require_approval: false # No approval needed for searches + + bash: + enabled: true + require_approval: true # Approval required for command execution +``` + +### Default Approval Requirements + +**Tools requiring approval by default:** + +- `write` - Writing new files or overwriting existing ones +- `edit` - Modifying file contents +- `multiedit` - Making multiple file edits +- `delete` - Deleting files or directories +- `bash` - Executing shell commands + +**Tools NOT requiring approval by default:** + +- `read` - Reading file contents +- `grep` - Searching code +- `websearch` - Web searches +- `webfetch` - Fetching web content +- `github` - GitHub API operations +- `tree` - Displaying directory structure +- `todowrite` - Managing task lists + +### UI Controls + +When an approval modal is displayed: + +- **Navigate**: Use ←/→ arrow keys to select Approve or Reject +- **Approve**: Press Enter or 'y' to approve the operation +- **Reject**: Press Esc or 'n' to reject the operation +- **View Details**: The modal shows: + - Tool name being executed + - Tool arguments (formatted for readability) + - Diff preview (for file modification tools) + +### Environment Variable Override + +You can override approval requirements using environment variables: + +```bash +# Disable approval for Write tool +export INFER_TOOLS_WRITE_REQUIRE_APPROVAL=false + +# Enable approval for Read tool +export INFER_TOOLS_READ_REQUIRE_APPROVAL=true +``` + +### Security Best Practices + +1. **Keep approval enabled** for destructive tools (Write, Edit, Delete) in production +2. **Review diffs carefully** before approving file modifications +3. **Use project configs** to enforce approval requirements across team +4. **Disable approval only** in trusted, sandboxed environments + ## Command Usage ### Configuration Management diff --git a/README.md b/README.md index c7998557..63c49fcd 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,7 @@ and management of inference services. - **MultiEdit**: Make multiple edits to files in atomic operations - **TodoWrite**: Create and manage structured task lists - **A2A Tools**: Agent-to-agent communication for task delegation and coordination +- **Tool Approval System**: User approval workflow for sensitive operations with real-time diff visualization for file modifications ## Installation diff --git a/cmd/agent.go b/cmd/agent.go index bab210e6..f6700ede 100644 --- a/cmd/agent.go +++ b/cmd/agent.go @@ -70,6 +70,19 @@ func RunAgentCommand(cfg *config.Config, modelFlag, taskDescription string) erro _ = services.Shutdown(ctx) }() + gatewayManager := services.GetGatewayManager() + if gatewayManager != nil && !gatewayManager.IsRunning() { + return fmt.Errorf(`inference gateway is not running. Please ensure the gateway is started. + +Possible solutions: +1. Ensure Docker is running: docker ps +2. Manually start the gateway container: docker run -d --name inference-gateway -p 8080:8080 ghcr.io/inference-gateway/inference-gateway:latest +3. Or disable auto-run in config and start the gateway manually +4. Check gateway logs: docker logs inference-gateway + +For more information, visit: https://github.com/inference-gateway/inference-gateway`) + } + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(cfg.Gateway.Timeout)*time.Second) defer cancel() diff --git a/cmd/init.go b/cmd/init.go index 7b3d2f10..2f22cbcf 100644 --- a/cmd/init.go +++ b/cmd/init.go @@ -6,8 +6,10 @@ import ( "encoding/json" "fmt" "os" + "os/signal" "path/filepath" "sync" + "syscall" "time" uuid "github.com/google/uuid" @@ -30,6 +32,8 @@ This creates the .infer directory with configuration file and additional setup f Use --model / to enable AI project analysis and generate a comprehensive AGENTS.md file tailored to your specific project. +For larger projects or complex analysis, use --timeout to increase the analysis time limit. + This is the recommended command to start working with Inference Gateway CLI in a new project.`, RunE: func(cmd *cobra.Command, args []string) error { return initializeProject(cmd) @@ -40,6 +44,7 @@ func init() { initCmd.Flags().Bool("overwrite", false, "Overwrite existing files if they already exist") initCmd.Flags().Bool("userspace", false, "Initialize configuration in user home directory (~/.infer/)") initCmd.Flags().String("model", "", "LLM model to use for AI project analysis and AGENTS.md generation (recommended)") + initCmd.Flags().Int("timeout", 60, "Timeout in seconds for project analysis (default: 60)") rootCmd.AddCommand(initCmd) } @@ -47,6 +52,7 @@ func initializeProject(cmd *cobra.Command) error { overwrite, _ := cmd.Flags().GetBool("overwrite") userspace, _ := cmd.Flags().GetBool("userspace") model, _ := cmd.Flags().GetString("model") + timeout, _ := cmd.Flags().GetInt("timeout") var configPath, gitignorePath, agentsMDPath string @@ -87,7 +93,7 @@ bin/ } if model != "" { - if err := generateAgentsMD(agentsMDPath, userspace, model); err != nil { + if err := generateAgentsMD(agentsMDPath, userspace, model, timeout, overwrite); err != nil { return fmt.Errorf("failed to create AGENTS.md file: %w", err) } } @@ -146,13 +152,13 @@ func writeConfigAsYAMLWithIndent(filename string, indent int) error { } // generateAgentsMD creates an AGENTS.md file based on project analysis -func generateAgentsMD(agentsMDPath string, userspace bool, model string) error { +func generateAgentsMD(agentsMDPath string, userspace bool, model string, timeout int, overwrite bool) error { wd, err := os.Getwd() if err != nil { return fmt.Errorf("failed to get working directory: %w", err) } - err = analyzeProjectForAgents(wd, userspace, model, agentsMDPath) + err = analyzeProjectForAgents(wd, userspace, model, agentsMDPath, timeout, overwrite) if err != nil { return fmt.Errorf("failed to analyze project: %w", err) } @@ -220,13 +226,13 @@ Essential commands developers use regularly: Key files that agents should be aware of and their purposes. RESEARCH APPROACH: -1. Start by reading package.json, go.mod, Cargo.toml, or similar dependency files -2. Look for README files, documentation, and setup guides -3. Examine build scripts, Makefiles, or task runners (package.json scripts, Taskfile.yml, etc.) +1. **MANDATORY FIRST STEP**: You MUST run the Tree tool as your very first action to understand the project structure. This is required before any other tool calls. +2. After running Tree, look for README files, documentation, and setup guides +3. Examine build scripts - Makefile, or Taskfile.yml 4. Check for configuration files (.gitignore, .env examples, config files) 5. Identify testing frameworks and CI/CD configurations 6. Look for code quality tools configurations -7. Examine directory structure and common patterns +7. Use the information from Tree to guide your exploration strategy IMPORTANT GUIDELINES: - Be concise but comprehensive - agents need actionable information @@ -238,22 +244,25 @@ IMPORTANT GUIDELINES: TOOL USAGE: - Use available tools to explore the project (Tree, Read, Grep, etc.) -- PARALLEL EXECUTION: You can call multiple tools simultaneously in a single response to improve efficiency. The system supports up to 3 concurrent tool executions using a semaphore-based approach. Use this to reduce back-and-forth communication by batching related operations. +- READ IN CHUNKS: Always read files in chunks of 50 lines using the Read tool's limit and offset parameters (e.g., Read(file_path="README.md", limit=50, offset=0) for first chunk, then Read(file_path="README.md", limit=50, offset=50) for next chunk). This significantly reduces token usage. +- PARALLEL EXECUTION: You can call multiple tools simultaneously in a single response to improve efficiency. The system supports up to 10 concurrent tool executions using a semaphore-based approach. Use this to reduce back-and-forth communication by batching related operations. - When you have gathered enough information, use the Write tool to create the AGENTS.md file - Write the file content directly without code fences or API calls - The Write tool expects: Write(file_path="/path/to/file", content="file content here") EFFICIENCY TIPS: -- Batch related file reads (e.g., read package.json, Taskfile.yml, and README.md in parallel) +- ALWAYS read files in 50-line chunks rather than reading entire files at once to minimize token usage +- Batch related file reads (e.g., read first 50 lines of build configuration, task files, and documentation in parallel) - Execute multiple grep searches simultaneously for different patterns - Combine directory exploration with file reading in the same response - Use parallel execution to gather comprehensive project information quickly +- Only read additional chunks if the first 50 lines don't provide enough context Your analysis should help other agents quickly understand how to work with this project effectively.` } // analyzeProjectForAgents analyzes the current project and generates AGENTS.md content -func analyzeProjectForAgents(projectDir string, userspace bool, model string, agentsMDPath string) error { +func analyzeProjectForAgents(projectDir string, userspace bool, model string, agentsMDPath string, timeoutSeconds int, overwrite bool) error { if userspace || model == "" { return fmt.Errorf("model is required for AGENTS.md generation") } @@ -276,6 +285,51 @@ func analyzeProjectForAgents(projectDir string, userspace bool, model string, ag services := container.NewServiceContainer(&cfgCopy, V) + gatewayManager := services.GetGatewayManager() + if gatewayManager != nil && !gatewayManager.IsRunning() { + return fmt.Errorf(`inference gateway is not running. Please ensure the gateway is started before running init with --model. + +Possible solutions: +1. Ensure Docker is running: docker ps +2. Manually start the gateway container: docker run -d --name inference-gateway -p 8080:8080 ghcr.io/inference-gateway/inference-gateway:latest +3. Or disable auto-run in config and start the gateway manually +4. Check gateway logs: docker logs inference-gateway + +For more information, visit: https://github.com/inference-gateway/inference-gateway`) + } + + cleanupDone := make(chan struct{}) + cleanup := func() { + select { + case <-cleanupDone: + return + default: + close(cleanupDone) + fmt.Println("Shutting down containers...") + logger.Info("Shutting down containers...") + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + if err := services.Shutdown(ctx); err != nil { + fmt.Fprintf(os.Stderr, "Failed to shutdown containers: %v\n", err) + logger.Error("Failed to shutdown containers", "error", err) + } else { + fmt.Println("Containers stopped successfully") + logger.Info("Containers stopped successfully") + } + } + } + defer cleanup() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + go func() { + <-sigChan + fmt.Println("\nReceived interrupt signal, cleaning up...") + cleanup() + fmt.Println("Cleanup complete, exiting...") + os.Exit(130) + }() + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(cfgCopy.Gateway.Timeout)*time.Second) defer cancel() @@ -306,15 +360,22 @@ func analyzeProjectForAgents(projectDir string, userspace bool, model string, ag conversation: []InitConversationMessage{}, maxTurns: cfgCopy.Agent.MaxTurns, startTime: time.Now(), - timeoutSeconds: 30, + timeoutSeconds: timeoutSeconds, agentsMDPath: agentsMDPath, - toolSemaphore: make(chan struct{}, 3), + toolSemaphore: make(chan struct{}, 10), + overwrite: overwrite, } _, _ = fmt.Fprintf(os.Stdout, "%s\n", `{"content":"Initializing project analysis session...","timestamp":"`+time.Now().Format("15:04:05")+`","elapsed":"0.0s","tokens":{"input":0,"output":0,"total":0}}`) _ = os.Stdout.Sync() - err = session.analyze(fmt.Sprintf("Please analyze the project in directory '%s' and generate a comprehensive AGENTS.md file. Use your available tools to examine the project structure, configuration files, documentation, build systems, and development workflow. Focus on creating actionable documentation that will help other AI agents understand how to work effectively with this project. Write the AGENTS.md file to: %s", projectDir, agentsMDPath)) + initialPrompt := fmt.Sprintf("Please analyze the project in directory '%s' and generate a comprehensive AGENTS.md file. Use your available tools to examine the project structure, configuration files, documentation, build systems, and development workflow. Focus on creating actionable documentation that will help other AI agents understand how to work effectively with this project. Write the AGENTS.md file to: %s", projectDir, agentsMDPath) + + if overwrite { + initialPrompt += fmt.Sprintf("\n\nIMPORTANT: The user has explicitly requested to overwrite the existing AGENTS.md file using the --overwrite flag. First, use the Read tool with limit=50 to read only the first 50 lines of the existing file at '%s' to get a quick overview of what's already documented. Then, use the Write tool to replace it with your updated and improved analysis. You MUST write the file even though it already exists - this is intentional.", agentsMDPath) + } + + err = session.analyze(initialPrompt) if err != nil { _, _ = fmt.Fprintf(os.Stdout, "%s\n", `{"role":"error","content":"Analysis failed","timestamp":"`+time.Now().Format("15:04:05")+`","elapsed":"0.0s","tokens":{"input":0,"output":0,"total":0}}`) _ = os.Stdout.Sync() @@ -376,6 +437,7 @@ type ProjectAnalysisSession struct { agentsMDPath string toolSemaphore chan struct{} conversationMutex sync.Mutex + overwrite bool } func (s *ProjectAnalysisSession) analyze(taskDescription string) error { diff --git a/cmd/init_test.go b/cmd/init_test.go index b72997d4..85bcf48a 100644 --- a/cmd/init_test.go +++ b/cmd/init_test.go @@ -159,7 +159,7 @@ func TestGenerateAgentsMD(t *testing.T) { agentsMDPath := filepath.Join(tmpDir, "AGENTS.md") - err = generateAgentsMD(agentsMDPath, tt.userspace, tt.model) + err = generateAgentsMD(agentsMDPath, tt.userspace, tt.model, 60, false) if (err != nil) != tt.wantErr { t.Errorf("generateAgentsMD() error = %v, wantErr %v", err, tt.wantErr) diff --git a/config/config.go b/config/config.go index 183893a1..ed1a5af1 100644 --- a/config/config.go +++ b/config/config.go @@ -254,7 +254,7 @@ type GitConfig struct { // A2AConfig contains A2A agent configuration type A2AConfig struct { Enabled bool `yaml:"enabled" mapstructure:"enabled"` - Agents []string `yaml:"agents" mapstructure:"agents"` + Agents []string `yaml:"agents,omitempty" mapstructure:"agents"` Cache A2ACacheConfig `yaml:"cache" mapstructure:"cache"` Task A2ATaskConfig `yaml:"task" mapstructure:"task"` Tools A2AToolsConfig `yaml:"tools" mapstructure:"tools"` @@ -417,24 +417,17 @@ func DefaultConfig() *Config { //nolint:funlen Enabled: true, Whitelist: ToolWhitelistConfig{ Commands: []string{ - "ls", "pwd", "echo", - "wc", "sort", "uniq", - "task", + "ls", "pwd", "tree", + "wc", "sort", "uniq", "head", "tail", + "task", "make", "find", }, Patterns: []string{ - "^git branch( --show-current)?$", - "^git checkout -b [a-zA-Z0-9/_-]+( [a-zA-Z0-9/_-]+)?$", - "^git checkout [a-zA-Z0-9/_-]+", - "^git add [a-zA-Z0-9/_.-]+", - "^git diff+", - "^git remote -v$", "^git status$", - "^git log --oneline -n [0-9]+$", - "^git commit", - "^git push( --set-upstream)?( origin)? (feature|fix|bugfix|hotfix|chore|docs|test|refactor|build|ci|perf|style)/[a-zA-Z0-9/_.-]+$", - "^git push( --set-upstream)?( origin)? develop$", - "^git push( --set-upstream)?( origin)? staging$", - "^git push( --set-upstream)?( origin)? release/[a-zA-Z0-9._-]+$", + "^git branch( --show-current)?( -[alrvd])?$", + "^git log", + "^git diff", + "^git remote( -v)?$", + "^git show", }, }, }, @@ -796,6 +789,26 @@ func (c *Config) GetTheme() string { return c.Chat.Theme } +// IsBashCommandWhitelisted checks if a specific bash command is whitelisted +func (c *Config) IsBashCommandWhitelisted(command string) bool { + command = strings.TrimSpace(command) + + for _, allowed := range c.Tools.Bash.Whitelist.Commands { + if command == allowed || strings.HasPrefix(command, allowed+" ") { + return true + } + } + + for _, pattern := range c.Tools.Bash.Whitelist.Patterns { + matched, err := regexp.MatchString(pattern, command) + if err == nil && matched { + return true + } + } + + return false +} + // ValidatePathInSandbox checks if a path is within the configured sandbox directories func (c *Config) ValidatePathInSandbox(path string) error { if err := c.checkProtectedPaths(path); err != nil { diff --git a/internal/app/chat.go b/internal/app/chat.go index a1c95f1b..d65b034b 100644 --- a/internal/app/chat.go +++ b/internal/app/chat.go @@ -19,6 +19,7 @@ import ( components "github.com/inference-gateway/cli/internal/ui/components" keybinding "github.com/inference-gateway/cli/internal/ui/keybinding" shared "github.com/inference-gateway/cli/internal/ui/shared" + styles "github.com/inference-gateway/cli/internal/ui/styles" ) // ChatApplication represents the main application model using state management @@ -54,6 +55,7 @@ type ChatApplication struct { a2aServersView *components.A2AServersView taskManager *components.TaskManagerImpl toolCallRenderer *components.ToolCallRenderer + approvalComponent *components.ApprovalComponent // Presentation layer applicationViewRenderer *components.ApplicationViewRenderer @@ -117,9 +119,15 @@ func NewChatApplication( logger.Error("Failed to transition to initial view", "error", err) } - app.toolCallRenderer = components.NewToolCallRenderer() + styleProvider := styles.NewProvider(app.themeService) + + app.toolCallRenderer = components.NewToolCallRenderer(styleProvider) + app.approvalComponent = components.NewApprovalComponent(styleProvider) app.conversationView = ui.CreateConversationView(app.themeService) toolFormatterService := services.NewToolFormatterService(app.toolRegistry) + + app.approvalComponent.SetToolFormatter(toolFormatterService) + if cv, ok := app.conversationView.(*components.ConversationView); ok { cv.SetToolFormatter(toolFormatterService) cv.SetConfigPath(configPath) @@ -137,23 +145,23 @@ func NewChatApplication( } app.statusView = ui.CreateStatusView(app.themeService) app.helpBar = ui.CreateHelpBar(app.themeService) - app.queueBoxView = components.NewQueueBoxView(app.themeService) + app.queueBoxView = components.NewQueueBoxView(styleProvider) - app.fileSelectionView = components.NewFileSelectionView(app.themeService) - app.textSelectionView = components.NewTextSelectionView() + app.fileSelectionView = components.NewFileSelectionView(styleProvider) + app.textSelectionView = components.NewTextSelectionView(styleProvider) - app.applicationViewRenderer = components.NewApplicationViewRenderer(app.themeService) - app.fileSelectionHandler = components.NewFileSelectionHandler(app.themeService) + app.applicationViewRenderer = components.NewApplicationViewRenderer(styleProvider) + app.fileSelectionHandler = components.NewFileSelectionHandler(styleProvider) app.keyBindingManager = keybinding.NewKeyBindingManager(app) app.updateHelpBarShortcuts() - app.modelSelector = components.NewModelSelector(models, app.modelService, app.themeService) - app.themeSelector = components.NewThemeSelector(app.themeService) + app.modelSelector = components.NewModelSelector(models, app.modelService, styleProvider) + app.themeSelector = components.NewThemeSelector(app.themeService, styleProvider) if persistentRepo, ok := app.conversationRepo.(*services.PersistentConversationRepository); ok { adapter := adapters.NewPersistentConversationAdapter(persistentRepo) - app.conversationSelector = components.NewConversationSelector(adapter, app.themeService) + app.conversationSelector = components.NewConversationSelector(adapter, styleProvider) } else { app.conversationSelector = nil } @@ -290,6 +298,8 @@ func (app *ChatApplication) handleViewSpecificMessages(msg tea.Msg) []tea.Cmd { return app.handleA2AServersView(msg) case domain.ViewStateA2ATaskManagement: return app.handleA2ATaskManagementView(msg) + case domain.ViewStateToolApproval: + return app.handleToolApprovalView(msg) default: return nil } @@ -341,6 +351,31 @@ func (app *ChatApplication) handleChatViewKeyPress(keyMsg tea.KeyMsg) []tea.Cmd return cmds } +func (app *ChatApplication) handleToolApprovalView(msg tea.Msg) []tea.Cmd { + var cmds []tea.Cmd + + if keyMsg, ok := msg.(tea.KeyMsg); ok { + if cmd := app.keyBindingManager.ProcessKey(keyMsg); cmd != nil { + cmds = append(cmds, cmd) + } + } + + if approvalEvent, ok := msg.(domain.ToolApprovalResponseEvent); ok { + approvalState := app.stateManager.GetApprovalUIState() + if approvalState != nil && approvalState.ResponseChan != nil { + approvalState.ResponseChan <- approvalEvent.Action + + if err := app.stateManager.TransitionToView(domain.ViewStateChat); err != nil { + logger.Error("Failed to transition back to chat view", "error", err) + } + + app.stateManager.ClearApprovalUIState() + } + } + + return cmds +} + func (app *ChatApplication) handleFileSelectionView(msg tea.Msg) []tea.Cmd { var cmds []tea.Cmd @@ -403,6 +438,8 @@ func (app *ChatApplication) View() string { return app.renderA2AServers() case domain.ViewStateA2ATaskManagement: return app.renderA2ATaskManagement() + case domain.ViewStateToolApproval: + return app.renderToolApproval() default: return fmt.Sprintf("Unknown view state: %v", currentView) } @@ -579,7 +616,8 @@ func (app *ChatApplication) handleA2AServersView(msg tea.Msg) []tea.Cmd { a2aAgentService = a2a.GetA2AAgentService() } } - app.a2aServersView = components.NewA2AServersView(app.configService, a2aAgentService, app.themeService) + styleProvider := styles.NewProvider(app.themeService) + app.a2aServersView = components.NewA2AServersView(app.configService, a2aAgentService, styleProvider) ctx := context.Background() if cmd := app.a2aServersView.LoadServers(ctx); cmd != nil { @@ -660,7 +698,8 @@ func (app *ChatApplication) updateAllComponentsWithNewTheme() { inputView.SetThemeService(app.themeService) } - app.modelSelector = components.NewModelSelector(app.availableModels, app.modelService, app.themeService) + styleProvider := styles.NewProvider(app.themeService) + app.modelSelector = components.NewModelSelector(app.availableModels, app.modelService, styleProvider) } func (app *ChatApplication) renderThemeSelection() string { @@ -678,7 +717,8 @@ func (app *ChatApplication) renderA2AServers() string { a2aAgentService = a2a.GetA2AAgentService() } } - app.a2aServersView = components.NewA2AServersView(app.configService, a2aAgentService, app.themeService) + styleProvider := styles.NewProvider(app.themeService) + app.a2aServersView = components.NewA2AServersView(app.configService, a2aAgentService, styleProvider) } width, height := app.stateManager.GetDimensions() @@ -709,6 +749,19 @@ func (app *ChatApplication) renderA2ATaskManagement() string { return app.taskManager.View() } +func (app *ChatApplication) renderToolApproval() string { + approvalState := app.stateManager.GetApprovalUIState() + if approvalState == nil { + return "No pending tool approval" + } + + width, height := app.stateManager.GetDimensions() + app.approvalComponent.SetDimensions(width, height) + + theme := app.themeService.GetCurrentTheme() + return app.approvalComponent.Render(approvalState, theme) +} + func (app *ChatApplication) renderChatInterface() string { app.inputView.SetTextSelectionMode(false) diff --git a/internal/container/container.go b/internal/container/container.go index d7019ad3..ba9e19f3 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -95,9 +95,13 @@ func (c *ServiceContainer) initializeGatewayManager() { c.gatewayManager = services.NewGatewayManager(c.config) if c.config.Gateway.Run { - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + if err := c.gatewayManager.Start(ctx); err != nil { - fmt.Printf("Failed to start gateway: %v\n", err) + fmt.Printf("\n⚠️ Failed to start gateway automatically: %v\n", err) + fmt.Printf(" Continuing without local gateway.\n") + fmt.Printf(" Make sure the inference gateway is running at: %s\n\n", c.config.Gateway.URL) logger.Error("Failed to start gateway", "error", err) logger.Warn("Continuing without local gateway - make sure gateway is running manually") } @@ -462,6 +466,11 @@ func (c *ServiceContainer) GetStorage() storage.ConversationStorage { return c.storage } +// GetGatewayManager returns the gateway manager +func (c *ServiceContainer) GetGatewayManager() domain.GatewayManager { + return c.gatewayManager +} + // Shutdown gracefully shuts down the service container and its resources func (c *ServiceContainer) Shutdown(ctx context.Context) error { if c.agentManager != nil && c.agentManager.IsRunning() { diff --git a/internal/domain/agent.go b/internal/domain/agent.go index 04a20938..2e9dfd66 100644 --- a/internal/domain/agent.go +++ b/internal/domain/agent.go @@ -19,9 +19,10 @@ type SDKClient interface { // AgentRequest represents a request to the agent service type AgentRequest struct { - RequestID string `json:"request_id"` - Model string `json:"model"` - Messages []sdk.Message `json:"messages"` + RequestID string `json:"request_id"` + Model string `json:"model"` + Messages []sdk.Message `json:"messages"` + IsChatMode bool `json:"is_chat_mode"` } // AgentService handles agent operations with both sync and streaming modes diff --git a/internal/domain/config_service.go b/internal/domain/config_service.go index 77312bf0..1c0f4d76 100644 --- a/internal/domain/config_service.go +++ b/internal/domain/config_service.go @@ -6,6 +6,7 @@ import "github.com/inference-gateway/cli/config" type ConfigService interface { // Tool approval configuration IsApprovalRequired(toolName string) bool + IsBashCommandWhitelisted(command string) bool // Debug and output configuration GetOutputDirectory() string diff --git a/internal/domain/context.go b/internal/domain/context.go new file mode 100644 index 00000000..094dcf9d --- /dev/null +++ b/internal/domain/context.go @@ -0,0 +1,9 @@ +package domain + +// ContextKey is the type used for context keys in the application +type ContextKey string + +// ToolApprovedKey is the context key for user-approved tool executions +// When this key is set to true in the context, it indicates that the tool +// execution was explicitly approved by the user and should bypass whitelist validation +const ToolApprovedKey ContextKey = "tool_approved" diff --git a/internal/domain/events.go b/internal/domain/events.go index 20593659..282aac5b 100644 --- a/internal/domain/events.go +++ b/internal/domain/events.go @@ -204,3 +204,34 @@ type MessageQueuedEvent struct { func (e MessageQueuedEvent) GetRequestID() string { return e.RequestID } func (e MessageQueuedEvent) GetTimestamp() time.Time { return e.Timestamp } + +// ToolApprovalRequestedEvent indicates a tool requires user approval before execution +type ToolApprovalRequestedEvent struct { + RequestID string + Timestamp time.Time + ToolCall sdk.ChatCompletionMessageToolCall + ResponseChan chan ApprovalAction +} + +func (e ToolApprovalRequestedEvent) GetRequestID() string { return e.RequestID } +func (e ToolApprovalRequestedEvent) GetTimestamp() time.Time { return e.Timestamp } + +// ToolApprovedEvent indicates the user approved the tool execution +type ToolApprovedEvent struct { + RequestID string + Timestamp time.Time + ToolCall sdk.ChatCompletionMessageToolCall +} + +func (e ToolApprovedEvent) GetRequestID() string { return e.RequestID } +func (e ToolApprovedEvent) GetTimestamp() time.Time { return e.Timestamp } + +// ToolRejectedEvent indicates the user rejected the tool execution +type ToolRejectedEvent struct { + RequestID string + Timestamp time.Time + ToolCall sdk.ChatCompletionMessageToolCall +} + +func (e ToolRejectedEvent) GetRequestID() string { return e.RequestID } +func (e ToolRejectedEvent) GetTimestamp() time.Time { return e.Timestamp } diff --git a/internal/domain/interfaces.go b/internal/domain/interfaces.go index 92dcfb63..98f28ba8 100644 --- a/internal/domain/interfaces.go +++ b/internal/domain/interfaces.go @@ -155,6 +155,12 @@ type StateManager interface { SetFileSelectedIndex(index int) ClearFileSelectionState() + // Approval management + SetupApprovalUIState(toolCall *sdk.ChatCompletionMessageToolCall, responseChan chan ApprovalAction) + GetApprovalUIState() *ApprovalUIState + SetApprovalSelectedIndex(index int) + ClearApprovalUIState() + // Message queue management (DEPRECATED - use MessageQueue service instead) AddQueuedMessage(message Message, requestID string) PopQueuedMessage() *QueuedMessage @@ -350,6 +356,7 @@ type Theme interface { GetUserColor() string GetAssistantColor() string GetErrorColor() string + GetSuccessColor() string GetStatusColor() string GetAccentColor() string GetDimColor() string diff --git a/internal/domain/state.go b/internal/domain/state.go index 63e12c6d..63855e9d 100644 --- a/internal/domain/state.go +++ b/internal/domain/state.go @@ -28,6 +28,7 @@ type ApplicationState struct { // UI State fileSelectionState *FileSelectionState + approvalUIState *ApprovalUIState // Debugging debugMode bool @@ -45,6 +46,7 @@ const ( ViewStateThemeSelection ViewStateA2AServers ViewStateA2ATaskManagement + ViewStateToolApproval ) func (v ViewState) String() string { @@ -65,6 +67,8 @@ func (v ViewState) String() string { return "A2AServers" case ViewStateA2ATaskManagement: return "A2ATaskManagement" + case ViewStateToolApproval: + return "ToolApproval" default: return "Unknown" } @@ -155,6 +159,7 @@ type ToolCallStatus int const ( ToolCallStatusPending ToolCallStatus = iota + ToolCallStatusWaitingApproval ToolCallStatusExecuting ToolCallStatusCompleted ToolCallStatusFailed @@ -166,6 +171,8 @@ func (t ToolCallStatus) String() string { switch t { case ToolCallStatusPending: return "Pending" + case ToolCallStatusWaitingApproval: + return "WaitingApproval" case ToolCallStatusExecuting: return "Executing" case ToolCallStatusCompleted: @@ -209,6 +216,25 @@ func (t ToolExecutionStatus) String() string { } } +// ApprovalAction represents the user's choice for tool approval +type ApprovalAction int + +const ( + ApprovalApprove ApprovalAction = iota + ApprovalReject +) + +func (a ApprovalAction) String() string { + switch a { + case ApprovalApprove: + return "Approve" + case ApprovalReject: + return "Reject" + default: + return "Unknown" + } +} + // FileSelectionState represents the state of file selection UI type FileSelectionState struct { Files []string `json:"files"` @@ -216,6 +242,13 @@ type FileSelectionState struct { SelectedIndex int `json:"selected_index"` } +// ApprovalUIState represents the state of approval UI +type ApprovalUIState struct { + SelectedIndex int `json:"selected_index"` + PendingToolCall *sdk.ChatCompletionMessageToolCall `json:"pending_tool_call"` + ResponseChan chan ApprovalAction `json:"-"` +} + // NewApplicationState creates a new application state func NewApplicationState() *ApplicationState { return &ApplicationState{ @@ -261,6 +294,7 @@ func (s *ApplicationState) isValidTransition(from, to ViewState) bool { ViewStateThemeSelection, ViewStateA2AServers, ViewStateA2ATaskManagement, + ViewStateToolApproval, }, ViewStateFileSelection: {ViewStateChat}, ViewStateTextSelection: {ViewStateChat}, @@ -268,6 +302,7 @@ func (s *ApplicationState) isValidTransition(from, to ViewState) bool { ViewStateThemeSelection: {ViewStateChat}, ViewStateA2AServers: {ViewStateChat}, ViewStateA2ATaskManagement: {ViewStateChat}, + ViewStateToolApproval: {ViewStateChat}, } allowed, exists := validTransitions[from] @@ -582,6 +617,37 @@ func (s *ApplicationState) ClearFileSelectionState() { s.fileSelectionState = nil } +// Approval State Management + +// SetupApprovalUIState initializes approval UI state with the pending tool call +func (s *ApplicationState) SetupApprovalUIState(toolCall *sdk.ChatCompletionMessageToolCall, responseChan chan ApprovalAction) { + s.approvalUIState = &ApprovalUIState{ + SelectedIndex: int(ApprovalApprove), // Default to approve + PendingToolCall: toolCall, + ResponseChan: responseChan, + } +} + +// GetApprovalUIState returns the current approval UI state +func (s *ApplicationState) GetApprovalUIState() *ApprovalUIState { + return s.approvalUIState +} + +// SetApprovalSelectedIndex sets the approval selection index +func (s *ApplicationState) SetApprovalSelectedIndex(index int) { + if s.approvalUIState != nil { + s.approvalUIState.SelectedIndex = index + } +} + +// ClearApprovalUIState clears the approval UI state +func (s *ApplicationState) ClearApprovalUIState() { + if s.approvalUIState != nil && s.approvalUIState.ResponseChan != nil { + close(s.approvalUIState.ResponseChan) + } + s.approvalUIState = nil +} + // StateSnapshot represents a point-in-time snapshot of application state type StateSnapshot struct { CurrentView string `json:"current_view"` diff --git a/internal/domain/theme_provider.go b/internal/domain/theme_provider.go index 0a4061d3..d751e1fd 100644 --- a/internal/domain/theme_provider.go +++ b/internal/domain/theme_provider.go @@ -85,15 +85,16 @@ func NewTokyoNightTheme() *TokyoNightTheme { return &TokyoNightTheme{} } -func (t *TokyoNightTheme) GetUserColor() string { return colors.UserColor.ANSI } -func (t *TokyoNightTheme) GetAssistantColor() string { return colors.AssistantColor.ANSI } -func (t *TokyoNightTheme) GetErrorColor() string { return colors.ErrorColor.ANSI } -func (t *TokyoNightTheme) GetStatusColor() string { return colors.StatusColor.ANSI } -func (t *TokyoNightTheme) GetAccentColor() string { return colors.AccentColor.ANSI } -func (t *TokyoNightTheme) GetDimColor() string { return colors.DimColor.ANSI } -func (t *TokyoNightTheme) GetBorderColor() string { return colors.BorderColor.ANSI } -func (t *TokyoNightTheme) GetDiffAddColor() string { return colors.DiffAddColor.ANSI } -func (t *TokyoNightTheme) GetDiffRemoveColor() string { return colors.DiffRemoveColor.ANSI } +func (t *TokyoNightTheme) GetUserColor() string { return colors.UserColor.Lipgloss } +func (t *TokyoNightTheme) GetAssistantColor() string { return colors.AssistantColor.Lipgloss } +func (t *TokyoNightTheme) GetErrorColor() string { return colors.ErrorColor.Lipgloss } +func (t *TokyoNightTheme) GetSuccessColor() string { return colors.SuccessColor.Lipgloss } +func (t *TokyoNightTheme) GetStatusColor() string { return colors.StatusColor.Lipgloss } +func (t *TokyoNightTheme) GetAccentColor() string { return colors.AccentColor.Lipgloss } +func (t *TokyoNightTheme) GetDimColor() string { return colors.DimColor.Lipgloss } +func (t *TokyoNightTheme) GetBorderColor() string { return colors.BorderColor.Lipgloss } +func (t *TokyoNightTheme) GetDiffAddColor() string { return colors.DiffAddColor.Lipgloss } +func (t *TokyoNightTheme) GetDiffRemoveColor() string { return colors.DiffRemoveColor.Lipgloss } // GithubLightTheme provides a light theme similar to GitHub's interface type GithubLightTheme struct{} @@ -102,15 +103,16 @@ func NewGithubLightTheme() *GithubLightTheme { return &GithubLightTheme{} } -func (t *GithubLightTheme) GetUserColor() string { return "\033[38;2;3;102;214m" } // GitHub blue -func (t *GithubLightTheme) GetAssistantColor() string { return "\033[38;2;36;41;46m" } // Dark gray -func (t *GithubLightTheme) GetErrorColor() string { return "\033[38;2;207;34;46m" } // GitHub red -func (t *GithubLightTheme) GetStatusColor() string { return "\033[38;2;130;80;223m" } // GitHub purple -func (t *GithubLightTheme) GetAccentColor() string { return "\033[38;2;3;102;214m" } // GitHub blue -func (t *GithubLightTheme) GetDimColor() string { return "\033[38;2;101;109;118m" } // GitHub gray -func (t *GithubLightTheme) GetBorderColor() string { return "\033[38;2;208;215;222m" } // Light gray border -func (t *GithubLightTheme) GetDiffAddColor() string { return "\033[38;2;40;167;69m" } // GitHub green -func (t *GithubLightTheme) GetDiffRemoveColor() string { return "\033[38;2;207;34;46m" } // GitHub red +func (t *GithubLightTheme) GetUserColor() string { return colors.GithubUserColor.Lipgloss } +func (t *GithubLightTheme) GetAssistantColor() string { return colors.GithubAssistantColor.Lipgloss } +func (t *GithubLightTheme) GetErrorColor() string { return colors.GithubErrorColor.Lipgloss } +func (t *GithubLightTheme) GetSuccessColor() string { return colors.GithubSuccessColor.Lipgloss } +func (t *GithubLightTheme) GetStatusColor() string { return colors.GithubStatusColor.Lipgloss } +func (t *GithubLightTheme) GetAccentColor() string { return colors.GithubAccentColor.Lipgloss } +func (t *GithubLightTheme) GetDimColor() string { return colors.GithubDimColor.Lipgloss } +func (t *GithubLightTheme) GetBorderColor() string { return colors.GithubBorderColor.Lipgloss } +func (t *GithubLightTheme) GetDiffAddColor() string { return colors.GithubDiffAddColor.Lipgloss } +func (t *GithubLightTheme) GetDiffRemoveColor() string { return colors.GithubDiffRemoveColor.Lipgloss } // DraculaTheme provides the popular Dracula color scheme type DraculaTheme struct{} @@ -119,12 +121,13 @@ func NewDraculaTheme() *DraculaTheme { return &DraculaTheme{} } -func (t *DraculaTheme) GetUserColor() string { return "\033[38;2;139;233;253m" } // Cyan -func (t *DraculaTheme) GetAssistantColor() string { return "\033[38;2;248;248;242m" } // Foreground -func (t *DraculaTheme) GetErrorColor() string { return "\033[38;2;255;85;85m" } // Red -func (t *DraculaTheme) GetStatusColor() string { return "\033[38;2;189;147;249m" } // Purple -func (t *DraculaTheme) GetAccentColor() string { return "\033[38;2;255;121;198m" } // Pink -func (t *DraculaTheme) GetDimColor() string { return "\033[38;2;98;114;164m" } // Comment -func (t *DraculaTheme) GetBorderColor() string { return "\033[38;2;68;71;90m" } // Selection -func (t *DraculaTheme) GetDiffAddColor() string { return "\033[38;2;80;250;123m" } // Green -func (t *DraculaTheme) GetDiffRemoveColor() string { return "\033[38;2;255;85;85m" } // Red +func (t *DraculaTheme) GetUserColor() string { return colors.DraculaUserColor.Lipgloss } +func (t *DraculaTheme) GetAssistantColor() string { return colors.DraculaAssistantColor.Lipgloss } +func (t *DraculaTheme) GetErrorColor() string { return colors.DraculaErrorColor.Lipgloss } +func (t *DraculaTheme) GetSuccessColor() string { return colors.DraculaSuccessColor.Lipgloss } +func (t *DraculaTheme) GetStatusColor() string { return colors.DraculaStatusColor.Lipgloss } +func (t *DraculaTheme) GetAccentColor() string { return colors.DraculaAccentColor.Lipgloss } +func (t *DraculaTheme) GetDimColor() string { return colors.DraculaDimColor.Lipgloss } +func (t *DraculaTheme) GetBorderColor() string { return colors.DraculaBorderColor.Lipgloss } +func (t *DraculaTheme) GetDiffAddColor() string { return colors.DraculaDiffAddColor.Lipgloss } +func (t *DraculaTheme) GetDiffRemoveColor() string { return colors.DraculaDiffRemoveColor.Lipgloss } diff --git a/internal/domain/ui_events.go b/internal/domain/ui_events.go index c5764c90..d517e0ac 100644 --- a/internal/domain/ui_events.go +++ b/internal/domain/ui_events.go @@ -1,5 +1,9 @@ package domain +import ( + sdk "github.com/inference-gateway/sdk" +) + // UI Events for application state management // UpdateHistoryEvent updates the conversation history display @@ -158,3 +162,17 @@ type ToolExecutionCompletedEvent struct { FailureCount int Results []*ToolExecutionResult } + +// Approval Events + +// ShowToolApprovalEvent displays the tool approval modal +type ShowToolApprovalEvent struct { + ToolCall sdk.ChatCompletionMessageToolCall + ResponseChan chan ApprovalAction +} + +// ToolApprovalResponseEvent captures the user's approval decision +type ToolApprovalResponseEvent struct { + Action ApprovalAction + ToolCall sdk.ChatCompletionMessageToolCall +} diff --git a/internal/handlers/chat_command_handler.go b/internal/handlers/chat_command_handler.go index 11872ad3..319c99db 100644 --- a/internal/handlers/chat_command_handler.go +++ b/internal/handlers/chat_command_handler.go @@ -78,6 +78,18 @@ func (c *ChatCommandHandler) handleBashCommand( } } + isWhitelisted := c.handler.configService.IsBashCommandWhitelisted(command) + requiresApproval := !isWhitelisted + + if requiresApproval { + return c.handleBashCommandWithApproval(commandText, command) + } + + return c.executeBashCommand(commandText, command) +} + +// executeBashCommand executes a bash command without approval +func (c *ChatCommandHandler) executeBashCommand(commandText, command string) tea.Cmd { return func() tea.Msg { toolCall := sdk.ChatCompletionMessageToolCallFunction{ Name: "Bash", @@ -119,6 +131,29 @@ func (c *ChatCommandHandler) handleBashCommand( } } +// handleBashCommandWithApproval requests approval before executing a bash command +func (c *ChatCommandHandler) handleBashCommandWithApproval(commandText, command string) tea.Cmd { + return func() tea.Msg { + toolCall := sdk.ChatCompletionMessageToolCall{ + Id: fmt.Sprintf("manual-%d", time.Now().UnixNano()), + Type: "function", + Function: sdk.ChatCompletionMessageToolCallFunction{ + Name: "Bash", + Arguments: fmt.Sprintf(`{"command": "%s"}`, strings.ReplaceAll(command, `"`, `\"`)), + }, + } + + responseChan := make(chan domain.ApprovalAction, 1) + + return domain.ToolApprovalRequestedEvent{ + RequestID: fmt.Sprintf("manual-bash-%d", time.Now().UnixNano()), + Timestamp: time.Now(), + ToolCall: toolCall, + ResponseChan: responseChan, + } + } +} + // handleToolCommand processes tool commands starting with !! func (c *ChatCommandHandler) handleToolCommand( commandText string, @@ -153,18 +188,31 @@ func (c *ChatCommandHandler) handleToolCommand( } } - return func() tea.Msg { - argsJSON, err := json.Marshal(args) - if err != nil { + requiresApproval := c.handler.configService.IsApprovalRequired(toolName) + + argsJSON, err := json.Marshal(args) + if err != nil { + return func() tea.Msg { return domain.ShowErrorEvent{ Error: fmt.Sprintf("Failed to marshal arguments: %v", err), Sticky: false, } } + } + if requiresApproval { + return c.handleToolCommandWithApproval(toolName, string(argsJSON)) + } + + return c.executeToolCommand(toolName, string(argsJSON)) +} + +// executeToolCommand executes a tool command without approval +func (c *ChatCommandHandler) executeToolCommand(toolName, argsJSON string) tea.Cmd { + return func() tea.Msg { toolCall := sdk.ChatCompletionMessageToolCallFunction{ Name: toolName, - Arguments: string(argsJSON), + Arguments: argsJSON, } result, err := c.handler.toolService.ExecuteTool(context.Background(), toolCall) @@ -175,6 +223,8 @@ func (c *ChatCommandHandler) handleToolCommand( } } + commandText := "!!" + toolName + "(...)" + userEntry := domain.ConversationEntry{ Message: sdk.Message{ Role: sdk.User, @@ -201,6 +251,29 @@ func (c *ChatCommandHandler) handleToolCommand( } } +// handleToolCommandWithApproval requests approval before executing a tool command +func (c *ChatCommandHandler) handleToolCommandWithApproval(toolName, argsJSON string) tea.Cmd { + return func() tea.Msg { + toolCall := sdk.ChatCompletionMessageToolCall{ + Id: fmt.Sprintf("manual-%d", time.Now().UnixNano()), + Type: "function", + Function: sdk.ChatCompletionMessageToolCallFunction{ + Name: toolName, + Arguments: argsJSON, + }, + } + + responseChan := make(chan domain.ApprovalAction, 1) + + return domain.ToolApprovalRequestedEvent{ + RequestID: fmt.Sprintf("manual-tool-%d", time.Now().UnixNano()), + Timestamp: time.Now(), + ToolCall: toolCall, + ResponseChan: responseChan, + } + } +} + // ParseToolCall parses a tool call in the format ToolName(arg="value", arg2="value2") (exposed for testing) func (c *ChatCommandHandler) ParseToolCall(input string) (string, map[string]any, error) { parenIndex := strings.Index(input, "(") diff --git a/internal/handlers/chat_command_handler_test.go b/internal/handlers/chat_command_handler_test.go index 13f9ee9f..0fe83f6b 100644 --- a/internal/handlers/chat_command_handler_test.go +++ b/internal/handlers/chat_command_handler_test.go @@ -105,10 +105,14 @@ func TestChatCommandHandler_handleBashCommand(t *testing.T) { mockTool := &mocks.FakeToolService{} mockTool.IsToolEnabledReturns(tt.toolEnabled) + mockConfig := &mocks.FakeConfigService{} + mockConfig.IsBashCommandWhitelistedReturns(true) + conversationRepo := services.NewInMemoryConversationRepository(nil) handler := &ChatHandler{ toolService: mockTool, + configService: mockConfig, conversationRepo: conversationRepo, } @@ -169,10 +173,14 @@ func TestChatCommandHandler_handleToolCommand(t *testing.T) { mockTool := &mocks.FakeToolService{} mockTool.IsToolEnabledReturns(tt.toolEnabled) + mockConfig := &mocks.FakeConfigService{} + mockConfig.IsApprovalRequiredReturns(false) + conversationRepo := services.NewInMemoryConversationRepository(nil) handler := &ChatHandler{ toolService: mockTool, + configService: mockConfig, conversationRepo: conversationRepo, } diff --git a/internal/handlers/chat_event_handler.go b/internal/handlers/chat_event_handler.go index 7d4a4f6f..0d837a60 100644 --- a/internal/handlers/chat_event_handler.go +++ b/internal/handlers/chat_event_handler.go @@ -9,6 +9,7 @@ import ( domain "github.com/inference-gateway/cli/internal/domain" tools "github.com/inference-gateway/cli/internal/services/tools" components "github.com/inference-gateway/cli/internal/ui/components" + styles "github.com/inference-gateway/cli/internal/ui/styles" sdk "github.com/inference-gateway/sdk" ) @@ -18,9 +19,13 @@ type ChatEventHandler struct { } func NewChatEventHandler(handler *ChatHandler) *ChatEventHandler { + // Create style provider with default theme for tool call rendering + themeService := domain.NewThemeProvider() + styleProvider := styles.NewProvider(themeService) + return &ChatEventHandler{ handler: handler, - toolCallRenderer: components.NewToolCallRenderer(), + toolCallRenderer: components.NewToolCallRenderer(styleProvider), } } @@ -370,6 +375,29 @@ func (e *ChatEventHandler) handleToolCallReady( return tea.Batch(cmds...) } +func (e *ChatEventHandler) handleToolApprovalRequested( + msg domain.ToolApprovalRequestedEvent, +) tea.Cmd { + _ = e.handler.stateManager.TransitionToView(domain.ViewStateToolApproval) + + e.handler.stateManager.SetupApprovalUIState(&msg.ToolCall, msg.ResponseChan) + + var cmds []tea.Cmd + + cmds = append(cmds, func() tea.Msg { + return domain.ShowToolApprovalEvent{ + ToolCall: msg.ToolCall, + ResponseChan: msg.ResponseChan, + } + }) + + if chatSession := e.handler.stateManager.GetChatSession(); chatSession != nil && chatSession.EventChannel != nil { + cmds = append(cmds, e.handler.listenForChatEvents(chatSession.EventChannel)) + } + + return tea.Batch(cmds...) +} + func (e *ChatEventHandler) handleToolExecutionStarted( msg domain.ToolExecutionStartedEvent, ) tea.Cmd { diff --git a/internal/handlers/chat_handler.go b/internal/handlers/chat_handler.go index dc173126..5b3b02df 100644 --- a/internal/handlers/chat_handler.go +++ b/internal/handlers/chat_handler.go @@ -2,7 +2,9 @@ package handlers import ( "context" + "encoding/json" "fmt" + "strings" "time" spinner "github.com/charmbracelet/bubbles/spinner" @@ -119,6 +121,10 @@ func (h *ChatHandler) Handle(msg tea.Msg) tea.Cmd { // nolint:cyclop,gocyclo return h.HandleA2ATaskInputRequiredEvent(m) case domain.MessageQueuedEvent: return h.HandleMessageQueuedEvent(m) + case domain.ToolApprovalRequestedEvent: + return h.HandleToolApprovalRequestedEvent(m) + case domain.ToolApprovalResponseEvent: + return h.HandleToolApprovalResponseEvent(m) default: if isUIOnlyEvent(msg) { return nil @@ -153,9 +159,10 @@ func (h *ChatHandler) startChatCompletion() tea.Cmd { requestID := generateRequestID() req := &domain.AgentRequest{ - RequestID: requestID, - Model: currentModel, - Messages: messages, + RequestID: requestID, + Model: currentModel, + Messages: messages, + IsChatMode: true, } eventChan, err := h.agentService.RunWithStream(ctx, req) @@ -361,6 +368,12 @@ func (h *ChatHandler) HandleToolCallReadyEvent( return h.eventHandler.handleToolCallReady(msg) } +func (h *ChatHandler) HandleToolApprovalRequestedEvent( + msg domain.ToolApprovalRequestedEvent, +) tea.Cmd { + return h.eventHandler.handleToolApprovalRequested(msg) +} + func (h *ChatHandler) HandleToolExecutionStartedEvent( msg domain.ToolExecutionStartedEvent, ) tea.Cmd { @@ -444,6 +457,94 @@ func (h *ChatHandler) HandleMessageQueuedEvent( return cmd } +func (h *ChatHandler) HandleToolApprovalResponseEvent( + msg domain.ToolApprovalResponseEvent, +) tea.Cmd { + return h.handleToolApprovalResponse(msg) +} + +// handleToolApprovalResponse processes the user's approval decision +func (h *ChatHandler) handleToolApprovalResponse( + msg domain.ToolApprovalResponseEvent, +) tea.Cmd { + approvalState := h.stateManager.GetApprovalUIState() + if approvalState != nil && approvalState.ResponseChan != nil { + select { + case approvalState.ResponseChan <- msg.Action: + default: + } + } + + h.stateManager.ClearApprovalUIState() + _ = h.stateManager.TransitionToView(domain.ViewStateChat) + + isManualExecution := strings.HasPrefix(msg.ToolCall.Id, "manual-") + + if !isManualExecution { + return nil + } + + if msg.Action == domain.ApprovalReject { + return func() tea.Msg { + return domain.ShowErrorEvent{ + Error: fmt.Sprintf("Tool execution rejected: %s", msg.ToolCall.Function.Name), + Sticky: false, + } + } + } + + return func() tea.Msg { + ctx := context.WithValue(context.Background(), domain.ToolApprovedKey, true) + result, err := h.toolService.ExecuteTool(ctx, msg.ToolCall.Function) + if err != nil { + return domain.ShowErrorEvent{ + Error: fmt.Sprintf("Failed to execute tool: %v", err), + Sticky: false, + } + } + + commandText := h.reconstructCommandText(msg.ToolCall) + + userEntry := domain.ConversationEntry{ + Message: sdk.Message{ + Role: sdk.User, + Content: commandText, + }, + Time: time.Now(), + } + _ = h.conversationRepo.AddMessage(userEntry) + + responseContent := h.conversationRepo.FormatToolResultForLLM(result) + assistantEntry := domain.ConversationEntry{ + Message: sdk.Message{ + Role: sdk.Assistant, + Content: responseContent, + }, + Time: time.Now(), + } + _ = h.conversationRepo.AddMessage(assistantEntry) + + return domain.UpdateHistoryEvent{ + History: h.conversationRepo.GetMessages(), + } + } +} + +// reconstructCommandText reconstructs the original command text from a tool call +func (h *ChatHandler) reconstructCommandText(toolCall sdk.ChatCompletionMessageToolCall) string { + if toolCall.Function.Name == "Bash" { + var args map[string]interface{} + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err == nil { + if command, ok := args["command"].(string); ok { + return "!" + command + } + } + return "!" + } + + return "!!" + toolCall.Function.Name + "(...)" +} + // isUIOnlyEvent checks if the event is a UI-only event that doesn't require business logic handling func isUIOnlyEvent(msg tea.Msg) bool { switch msg.(type) { @@ -465,6 +566,7 @@ func isUIOnlyEvent(msg tea.Msg) bool { domain.ExitSelectionModeEvent, domain.ModelSelectedEvent, domain.ThemeSelectedEvent, + domain.ShowToolApprovalEvent, tea.KeyMsg, tea.WindowSizeMsg, spinner.TickMsg: diff --git a/internal/services/agent.go b/internal/services/agent.go index 44f65026..7dc35513 100644 --- a/internal/services/agent.go +++ b/internal/services/agent.go @@ -506,7 +506,7 @@ func (s *AgentServiceImpl) RunWithStream(ctx context.Context, req *domain.AgentR toolCallsSlice = append(toolCallsSlice, tc) } - toolResults := s.executeToolCallsParallel(ctx, toolCallsSlice, eventPublisher) + toolResults := s.executeToolCallsParallel(ctx, toolCallsSlice, eventPublisher, req.IsChatMode) for _, entry := range toolResults { toolResult := sdk.Message{ @@ -642,6 +642,7 @@ func (s *AgentServiceImpl) executeToolCallsParallel( ctx context.Context, toolCalls []*sdk.ChatCompletionMessageToolCall, eventPublisher *eventPublisher, + isChatMode bool, ) []domain.ConversationEntry { if len(toolCalls) == 0 { @@ -661,58 +662,104 @@ func (s *AgentServiceImpl) executeToolCallsParallel( time.Sleep(constants.AgentToolExecutionDelay) results := make([]domain.ConversationEntry, len(toolCalls)) - resultsChan := make(chan IndexedToolResult, len(toolCalls)) - semaphore := make(chan struct{}, s.config.GetAgentConfig().MaxConcurrentTools) + var approvalTools []struct { + index int + tool *sdk.ChatCompletionMessageToolCall + } + var parallelTools []struct { + index int + tool *sdk.ChatCompletionMessageToolCall + } - var wg sync.WaitGroup for i, tc := range toolCalls { - wg.Add(1) - go func(index int, toolCall *sdk.ChatCompletionMessageToolCall) { - defer func() { - wg.Done() - }() + requiresApproval := s.shouldRequireApproval(tc, isChatMode) + if requiresApproval { + approvalTools = append(approvalTools, struct { + index int + tool *sdk.ChatCompletionMessageToolCall + }{i, tc}) + } else { + parallelTools = append(parallelTools, struct { + index int + tool *sdk.ChatCompletionMessageToolCall + }{i, tc}) + } + } - semaphore <- struct{}{} - defer func() { - <-semaphore - }() + for _, at := range approvalTools { + eventPublisher.publishToolStatusChange( + at.tool.Id, + "starting", + fmt.Sprintf("Initializing %s...", at.tool.Function.Name), + ) - eventPublisher.publishToolStatusChange( - toolCall.Id, - "starting", - fmt.Sprintf("Initializing %s...", toolCall.Function.Name), - ) + time.Sleep(constants.AgentToolExecutionDelay) - time.Sleep(constants.AgentToolExecutionDelay) + result := s.executeToolWithFlashingUI(ctx, *at.tool, eventPublisher, isChatMode) - result := s.executeToolWithFlashingUI(ctx, *toolCall, eventPublisher) + status := "complete" + message := "Completed successfully" + if result.ToolExecution != nil && !result.ToolExecution.Success { + status = "failed" + message = "Execution failed" + } - status := "complete" - message := "Completed successfully" - if result.ToolExecution != nil && !result.ToolExecution.Success { - status = "failed" - message = "Execution failed" - } + eventPublisher.publishToolStatusChange(at.tool.Id, status, message) + results[at.index] = result + } - eventPublisher.publishToolStatusChange(toolCall.Id, status, message) + if len(parallelTools) > 0 { + resultsChan := make(chan IndexedToolResult, len(parallelTools)) + semaphore := make(chan struct{}, s.config.GetAgentConfig().MaxConcurrentTools) - resultsChan <- IndexedToolResult{ - Index: index, - Result: result, - } - }(i, tc) - } + var wg sync.WaitGroup + for _, pt := range parallelTools { + wg.Add(1) + go func(index int, toolCall *sdk.ChatCompletionMessageToolCall) { + defer func() { + wg.Done() + }() - go func() { - wg.Wait() - close(resultsChan) - }() + semaphore <- struct{}{} + defer func() { + <-semaphore + }() + + eventPublisher.publishToolStatusChange( + toolCall.Id, + "starting", + fmt.Sprintf("Initializing %s...", toolCall.Function.Name), + ) + + time.Sleep(constants.AgentToolExecutionDelay) + + result := s.executeToolWithFlashingUI(ctx, *toolCall, eventPublisher, isChatMode) + + status := "complete" + message := "Completed successfully" + if result.ToolExecution != nil && !result.ToolExecution.Success { + status = "failed" + message = "Execution failed" + } + + eventPublisher.publishToolStatusChange(toolCall.Id, status, message) + + resultsChan <- IndexedToolResult{ + Index: index, + Result: result, + } + }(pt.index, pt.tool) + } + + go func() { + wg.Wait() + close(resultsChan) + }() - resultCount := 0 - for res := range resultsChan { - resultCount++ - results[res.Index] = res.Result + for res := range resultsChan { + results[res.Index] = res.Result + } } duration := time.Since(startTime) @@ -740,10 +787,31 @@ func (s *AgentServiceImpl) executeToolWithFlashingUI( ctx context.Context, tc sdk.ChatCompletionMessageToolCall, eventPublisher *eventPublisher, + isChatMode bool, ) domain.ConversationEntry { startTime := time.Now() + requiresApproval := s.shouldRequireApproval(&tc, isChatMode) + logger.Debug("tool approval check", "tool", tc.Function.Name, "is_chat_mode", isChatMode, "requires_approval", requiresApproval) + + wasApproved := false + + if requiresApproval { + logger.Info("requesting approval for tool", "tool", tc.Function.Name) + approved, err := s.requestToolApproval(ctx, tc, eventPublisher) + if err != nil { + logger.Error("failed to request tool approval", "tool", tc.Function.Name, "error", err) + return s.createErrorEntry(tc, err, startTime) + } + if !approved { + logger.Info("tool execution rejected by user", "tool", tc.Function.Name) + rejectionErr := fmt.Errorf("tool execution rejected by user") + return s.createErrorEntry(tc, rejectionErr, startTime) + } + wasApproved = true + } + eventPublisher.publishToolStatusChange(tc.Id, "running", "Executing...") time.Sleep(constants.AgentToolExecutionDelay) @@ -754,9 +822,16 @@ func (s *AgentServiceImpl) executeToolWithFlashingUI( return s.createErrorEntry(tc, err, startTime) } - if err := s.toolService.ValidateTool(tc.Function.Name, args); err != nil { - logger.Error("tool validation failed", "tool", tc.Function.Name, "error", err) - return s.createErrorEntry(tc, err, startTime) + if !wasApproved { + if err := s.toolService.ValidateTool(tc.Function.Name, args); err != nil { + logger.Error("tool validation failed", "tool", tc.Function.Name, "error", err) + return s.createErrorEntry(tc, err, startTime) + } + } + + execCtx := ctx + if wasApproved { + execCtx = context.WithValue(ctx, domain.ToolApprovedKey, true) } resultChan := make(chan struct { @@ -765,7 +840,7 @@ func (s *AgentServiceImpl) executeToolWithFlashingUI( }, 1) go func() { - result, err := s.toolService.ExecuteTool(ctx, tc.Function) + result, err := s.toolService.ExecuteTool(execCtx, tc.Function) resultChan <- struct { result *domain.ToolExecutionResult err error @@ -829,6 +904,59 @@ done: return entry } +// requestToolApproval requests user approval for a tool and waits for response +func (s *AgentServiceImpl) requestToolApproval( + ctx context.Context, + tc sdk.ChatCompletionMessageToolCall, + eventPublisher *eventPublisher, +) (bool, error) { + // Create response channel + responseChan := make(chan domain.ApprovalAction, 1) + + // Publish approval request event + eventPublisher.chatEvents <- domain.ToolApprovalRequestedEvent{ + RequestID: eventPublisher.requestID, + Timestamp: time.Now(), + ToolCall: tc, + ResponseChan: responseChan, + } + + // Wait for user response or context cancellation + select { + case response := <-responseChan: + return response == domain.ApprovalApprove, nil + case <-ctx.Done(): + return false, fmt.Errorf("approval request cancelled: %w", ctx.Err()) + case <-time.After(5 * time.Minute): // Timeout after 5 minutes + return false, fmt.Errorf("approval request timed out") + } +} + +// shouldRequireApproval determines if a tool execution requires user approval +// For Bash tool specifically, it checks if the command is whitelisted +func (s *AgentServiceImpl) shouldRequireApproval(tc *sdk.ChatCompletionMessageToolCall, isChatMode bool) bool { + if !isChatMode { + return false + } + + if tc.Function.Name == "Bash" { + var args map[string]any + if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil { + return true + } + + command, ok := args["command"].(string) + if !ok { + return true + } + + isWhitelisted := s.config.IsBashCommandWhitelisted(command) + return !isWhitelisted + } + + return s.config.IsApprovalRequired(tc.Function.Name) +} + func (s *AgentServiceImpl) createErrorEntry(tc sdk.ChatCompletionMessageToolCall, err error, startTime time.Time) domain.ConversationEntry { return domain.ConversationEntry{ Message: domain.Message{ diff --git a/internal/services/agent_manager.go b/internal/services/agent_manager.go index 692ac035..e402edfc 100644 --- a/internal/services/agent_manager.go +++ b/internal/services/agent_manager.go @@ -12,6 +12,11 @@ import ( logger "github.com/inference-gateway/cli/internal/logger" ) +const ( + // AgentContainerPrefix is the naming prefix for agent containers + AgentContainerPrefix = "inference-agent-" +) + // AgentManager manages the lifecycle of A2A agent containers type AgentManager struct { config *config.Config @@ -92,11 +97,14 @@ func (am *AgentManager) StartAgent(ctx context.Context, agent config.AgentEntry) // StopAgents stops all running agent containers func (am *AgentManager) StopAgents(ctx context.Context) error { + logger.Info("Stopping agents", "trackedCount", len(am.containers)) + for agentName := range am.containers { if err := am.StopAgent(ctx, agentName); err != nil { logger.Warn("Failed to stop agent", "name", agentName, "error", err) } } + am.isRunning = false return nil } @@ -154,7 +162,7 @@ func (am *AgentManager) startContainer(ctx context.Context, agent config.AgentEn args := []string{ "run", "-d", - "--name", fmt.Sprintf("infer-agent-%s", agent.Name), + "--name", fmt.Sprintf("%s%s", AgentContainerPrefix, agent.Name), "-p", fmt.Sprintf("%s:8080", port), "--rm", } @@ -182,7 +190,7 @@ func (am *AgentManager) startContainer(ctx context.Context, agent config.AgentEn // isAgentRunning checks if an agent container is already running func (am *AgentManager) isAgentRunning(agentName string) bool { - containerName := fmt.Sprintf("infer-agent-%s", agentName) + containerName := fmt.Sprintf("%s%s", AgentContainerPrefix, agentName) cmd := exec.Command("docker", "ps", "--filter", fmt.Sprintf("name=%s", containerName), "--format", "{{.ID}}") output, err := cmd.CombinedOutput() if err != nil { diff --git a/internal/services/state_manager.go b/internal/services/state_manager.go index 67e72d76..c129a2fd 100644 --- a/internal/services/state_manager.go +++ b/internal/services/state_manager.go @@ -425,6 +425,40 @@ func (sm *StateManager) ClearFileSelectionState() { sm.state.ClearFileSelectionState() } +// Approval state methods + +// SetupApprovalUIState initializes approval UI state +func (sm *StateManager) SetupApprovalUIState(toolCall *sdk.ChatCompletionMessageToolCall, responseChan chan domain.ApprovalAction) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + + sm.state.SetupApprovalUIState(toolCall, responseChan) +} + +// GetApprovalUIState returns the current approval UI state +func (sm *StateManager) GetApprovalUIState() *domain.ApprovalUIState { + sm.mutex.RLock() + defer sm.mutex.RUnlock() + + return sm.state.GetApprovalUIState() +} + +// SetApprovalSelectedIndex sets the approval selection index +func (sm *StateManager) SetApprovalSelectedIndex(index int) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + + sm.state.SetApprovalSelectedIndex(index) +} + +// ClearApprovalUIState clears the approval UI state +func (sm *StateManager) ClearApprovalUIState() { + sm.mutex.Lock() + defer sm.mutex.Unlock() + + sm.state.ClearApprovalUIState() +} + // AddQueuedMessage adds a message to the input queue func (sm *StateManager) AddQueuedMessage(message sdk.Message, requestID string) { sm.mutex.Lock() diff --git a/internal/services/tools/bash.go b/internal/services/tools/bash.go index 7fd80445..3e42b721 100644 --- a/internal/services/tools/bash.go +++ b/internal/services/tools/bash.go @@ -160,7 +160,9 @@ func (t *BashTool) executeBash(ctx context.Context, command string) (*BashResult Command: command, } - if !t.isCommandAllowed(command) { + wasApproved, _ := ctx.Value(domain.ToolApprovedKey).(bool) + + if !wasApproved && !t.isCommandAllowed(command) { result.ExitCode = -1 result.Duration = time.Since(start).String() result.Error = fmt.Sprintf("command not whitelisted: %s", command) diff --git a/internal/services/tools/edit.go b/internal/services/tools/edit.go index ca021b90..7ed9cd5d 100644 --- a/internal/services/tools/edit.go +++ b/internal/services/tools/edit.go @@ -10,6 +10,7 @@ import ( config "github.com/inference-gateway/cli/config" domain "github.com/inference-gateway/cli/internal/domain" components "github.com/inference-gateway/cli/internal/ui/components" + styles "github.com/inference-gateway/cli/internal/ui/styles" sdk "github.com/inference-gateway/sdk" ) @@ -663,7 +664,9 @@ func (t *EditTool) GetDiffInfo(args map[string]any) *components.DiffInfo { // FormatArgumentsForApproval formats arguments for approval display with diff preview func (t *EditTool) FormatArgumentsForApproval(args map[string]any) string { - // Use colored diff renderer - diffRenderer := components.NewToolDiffRenderer() + // Use colored diff renderer with default theme + themeService := domain.NewThemeProvider() + styleProvider := styles.NewProvider(themeService) + diffRenderer := components.NewToolDiffRenderer(styleProvider) return diffRenderer.RenderEditToolArguments(args) } diff --git a/internal/services/tools/grep.go b/internal/services/tools/grep.go index 496b8e2f..7f75020b 100644 --- a/internal/services/tools/grep.go +++ b/internal/services/tools/grep.go @@ -11,6 +11,7 @@ import ( "regexp" "strconv" "strings" + "sync" "time" config "github.com/inference-gateway/cli/config" @@ -25,6 +26,7 @@ type GrepTool struct { enabled bool gitignore *ignore.GitIgnore gitignoreCache map[string]*ignore.GitIgnore + cacheMutex sync.RWMutex ripgrepPath string useRipgrep bool formatter domain.BaseFormatter @@ -973,6 +975,16 @@ func (t *GrepTool) loadGitignore() { // getOrLoadDirGitignore loads and caches .gitignore for a specific directory func (t *GrepTool) getOrLoadDirGitignore(dirPath string) *ignore.GitIgnore { + t.cacheMutex.RLock() + if cached, exists := t.gitignoreCache[dirPath]; exists { + t.cacheMutex.RUnlock() + return cached + } + t.cacheMutex.RUnlock() + + t.cacheMutex.Lock() + defer t.cacheMutex.Unlock() + if cached, exists := t.gitignoreCache[dirPath]; exists { return cached } diff --git a/internal/services/tools/multiedit.go b/internal/services/tools/multiedit.go index 652cd2cc..02c1557d 100644 --- a/internal/services/tools/multiedit.go +++ b/internal/services/tools/multiedit.go @@ -10,6 +10,7 @@ import ( config "github.com/inference-gateway/cli/config" domain "github.com/inference-gateway/cli/internal/domain" components "github.com/inference-gateway/cli/internal/ui/components" + styles "github.com/inference-gateway/cli/internal/ui/styles" sdk "github.com/inference-gateway/sdk" ) @@ -895,6 +896,9 @@ func (t *MultiEditTool) GetDiffInfo(args map[string]any) *components.DiffInfo { // FormatArgumentsForApproval formats arguments for approval display with diff preview func (t *MultiEditTool) FormatArgumentsForApproval(args map[string]any) string { - diffRenderer := components.NewToolDiffRenderer() + // Use colored diff renderer with default theme + themeService := domain.NewThemeProvider() + styleProvider := styles.NewProvider(themeService) + diffRenderer := components.NewToolDiffRenderer(styleProvider) return diffRenderer.RenderMultiEditToolArguments(args) } diff --git a/internal/ui/autocomplete.go b/internal/ui/autocomplete.go index b5f1973f..8988199f 100644 --- a/internal/ui/autocomplete.go +++ b/internal/ui/autocomplete.go @@ -325,41 +325,55 @@ func (a *AutocompleteImpl) Render() string { } } + // Calculate max width across ALL filtered items to prevent jumping + maxShortcutWidth := 0 + for _, cmd := range a.filtered { + width := 0 + if cmd.Usage != "" && cmd.Usage != cmd.Shortcut { + width = len(cmd.Usage) + } else { + width = len(cmd.Shortcut) + } + if width > maxShortcutWidth { + maxShortcutWidth = width + } + } + + if maxShortcutWidth < 30 { + maxShortcutWidth = 30 + } + for i := start; i < end; i++ { cmd := a.filtered[i] var prefix string + var marker string if i == a.selected { - marker := "▶" - prefix = fmt.Sprintf("%s%s%s ", a.theme.GetAccentColor(), marker, colors.Reset) + marker = "▶ " + prefix = fmt.Sprintf("%s%s%s", a.theme.GetAccentColor(), marker, colors.Reset) } else { - prefix = " " + marker = " " + prefix = marker } - var line string + var shortcutText string if cmd.Usage != "" && cmd.Usage != cmd.Shortcut { - parts := strings.SplitN(cmd.Usage, " ", 2) - shortcutName := parts[0] - usageArgs := "" - if len(parts) > 1 { - usageArgs = parts[1] - } - - line = fmt.Sprintf("%s%-20s %s%-50s%s", - prefix, - shortcutName+" "+usageArgs, - a.theme.GetDimColor(), - cmd.Description, - colors.Reset) + shortcutText = cmd.Usage } else { - line = fmt.Sprintf("%s%-20s %s%s%s", - prefix, - cmd.Shortcut, - a.theme.GetDimColor(), - cmd.Description, - colors.Reset) + shortcutText = cmd.Shortcut } + paddedShortcut := shortcutText + strings.Repeat(" ", maxShortcutWidth-len(shortcutText)) + separator := " │ " + + line := fmt.Sprintf("%s%s%s%s%s%s", + prefix, + paddedShortcut, + a.theme.GetDimColor(), + separator, + cmd.Description, + colors.Reset) + b.WriteString(line) if i < end-1 { b.WriteString("\n") diff --git a/internal/ui/autocomplete_test.go b/internal/ui/autocomplete_test.go index f2678102..0205eb52 100644 --- a/internal/ui/autocomplete_test.go +++ b/internal/ui/autocomplete_test.go @@ -53,6 +53,7 @@ type MockTheme struct{} func (m MockTheme) GetUserColor() string { return "#00FF00" } func (m MockTheme) GetAssistantColor() string { return "#0000FF" } func (m MockTheme) GetErrorColor() string { return "#FF0000" } +func (m MockTheme) GetSuccessColor() string { return "#00FF00" } func (m MockTheme) GetStatusColor() string { return "#FFFF00" } func (m MockTheme) GetAccentColor() string { return "#FF00FF" } func (m MockTheme) GetDimColor() string { return "#808080" } diff --git a/internal/ui/components.go b/internal/ui/components.go index 90e8c218..10c108ef 100644 --- a/internal/ui/components.go +++ b/internal/ui/components.go @@ -4,11 +4,13 @@ import ( domain "github.com/inference-gateway/cli/internal/domain" shortcuts "github.com/inference-gateway/cli/internal/shortcuts" components "github.com/inference-gateway/cli/internal/ui/components" + styles "github.com/inference-gateway/cli/internal/ui/styles" ) // CreateConversationView creates a new conversation view component func CreateConversationView(themeService domain.ThemeService) ConversationRenderer { - return components.NewConversationView(themeService) + styleProvider := styles.NewProvider(themeService) + return components.NewConversationView(styleProvider) } // CreateInputView creates a new input view component @@ -42,12 +44,14 @@ func CreateInputViewWithToolServiceAndConfigDir(modelService domain.ModelService // CreateStatusView creates a new status view component func CreateStatusView(themeService domain.ThemeService) StatusComponent { - return components.NewStatusView(themeService) + styleProvider := styles.NewProvider(themeService) + return components.NewStatusView(styleProvider) } // CreateHelpBar creates a new help bar component func CreateHelpBar(themeService domain.ThemeService) HelpBarComponent { - return components.NewHelpBar(themeService) + styleProvider := styles.NewProvider(themeService) + return components.NewHelpBar(styleProvider) } // Layout calculations - simplified without interfaces diff --git a/internal/ui/components/a2a_servers_view.go b/internal/ui/components/a2a_servers_view.go index 0296c101..a462b93c 100644 --- a/internal/ui/components/a2a_servers_view.go +++ b/internal/ui/components/a2a_servers_view.go @@ -6,10 +6,9 @@ import ( "strings" tea "github.com/charmbracelet/bubbletea" - lipgloss "github.com/charmbracelet/lipgloss" config "github.com/inference-gateway/cli/config" domain "github.com/inference-gateway/cli/internal/domain" - colors "github.com/inference-gateway/cli/internal/ui/styles/colors" + styles "github.com/inference-gateway/cli/internal/ui/styles" icons "github.com/inference-gateway/cli/internal/ui/styles/icons" ) @@ -22,7 +21,7 @@ type A2AServersView struct { height int isLoading bool error string - themeService domain.ThemeService + styleProvider *styles.Provider } // A2AServerInfo represents information about an A2A server @@ -39,14 +38,14 @@ type A2AServerInfo struct { } // NewA2AServersView creates a new A2A servers view -func NewA2AServersView(cfg *config.Config, a2aAgentService domain.A2AAgentService, themeService domain.ThemeService) *A2AServersView { +func NewA2AServersView(cfg *config.Config, a2aAgentService domain.A2AAgentService, styleProvider *styles.Provider) *A2AServersView { return &A2AServersView{ config: cfg, a2aAgentService: a2aAgentService, servers: []A2AServerInfo{}, width: 80, height: 20, - themeService: themeService, + styleProvider: styleProvider, } } @@ -120,78 +119,51 @@ func (v *A2AServersView) Render() string { } func (v *A2AServersView) renderLoading() string { - headerColor := v.getHeaderColor() - content := headerColor + "Loading A2A servers..." + colors.Reset - - style := lipgloss.NewStyle(). - Width(v.width). - Height(v.height). - Align(lipgloss.Center, lipgloss.Center). - Border(lipgloss.RoundedBorder(), true). - BorderForeground(lipgloss.Color(v.getAccentColor())). - Padding(2, 4) - - return style.Render(content) + accentColor := v.styleProvider.GetThemeColor("accent") + content := v.styleProvider.RenderWithColor("Loading A2A servers...", accentColor) + return v.styleProvider.RenderCenteredBorderedBox(content, accentColor, v.width, v.height, 2, 4) } func (v *A2AServersView) renderError() string { - errorColor := v.getErrorColor() - dimColor := v.getDimColor() + errorColor := v.styleProvider.GetThemeColor("error") var content strings.Builder errorIcon := icons.StyledCrossMark() - content.WriteString(fmt.Sprintf("%s %sError loading A2A servers%s\n\n", errorIcon, errorColor, colors.Reset)) - content.WriteString(fmt.Sprintf("%s%s%s\n\n", dimColor, v.error, colors.Reset)) + content.WriteString(fmt.Sprintf("%s %s\n\n", errorIcon, v.styleProvider.RenderWithColor("Error loading A2A servers", errorColor))) + content.WriteString(fmt.Sprintf("%s\n\n", v.styleProvider.RenderDimText(v.error))) - content.WriteString(fmt.Sprintf("%sMake sure:%s\n", dimColor, colors.Reset)) + content.WriteString(fmt.Sprintf("%s\n", v.styleProvider.RenderDimText("Make sure:"))) content.WriteString("• The Gateway is running and accessible\n") content.WriteString("• A2A middleware is exposed (EXPOSE_A2A=true)\n") content.WriteString("• Your API key is valid") - style := lipgloss.NewStyle(). - Width(v.width). - Height(v.height). - Align(lipgloss.Left, lipgloss.Center). - Border(lipgloss.RoundedBorder(), true). - BorderForeground(lipgloss.Color(v.getErrorColor())). - Padding(2, 4) - - return style.Render(content.String()) + return v.styleProvider.RenderLeftAlignedBorderedBox(content.String(), errorColor, v.width, v.height, 2, 4) } func (v *A2AServersView) renderEmpty() string { - warningColor := v.getWarningColor() - dimColor := v.getDimColor() + warningColor := v.styleProvider.GetThemeColor("warning") var content strings.Builder - content.WriteString(fmt.Sprintf("%sNo A2A agents available%s\n\n", warningColor, colors.Reset)) + content.WriteString(fmt.Sprintf("%s\n\n", v.styleProvider.RenderWithColor("No A2A agents available", warningColor))) - content.WriteString(fmt.Sprintf("%sAgents are starting or not configured in .infer/agents.yaml%s\n\n", dimColor, colors.Reset)) + content.WriteString(fmt.Sprintf("%s\n\n", v.styleProvider.RenderDimText("Agents are starting or not configured in .infer/agents.yaml"))) content.WriteString("Available A2A tools:\n") content.WriteString("• SubmitTask: Submit tasks to A2A agents\n") content.WriteString("• QueryTask: Query task status and results\n") content.WriteString("• QueryAgent: Query agent capabilities and information\n") content.WriteString("• DownloadArtifacts: Download artifacts from completed tasks") - style := lipgloss.NewStyle(). - Width(v.width). - Height(v.height). - Align(lipgloss.Left, lipgloss.Center). - Border(lipgloss.RoundedBorder(), true). - BorderForeground(lipgloss.Color(v.getWarningColor())). - Padding(2, 4) - - return style.Render(content.String()) + return v.styleProvider.RenderLeftAlignedBorderedBox(content.String(), warningColor, v.width, v.height, 2, 4) } func (v *A2AServersView) renderServers() string { - headerColor := v.getHeaderColor() - successColor := v.getSuccessColor() + accentColor := v.styleProvider.GetThemeColor("accent") + successColor := v.styleProvider.GetThemeColor("success") var content strings.Builder - content.WriteString(fmt.Sprintf("%sA2A Agent Servers%s\n\n", headerColor, colors.Reset)) - content.WriteString(fmt.Sprintf("%sFound %d agent card(s)%s\n\n", successColor, len(v.servers), colors.Reset)) + content.WriteString(fmt.Sprintf("%s\n\n", v.styleProvider.RenderBold("Agent Servers"))) + content.WriteString(fmt.Sprintf("%s\n\n", v.styleProvider.RenderWithColor(fmt.Sprintf("Found %d agent card(s)", len(v.servers)), successColor))) for i, server := range v.servers { content.WriteString(v.renderSingleServer(server)) @@ -202,63 +174,52 @@ func (v *A2AServersView) renderServers() string { content.WriteString(v.renderConnectionInfo()) - style := lipgloss.NewStyle(). - Width(v.width). - Height(v.height). - Align(lipgloss.Left, lipgloss.Top). - Border(lipgloss.RoundedBorder(), true). - BorderForeground(lipgloss.Color(v.getAccentColor())). - Padding(1, 2) - - return style.Render(content.String()) + return v.styleProvider.RenderTopAlignedBorderedBox(content.String(), accentColor, v.width, v.height, 1, 2) } func (v *A2AServersView) renderSingleServer(server A2AServerInfo) string { successIcon := icons.StyledCheckMark() - dimColor := v.getDimColor() - headerColor := v.getHeaderColor() + accentColor := v.styleProvider.GetThemeColor("accent") var content strings.Builder - content.WriteString(fmt.Sprintf("%s%s%s %s (%s%s%s)\n", - colors.Bold, server.Name, colors.Reset, successIcon, dimColor, server.ID, colors.Reset)) + content.WriteString(fmt.Sprintf("%s %s\n", + v.styleProvider.RenderBold(server.Name), successIcon)) if server.Description != "" { content.WriteString(fmt.Sprintf(" %s\n", server.Description)) } if server.DocumentsURL != nil && *server.DocumentsURL != "" { - content.WriteString(fmt.Sprintf(" %sDocs:%s %s\n", headerColor, colors.Reset, *server.DocumentsURL)) + content.WriteString(fmt.Sprintf(" %s %s\n", v.styleProvider.RenderWithColor("Docs:", accentColor), *server.DocumentsURL)) } if len(server.InputModes) > 0 { - content.WriteString(fmt.Sprintf(" %sInput:%s %s\n", - headerColor, colors.Reset, strings.Join(server.InputModes, ", "))) + content.WriteString(fmt.Sprintf(" %s %s\n", + v.styleProvider.RenderWithColor("Input:", accentColor), strings.Join(server.InputModes, ", "))) } if len(server.OutputModes) > 0 { - content.WriteString(fmt.Sprintf(" %sOutput:%s %s\n", - headerColor, colors.Reset, strings.Join(server.OutputModes, ", "))) + content.WriteString(fmt.Sprintf(" %s %s\n", + v.styleProvider.RenderWithColor("Output:", accentColor), strings.Join(server.OutputModes, ", "))) } if server.URL != "" { - content.WriteString(fmt.Sprintf(" %sURL:%s %s%s%s\n", - headerColor, colors.Reset, dimColor, server.URL, colors.Reset)) + content.WriteString(fmt.Sprintf(" %s %s\n", + v.styleProvider.RenderWithColor("URL:", accentColor), v.styleProvider.RenderDimText(server.URL))) } return content.String() } func (v *A2AServersView) renderConnectionInfo() string { - dimColor := v.getDimColor() - var content strings.Builder content.WriteString("\n") - content.WriteString(fmt.Sprintf("%s━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━%s\n", dimColor, colors.Reset)) + content.WriteString(v.styleProvider.RenderDimText("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + "\n") - content.WriteString(fmt.Sprintf("A2A Connection Mode%s\n", colors.Reset)) + content.WriteString("A2A Connection Mode\n") content.WriteString("\n") - content.WriteString(fmt.Sprintf("%sPress ESC to return to chat%s", dimColor, colors.Reset)) + content.WriteString(v.styleProvider.RenderDimText("Press ESC to return to chat")) return content.String() } @@ -290,46 +251,3 @@ type A2AServersLoadedMsg struct { servers []A2AServerInfo error string } - -// Helper methods to get theme colors with fallbacks -func (v *A2AServersView) getHeaderColor() string { - if v.themeService != nil { - return v.themeService.GetCurrentTheme().GetAccentColor() - } - return colors.HeaderColor.ANSI -} - -func (v *A2AServersView) getSuccessColor() string { - if v.themeService != nil { - return v.themeService.GetCurrentTheme().GetStatusColor() - } - return colors.SuccessColor.ANSI -} - -func (v *A2AServersView) getErrorColor() string { - if v.themeService != nil { - return v.themeService.GetCurrentTheme().GetErrorColor() - } - return colors.ErrorColor.ANSI -} - -func (v *A2AServersView) getWarningColor() string { - if v.themeService != nil { - return v.themeService.GetCurrentTheme().GetErrorColor() - } - return colors.WarningColor.ANSI -} - -func (v *A2AServersView) getAccentColor() string { - if v.themeService != nil { - return v.themeService.GetCurrentTheme().GetAccentColor() - } - return colors.AccentColor.ANSI -} - -func (v *A2AServersView) getDimColor() string { - if v.themeService != nil { - return v.themeService.GetCurrentTheme().GetDimColor() - } - return colors.DimColor.ANSI -} diff --git a/internal/ui/components/application_view.go b/internal/ui/components/application_view.go index 73c6338b..54842824 100644 --- a/internal/ui/components/application_view.go +++ b/internal/ui/components/application_view.go @@ -4,21 +4,20 @@ import ( "fmt" "strings" - lipgloss "github.com/charmbracelet/lipgloss" domain "github.com/inference-gateway/cli/internal/domain" shared "github.com/inference-gateway/cli/internal/ui/shared" - colors "github.com/inference-gateway/cli/internal/ui/styles/colors" + styles "github.com/inference-gateway/cli/internal/ui/styles" ) // ApplicationViewRenderer handles rendering of different application views type ApplicationViewRenderer struct { - themeService domain.ThemeService + styleProvider *styles.Provider } // NewApplicationViewRenderer creates a new application view renderer -func NewApplicationViewRenderer(themeService domain.ThemeService) *ApplicationViewRenderer { +func NewApplicationViewRenderer(styleProvider *styles.Provider) *ApplicationViewRenderer { return &ApplicationViewRenderer{ - themeService: themeService, + styleProvider: styleProvider, } } @@ -79,29 +78,17 @@ func (r *ApplicationViewRenderer) RenderChatInterface( queueBoxView.SetWidth(width) } - headerStyle := lipgloss.NewStyle(). - Width(width). - Align(lipgloss.Center). - Foreground(lipgloss.Color(r.themeService.GetCurrentTheme().GetAccentColor())). - Bold(true). - Padding(0, 1) - headerText := "" if len(data.BackgroundTasks) > 0 { headerText = fmt.Sprintf("(%d)", len(data.BackgroundTasks)) } - header := headerStyle.Render(headerText) + accentColor := r.styleProvider.GetThemeColor("accent") + header := r.styleProvider.RenderCenteredBoldWithColor(headerText, accentColor, width) headerBorder := "" - conversationStyle := lipgloss.NewStyle(). - Height(conversationHeight) - - inputStyle := lipgloss.NewStyle(). - Width(width) - - conversationArea := conversationStyle.Render(conversationView.Render()) - separator := colors.CreateSeparator(width, "─") - inputArea := inputStyle.Render(inputView.Render()) + conversationArea := conversationView.Render() + separator := strings.Repeat("─", width) + inputArea := inputView.Render() components := []string{header, headerBorder, conversationArea, separator} @@ -124,16 +111,12 @@ func (r *ApplicationViewRenderer) RenderChatInterface( helpBar.SetWidth(width) helpBarContent := helpBar.Render() if helpBarContent != "" { - separator := colors.CreateSeparator(width, "─") + separator := strings.Repeat("─", width) components = append(components, separator) - - helpBarStyle := lipgloss.NewStyle(). - Width(width). - Padding(1, 1) - components = append(components, helpBarStyle.Render(helpBarContent)) + components = append(components, helpBarContent) } - return lipgloss.JoinVertical(lipgloss.Left, components...) + return strings.Join(components, "\n") } // FileSelectionData holds the data needed to render the file selection view diff --git a/internal/ui/components/approval_component.go b/internal/ui/components/approval_component.go new file mode 100644 index 00000000..4e417e5a --- /dev/null +++ b/internal/ui/components/approval_component.go @@ -0,0 +1,194 @@ +package components + +import ( + "encoding/json" + "fmt" + "strings" + + domain "github.com/inference-gateway/cli/internal/domain" + styles "github.com/inference-gateway/cli/internal/ui/styles" +) + +// ToolFormatterService interface for formatting tool arguments +type ToolFormatterService interface { + FormatToolArgumentsForApproval(toolName string, args map[string]any) string +} + +// ApprovalComponent renders the tool approval modal +type ApprovalComponent struct { + width int + height int + toolFormatter ToolFormatterService + styleProvider *styles.Provider +} + +// NewApprovalComponent creates a new approval component +func NewApprovalComponent(styleProvider *styles.Provider) *ApprovalComponent { + return &ApprovalComponent{ + styleProvider: styleProvider, + } +} + +// SetDimensions updates the component dimensions +func (c *ApprovalComponent) SetDimensions(width, height int) { + c.width = width + c.height = height +} + +// SetToolFormatter sets the tool formatter service +func (c *ApprovalComponent) SetToolFormatter(formatter ToolFormatterService) { + c.toolFormatter = formatter +} + +// Render renders the approval modal +func (c *ApprovalComponent) Render(approvalState *domain.ApprovalUIState, theme domain.Theme) string { + if approvalState == nil || approvalState.PendingToolCall == nil { + return "" + } + + toolCall := approvalState.PendingToolCall + + modalWidth := min(c.width-4, 80) + + title := c.styleProvider.RenderStyledText("🔒 Tool Approval Required", styles.StyleOptions{ + Foreground: c.styleProvider.GetThemeColor("accent"), + Bold: true, + MarginBottom: 1, + }) + + toolNameStyled := c.styleProvider.RenderWithColorAndBold(toolCall.Function.Name, c.styleProvider.GetThemeColor("accent")) + toolName := fmt.Sprintf("Tool: %s", toolNameStyled) + + args := c.formatArguments(toolCall.Function.Name, toolCall.Function.Arguments, modalWidth-4) + + options := c.renderOptions(approvalState.SelectedIndex) + + helpText := c.styleProvider.RenderStyledText( + "←/→: Navigate • Enter/y: Approve • n/Esc: Reject", + styles.StyleOptions{ + Foreground: c.styleProvider.GetThemeColor("dim"), + Italic: true, + MarginTop: 1, + }, + ) + + content := c.styleProvider.JoinVertical( + title, + toolName, + args, + "", + c.styleProvider.PlaceCenterTop(modalWidth, c.styleProvider.GetHeight(options), options), + helpText, + ) + + modal := c.styleProvider.RenderModal(content, modalWidth) + return c.styleProvider.PlaceCenter(c.width, c.height, modal) +} + +// formatArguments formats tool arguments for display +func (c *ApprovalComponent) formatArguments(toolName, argsJSON string, maxWidth int) string { + if argsJSON == "" { + return "" + } + + var args map[string]any + if err := json.Unmarshal([]byte(argsJSON), &args); err != nil { + return "" + } + + if len(args) == 0 { + return "" + } + + if c.toolFormatter != nil { + formatted := c.toolFormatter.FormatToolArgumentsForApproval(toolName, args) + if formatted != "" { + return "\n" + formatted + } + } + + var argLines []string + argLines = append(argLines, "\nArguments:") + + keys := make([]string, 0, len(args)) + for key := range args { + keys = append(keys, key) + } + + for i := 0; i < len(keys); i++ { + for j := i + 1; j < len(keys); j++ { + if keys[i] > keys[j] { + keys[i], keys[j] = keys[j], keys[i] + } + } + } + + for _, key := range keys { + valueStr := c.formatValue(args[key], maxWidth-len(key)-4) + line := fmt.Sprintf(" %s: %s", key, valueStr) + argLines = append(argLines, line) + } + + argsText := strings.Join(argLines, "\n") + return c.styleProvider.RenderStyledText(argsText, styles.StyleOptions{ + Foreground: c.styleProvider.GetThemeColor("dim"), + Width: maxWidth, + }) +} + +// formatValue formats a single argument value, truncating if necessary +func (c *ApprovalComponent) formatValue(value any, maxLen int) string { + var str string + + switch v := value.(type) { + case string: + str = v + case map[string]any, []any: + jsonBytes, _ := json.Marshal(v) + str = string(jsonBytes) + default: + str = fmt.Sprintf("%v", v) + } + + if len(str) > maxLen { + if maxLen > 3 { + str = str[:maxLen-3] + "..." + } else { + str = str[:maxLen] + } + } + + str = strings.ReplaceAll(str, "\n", " ") + str = strings.ReplaceAll(str, "\r", "") + + return str +} + +// renderOptions renders the Approve/Reject options +func (c *ApprovalComponent) renderOptions(selectedIndex int) string { + isApproveSelected := selectedIndex == int(domain.ApprovalApprove) + isRejectSelected := selectedIndex == int(domain.ApprovalReject) + + approveText := " Approve" + if isApproveSelected { + approveText = "✓ Approve" + } + + rejectText := " Reject" + if isRejectSelected { + rejectText = "✗ Reject" + } + + approveButton := c.styleProvider.RenderApprovalButton(approveText, isApproveSelected, true) + rejectButton := c.styleProvider.RenderApprovalButton(rejectText, isRejectSelected, false) + + return c.styleProvider.JoinHorizontal(approveButton, " ", rejectButton) +} + +// min returns the minimum of two integers +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/internal/ui/components/conversation_selection_view.go b/internal/ui/components/conversation_selection_view.go index 89e244a7..0018f878 100644 --- a/internal/ui/components/conversation_selection_view.go +++ b/internal/ui/components/conversation_selection_view.go @@ -11,7 +11,7 @@ import ( domain "github.com/inference-gateway/cli/internal/domain" logger "github.com/inference-gateway/cli/internal/logger" shortcuts "github.com/inference-gateway/cli/internal/shortcuts" - colors "github.com/inference-gateway/cli/internal/ui/styles/colors" + styles "github.com/inference-gateway/cli/internal/ui/styles" ) // ConversationSelectorImpl implements conversation selection UI @@ -21,7 +21,7 @@ type ConversationSelectorImpl struct { selected int width int height int - themeService domain.ThemeService + styleProvider *styles.Provider done bool cancelled bool repo shortcuts.PersistentConversationRepository @@ -34,14 +34,14 @@ type ConversationSelectorImpl struct { } // NewConversationSelector creates a new conversation selector -func NewConversationSelector(repo shortcuts.PersistentConversationRepository, themeService domain.ThemeService) *ConversationSelectorImpl { +func NewConversationSelector(repo shortcuts.PersistentConversationRepository, styleProvider *styles.Provider) *ConversationSelectorImpl { c := &ConversationSelectorImpl{ conversations: make([]shortcuts.ConversationSummary, 0), filteredConversations: make([]shortcuts.ConversationSummary, 0), selected: 0, width: 80, height: 24, - themeService: themeService, + styleProvider: styleProvider, repo: repo, searchQuery: "", searchMode: false, @@ -377,43 +377,42 @@ func (c *ConversationSelectorImpl) Reset() { // writeHeader writes the header section of the view func (c *ConversationSelectorImpl) writeHeader(b *strings.Builder) { - fmt.Fprintf(b, "%sSelect a Conversation%s\n\n", - c.themeService.GetCurrentTheme().GetAccentColor(), colors.Reset) + fmt.Fprintf(b, "%s\n\n", c.styleProvider.RenderWithColor("Select a Conversation", c.styleProvider.GetThemeColor("accent"))) } // writeLoadingView writes the loading view and returns the complete string func (c *ConversationSelectorImpl) writeLoadingView(b *strings.Builder) string { - fmt.Fprintf(b, "%sLoading conversations...%s\n", - c.themeService.GetCurrentTheme().GetStatusColor(), colors.Reset) + fmt.Fprintf(b, "%s\n", c.styleProvider.RenderWithColor("Loading conversations...", c.styleProvider.GetThemeColor("status"))) return b.String() } // writeErrorView writes the error view and returns the complete string func (c *ConversationSelectorImpl) writeErrorView(b *strings.Builder) string { - fmt.Fprintf(b, "%sError loading conversations: %v%s\n", - c.themeService.GetCurrentTheme().GetErrorColor(), c.loadError, colors.Reset) + errorMsg := fmt.Sprintf("Error loading conversations: %v", c.loadError) + fmt.Fprintf(b, "%s\n", c.styleProvider.RenderWithColor(errorMsg, c.styleProvider.GetThemeColor("error"))) return b.String() } // writeSearchInfo writes the search information section func (c *ConversationSelectorImpl) writeSearchInfo(b *strings.Builder) { if c.searchMode { - fmt.Fprintf(b, "%sSearch: %s%s│%s\n\n", - c.themeService.GetCurrentTheme().GetStatusColor(), c.searchQuery, c.themeService.GetCurrentTheme().GetAccentColor(), colors.Reset) + fmt.Fprintf(b, "%s%s\n\n", + c.styleProvider.RenderWithColor("Search: "+c.searchQuery, c.styleProvider.GetThemeColor("status")), + c.styleProvider.RenderWithColor("│", c.styleProvider.GetThemeColor("accent"))) } else { - fmt.Fprintf(b, "%sPress / to search • %d conversations available%s\n\n", - c.themeService.GetCurrentTheme().GetDimColor(), len(c.conversations), colors.Reset) + helpText := fmt.Sprintf("Press / to search • %d conversations available", len(c.conversations)) + fmt.Fprintf(b, "%s\n\n", c.styleProvider.RenderDimText(helpText)) } } // writeEmptyView writes the empty view and returns the complete string func (c *ConversationSelectorImpl) writeEmptyView(b *strings.Builder) string { if c.searchQuery != "" { - fmt.Fprintf(b, "%sNo conversations match '%s'%s\n", - c.themeService.GetCurrentTheme().GetErrorColor(), c.searchQuery, colors.Reset) + msg := fmt.Sprintf("No conversations match '%s'", c.searchQuery) + fmt.Fprintf(b, "%s\n", c.styleProvider.RenderWithColor(msg, c.styleProvider.GetThemeColor("error"))) } else if len(c.conversations) == 0 { - fmt.Fprintf(b, "%sNo saved conversations found. Start chatting to create your first conversation!%s\n", - c.themeService.GetCurrentTheme().GetErrorColor(), colors.Reset) + msg := "No saved conversations found. Start chatting to create your first conversation!" + fmt.Fprintf(b, "%s\n", c.styleProvider.RenderWithColor(msg, c.styleProvider.GetThemeColor("error"))) } return b.String() } @@ -430,18 +429,20 @@ func (c *ConversationSelectorImpl) writeConversationList(b *strings.Builder) { } if len(c.filteredConversations) > pagination.maxVisible { - fmt.Fprintf(b, "%sShowing %d-%d of %d conversations%s\n", - c.themeService.GetCurrentTheme().GetDimColor(), pagination.start+1, pagination.start+pagination.maxVisible, - len(c.filteredConversations), colors.Reset) + paginationText := fmt.Sprintf("Showing %d-%d of %d conversations", + pagination.start+1, pagination.start+pagination.maxVisible, len(c.filteredConversations)) + fmt.Fprintf(b, "%s\n", c.styleProvider.RenderDimText(paginationText)) } } // writeTableHeader writes the table header func (c *ConversationSelectorImpl) writeTableHeader(b *strings.Builder) { - fmt.Fprintf(b, "%s%-38s │ %-25s │ %-20s │ %-10s │ %-12s%s\n", - c.themeService.GetCurrentTheme().GetDimColor(), "ID", "Summary", "Updated", "Messages", "Input Tokens", colors.Reset) - fmt.Fprintf(b, "%s%s%s\n", - c.themeService.GetCurrentTheme().GetDimColor(), strings.Repeat("─", c.width-4), colors.Reset) + headerLine := fmt.Sprintf("%-38s │ %-25s │ %-20s │ %-10s │ %-12s", + "ID", "Summary", "Updated", "Messages", "Input Tokens") + fmt.Fprintf(b, "%s\n", c.styleProvider.RenderDimText(headerLine)) + + separator := strings.Repeat("─", c.width-4) + fmt.Fprintf(b, "%s\n", c.styleProvider.RenderDimText(separator)) } // paginationInfo holds pagination calculation results @@ -483,8 +484,10 @@ func (c *ConversationSelectorImpl) writeConversationRow(b *strings.Builder, conv inputTokens := fmt.Sprintf("%d", conv.TokenStats.TotalInputTokens) if index == c.selected { - fmt.Fprintf(b, "%s▶ %-36s │ %-25s │ %-20s │ %-10s │ %-12s%s\n", - c.themeService.GetCurrentTheme().GetAccentColor(), fullID, summary, updatedAt, msgCount, inputTokens, colors.Reset) + accentColor := c.styleProvider.GetThemeColor("accent") + rowText := fmt.Sprintf("▶ %-36s │ %-25s │ %-20s │ %-10s │ %-12s", + fullID, summary, updatedAt, msgCount, inputTokens) + fmt.Fprintf(b, "%s\n", c.styleProvider.RenderWithColor(rowText, accentColor)) } else { fmt.Fprintf(b, " %-36s │ %-25s │ %-20s │ %-10s │ %-12s\n", fullID, summary, updatedAt, msgCount, inputTokens) @@ -544,15 +547,15 @@ func (c *ConversationSelectorImpl) formatDateTimeParts(updatedAt string) string // writeFooter writes the footer section func (c *ConversationSelectorImpl) writeFooter(b *strings.Builder) { b.WriteString("\n") - b.WriteString(colors.CreateSeparator(c.width, "─")) + b.WriteString(strings.Repeat("─", c.width)) b.WriteString("\n") if c.searchMode { - fmt.Fprintf(b, "%sType to search, ↑↓ to navigate, Enter to select, Esc to clear search%s", - c.themeService.GetCurrentTheme().GetDimColor(), colors.Reset) + helpText := "Type to search, ↑↓ to navigate, Enter to select, Esc to clear search" + fmt.Fprintf(b, "%s", c.styleProvider.RenderDimText(helpText)) } else { - fmt.Fprintf(b, "%sUse ↑↓ arrows to navigate, Enter to select, d to delete, / to search, Esc/Ctrl+C to cancel%s", - c.themeService.GetCurrentTheme().GetDimColor(), colors.Reset) + helpText := "Use ↑↓ arrows to navigate, Enter to select, d to delete, / to search, Esc/Ctrl+C to cancel" + fmt.Fprintf(b, "%s", c.styleProvider.RenderDimText(helpText)) } } @@ -568,24 +571,24 @@ func (c *ConversationSelectorImpl) writeDeleteConfirmation(b *strings.Builder) s c.writeConversationList(b) b.WriteString("\n") - b.WriteString(colors.CreateSeparator(c.width, "─")) + b.WriteString(strings.Repeat("─", c.width)) b.WriteString("\n\n") - fmt.Fprintf(b, "%s⚠ Delete Confirmation%s\n\n", - c.themeService.GetCurrentTheme().GetErrorColor(), colors.Reset) + errorColor := c.styleProvider.GetThemeColor("error") + accentColor := c.styleProvider.GetThemeColor("accent") + fmt.Fprintf(b, "%s\n\n", c.styleProvider.RenderWithColor("⚠ Delete Confirmation", errorColor)) fmt.Fprintf(b, "Are you sure you want to delete this conversation?\n\n") - fmt.Fprintf(b, "%sID: %s%s\n", c.themeService.GetCurrentTheme().GetDimColor(), conv.ID, colors.Reset) - fmt.Fprintf(b, "%sTitle: %s%s\n\n", c.themeService.GetCurrentTheme().GetDimColor(), conv.Title, colors.Reset) - - fmt.Fprintf(b, "%sPress Y to confirm, N or Esc to cancel%s", - c.themeService.GetCurrentTheme().GetAccentColor(), colors.Reset) + fmt.Fprintf(b, "%s\n", c.styleProvider.RenderDimText("ID: "+conv.ID)) + fmt.Fprintf(b, "%s\n\n", c.styleProvider.RenderDimText("Title: "+conv.Title)) + fmt.Fprintf(b, "%s", c.styleProvider.RenderWithColor("Press Y to confirm, N or Esc to cancel", accentColor)) return b.String() } // writeDeleteError writes the delete error message func (c *ConversationSelectorImpl) writeDeleteError(b *strings.Builder) { - fmt.Fprintf(b, "%sError deleting conversation: %v%s\n\n", - c.themeService.GetCurrentTheme().GetErrorColor(), c.deleteError, colors.Reset) + errorColor := c.styleProvider.GetThemeColor("error") + errorMsg := fmt.Sprintf("Error deleting conversation: %v", c.deleteError) + fmt.Fprintf(b, "%s\n\n", c.styleProvider.RenderWithColor(errorMsg, errorColor)) } diff --git a/internal/ui/components/conversation_selection_view_test.go b/internal/ui/components/conversation_selection_view_test.go index 366a02df..8ae0e9ce 100644 --- a/internal/ui/components/conversation_selection_view_test.go +++ b/internal/ui/components/conversation_selection_view_test.go @@ -6,13 +6,15 @@ import ( domain "github.com/inference-gateway/cli/internal/domain" shortcuts "github.com/inference-gateway/cli/internal/shortcuts" + styles "github.com/inference-gateway/cli/internal/ui/styles" ) func TestConversationSelectorImpl_Reset(t *testing.T) { mockRepo := &mockPersistentConversationRepository{} themeService := &mockThemeService{} + styleProvider := styles.NewProvider(themeService) - selector := NewConversationSelector(mockRepo, themeService) + selector := NewConversationSelector(mockRepo, styleProvider) selector.done = true selector.cancelled = true @@ -60,8 +62,9 @@ func TestConversationSelectorImpl_Reset(t *testing.T) { func TestConversationSelectorImpl_ResetAllowsReuse(t *testing.T) { mockRepo := &mockPersistentConversationRepository{} themeService := &mockThemeService{} + styleProvider := styles.NewProvider(themeService) - selector := NewConversationSelector(mockRepo, themeService) + selector := NewConversationSelector(mockRepo, styleProvider) selector.cancelled = true selector.done = true @@ -131,6 +134,7 @@ type mockTheme struct{} func (t *mockTheme) GetUserColor() string { return "" } func (t *mockTheme) GetAssistantColor() string { return "" } func (t *mockTheme) GetErrorColor() string { return "" } +func (t *mockTheme) GetSuccessColor() string { return "" } func (t *mockTheme) GetStatusColor() string { return "" } func (t *mockTheme) GetAccentColor() string { return "" } func (t *mockTheme) GetDimColor() string { return "" } diff --git a/internal/ui/components/conversation_view.go b/internal/ui/components/conversation_view.go index ec0f9d3e..d7b39db4 100644 --- a/internal/ui/components/conversation_view.go +++ b/internal/ui/components/conversation_view.go @@ -10,11 +10,9 @@ import ( viewport "github.com/charmbracelet/bubbles/viewport" tea "github.com/charmbracelet/bubbletea" - lipgloss "github.com/charmbracelet/lipgloss" domain "github.com/inference-gateway/cli/internal/domain" shared "github.com/inference-gateway/cli/internal/ui/shared" styles "github.com/inference-gateway/cli/internal/ui/styles" - colors "github.com/inference-gateway/cli/internal/ui/styles/colors" sdk "github.com/inference-gateway/sdk" ) @@ -30,12 +28,12 @@ type ConversationView struct { lineFormatter *shared.ConversationLineFormatter plainTextLines []string configPath string - themeService domain.ThemeService + styleProvider *styles.Provider isStreaming bool toolCallRenderer *ToolCallRenderer } -func NewConversationView(themeService domain.ThemeService) *ConversationView { +func NewConversationView(styleProvider *styles.Provider) *ConversationView { vp := viewport.New(80, 20) vp.SetContent("") return &ConversationView{ @@ -47,7 +45,7 @@ func NewConversationView(themeService domain.ThemeService) *ConversationView { allToolsExpanded: false, lineFormatter: shared.NewConversationLineFormatter(80, nil), plainTextLines: []string{}, - themeService: themeService, + styleProvider: styleProvider, } } @@ -189,25 +187,22 @@ func (cv *ConversationView) renderWelcome() string { wd = "unknown" } - statusColor := cv.getStatusColor() - successColor := cv.getSuccessColor() - dimColor := cv.getDimColor() + statusColor := cv.styleProvider.GetThemeColor("status") + successColor := cv.styleProvider.GetThemeColor("success") + dimColor := cv.styleProvider.GetThemeColor("dim") headerColor := cv.getHeaderColor() - headerLine := statusColor + "✨ Inference Gateway CLI" + colors.Reset - readyLine := successColor + "🚀 Ready to chat!" + colors.Reset - workingLine := dimColor + "📂 Working in: " + colors.Reset + headerColor + wd + colors.Reset + headerLine := cv.styleProvider.RenderWithColor("✨ Inference Gateway CLI", statusColor) + readyLine := cv.styleProvider.RenderWithColor("🚀 Ready to chat!", successColor) + workingLinePrefix := cv.styleProvider.RenderWithColor("📂 Working in: ", dimColor) + workingLinePath := cv.styleProvider.RenderWithColor(wd, headerColor) + workingLine := workingLinePrefix + workingLinePath configLine := cv.buildConfigLine() content := headerLine + "\n\n" + readyLine + "\n\n" + workingLine + "\n\n" + configLine - style := styles.NewCommonStyles().Border. - Border(styles.RoundedBorder(), true). - BorderForeground(lipgloss.Color(cv.getAccentColorLipgloss())). - Padding(1, 1) - - return style.Render(content) + return cv.styleProvider.RenderBorderedBox(content, cv.styleProvider.GetThemeColor("accent"), 1, 1) } func (cv *ConversationView) renderEntryWithIndex(entry domain.ConversationEntry, index int) string { @@ -229,22 +224,22 @@ func (cv *ConversationView) renderEntryWithIndex(entry domain.ConversationEntry, return cv.renderAssistantWithToolCalls(entry, index, color, role) } case "system": - color = cv.getDimColor() + color = cv.styleProvider.GetThemeColor("dim") role = "⚙️ System" case "tool": if entry.ToolExecution != nil && !entry.ToolExecution.Success { - color = cv.getErrorColor() + color = cv.styleProvider.GetThemeColor("error") role = "🔧 Tool" } else if entry.ToolExecution != nil && entry.ToolExecution.Success { - color = cv.getSuccessColor() + color = cv.styleProvider.GetThemeColor("success") role = "🔧 Tool" } else { - color = cv.getAccentColor() + color = cv.styleProvider.GetThemeColor("accent") role = "🔧 Tool" } return cv.renderToolEntry(entry, index, color, role) default: - color = cv.getDimColor() + color = cv.styleProvider.GetThemeColor("dim") role = string(entry.Message.Role) } @@ -254,7 +249,9 @@ func (cv *ConversationView) renderEntryWithIndex(entry domain.ConversationEntry, content := entry.Message.Content wrappedContent := shared.FormatResponsiveMessage(content, cv.width) - message := fmt.Sprintf("%s%s:%s %s", color, role, colors.Reset, wrappedContent) + + roleStyled := cv.styleProvider.RenderWithColor(role+":", color) + message := roleStyled + " " + wrappedContent return message + "\n" } @@ -262,15 +259,17 @@ func (cv *ConversationView) renderEntryWithIndex(entry domain.ConversationEntry, func (cv *ConversationView) renderAssistantWithToolCalls(entry domain.ConversationEntry, _ int, color, role string) string { var result strings.Builder + roleStyled := cv.styleProvider.RenderWithColor(role+":", color) + if entry.Message.Content != "" { wrappedContent := shared.FormatResponsiveMessage(entry.Message.Content, cv.width) - result.WriteString(fmt.Sprintf("%s%s:%s %s\n", color, role, colors.Reset, wrappedContent)) + result.WriteString(roleStyled + " " + wrappedContent + "\n") } else { - result.WriteString(fmt.Sprintf("%s%s:%s\n", color, role, colors.Reset)) + result.WriteString(roleStyled + "\n") } if entry.Message.ToolCalls != nil && len(*entry.Message.ToolCalls) > 0 { // nolint:nestif - toolCallsColor := cv.getAccentColor() + toolCallsColor := cv.styleProvider.GetThemeColor("accent") for _, toolCall := range *entry.Message.ToolCalls { toolName := toolCall.Function.Name @@ -283,11 +282,11 @@ func (cv *ConversationView) renderAssistantWithToolCalls(entry domain.Conversati } else { argsDisplay = toolArgs } - result.WriteString(fmt.Sprintf(" • %s%s%s: %s\n", - toolCallsColor, toolName, colors.Reset, argsDisplay)) + toolNameStyled := cv.styleProvider.RenderWithColor(toolName, toolCallsColor) + result.WriteString(fmt.Sprintf(" • %s: %s\n", toolNameStyled, argsDisplay)) } else { - result.WriteString(fmt.Sprintf(" • %s%s%s\n", - toolCallsColor, toolName, colors.Reset)) + toolNameStyled := cv.styleProvider.RenderWithColor(toolName, toolCallsColor) + result.WriteString(fmt.Sprintf(" • %s\n", toolNameStyled)) } } } @@ -308,7 +307,9 @@ func (cv *ConversationView) renderToolEntry(entry domain.ConversationEntry, inde } content := cv.formatEntryContent(entry, isExpanded) - message := fmt.Sprintf("%s%s:%s %s", color, role, colors.Reset, content) + + roleStyled := cv.styleProvider.RenderWithColor(role+":", color) + message := roleStyled + " " + content return message + "\n" } @@ -419,10 +420,14 @@ func (cv *ConversationView) buildConfigLine() string { configType := cv.getConfigType() displayPath := cv.shortenPath(cv.configPath) - dimColor := cv.getDimColor() - accentColor := cv.getAccentColor() + dimColor := cv.styleProvider.GetThemeColor("dim") + accentColor := cv.styleProvider.GetThemeColor("accent") - return dimColor + "⚙ Config: " + colors.Reset + accentColor + displayPath + colors.Reset + dimColor + " (" + configType + ")" + colors.Reset + configPrefix := cv.styleProvider.RenderWithColor("⚙ Config: ", dimColor) + pathStyled := cv.styleProvider.RenderWithColor(displayPath, accentColor) + configTypeStyled := cv.styleProvider.RenderWithColor(" ("+configType+")", dimColor) + + return configPrefix + pathStyled + configTypeStyled } // getConfigType determines if the config is project-level or userspace @@ -531,63 +536,15 @@ func (cv *ConversationView) handleScrollRequest(msg domain.ScrollRequestEvent) ( // Helper methods to get theme colors with fallbacks func (cv *ConversationView) getUserColor() string { - if cv.themeService != nil { - return cv.themeService.GetCurrentTheme().GetUserColor() - } - return colors.UserColor.ANSI + return cv.styleProvider.GetThemeColor("user") } func (cv *ConversationView) getAssistantColor() string { - if cv.themeService != nil { - return cv.themeService.GetCurrentTheme().GetAssistantColor() - } - return colors.AssistantColor.ANSI -} - -func (cv *ConversationView) getErrorColor() string { - if cv.themeService != nil { - return cv.themeService.GetCurrentTheme().GetErrorColor() - } - return colors.ErrorColor.ANSI -} - -func (cv *ConversationView) getStatusColor() string { - if cv.themeService != nil { - return cv.themeService.GetCurrentTheme().GetStatusColor() - } - return colors.StatusColor.ANSI -} - -func (cv *ConversationView) getSuccessColor() string { - return colors.SuccessColor.ANSI -} - -func (cv *ConversationView) getAccentColor() string { - if cv.themeService != nil { - return cv.themeService.GetCurrentTheme().GetAccentColor() - } - return colors.AccentColor.ANSI -} - -func (cv *ConversationView) getDimColor() string { - if cv.themeService != nil { - return cv.themeService.GetCurrentTheme().GetDimColor() - } - return colors.DimColor.ANSI + return cv.styleProvider.GetThemeColor("assistant") } func (cv *ConversationView) getHeaderColor() string { - if cv.themeService != nil { - return cv.themeService.GetCurrentTheme().GetAccentColor() - } - return colors.HeaderColor.ANSI -} - -func (cv *ConversationView) getAccentColorLipgloss() string { - if cv.themeService != nil { - return cv.themeService.GetCurrentTheme().GetAccentColor() - } - return colors.AccentColor.Lipgloss + return cv.styleProvider.GetThemeColor("accent") } // appendStreamingContent appends streaming content to the last assistant message diff --git a/internal/ui/components/conversation_view_test.go b/internal/ui/components/conversation_view_test.go index c91d7c18..10b3b9f6 100644 --- a/internal/ui/components/conversation_view_test.go +++ b/internal/ui/components/conversation_view_test.go @@ -5,11 +5,17 @@ import ( "time" domain "github.com/inference-gateway/cli/internal/domain" + styles "github.com/inference-gateway/cli/internal/ui/styles" sdk "github.com/inference-gateway/sdk" ) +// createMockStyleProvider creates a mock styles provider for testing +func createMockStyleProvider() *styles.Provider { + return styles.NewProvider(&mockThemeService{}) +} + func TestNewConversationView(t *testing.T) { - cv := NewConversationView(nil) + cv := NewConversationView(createMockStyleProvider()) if cv.width != 80 { t.Errorf("Expected default width 80, got %d", cv.width) @@ -33,7 +39,7 @@ func TestNewConversationView(t *testing.T) { } func TestConversationView_SetConversation(t *testing.T) { - cv := NewConversationView(nil) + cv := NewConversationView(createMockStyleProvider()) conversation := []domain.ConversationEntry{ { @@ -68,7 +74,7 @@ func TestConversationView_SetConversation(t *testing.T) { } func TestConversationView_GetScrollOffset(t *testing.T) { - cv := NewConversationView(nil) + cv := NewConversationView(createMockStyleProvider()) offset := cv.GetScrollOffset() @@ -78,7 +84,7 @@ func TestConversationView_GetScrollOffset(t *testing.T) { } func TestConversationView_CanScrollUp(t *testing.T) { - cv := NewConversationView(nil) + cv := NewConversationView(createMockStyleProvider()) if cv.CanScrollUp() { t.Error("Expected CanScrollUp to be false when at top") @@ -86,7 +92,7 @@ func TestConversationView_CanScrollUp(t *testing.T) { } func TestConversationView_CanScrollDown(t *testing.T) { - cv := NewConversationView(nil) + cv := NewConversationView(createMockStyleProvider()) if cv.CanScrollDown() { t.Error("Expected CanScrollDown to be false with no content") @@ -94,7 +100,7 @@ func TestConversationView_CanScrollDown(t *testing.T) { } func TestConversationView_ToggleToolResultExpansion(t *testing.T) { - cv := NewConversationView(nil) + cv := NewConversationView(createMockStyleProvider()) conversation := []domain.ConversationEntry{ { @@ -121,7 +127,7 @@ func TestConversationView_ToggleToolResultExpansion(t *testing.T) { } func TestConversationView_ToggleAllToolResultsExpansion(t *testing.T) { - cv := NewConversationView(nil) + cv := NewConversationView(createMockStyleProvider()) conversation := []domain.ConversationEntry{ { @@ -170,7 +176,7 @@ func TestConversationView_ToggleAllToolResultsExpansion(t *testing.T) { } func TestConversationView_IsToolResultExpanded(t *testing.T) { - cv := NewConversationView(nil) + cv := NewConversationView(createMockStyleProvider()) if cv.IsToolResultExpanded(0) { t.Error("Expected tool result 0 to not be expanded initially") @@ -182,7 +188,7 @@ func TestConversationView_IsToolResultExpanded(t *testing.T) { } func TestConversationView_SetWidth(t *testing.T) { - cv := NewConversationView(nil) + cv := NewConversationView(createMockStyleProvider()) cv.SetWidth(120) @@ -196,7 +202,7 @@ func TestConversationView_SetWidth(t *testing.T) { } func TestConversationView_SetHeight(t *testing.T) { - cv := NewConversationView(nil) + cv := NewConversationView(createMockStyleProvider()) cv.SetHeight(30) @@ -210,7 +216,7 @@ func TestConversationView_SetHeight(t *testing.T) { } func TestConversationView_Render(t *testing.T) { - cv := NewConversationView(nil) + cv := NewConversationView(createMockStyleProvider()) output := cv.Render() diff --git a/internal/ui/components/diff_renderer.go b/internal/ui/components/diff_renderer.go index 075784ec..f8e20ab2 100644 --- a/internal/ui/components/diff_renderer.go +++ b/internal/ui/components/diff_renderer.go @@ -2,36 +2,21 @@ package components import ( "fmt" + "os" "strings" - lipgloss "github.com/charmbracelet/lipgloss" - domain "github.com/inference-gateway/cli/internal/domain" - colors "github.com/inference-gateway/cli/internal/ui/styles/colors" + styles "github.com/inference-gateway/cli/internal/ui/styles" ) // DiffRenderer provides high-performance diff rendering with colors type DiffRenderer struct { - themeService domain.ThemeService - additionStyle lipgloss.Style - deletionStyle lipgloss.Style - headerStyle lipgloss.Style - fileStyle lipgloss.Style - contextStyle lipgloss.Style - lineNumStyle lipgloss.Style - chunkStyle lipgloss.Style + styleProvider *styles.Provider } // NewDiffRenderer creates a new diff renderer with colored output -func NewDiffRenderer(themeService domain.ThemeService) *DiffRenderer { +func NewDiffRenderer(styleProvider *styles.Provider) *DiffRenderer { return &DiffRenderer{ - themeService: themeService, - additionStyle: lipgloss.NewStyle().Foreground(colors.DiffAddColor.GetLipglossColor()), - deletionStyle: lipgloss.NewStyle().Foreground(colors.DiffRemoveColor.GetLipglossColor()), - headerStyle: lipgloss.NewStyle().Foreground(colors.HeaderColor.GetLipglossColor()), - fileStyle: lipgloss.NewStyle().Foreground(colors.AccentColor.GetLipglossColor()).Bold(true), - contextStyle: lipgloss.NewStyle().Foreground(colors.DimColor.GetLipglossColor()), - lineNumStyle: lipgloss.NewStyle().Foreground(colors.DimColor.GetLipglossColor()), - chunkStyle: lipgloss.NewStyle().Foreground(colors.StatusColor.GetLipglossColor()).Bold(true), + styleProvider: styleProvider, } } @@ -44,24 +29,143 @@ func (d *DiffRenderer) RenderEditToolArguments(args map[string]any) string { var result strings.Builder - result.WriteString(d.fileStyle.Render(filePath)) + result.WriteString(d.styleProvider.RenderWithColorAndBold(filePath, d.styleProvider.GetThemeColor("accent"))) result.WriteString("\n") if replaceAll { - result.WriteString(d.contextStyle.Render("Mode: Replace all occurrences")) + result.WriteString(d.styleProvider.RenderDimText("Mode: Replace all occurrences")) result.WriteString("\n") } result.WriteString("\n") - result.WriteString(d.headerStyle.Render(fmt.Sprintf("--- a/%s", filePath))) + result.WriteString(d.styleProvider.RenderWithColor(fmt.Sprintf("--- a/%s", filePath), d.styleProvider.GetThemeColor("accent"))) result.WriteString("\n") - result.WriteString(d.headerStyle.Render(fmt.Sprintf("+++ b/%s", filePath))) + result.WriteString(d.styleProvider.RenderWithColor(fmt.Sprintf("+++ b/%s", filePath), d.styleProvider.GetThemeColor("accent"))) result.WriteString("\n") - result.WriteString(d.renderUnifiedDiff(oldString, newString, 1)) + cleanedOldString := d.cleanString(oldString) + cleanedNewString := d.cleanString(newString) + + startLine := d.findLineNumber(filePath, oldString) + result.WriteString(d.renderUnifiedDiff(cleanedOldString, cleanedNewString, startLine)) return result.String() } +// findLineNumber finds the line number where the old string starts in the file +func (d *DiffRenderer) findLineNumber(filePath, oldString string) int { + content, err := os.ReadFile(filePath) + if err != nil { + return 1 + } + + fileContent := string(content) + cleanedOldString := d.cleanString(oldString) + + index := strings.Index(fileContent, cleanedOldString) + if index != -1 { + lineNum := 1 + for i := 0; i < index; i++ { + if fileContent[i] == '\n' { + lineNum++ + } + } + return lineNum + } + + fileLines := strings.Split(fileContent, "\n") + oldLines := strings.Split(cleanedOldString, "\n") + + if len(oldLines) == 0 { + return 1 + } + + firstOldLine := strings.TrimSpace(oldLines[0]) + if firstOldLine == "" && len(oldLines) > 1 { + firstOldLine = strings.TrimSpace(oldLines[1]) + } + + for i, fileLine := range fileLines { + if strings.TrimSpace(fileLine) == firstOldLine { + if len(oldLines) == 1 { + return i + 1 + } + + match := true + for j := 1; j < len(oldLines) && (i+j) < len(fileLines); j++ { + if strings.TrimSpace(oldLines[j]) != strings.TrimSpace(fileLines[i+j]) { + match = false + break + } + } + + if match { + return i + 1 + } + } + } + + return 1 +} + +// cleanString removes line number prefixes from Read tool output +func (d *DiffRenderer) cleanString(s string) string { + lines := strings.Split(s, "\n") + var cleanedLines []string + + for _, line := range lines { + if d.isLineNumberPrefix(line) { + if cleanedLine, shouldSkip := d.extractContentAfterLineNumber(line); shouldSkip { + cleanedLines = append(cleanedLines, cleanedLine) + continue + } + } + cleanedLines = append(cleanedLines, line) + } + + return strings.Join(cleanedLines, "\n") +} + +// isLineNumberPrefix checks if a line starts with a line number prefix pattern +func (d *DiffRenderer) isLineNumberPrefix(line string) bool { + return len(line) > 0 && (line[0] == ' ' || (line[0] >= '0' && line[0] <= '9')) +} + +// extractContentAfterLineNumber extracts content after line number prefix if present +func (d *DiffRenderer) extractContentAfterLineNumber(line string) (string, bool) { + tabIndex := strings.Index(line, "\t") + if tabIndex > 0 { + prefix := line[:tabIndex] + if d.isValidLineNumberPrefix(prefix) { + return line[tabIndex+1:], true + } + } + + arrowIndex := strings.Index(line, "→") + if arrowIndex > 0 { + prefix := line[:arrowIndex] + if d.isValidLineNumberPrefix(prefix) { + return line[arrowIndex+len("→"):], true + } + } + + return "", false +} + +// isValidLineNumberPrefix validates if a prefix contains only spaces and digits +func (d *DiffRenderer) isValidLineNumberPrefix(prefix string) bool { + hasDigit := false + + for _, r := range prefix { + if r >= '0' && r <= '9' { + hasDigit = true + } else if r != ' ' && r != '→' { + return false + } + } + + return hasDigit +} + // RenderMultiEditToolArguments renders MultiEdit tool arguments func (d *DiffRenderer) RenderMultiEditToolArguments(args map[string]any) string { filePath, _ := args["file_path"].(string) @@ -69,7 +173,7 @@ func (d *DiffRenderer) RenderMultiEditToolArguments(args map[string]any) string var result strings.Builder - result.WriteString(d.fileStyle.Render(filePath)) + result.WriteString(d.styleProvider.RenderWithColorAndBold(filePath, d.styleProvider.GetThemeColor("accent"))) result.WriteString("\n\n") editsArray, ok := editsInterface.([]any) @@ -78,7 +182,7 @@ func (d *DiffRenderer) RenderMultiEditToolArguments(args map[string]any) string return result.String() } - result.WriteString(d.contextStyle.Render(fmt.Sprintf("Operations: %d edits", len(editsArray)))) + result.WriteString(d.styleProvider.RenderDimText(fmt.Sprintf("Operations: %d edits", len(editsArray)))) result.WriteString("\n\n") for i, editInterface := range editsArray { @@ -91,13 +195,19 @@ func (d *DiffRenderer) RenderMultiEditToolArguments(args map[string]any) string newString, _ := editMap["new_string"].(string) replaceAll, _ := editMap["replace_all"].(bool) - result.WriteString(d.headerStyle.Render(fmt.Sprintf("Edit %d:", i+1))) + result.WriteString(d.styleProvider.RenderWithColor(fmt.Sprintf("Edit %d:", i+1), d.styleProvider.GetThemeColor("accent"))) result.WriteString("\n") if replaceAll { - result.WriteString(d.contextStyle.Render("Replace all occurrences")) + result.WriteString(d.styleProvider.RenderDimText("Replace all occurrences")) result.WriteString("\n") } - result.WriteString(d.renderUnifiedDiff(oldString, newString, 1)) + + cleanedOldString := d.cleanString(oldString) + cleanedNewString := d.cleanString(newString) + + startLine := d.findLineNumber(filePath, oldString) + result.WriteString(d.renderUnifiedDiff(cleanedOldString, cleanedNewString, startLine)) + if i < len(editsArray)-1 { result.WriteString("\n") } @@ -106,51 +216,129 @@ func (d *DiffRenderer) RenderMultiEditToolArguments(args map[string]any) string return result.String() } -// RenderWriteToolArguments renders Write tool arguments +// RenderWriteToolArguments renders Write tool arguments with diff for existing files func (d *DiffRenderer) RenderWriteToolArguments(args map[string]any) string { filePath, _ := args["file_path"].(string) content, _ := args["content"].(string) var result strings.Builder - result.WriteString(d.fileStyle.Render(filePath)) + existingContent, err := d.readFileIfExists(filePath) + if err == nil && existingContent != "" { + diffInfo := DiffInfo{ + FilePath: filePath, + OldContent: existingContent, + NewContent: content, + Title: "← File will be overwritten →", + } + return d.RenderDiff(diffInfo) + } + + icon := d.getFileIcon(filePath) + header := d.styleProvider.RenderWithColorAndBold(icon+" "+filePath, d.styleProvider.GetThemeColor("accent")) + + result.WriteString(d.styleProvider.RenderBordered(header, 80)) result.WriteString("\n\n") - result.WriteString(d.contextStyle.Render("Content:")) - result.WriteString("\n") - result.WriteString(content) - if !strings.HasSuffix(content, "\n") { - result.WriteString("\n") + + opts := styles.StyleOptions{ + Background: d.styleProvider.GetThemeColor("success"), + Foreground: "#000000", + Bold: true, } + newFileBadge := d.styleProvider.RenderStyledText("NEW FILE", opts) + + result.WriteString(newFileBadge) + result.WriteString("\n\n") + + result.WriteString(d.renderContentPreview(content)) return result.String() } -// RenderDiff renders a unified diff with colors +// readFileIfExists attempts to read a file, returning empty string if not exists +func (d *DiffRenderer) readFileIfExists(filePath string) (string, error) { + content, err := os.ReadFile(filePath) + if err != nil { + return "", err + } + return string(content), nil +} + +// renderContentPreview renders content with line numbers for preview +func (d *DiffRenderer) renderContentPreview(content string) string { + lines := strings.Split(content, "\n") + + var result strings.Builder + maxLineNumWidth := len(fmt.Sprintf("%d", len(lines))) + gutterSep := d.styleProvider.RenderDimText(" │ ") + + for i, line := range lines { + if i >= 50 { + remaining := len(lines) - i + opts := styles.StyleOptions{ + Foreground: d.styleProvider.GetThemeColor("dim"), + Italic: true, + } + result.WriteString(d.styleProvider.RenderStyledText(fmt.Sprintf("\n... %d more lines ...", remaining), opts)) + break + } + + lineNumStr := d.styleProvider.RenderDimText( + fmt.Sprintf("%*d", maxLineNumWidth, i+1)) + result.WriteString(lineNumStr) + result.WriteString(gutterSep) + result.WriteString(line) + if i < len(lines)-1 { + result.WriteString("\n") + } + } + + return d.styleProvider.RenderBordered(result.String(), 80) +} + +// RenderDiff renders a unified diff with colors and modern styling func (d *DiffRenderer) RenderDiff(diffInfo DiffInfo) string { var result strings.Builder + stats := d.calculateDiffStats(diffInfo.OldContent, diffInfo.NewContent) + if diffInfo.Title != "" { - result.WriteString(d.headerStyle.Render(diffInfo.Title)) + opts := styles.StyleOptions{ + Foreground: d.styleProvider.GetThemeColor("accent"), + Bold: true, + } + result.WriteString(d.styleProvider.RenderStyledText(diffInfo.Title, opts)) result.WriteString("\n\n") } - result.WriteString(d.fileStyle.Render(diffInfo.FilePath)) + result.WriteString(d.renderFileHeader(diffInfo.FilePath, stats)) result.WriteString("\n\n") - result.WriteString(d.headerStyle.Render(fmt.Sprintf("--- a/%s", diffInfo.FilePath))) + result.WriteString(d.styleProvider.RenderDimText(fmt.Sprintf("--- a/%s", diffInfo.FilePath))) result.WriteString("\n") - result.WriteString(d.headerStyle.Render(fmt.Sprintf("+++ b/%s", diffInfo.FilePath))) + result.WriteString(d.styleProvider.RenderDimText(fmt.Sprintf("+++ b/%s", diffInfo.FilePath))) result.WriteString("\n") + var diffContent string switch { case diffInfo.OldContent == "" && diffInfo.NewContent != "": - result.WriteString(d.renderNewFileContent(diffInfo.NewContent)) + diffContent = d.renderNewFileContent(diffInfo.NewContent) case diffInfo.OldContent != "" && diffInfo.NewContent == "": - result.WriteString(d.renderDeletedFileContent(diffInfo.OldContent)) + diffContent = d.renderDeletedFileContent(diffInfo.OldContent) default: - result.WriteString(d.renderUnifiedDiff(diffInfo.OldContent, diffInfo.NewContent, 1)) + cleanedOldContent := d.cleanString(diffInfo.OldContent) + cleanedNewContent := d.cleanString(diffInfo.NewContent) + + startLine := 1 + if diffInfo.FilePath != "" && diffInfo.OldContent != "" { + startLine = d.findLineNumber(diffInfo.FilePath, diffInfo.OldContent) + } + + diffContent = d.renderUnifiedDiff(cleanedOldContent, cleanedNewContent, startLine) } + result.WriteString(d.styleProvider.RenderBordered(diffContent, 100)) + return result.String() } @@ -162,18 +350,9 @@ type DiffInfo struct { Title string } -// NewToolDiffRenderer creates a tool diff renderer (alias for DiffRenderer) -func NewToolDiffRenderer() *DiffRenderer { - return &DiffRenderer{ - themeService: nil, - additionStyle: lipgloss.NewStyle().Foreground(colors.DiffAddColor.GetLipglossColor()), - deletionStyle: lipgloss.NewStyle().Foreground(colors.DiffRemoveColor.GetLipglossColor()), - headerStyle: lipgloss.NewStyle().Foreground(colors.HeaderColor.GetLipglossColor()), - fileStyle: lipgloss.NewStyle().Foreground(colors.AccentColor.GetLipglossColor()).Bold(true), - contextStyle: lipgloss.NewStyle().Foreground(colors.DimColor.GetLipglossColor()), - lineNumStyle: lipgloss.NewStyle().Foreground(colors.DimColor.GetLipglossColor()), - chunkStyle: lipgloss.NewStyle().Foreground(colors.StatusColor.GetLipglossColor()).Bold(true), - } +// NewToolDiffRenderer creates a tool diff renderer (alias for NewDiffRenderer) +func NewToolDiffRenderer(styleProvider *styles.Provider) *DiffRenderer { + return NewDiffRenderer(styleProvider) } // RenderColoredDiff renders a simple diff between old and new content (for compatibility) @@ -192,12 +371,23 @@ func (d *DiffRenderer) renderNewFileContent(newContent string) string { var result strings.Builder newLines := strings.Split(newContent, "\n") chunkHeader := fmt.Sprintf("@@ -0,0 +1,%d @@", len(newLines)) - result.WriteString(d.chunkStyle.Render(chunkHeader)) + opts := styles.StyleOptions{ + Foreground: d.styleProvider.GetThemeColor("status"), + Bold: true, + } + result.WriteString(d.styleProvider.RenderStyledText(chunkHeader, opts)) result.WriteString("\n") + maxLineNumWidth := len(fmt.Sprintf("%d", len(newLines))) + gutterSep := d.styleProvider.RenderDimText(" │ ") + for i, line := range newLines { if i < len(newLines)-1 || line != "" { - result.WriteString(d.additionStyle.Render(fmt.Sprintf("+%s", line))) + lineNumStr := d.styleProvider.RenderDimText( + fmt.Sprintf("%*d", maxLineNumWidth, i+1)) + result.WriteString(lineNumStr) + result.WriteString(gutterSep) + result.WriteString(d.styleProvider.RenderDiffAddition(fmt.Sprintf("+%s", line))) result.WriteString("\n") } } @@ -209,12 +399,23 @@ func (d *DiffRenderer) renderDeletedFileContent(oldContent string) string { var result strings.Builder oldLines := strings.Split(oldContent, "\n") chunkHeader := fmt.Sprintf("@@ -1,%d +0,0 @@", len(oldLines)) - result.WriteString(d.chunkStyle.Render(chunkHeader)) + opts := styles.StyleOptions{ + Foreground: d.styleProvider.GetThemeColor("status"), + Bold: true, + } + result.WriteString(d.styleProvider.RenderStyledText(chunkHeader, opts)) result.WriteString("\n") + maxLineNumWidth := len(fmt.Sprintf("%d", len(oldLines))) + gutterSep := d.styleProvider.RenderDimText(" │ ") + for i, line := range oldLines { if i < len(oldLines)-1 || line != "" { - result.WriteString(d.deletionStyle.Render(fmt.Sprintf("-%s", line))) + lineNumStr := d.styleProvider.RenderDimText( + fmt.Sprintf("%*d", maxLineNumWidth, i+1)) + result.WriteString(lineNumStr) + result.WriteString(gutterSep) + result.WriteString(d.styleProvider.RenderDiffRemoval(fmt.Sprintf("-%s", line))) result.WriteString("\n") } } @@ -235,12 +436,128 @@ func (d *DiffRenderer) renderUnifiedDiff(oldContent, newContent string, startLin newCount := len(newLines) chunkHeader := fmt.Sprintf("@@ -%d,%d +%d,%d @@", startLine, oldCount, startLine, newCount) - result.WriteString(d.chunkStyle.Render(chunkHeader)) + opts := styles.StyleOptions{ + Foreground: d.styleProvider.GetThemeColor("status"), + Bold: true, + } + result.WriteString(d.styleProvider.RenderStyledText(chunkHeader, opts)) result.WriteString("\n") - maxLines := max(oldCount, newCount) + type diffLine struct { + oldLineNum int + newLineNum int + content string + isAdd bool + isDelete bool + isContext bool + } + + var diffLines []diffLine + oldIdx := 0 + newIdx := 0 + oldLineNum := startLine + newLineNum := startLine + + for oldIdx < len(oldLines) || newIdx < len(newLines) { + oldLine := "" + newLine := "" + + if oldIdx < len(oldLines) { + oldLine = oldLines[oldIdx] + } + if newIdx < len(newLines) { + newLine = newLines[newIdx] + } + + if oldLine == newLine { + diffLines = append(diffLines, diffLine{ + oldLineNum: oldLineNum, + newLineNum: newLineNum, + content: oldLine, + isContext: true, + }) + oldIdx++ + newIdx++ + oldLineNum++ + newLineNum++ + } else if oldIdx >= len(oldLines) { + diffLines = append(diffLines, diffLine{ + newLineNum: newLineNum, + content: newLine, + isAdd: true, + }) + newIdx++ + newLineNum++ + } else if newIdx >= len(newLines) { + diffLines = append(diffLines, diffLine{ + oldLineNum: oldLineNum, + content: oldLine, + isDelete: true, + }) + oldIdx++ + oldLineNum++ + } else { + diffLines = append(diffLines, diffLine{ + oldLineNum: oldLineNum, + content: oldLine, + isDelete: true, + }) + diffLines = append(diffLines, diffLine{ + newLineNum: newLineNum, + content: newLine, + isAdd: true, + }) + oldIdx++ + newIdx++ + oldLineNum++ + newLineNum++ + } + } + + maxLineNumWidth := len(fmt.Sprintf("%d", max(oldLineNum, newLineNum))) + gutterSep := d.styleProvider.RenderDimText(" │ ") + + for _, line := range diffLines { + var lineNumStr string + if line.isDelete { + lineNumStr = d.styleProvider.RenderDimText(fmt.Sprintf("%*d", maxLineNumWidth, line.oldLineNum)) + } else { + lineNumStr = d.styleProvider.RenderDimText(fmt.Sprintf("%*d", maxLineNumWidth, line.newLineNum)) + } + + result.WriteString(lineNumStr) + result.WriteString(gutterSep) + + if line.isContext { + result.WriteString(d.styleProvider.RenderDimText(fmt.Sprintf(" %s", line.content))) + } else if line.isAdd { + result.WriteString(d.styleProvider.RenderDiffAddition(fmt.Sprintf("+%s", line.content))) + } else if line.isDelete { + result.WriteString(d.styleProvider.RenderDiffRemoval(fmt.Sprintf("-%s", line.content))) + } + result.WriteString("\n") + } + + output := result.String() + return strings.TrimSuffix(output, "\n") +} + +// DiffStats represents statistics about a diff +type DiffStats struct { + LinesAdded int + LinesRemoved int + LinesChanged int +} + +// calculateDiffStats computes statistics from old and new content +func (d *DiffRenderer) calculateDiffStats(oldContent, newContent string) DiffStats { + oldLines := strings.Split(oldContent, "\n") + newLines := strings.Split(newContent, "\n") + + stats := DiffStats{} + maxLines := max(len(oldLines), len(newLines)) - for i := range maxLines { + for i := 0; i < maxLines; i++ { oldLine := "" newLine := "" @@ -252,21 +569,83 @@ func (d *DiffRenderer) renderUnifiedDiff(oldContent, newContent string, startLin } if oldLine == newLine { - result.WriteString(d.contextStyle.Render(fmt.Sprintf(" %s", oldLine))) - result.WriteString("\n") - } else { - if i < len(oldLines) { - result.WriteString(d.deletionStyle.Render(fmt.Sprintf("-%s", oldLine))) - result.WriteString("\n") - } - if i < len(newLines) { - result.WriteString(d.additionStyle.Render(fmt.Sprintf("+%s", newLine))) - result.WriteString("\n") - } + continue + } + + if oldLine != "" && newLine != "" { + stats.LinesChanged++ + } else if oldLine == "" { + stats.LinesAdded++ + } else if newLine == "" { + stats.LinesRemoved++ } } - // Remove trailing newline - output := result.String() - return strings.TrimSuffix(output, "\n") + return stats +} + +// renderDiffStats creates a visual stats summary +func (d *DiffRenderer) renderDiffStats(stats DiffStats) string { + if stats.LinesAdded == 0 && stats.LinesRemoved == 0 && stats.LinesChanged == 0 { + return "" + } + + var parts []string + + if stats.LinesAdded > 0 { + parts = append(parts, d.styleProvider.RenderDiffAddition(fmt.Sprintf("+%d", stats.LinesAdded))) + } + if stats.LinesRemoved > 0 { + parts = append(parts, d.styleProvider.RenderDiffRemoval(fmt.Sprintf("-%d", stats.LinesRemoved))) + } + if stats.LinesChanged > 0 { + parts = append(parts, d.styleProvider.RenderWithColor(fmt.Sprintf("~%d", stats.LinesChanged), d.styleProvider.GetThemeColor("status"))) + } + + return d.styleProvider.RenderDimText("Changes: ") + strings.Join(parts, " ") +} + +// getFileIcon returns an appropriate icon/glyph for a file based on extension +func (d *DiffRenderer) getFileIcon(filePath string) string { + ext := strings.ToLower(filePath) + + switch { + case strings.HasSuffix(ext, ".go"): + return "🐹" + case strings.HasSuffix(ext, ".js"), strings.HasSuffix(ext, ".ts"): + return "📜" + case strings.HasSuffix(ext, ".py"): + return "🐍" + case strings.HasSuffix(ext, ".md"): + return "📝" + case strings.HasSuffix(ext, ".json"), strings.HasSuffix(ext, ".yaml"), strings.HasSuffix(ext, ".yml"): + return "⚙️ " + case strings.HasSuffix(ext, ".html"), strings.HasSuffix(ext, ".css"): + return "🌐" + case strings.HasSuffix(ext, ".rs"): + return "🦀" + case strings.HasSuffix(ext, ".java"): + return "☕" + case strings.HasSuffix(ext, ".sh"), strings.HasSuffix(ext, ".bash"): + return "🔧" + default: + return "📄" + } +} + +// renderFileHeader creates an elegant file header with metadata +func (d *DiffRenderer) renderFileHeader(filePath string, stats DiffStats) string { + icon := d.getFileIcon(filePath) + fileName := d.styleProvider.RenderWithColorAndBold(icon+" "+filePath, d.styleProvider.GetThemeColor("accent")) + + var header strings.Builder + header.WriteString(fileName) + + statsLine := d.renderDiffStats(stats) + if statsLine != "" { + header.WriteString(" ") + header.WriteString(statsLine) + } + + return d.styleProvider.RenderBordered(header.String(), 100) } diff --git a/internal/ui/components/diff_renderer_test.go b/internal/ui/components/diff_renderer_test.go index 421e0e39..88c83d8a 100644 --- a/internal/ui/components/diff_renderer_test.go +++ b/internal/ui/components/diff_renderer_test.go @@ -3,10 +3,15 @@ package components import ( "strings" "testing" + + domain "github.com/inference-gateway/cli/internal/domain" + styles "github.com/inference-gateway/cli/internal/ui/styles" ) func TestDiffRenderer_RenderDiff(t *testing.T) { - renderer := NewDiffRenderer(nil) + themeService := domain.NewThemeProvider() + styleProvider := styles.NewProvider(themeService) + renderer := NewDiffRenderer(styleProvider) t.Run("New file creation", func(t *testing.T) { diffInfo := DiffInfo{ @@ -89,7 +94,9 @@ func TestDiffRenderer_RenderDiff(t *testing.T) { } func TestDiffRenderer_RenderMultiEditToolArguments(t *testing.T) { - renderer := NewDiffRenderer(nil) + themeService := domain.NewThemeProvider() + styleProvider := styles.NewProvider(themeService) + renderer := NewDiffRenderer(styleProvider) t.Run("Multiple edits", func(t *testing.T) { args := map[string]any{ @@ -165,7 +172,9 @@ func TestDiffRenderer_RenderMultiEditToolArguments(t *testing.T) { } func TestDiffRenderer_HelperMethods(t *testing.T) { - renderer := NewDiffRenderer(nil) + themeService := domain.NewThemeProvider() + styleProvider := styles.NewProvider(themeService) + renderer := NewDiffRenderer(styleProvider) t.Run("renderNewFileContent", func(t *testing.T) { content := "line1\nline2\nline3" @@ -234,20 +243,30 @@ func TestDiffRenderer_HelperMethods(t *testing.T) { } func TestDiffRenderer_Styling(t *testing.T) { - renderer := NewDiffRenderer(nil) + themeService := domain.NewThemeProvider() + styleProvider := styles.NewProvider(themeService) + renderer := NewDiffRenderer(styleProvider) - testContent := "test" + if renderer == nil { + t.Fatal("NewDiffRenderer should not return nil") + } - _ = renderer.additionStyle.Render(testContent) - _ = renderer.deletionStyle.Render(testContent) - _ = renderer.headerStyle.Render(testContent) - _ = renderer.fileStyle.Render(testContent) - _ = renderer.contextStyle.Render(testContent) - _ = renderer.chunkStyle.Render(testContent) + diffInfo := DiffInfo{ + FilePath: "test.go", + OldContent: "old", + NewContent: "new", + Title: "Test", + } + output := renderer.RenderDiff(diffInfo) + if output == "" { + t.Error("Expected non-empty output from RenderDiff") + } } func TestNewToolDiffRenderer(t *testing.T) { - renderer := NewToolDiffRenderer() + themeService := domain.NewThemeProvider() + styleProvider := styles.NewProvider(themeService) + renderer := NewToolDiffRenderer(styleProvider) if renderer == nil { t.Fatal("NewToolDiffRenderer should not return nil") diff --git a/internal/ui/components/file_selection_handler.go b/internal/ui/components/file_selection_handler.go index fc56623a..9426d203 100644 --- a/internal/ui/components/file_selection_handler.go +++ b/internal/ui/components/file_selection_handler.go @@ -7,6 +7,7 @@ import ( tea "github.com/charmbracelet/bubbletea" domain "github.com/inference-gateway/cli/internal/domain" shared "github.com/inference-gateway/cli/internal/ui/shared" + styles "github.com/inference-gateway/cli/internal/ui/styles" ) // FileSelectionHandler handles file selection logic and state management @@ -15,9 +16,9 @@ type FileSelectionHandler struct { } // NewFileSelectionHandler creates a new file selection handler -func NewFileSelectionHandler(themeService domain.ThemeService) *FileSelectionHandler { +func NewFileSelectionHandler(styleProvider *styles.Provider) *FileSelectionHandler { return &FileSelectionHandler{ - view: NewFileSelectionView(themeService), + view: NewFileSelectionView(styleProvider), } } diff --git a/internal/ui/components/file_selection_view.go b/internal/ui/components/file_selection_view.go index 0fb18dbe..6a382e41 100644 --- a/internal/ui/components/file_selection_view.go +++ b/internal/ui/components/file_selection_view.go @@ -4,19 +4,19 @@ import ( "fmt" "strings" - domain "github.com/inference-gateway/cli/internal/domain" + styles "github.com/inference-gateway/cli/internal/ui/styles" ) type FileSelectionView struct { - themeService domain.ThemeService - maxVisible int - width int + styleProvider *styles.Provider + maxVisible int + width int } -func NewFileSelectionView(themeService domain.ThemeService) *FileSelectionView { +func NewFileSelectionView(styleProvider *styles.Provider) *FileSelectionView { return &FileSelectionView{ - themeService: themeService, - maxVisible: 12, + styleProvider: styleProvider, + maxVisible: 12, } } @@ -85,17 +85,19 @@ func (f *FileSelectionView) renderHeader(b *strings.Builder, files, allFiles []s func (f *FileSelectionView) renderSearchField(b *strings.Builder, searchQuery string) { b.WriteString("🔍 Search: ") if searchQuery != "" { - fmt.Fprintf(b, "%s%s%s│", f.themeService.GetCurrentTheme().GetUserColor(), searchQuery, "\033[0m") + fmt.Fprintf(b, "%s│", f.styleProvider.RenderWithColor(searchQuery, f.styleProvider.GetThemeColor("user"))) } else { - fmt.Fprintf(b, "%stype to filter files...%s│", f.themeService.GetCurrentTheme().GetDimColor(), "\033[0m") + fmt.Fprintf(b, "%s│", f.styleProvider.RenderDimText("type to filter files...")) } b.WriteString("\n\n") } func (f *FileSelectionView) renderNoFilesFound(b *strings.Builder, searchQuery string) string { - fmt.Fprintf(b, "%sNo files match '%s'%s\n\n", f.themeService.GetCurrentTheme().GetErrorColor(), searchQuery, "\033[0m") + errorMsg := fmt.Sprintf("No files match '%s'", searchQuery) + fmt.Fprintf(b, "%s\n\n", f.styleProvider.RenderWithColor(errorMsg, f.styleProvider.GetThemeColor("error"))) + helpText := "Type to search, BACKSPACE to clear search, ESC to cancel" - b.WriteString(f.themeService.GetCurrentTheme().GetDimColor() + helpText + "\033[0m") + b.WriteString(f.styleProvider.RenderDimText(helpText)) return b.String() } @@ -105,9 +107,9 @@ func (f *FileSelectionView) renderFileList(b *strings.Builder, files []string, s for i := startIndex; i < endIndex; i++ { file := files[i] if i == selectedIndex { - fmt.Fprintf(b, "%s▶ %s%s\n", f.themeService.GetCurrentTheme().GetAccentColor(), file, "\033[0m") + fmt.Fprintf(b, "%s\n", f.styleProvider.RenderWithColor("▶ "+file, f.styleProvider.GetThemeColor("accent"))) } else { - fmt.Fprintf(b, "%s %s%s\n", f.themeService.GetCurrentTheme().GetDimColor(), file, "\033[0m") + fmt.Fprintf(b, "%s\n", f.styleProvider.RenderDimText(" "+file)) } } } @@ -129,11 +131,11 @@ func (f *FileSelectionView) renderFooter(b *strings.Builder, files []string, sel if len(files) > f.maxVisible { startIndex, endIndex := f.calculateVisibleRange(len(files), selectedIndex) - fmt.Fprintf(b, "%sShowing %d-%d of %d matches%s\n", - f.themeService.GetCurrentTheme().GetDimColor(), startIndex+1, endIndex, len(files), "\033[0m") + paginationText := fmt.Sprintf("Showing %d-%d of %d matches", startIndex+1, endIndex, len(files)) + fmt.Fprintf(b, "%s\n", f.styleProvider.RenderDimText(paginationText)) b.WriteString("\n") } helpText := "Type to search, ↑↓ to navigate, ENTER to select, BACKSPACE to clear, ESC to cancel" - b.WriteString(f.themeService.GetCurrentTheme().GetDimColor() + helpText + "\033[0m") + b.WriteString(f.styleProvider.RenderDimText(helpText)) } diff --git a/internal/ui/components/help_bar.go b/internal/ui/components/help_bar.go index 131e750e..c778f0c4 100644 --- a/internal/ui/components/help_bar.go +++ b/internal/ui/components/help_bar.go @@ -4,26 +4,25 @@ import ( "strings" tea "github.com/charmbracelet/bubbletea" - lipgloss "github.com/charmbracelet/lipgloss" domain "github.com/inference-gateway/cli/internal/domain" shared "github.com/inference-gateway/cli/internal/ui/shared" - colors "github.com/inference-gateway/cli/internal/ui/styles/colors" + styles "github.com/inference-gateway/cli/internal/ui/styles" ) // HelpBar displays keyboard shortcuts at the bottom of the screen type HelpBar struct { - enabled bool - width int - shortcuts []shared.KeyShortcut - themeService domain.ThemeService + enabled bool + width int + shortcuts []shared.KeyShortcut + styleProvider *styles.Provider } -func NewHelpBar(themeService domain.ThemeService) *HelpBar { +func NewHelpBar(styleProvider *styles.Provider) *HelpBar { return &HelpBar{ - enabled: false, - width: 80, - shortcuts: make([]shared.KeyShortcut, 0), - themeService: themeService, + enabled: false, + width: 80, + shortcuts: make([]shared.KeyShortcut, 0), + styleProvider: styleProvider, } } @@ -151,20 +150,19 @@ func (hb *HelpBar) renderResponsiveTable() string { } } - cellStyle := lipgloss.NewStyle(). - Width(colWidth). - Align(lipgloss.Left) - cells = append(cells, cellStyle.Render(cellText)) + cellText = hb.styleProvider.RenderStyledText(cellText, styles.StyleOptions{ + Width: colWidth, + }) + cells = append(cells, cellText) } - tableRows = append(tableRows, lipgloss.JoinHorizontal(lipgloss.Left, cells...)) + tableRows = append(tableRows, hb.styleProvider.JoinHorizontal(cells...)) } - dimColor := hb.getDimColor() - tableStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color(dimColor)). - Width(hb.width) - - return tableStyle.Render(strings.Join(tableRows, "\n")) + fullTable := strings.Join(tableRows, "\n") + return hb.styleProvider.RenderStyledText(fullTable, styles.StyleOptions{ + Foreground: hb.styleProvider.GetThemeColor("dim"), + Width: hb.width, + }) } // Bubble Tea interface @@ -183,11 +181,3 @@ func (hb *HelpBar) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } return hb, nil } - -// Helper method to get theme colors with fallback -func (hb *HelpBar) getDimColor() string { - if hb.themeService != nil { - return hb.themeService.GetCurrentTheme().GetDimColor() - } - return colors.DimColor.Lipgloss -} diff --git a/internal/ui/components/help_bar_test.go b/internal/ui/components/help_bar_test.go index 60e65358..a5934d96 100644 --- a/internal/ui/components/help_bar_test.go +++ b/internal/ui/components/help_bar_test.go @@ -5,10 +5,16 @@ import ( "testing" shared "github.com/inference-gateway/cli/internal/ui/shared" + styles "github.com/inference-gateway/cli/internal/ui/styles" ) +// createMockStyleProviderForHelpBar creates a mock styles provider for testing +func createMockStyleProviderForHelpBar() *styles.Provider { + return styles.NewProvider(&mockThemeService{}) +} + func TestNewHelpBar(t *testing.T) { - hb := NewHelpBar(nil) + hb := NewHelpBar(createMockStyleProviderForHelpBar()) if hb.width != 80 { t.Errorf("Expected default width 80, got %d", hb.width) @@ -24,7 +30,7 @@ func TestNewHelpBar(t *testing.T) { } func TestHelpBar_SetShortcuts(t *testing.T) { - hb := NewHelpBar(nil) + hb := NewHelpBar(createMockStyleProviderForHelpBar()) shortcuts := []shared.KeyShortcut{ {Key: "Enter", Description: "Send message"}, @@ -48,7 +54,7 @@ func TestHelpBar_SetShortcuts(t *testing.T) { } func TestHelpBar_IsEnabled(t *testing.T) { - hb := NewHelpBar(nil) + hb := NewHelpBar(createMockStyleProviderForHelpBar()) if hb.IsEnabled() { t.Error("Expected help bar to be disabled by default") @@ -56,7 +62,7 @@ func TestHelpBar_IsEnabled(t *testing.T) { } func TestHelpBar_SetEnabled(t *testing.T) { - hb := NewHelpBar(nil) + hb := NewHelpBar(createMockStyleProviderForHelpBar()) hb.SetEnabled(false) if hb.IsEnabled() { @@ -70,7 +76,7 @@ func TestHelpBar_SetEnabled(t *testing.T) { } func TestHelpBar_SetWidth(t *testing.T) { - hb := NewHelpBar(nil) + hb := NewHelpBar(createMockStyleProviderForHelpBar()) hb.SetWidth(120) @@ -80,13 +86,13 @@ func TestHelpBar_SetWidth(t *testing.T) { } func TestHelpBar_SetHeight(t *testing.T) { - hb := NewHelpBar(nil) + hb := NewHelpBar(createMockStyleProviderForHelpBar()) hb.SetHeight(2) } func TestHelpBar_Render_Disabled(t *testing.T) { - hb := NewHelpBar(nil) + hb := NewHelpBar(createMockStyleProviderForHelpBar()) hb.SetEnabled(false) output := hb.Render() @@ -97,7 +103,7 @@ func TestHelpBar_Render_Disabled(t *testing.T) { } func TestHelpBar_Render_NoShortcuts(t *testing.T) { - hb := NewHelpBar(nil) + hb := NewHelpBar(createMockStyleProviderForHelpBar()) hb.SetEnabled(true) output := hb.Render() @@ -108,7 +114,8 @@ func TestHelpBar_Render_NoShortcuts(t *testing.T) { } func TestHelpBar_Render_WithShortcuts(t *testing.T) { - hb := NewHelpBar(nil) + styleProvider := styles.NewProvider(&mockThemeService{}) + hb := NewHelpBar(styleProvider) hb.SetEnabled(true) shortcuts := []shared.KeyShortcut{ @@ -138,7 +145,7 @@ func TestHelpBar_Render_WithShortcuts(t *testing.T) { } func TestHelpBar_Render_LongShortcuts(t *testing.T) { - hb := NewHelpBar(nil) + hb := NewHelpBar(createMockStyleProviderForHelpBar()) hb.SetEnabled(true) hb.SetWidth(20) @@ -160,7 +167,7 @@ func TestHelpBar_Render_LongShortcuts(t *testing.T) { } func TestHelpBar_Render_EmptyShortcuts(t *testing.T) { - hb := NewHelpBar(nil) + hb := NewHelpBar(createMockStyleProviderForHelpBar()) hb.SetEnabled(true) hb.SetShortcuts([]shared.KeyShortcut{}) @@ -172,7 +179,7 @@ func TestHelpBar_Render_EmptyShortcuts(t *testing.T) { } func TestHelpBar_Render_SingleShortcut(t *testing.T) { - hb := NewHelpBar(nil) + hb := NewHelpBar(createMockStyleProviderForHelpBar()) hb.SetEnabled(true) shortcuts := []shared.KeyShortcut{ diff --git a/internal/ui/components/input_view.go b/internal/ui/components/input_view.go index 8aa09851..c4d06c5e 100644 --- a/internal/ui/components/input_view.go +++ b/internal/ui/components/input_view.go @@ -4,14 +4,13 @@ import ( "fmt" "strings" - "github.com/atotto/clipboard" + clipboard "github.com/atotto/clipboard" tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" - "github.com/inference-gateway/cli/internal/domain" - "github.com/inference-gateway/cli/internal/ui/history" - "github.com/inference-gateway/cli/internal/ui/keys" - "github.com/inference-gateway/cli/internal/ui/shared" - "github.com/inference-gateway/cli/internal/ui/styles/colors" + domain "github.com/inference-gateway/cli/internal/domain" + history "github.com/inference-gateway/cli/internal/ui/history" + keys "github.com/inference-gateway/cli/internal/ui/keys" + shared "github.com/inference-gateway/cli/internal/ui/shared" + styles "github.com/inference-gateway/cli/internal/ui/styles" ) // InputView handles user input with history and autocomplete @@ -26,6 +25,7 @@ type InputView struct { historyManager *history.HistoryManager isTextSelectionMode bool themeService domain.ThemeService + styleProvider *styles.Provider } func NewInputView(modelService domain.ModelService) *InputView { @@ -59,6 +59,7 @@ func NewInputViewWithConfigDir(modelService domain.ModelService, configDir strin // SetThemeService sets the theme service for this input view func (iv *InputView) SetThemeService(themeService domain.ThemeService) { iv.themeService = themeService + iv.styleProvider = styles.NewProvider(themeService) } func (iv *InputView) GetInput() string { @@ -109,22 +110,17 @@ func (iv *InputView) Render() string { displayText := iv.renderDisplayText() inputContent := fmt.Sprintf("> %s", displayText) - borderColor := iv.getBorderColor(isBashMode, isToolsMode) - inputStyle := lipgloss.NewStyle(). - Border(lipgloss.RoundedBorder()). - BorderForeground(lipgloss.Color(borderColor)). - Padding(0, 1). - Width(iv.width - 4) + focused := isBashMode || isToolsMode + borderedInput := iv.styleProvider.RenderInputField(inputContent, iv.width-4, focused) - borderedInput := inputStyle.Render(inputContent) components := []string{borderedInput} components = iv.addModeIndicator(components, isBashMode, isToolsMode) components = iv.addAutocomplete(components) components = iv.addModelDisplay(components, isBashMode, isToolsMode) - return lipgloss.JoinVertical(lipgloss.Left, components...) + return iv.styleProvider.JoinVertical(components...) } func (iv *InputView) renderDisplayText() string { @@ -135,10 +131,7 @@ func (iv *InputView) renderDisplayText() string { } func (iv *InputView) renderPlaceholder() string { - dimColor := iv.getDimColor() - return lipgloss.NewStyle(). - Foreground(lipgloss.Color(dimColor)). - Render(iv.placeholder) + return iv.styleProvider.RenderInputPlaceholder(iv.placeholder) } func (iv *InputView) renderTextWithCursor() string { @@ -177,48 +170,41 @@ func (iv *InputView) buildTextWithCursor(before, after string) string { } func (iv *InputView) createCursorChar(char string) string { - return lipgloss.NewStyle(). - Background(lipgloss.Color(colors.LipglossWhiteBg)). - Foreground(lipgloss.Color(colors.LipglossBlack)). - Render(char) -} - -func (iv *InputView) getBorderColor(isBashMode bool, isToolsMode bool) string { - if isBashMode { - return iv.getStatusColor() - } - if isToolsMode { - return iv.getAccentColor() - } - return iv.getDimColor() + return iv.styleProvider.RenderTextSelectionCursor(char) } func (iv *InputView) addModeIndicator(components []string, isBashMode bool, isToolsMode bool) []string { if iv.height >= 2 { if iv.isTextSelectionMode { - accentColor := iv.getAccentColor() - textSelectionIndicator := lipgloss.NewStyle(). - Foreground(lipgloss.Color(accentColor)). - Bold(true). - Width(iv.width). - Render("TEXT SELECTION MODE - Use vim keys to navigate and select text (Escape to exit)") - components = append(components, textSelectionIndicator) + indicator := iv.styleProvider.RenderStyledText( + "TEXT SELECTION MODE - Use vim keys to navigate and select text (Escape to exit)", + styles.StyleOptions{ + Foreground: iv.styleProvider.GetThemeColor("accent"), + Bold: true, + Width: iv.width, + }, + ) + components = append(components, indicator) } else if isBashMode { - statusColor := iv.getStatusColor() - bashIndicator := lipgloss.NewStyle(). - Foreground(lipgloss.Color(statusColor)). - Bold(true). - Width(iv.width). - Render("BASH MODE - Command will be executed directly") - components = append(components, bashIndicator) + indicator := iv.styleProvider.RenderStyledText( + "BASH MODE - Command will be executed directly", + styles.StyleOptions{ + Foreground: iv.styleProvider.GetThemeColor("status"), + Bold: true, + Width: iv.width, + }, + ) + components = append(components, indicator) } else if isToolsMode { - accentColor := iv.getAccentColor() - toolsIndicator := lipgloss.NewStyle(). - Foreground(lipgloss.Color(accentColor)). - Bold(true). - Width(iv.width). - Render("TOOLS MODE - !!ToolName(arg=\"value\") - Tab for autocomplete") - components = append(components, toolsIndicator) + indicator := iv.styleProvider.RenderStyledText( + "TOOLS MODE - !!ToolName(arg=\"value\") - Tab for autocomplete", + styles.StyleOptions{ + Foreground: iv.styleProvider.GetThemeColor("accent"), + Bold: true, + Width: iv.width, + }, + ) + components = append(components, indicator) } } return components @@ -237,11 +223,6 @@ func (iv *InputView) addModelDisplay(components []string, isBashMode bool, isToo if iv.modelService != nil { currentModel := iv.modelService.GetCurrentModel() if currentModel != "" && iv.height >= 2 && !isBashMode && !isToolsMode { - dimColor := iv.getDimColor() - modelStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color(dimColor)). - Width(iv.width) - displayText := fmt.Sprintf(" Model: %s", currentModel) if iv.themeService != nil { @@ -249,7 +230,10 @@ func (iv *InputView) addModelDisplay(components []string, isBashMode bool, isToo displayText = fmt.Sprintf(" Model: %s • Theme: %s", currentModel, currentTheme) } - modelDisplay := modelStyle.Render(displayText) + modelDisplay := iv.styleProvider.RenderStyledText(displayText, styles.StyleOptions{ + Foreground: iv.styleProvider.GetThemeColor("dim"), + Width: iv.width, + }) components = append(components, modelDisplay) } } @@ -446,25 +430,3 @@ func (iv *InputView) SetTextSelectionMode(enabled bool) { func (iv *InputView) IsTextSelectionMode() bool { return iv.isTextSelectionMode } - -// Helper methods to get theme colors with fallbacks -func (iv *InputView) getDimColor() string { - if iv.themeService != nil { - return iv.themeService.GetCurrentTheme().GetDimColor() - } - return colors.DimColor.Lipgloss -} - -func (iv *InputView) getAccentColor() string { - if iv.themeService != nil { - return iv.themeService.GetCurrentTheme().GetAccentColor() - } - return colors.AccentColor.Lipgloss -} - -func (iv *InputView) getStatusColor() string { - if iv.themeService != nil { - return iv.themeService.GetCurrentTheme().GetStatusColor() - } - return colors.StatusColor.Lipgloss -} diff --git a/internal/ui/components/input_view_test.go b/internal/ui/components/input_view_test.go index 1d34a79e..037e346b 100644 --- a/internal/ui/components/input_view_test.go +++ b/internal/ui/components/input_view_test.go @@ -2,10 +2,11 @@ package components import ( "context" + "strings" "testing" - "github.com/charmbracelet/bubbletea" - "github.com/inference-gateway/cli/internal/domain" + tea "github.com/charmbracelet/bubbletea" + domain "github.com/inference-gateway/cli/internal/domain" ) // mockModelService is a simple mock for testing @@ -33,6 +34,13 @@ func (m *mockModelService) ValidateModel(modelID string) error { return nil } +// createInputViewWithTheme creates an InputView with a mock theme service for testing +func createInputViewWithTheme(modelService domain.ModelService) *InputView { + iv := NewInputView(modelService) + iv.SetThemeService(&mockThemeService{}) + return iv +} + func TestNewInputView(t *testing.T) { mockModelService := &mockModelService{} iv := NewInputView(mockModelService) @@ -171,7 +179,7 @@ func TestInputView_SetHeight(t *testing.T) { func TestInputView_Render(t *testing.T) { mockModelService := &mockModelService{} - iv := NewInputView(mockModelService) + iv := createInputViewWithTheme(mockModelService) output := iv.Render() if output == "" { @@ -224,3 +232,34 @@ func TestInputView_History(t *testing.T) { t.Error("Expected history manager to be initialized") } } + +func TestInputView_BashModeBorderColor(t *testing.T) { + mockModelService := &mockModelService{} + iv := createInputViewWithTheme(mockModelService) + + iv.SetText("normal text") + normalOutput := iv.Render() + if normalOutput == "" { + t.Error("Expected non-empty render output for normal text") + } + + iv.SetText("!") + bashOutput := iv.Render() + if bashOutput == "" { + t.Error("Expected non-empty render output for bash mode") + } + + if !strings.Contains(bashOutput, "BASH MODE") { + t.Error("Expected bash mode output to contain 'BASH MODE' indicator") + } + + iv.SetText("!!") + toolsOutput := iv.Render() + if toolsOutput == "" { + t.Error("Expected non-empty render output for tools mode") + } + + if !strings.Contains(toolsOutput, "TOOLS MODE") { + t.Error("Expected tools mode output to contain 'TOOLS MODE' indicator") + } +} diff --git a/internal/ui/components/model_selection_view.go b/internal/ui/components/model_selection_view.go index 89140f8d..5f2c5072 100644 --- a/internal/ui/components/model_selection_view.go +++ b/internal/ui/components/model_selection_view.go @@ -6,7 +6,7 @@ import ( tea "github.com/charmbracelet/bubbletea" domain "github.com/inference-gateway/cli/internal/domain" - colors "github.com/inference-gateway/cli/internal/ui/styles/colors" + styles "github.com/inference-gateway/cli/internal/ui/styles" ) // ModelSelectorImpl implements model selection UI @@ -16,7 +16,7 @@ type ModelSelectorImpl struct { selected int width int height int - themeService domain.ThemeService + styleProvider *styles.Provider done bool cancelled bool modelService domain.ModelService @@ -25,14 +25,14 @@ type ModelSelectorImpl struct { } // NewModelSelector creates a new model selector -func NewModelSelector(models []string, modelService domain.ModelService, themeService domain.ThemeService) *ModelSelectorImpl { +func NewModelSelector(models []string, modelService domain.ModelService, styleProvider *styles.Provider) *ModelSelectorImpl { m := &ModelSelectorImpl{ models: models, filteredModels: make([]string, len(models)), selected: 0, width: 80, height: 24, - themeService: themeService, + styleProvider: styleProvider, modelService: modelService, searchQuery: "", searchMode: false, @@ -146,25 +146,30 @@ func (m *ModelSelectorImpl) updateSearch() { func (m *ModelSelectorImpl) View() string { var b strings.Builder - b.WriteString(fmt.Sprintf("%sSelect a Model%s\n\n", - m.themeService.GetCurrentTheme().GetAccentColor(), colors.Reset)) + accentColor := m.styleProvider.GetThemeColor("accent") + b.WriteString(m.styleProvider.RenderWithColor("Select a Model", accentColor)) + b.WriteString("\n\n") if m.searchMode { - b.WriteString(fmt.Sprintf("%sSearch: %s%s│%s\n\n", - m.themeService.GetCurrentTheme().GetStatusColor(), m.searchQuery, m.themeService.GetCurrentTheme().GetAccentColor(), colors.Reset)) + statusColor := m.styleProvider.GetThemeColor("status") + b.WriteString(m.styleProvider.RenderWithColor("Search: "+m.searchQuery, statusColor)) + b.WriteString(m.styleProvider.RenderWithColor("│", accentColor)) + b.WriteString("\n\n") } else { - b.WriteString(fmt.Sprintf("%sPress / to search • %d models available%s\n\n", - m.themeService.GetCurrentTheme().GetDimColor(), len(m.models), colors.Reset)) + helpText := fmt.Sprintf("Press / to search • %d models available", len(m.models)) + b.WriteString(m.styleProvider.RenderDimText(helpText)) + b.WriteString("\n\n") } if len(m.filteredModels) == 0 { + errorColor := m.styleProvider.GetThemeColor("error") + if m.searchQuery != "" { - b.WriteString(fmt.Sprintf("%sNo models match '%s'%s\n", - m.themeService.GetCurrentTheme().GetErrorColor(), m.searchQuery, colors.Reset)) + b.WriteString(m.styleProvider.RenderWithColor(fmt.Sprintf("No models match '%s'", m.searchQuery), errorColor)) } else { - b.WriteString(fmt.Sprintf("%sNo models available%s\n", - m.themeService.GetCurrentTheme().GetErrorColor(), colors.Reset)) + b.WriteString(m.styleProvider.RenderWithColor("No models available", errorColor)) } + b.WriteString("\n") return b.String() } @@ -182,27 +187,28 @@ func (m *ModelSelectorImpl) View() string { model := m.filteredModels[i] if i == m.selected { - b.WriteString(fmt.Sprintf("%s▶ %s%s\n", - m.themeService.GetCurrentTheme().GetAccentColor(), model, colors.Reset)) + b.WriteString(m.styleProvider.RenderWithColor("▶ "+model, accentColor)) + b.WriteString("\n") } else { b.WriteString(fmt.Sprintf(" %s\n", model)) } } if len(m.filteredModels) > maxVisible { - b.WriteString(fmt.Sprintf("\n%sShowing %d-%d of %d models%s\n", - m.themeService.GetCurrentTheme().GetDimColor(), start+1, start+maxVisible, len(m.filteredModels), colors.Reset)) + paginationText := fmt.Sprintf("Showing %d-%d of %d models", start+1, start+maxVisible, len(m.filteredModels)) + b.WriteString("\n") + b.WriteString(m.styleProvider.RenderDimText(paginationText)) + b.WriteString("\n") } b.WriteString("\n") - b.WriteString(colors.CreateSeparator(m.width, "─")) + b.WriteString(strings.Repeat("─", m.width)) b.WriteString("\n") + if m.searchMode { - b.WriteString(fmt.Sprintf("%sType to search, ↑↓ to navigate, Enter to select, Esc to clear search%s", - m.themeService.GetCurrentTheme().GetDimColor(), colors.Reset)) + b.WriteString(m.styleProvider.RenderDimText("Type to search, ↑↓ to navigate, Enter to select, Esc to clear search")) } else { - b.WriteString(fmt.Sprintf("%sUse ↑↓ arrows to navigate, Enter to select, / to search, Esc/Ctrl+C to cancel%s", - m.themeService.GetCurrentTheme().GetDimColor(), colors.Reset)) + b.WriteString(m.styleProvider.RenderDimText("Use ↑↓ arrows to navigate, Enter to select, / to search, Esc/Ctrl+C to cancel")) } return b.String() diff --git a/internal/ui/components/parallel_tools_renderer.go b/internal/ui/components/parallel_tools_renderer.go index 52ff953b..2e3eb0f2 100644 --- a/internal/ui/components/parallel_tools_renderer.go +++ b/internal/ui/components/parallel_tools_renderer.go @@ -5,10 +5,9 @@ import ( "time" tea "github.com/charmbracelet/bubbletea" - lipgloss "github.com/charmbracelet/lipgloss" constants "github.com/inference-gateway/cli/internal/constants" domain "github.com/inference-gateway/cli/internal/domain" - colors "github.com/inference-gateway/cli/internal/ui/styles/colors" + styles "github.com/inference-gateway/cli/internal/ui/styles" icons "github.com/inference-gateway/cli/internal/ui/styles/icons" ) @@ -33,74 +32,21 @@ type ToolExecutionState struct { } type ParallelToolsRenderer struct { - tools map[string]*ToolExecutionState - styles *parallelToolStyles - blinkState bool - visible bool - spinnerStep int -} - -type parallelToolStyles struct { - executing lipgloss.Style - queued lipgloss.Style - complete lipgloss.Style - failed lipgloss.Style - toolName lipgloss.Style - container lipgloss.Style - message lipgloss.Style - duration lipgloss.Style - statusLabel lipgloss.Style - toolBadge lipgloss.Style - separator lipgloss.Style + tools map[string]*ToolExecutionState + styleProvider *styles.Provider + blinkState bool + visible bool + spinnerStep int } type TickMsg struct{} -func NewParallelToolsRenderer() *ParallelToolsRenderer { - styles := ¶llelToolStyles{ - executing: lipgloss.NewStyle(). - Foreground(colors.AccentColor.GetLipglossColor()). - Bold(true), - queued: lipgloss.NewStyle(). - Foreground(colors.WarningColor.GetLipglossColor()), - complete: lipgloss.NewStyle(). - Foreground(colors.SuccessColor.GetLipglossColor()). - Bold(true), - failed: lipgloss.NewStyle(). - Foreground(colors.ErrorColor.GetLipglossColor()). - Bold(true), - toolName: lipgloss.NewStyle(). - Bold(false). - Foreground(colors.AssistantColor.GetLipglossColor()), - container: lipgloss.NewStyle(). - Border(lipgloss.RoundedBorder()). - BorderForeground(colors.BorderColor.GetLipglossColor()). - Padding(0, 1). - Margin(0, 0, 1, 0), - message: lipgloss.NewStyle(). - Foreground(colors.DimColor.GetLipglossColor()). - Italic(true), - duration: lipgloss.NewStyle(). - Foreground(colors.DimColor.GetLipglossColor()), - statusLabel: lipgloss.NewStyle(). - Foreground(colors.DimColor.GetLipglossColor()). - Bold(true). - Padding(0, 1), - toolBadge: lipgloss.NewStyle(). - Padding(0, 1). - Margin(0, 1, 0, 0). - Border(lipgloss.RoundedBorder()). - BorderForeground(colors.BorderColor.GetLipglossColor()), - separator: lipgloss.NewStyle(). - Foreground(colors.DimColor.GetLipglossColor()). - Padding(0, 1), - } - +func NewParallelToolsRenderer(styleProvider *styles.Provider) *ParallelToolsRenderer { return &ParallelToolsRenderer{ - tools: make(map[string]*ToolExecutionState), - styles: styles, - visible: false, - spinnerStep: 0, + tools: make(map[string]*ToolExecutionState), + styleProvider: styleProvider, + visible: false, + spinnerStep: 0, } } @@ -215,51 +161,47 @@ func (r *ParallelToolsRenderer) Render() string { return "" } - label := r.styles.statusLabel.Render("Tools:") + opts := styles.StyleOptions{ + Foreground: r.styleProvider.GetThemeColor("dim"), + Bold: true, + } + label := r.styleProvider.RenderStyledText("Tools:", opts) toolsContent := strings.Join(toolDisplays, " ") content := label + " " + toolsContent - return r.styles.container.Render(content) + return r.styleProvider.RenderBordered(content, 100) } func (r *ParallelToolsRenderer) renderToolBadge(tool *ToolExecutionState) string { var icon string - var badgeStyle lipgloss.Style + var colorName string switch tool.Status { case ToolStatusRunning, ToolStatusStarting, ToolStatusSaving: icon = icons.GetSpinnerFrame(r.spinnerStep) - badgeStyle = r.styles.toolBadge. - BorderForeground(colors.AccentColor.GetLipglossColor()). - Foreground(colors.AccentColor.GetLipglossColor()) + colorName = "accent" case ToolStatusQueued: icon = icons.QueuedIcon - badgeStyle = r.styles.toolBadge. - BorderForeground(colors.WarningColor.GetLipglossColor()). - Foreground(colors.WarningColor.GetLipglossColor()) + colorName = "warning" case ToolStatusComplete: icon = icons.CheckMark - badgeStyle = r.styles.toolBadge. - BorderForeground(colors.SuccessColor.GetLipglossColor()). - Foreground(colors.SuccessColor.GetLipglossColor()) + colorName = "success" case ToolStatusFailed: icon = icons.CrossMark - badgeStyle = r.styles.toolBadge. - BorderForeground(colors.ErrorColor.GetLipglossColor()). - Foreground(colors.ErrorColor.GetLipglossColor()) + colorName = "error" default: icon = icons.BulletIcon - badgeStyle = r.styles.toolBadge. - BorderForeground(colors.DimColor.GetLipglossColor()). - Foreground(colors.DimColor.GetLipglossColor()) + colorName = "dim" } - badgeContent := icon + " " + r.styles.toolName.Render(tool.ToolName) - return badgeStyle.Render(badgeContent) + toolNameText := r.styleProvider.RenderWithColor(tool.ToolName, r.styleProvider.GetThemeColor("assistant")) + badgeContent := icon + " " + toolNameText + + return r.styleProvider.RenderWithColor(badgeContent, r.styleProvider.GetThemeColor(colorName)) } func (r *ParallelToolsRenderer) Clear() { diff --git a/internal/ui/components/queue_box_view.go b/internal/ui/components/queue_box_view.go index 292c1899..fb34e3bd 100644 --- a/internal/ui/components/queue_box_view.go +++ b/internal/ui/components/queue_box_view.go @@ -5,21 +5,20 @@ import ( "strings" tea "github.com/charmbracelet/bubbletea" - lipgloss "github.com/charmbracelet/lipgloss" domain "github.com/inference-gateway/cli/internal/domain" shared "github.com/inference-gateway/cli/internal/ui/shared" - colors "github.com/inference-gateway/cli/internal/ui/styles/colors" + styles "github.com/inference-gateway/cli/internal/ui/styles" ) type QueueBoxView struct { - width int - themeService domain.ThemeService + width int + styleProvider *styles.Provider } -func NewQueueBoxView(themeService domain.ThemeService) *QueueBoxView { +func NewQueueBoxView(styleProvider *styles.Provider) *QueueBoxView { return &QueueBoxView{ - width: 80, - themeService: themeService, + width: 80, + styleProvider: styleProvider, } } @@ -46,23 +45,16 @@ func (qv *QueueBoxView) Render(queuedMessages []domain.QueuedMessage, background } separator := strings.Repeat("─", qv.width-4) - dimSeparator := lipgloss.NewStyle(). - Foreground(lipgloss.Color(colors.DimColor.Lipgloss)). - Render(separator) + dimColor := qv.styleProvider.GetThemeColor("dim") + dimSeparator := qv.styleProvider.RenderWithColor(separator, dimColor) contentText := strings.Join(sections, "\n"+dimSeparator+"\n") - boxStyle := lipgloss.NewStyle(). - Padding(0, 1). - Border(lipgloss.RoundedBorder()). - BorderForeground(lipgloss.Color(colors.DimColor.Lipgloss)) - - return boxStyle.Render(contentText) + return qv.styleProvider.RenderBorderedBox(contentText, dimColor, 0, 1) } func (qv *QueueBoxView) renderBackgroundTasks(backgroundTasks []domain.TaskPollingState) string { - accentColor := qv.getAccentColor() - dimColor := qv.getDimColor() + accentColor := qv.styleProvider.GetThemeColor("accent") count := len(backgroundTasks) taskWord := "task" @@ -70,45 +62,32 @@ func (qv *QueueBoxView) renderBackgroundTasks(backgroundTasks []domain.TaskPolli taskWord = "tasks" } - titleStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color(accentColor)). - Bold(true) - - hintStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color(dimColor)). - Italic(true) - titleText := fmt.Sprintf("Background Tasks (%d)", count) hintText := fmt.Sprintf(" %d active %s running • Type /tasks to view details", count, taskWord) - return titleStyle.Render(titleText) + "\n" + hintStyle.Render(hintText) + return qv.styleProvider.RenderWithColorAndBold(titleText, accentColor) + "\n" + qv.styleProvider.RenderDimText(hintText) } func (qv *QueueBoxView) renderQueuedMessages(queuedMessages []domain.QueuedMessage) string { - accentColor := qv.getAccentColor() + accentColor := qv.styleProvider.GetThemeColor("accent") titleText := fmt.Sprintf("Queued Messages (%d)", len(queuedMessages)) - titleStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color(accentColor)). - Bold(true) var messageLines []string for _, queuedMsg := range queuedMessages { messageLines = append(messageLines, qv.formatQueuedMessage(queuedMsg)) } - return titleStyle.Render(titleText) + "\n" + strings.Join(messageLines, "\n") + return qv.styleProvider.RenderWithColorAndBold(titleText, accentColor) + "\n" + strings.Join(messageLines, "\n") } func (qv *QueueBoxView) formatQueuedMessage(queuedMsg domain.QueuedMessage) string { - accentColor := qv.getAccentColor() + accentColor := qv.styleProvider.GetThemeColor("accent") + dimColor := qv.styleProvider.GetThemeColor("dim") preview := qv.formatMessagePreview(queuedMsg) - arrowStyle := lipgloss.NewStyle().Foreground(lipgloss.Color(accentColor)) - previewStyle := lipgloss.NewStyle().Foreground(lipgloss.Color(colors.QueuedMessageColor.Lipgloss)) - formattedLine := fmt.Sprintf(" %s %s", - arrowStyle.Render("→"), - previewStyle.Render(preview), + qv.styleProvider.RenderWithColor("→", accentColor), + qv.styleProvider.RenderWithColor(preview, dimColor), ) return formattedLine @@ -150,14 +129,3 @@ func (qv *QueueBoxView) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } return qv, nil } - -func (qv *QueueBoxView) getAccentColor() string { - if qv.themeService != nil { - return qv.themeService.GetCurrentTheme().GetAccentColor() - } - return colors.AccentColor.Lipgloss -} - -func (qv *QueueBoxView) getDimColor() string { - return colors.DimColor.Lipgloss -} diff --git a/internal/ui/components/status_view.go b/internal/ui/components/status_view.go index aebecc71..939f6275 100644 --- a/internal/ui/components/status_view.go +++ b/internal/ui/components/status_view.go @@ -6,28 +6,27 @@ import ( spinner "github.com/charmbracelet/bubbles/spinner" tea "github.com/charmbracelet/bubbletea" - lipgloss "github.com/charmbracelet/lipgloss" domain "github.com/inference-gateway/cli/internal/domain" shared "github.com/inference-gateway/cli/internal/ui/shared" - colors "github.com/inference-gateway/cli/internal/ui/styles/colors" + styles "github.com/inference-gateway/cli/internal/ui/styles" icons "github.com/inference-gateway/cli/internal/ui/styles/icons" ) // StatusView handles status messages, errors, and loading spinners type StatusView struct { - message string - isError bool - isSpinner bool - spinner spinner.Model - startTime time.Time - tokenUsage string - baseMessage string - debugInfo string - width int - statusType domain.StatusType - progress *domain.StatusProgress - savedState *StatusState - themeService domain.ThemeService + message string + isError bool + isSpinner bool + spinner spinner.Model + startTime time.Time + tokenUsage string + baseMessage string + debugInfo string + width int + statusType domain.StatusType + progress *domain.StatusProgress + savedState *StatusState + styleProvider *styles.Provider } // StatusState represents a saved status state @@ -42,16 +41,16 @@ type StatusState struct { progress *domain.StatusProgress } -func NewStatusView(themeService domain.ThemeService) *StatusView { +func NewStatusView(styleProvider *styles.Provider) *StatusView { s := spinner.New() s.Spinner = spinner.Dot - s.Style = lipgloss.NewStyle().Foreground(colors.SpinnerColor.GetLipglossColor()) + s.Style = styleProvider.GetSpinnerStyle() return &StatusView{ - message: "", - isError: false, - isSpinner: false, - spinner: s, - themeService: themeService, + message: "", + isError: false, + isSpinner: false, + spinner: s, + styleProvider: styleProvider, } } @@ -216,7 +215,8 @@ func (sv *StatusView) Render() string { } } - return fmt.Sprintf("%s%s %s%s", color, prefix, displayMessage, colors.Reset) + statusLine := fmt.Sprintf("%s %s", prefix, displayMessage) + return sv.styleProvider.RenderWithColor(statusLine, color) } // getStatusIcon returns the appropriate icon for the current status type @@ -269,7 +269,7 @@ func (sv *StatusView) createProgressBar() string { } func (sv *StatusView) formatErrorStatus() (string, string, string) { - errorColor := sv.getErrorColor() + errorColor := sv.styleProvider.GetThemeColor("error") return icons.CrossMarkStyle.Render(icons.CrossMark), errorColor, sv.message } @@ -286,13 +286,13 @@ func (sv *StatusView) formatSpinnerStatus() (string, string, string) { baseMsg := sv.formatStatusWithType(sv.baseMessage) displayMessage := fmt.Sprintf("%s (%ds) - Press ESC to interrupt", baseMsg, seconds) - statusColor := sv.getStatusColor() + statusColor := sv.styleProvider.GetThemeColor("status") return prefix, statusColor, displayMessage } func (sv *StatusView) formatNormalStatus() (string, string, string) { prefix := sv.getStatusIcon() - statusColor := sv.getStatusColor() + statusColor := sv.styleProvider.GetThemeColor("status") displayMessage := sv.formatStatusWithType(sv.message) if sv.tokenUsage != "" { @@ -343,18 +343,3 @@ func (sv *StatusView) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return sv, cmd } - -// Helper methods to get theme colors with fallbacks -func (sv *StatusView) getErrorColor() string { - if sv.themeService != nil { - return sv.themeService.GetCurrentTheme().GetErrorColor() - } - return colors.ErrorColor.ANSI -} - -func (sv *StatusView) getStatusColor() string { - if sv.themeService != nil { - return sv.themeService.GetCurrentTheme().GetStatusColor() - } - return colors.StatusColor.ANSI -} diff --git a/internal/ui/components/status_view_test.go b/internal/ui/components/status_view_test.go index a091e078..f3b89dad 100644 --- a/internal/ui/components/status_view_test.go +++ b/internal/ui/components/status_view_test.go @@ -3,10 +3,17 @@ package components import ( "strings" "testing" + + styles "github.com/inference-gateway/cli/internal/ui/styles" ) +// createMockStyleProviderForStatus creates a mock styles provider for testing +func createMockStyleProviderForStatus() *styles.Provider { + return styles.NewProvider(&mockThemeService{}) +} + func TestNewStatusView(t *testing.T) { - sv := NewStatusView(nil) + sv := NewStatusView(createMockStyleProviderForStatus()) if sv.width != 0 { t.Errorf("Expected default width 0, got %d", sv.width) @@ -26,7 +33,7 @@ func TestNewStatusView(t *testing.T) { } func TestStatusView_ShowStatus(t *testing.T) { - sv := NewStatusView(nil) + sv := NewStatusView(createMockStyleProviderForStatus()) testMessage := "Processing request..." sv.ShowStatus(testMessage) @@ -45,7 +52,7 @@ func TestStatusView_ShowStatus(t *testing.T) { } func TestStatusView_ShowError(t *testing.T) { - sv := NewStatusView(nil) + sv := NewStatusView(createMockStyleProviderForStatus()) testError := "Connection failed" sv.ShowError(testError) @@ -64,7 +71,7 @@ func TestStatusView_ShowError(t *testing.T) { } func TestStatusView_ShowSpinner(t *testing.T) { - sv := NewStatusView(nil) + sv := NewStatusView(createMockStyleProviderForStatus()) testMessage := "Loading..." sv.ShowSpinner(testMessage) @@ -83,7 +90,7 @@ func TestStatusView_ShowSpinner(t *testing.T) { } func TestStatusView_ClearStatus(t *testing.T) { - sv := NewStatusView(nil) + sv := NewStatusView(createMockStyleProviderForStatus()) sv.ShowError("Some error") sv.SetTokenUsage("100 tokens") @@ -108,7 +115,7 @@ func TestStatusView_ClearStatus(t *testing.T) { } func TestStatusView_IsShowingError(t *testing.T) { - sv := NewStatusView(nil) + sv := NewStatusView(createMockStyleProviderForStatus()) if sv.IsShowingError() { t.Error("Expected IsShowingError to be false initially") @@ -128,7 +135,7 @@ func TestStatusView_IsShowingError(t *testing.T) { } func TestStatusView_IsShowingSpinner(t *testing.T) { - sv := NewStatusView(nil) + sv := NewStatusView(createMockStyleProviderForStatus()) if sv.IsShowingSpinner() { t.Error("Expected IsShowingSpinner to be false initially") @@ -148,7 +155,7 @@ func TestStatusView_IsShowingSpinner(t *testing.T) { } func TestStatusView_SetTokenUsage(t *testing.T) { - sv := NewStatusView(nil) + sv := NewStatusView(createMockStyleProviderForStatus()) testUsage := "150 tokens used" sv.SetTokenUsage(testUsage) @@ -159,7 +166,7 @@ func TestStatusView_SetTokenUsage(t *testing.T) { } func TestStatusView_SetWidth(t *testing.T) { - sv := NewStatusView(nil) + sv := NewStatusView(createMockStyleProviderForStatus()) sv.SetWidth(120) @@ -169,13 +176,13 @@ func TestStatusView_SetWidth(t *testing.T) { } func TestStatusView_SetHeight(t *testing.T) { - sv := NewStatusView(nil) + sv := NewStatusView(createMockStyleProviderForStatus()) sv.SetHeight(4) } func TestStatusView_Render(t *testing.T) { - sv := NewStatusView(nil) + sv := NewStatusView(createMockStyleProviderForStatus()) output := sv.Render() if output != "" { @@ -213,7 +220,7 @@ func TestStatusView_Render(t *testing.T) { } func TestStatusView_StateTransitions(t *testing.T) { - sv := NewStatusView(nil) + sv := NewStatusView(createMockStyleProviderForStatus()) sv.ShowStatus("Normal") sv.ShowError("Error occurred") diff --git a/internal/ui/components/text_selection_view.go b/internal/ui/components/text_selection_view.go index 781bbd32..cf4b286c 100644 --- a/internal/ui/components/text_selection_view.go +++ b/internal/ui/components/text_selection_view.go @@ -6,9 +6,8 @@ import ( clipboard "github.com/atotto/clipboard" tea "github.com/charmbracelet/bubbletea" - lipgloss "github.com/charmbracelet/lipgloss" domain "github.com/inference-gateway/cli/internal/domain" - colors "github.com/inference-gateway/cli/internal/ui/styles/colors" + styles "github.com/inference-gateway/cli/internal/ui/styles" ) // TextSelectionView provides vim-like text selection mode @@ -24,6 +23,7 @@ type TextSelectionView struct { height int scrollOffset int copiedText string + styleProvider *styles.Provider } // Position represents a position in the text @@ -33,15 +33,16 @@ type Position struct { } // NewTextSelectionView creates a new text selection view -func NewTextSelectionView() *TextSelectionView { +func NewTextSelectionView(styleProvider *styles.Provider) *TextSelectionView { return &TextSelectionView{ - lines: []string{}, - cursorLine: 0, - cursorCol: 0, - selecting: false, - width: 80, - height: 20, - scrollOffset: 0, + lines: []string{}, + cursorLine: 0, + cursorCol: 0, + selecting: false, + width: 80, + height: 20, + scrollOffset: 0, + styleProvider: styleProvider, } } @@ -385,10 +386,6 @@ func (v *TextSelectionView) Render() string { var b strings.Builder - headerStyle := lipgloss.NewStyle(). - Foreground(colors.AccentColor.GetLipglossColor()). - Bold(true) - mode := "SELECTION MODE" if v.selecting { mode = "VISUAL" @@ -396,7 +393,8 @@ func (v *TextSelectionView) Render() string { mode = "VISUAL LINE" } - header := headerStyle.Render(fmt.Sprintf("-- %s --", mode)) + accentColor := v.styleProvider.GetThemeColor("accent") + header := v.styleProvider.RenderWithColorAndBold(fmt.Sprintf("-- %s --", mode), accentColor) b.WriteString(header) b.WriteString("\n") @@ -413,10 +411,6 @@ func (v *TextSelectionView) Render() string { b.WriteString("\n") } - posStyle := lipgloss.NewStyle(). - Foreground(colors.DimColor.GetLipglossColor()). - Italic(true) - debugInfo := "" if v.selecting { debugInfo = fmt.Sprintf(" | Visual: %d,%d -> %d,%d", @@ -428,7 +422,7 @@ func (v *TextSelectionView) Render() string { } position := fmt.Sprintf("Line %d/%d, Col %d%s", v.cursorLine+1, len(v.lines), v.cursorCol+1, debugInfo) - b.WriteString(posStyle.Render(position)) + b.WriteString(v.styleProvider.RenderDimText(position)) return b.String() } @@ -462,10 +456,6 @@ func (v *TextSelectionView) renderLineWithSelection(lineIdx int, line string) st start, end = end, start } - highlightStyle := lipgloss.NewStyle(). - Background(colors.AccentColor.GetLipglossColor()). - Foreground(colors.TextSelectionForeground.GetLipglossColor()) - if lineIdx < start.Line || lineIdx > end.Line { return line } @@ -486,7 +476,7 @@ func (v *TextSelectionView) renderLineWithSelection(lineIdx int, line string) st } if shouldHighlight { - result.WriteString(highlightStyle.Render(string(line[i]))) + result.WriteString(v.styleProvider.RenderTextSelection(string(line[i]))) } else { result.WriteByte(line[i]) } @@ -496,18 +486,18 @@ func (v *TextSelectionView) renderLineWithSelection(lineIdx int, line string) st } // renderCharAtPosition renders a character at a specific position with appropriate styling -func (v *TextSelectionView) renderCharAtPosition(i, lineLen int, line string, isCursor, shouldHighlight bool, cursorStyle, highlightStyle lipgloss.Style) string { +func (v *TextSelectionView) renderCharAtPosition(i, lineLen int, line string, isCursor, shouldHighlight bool) string { if isCursor { if i < lineLen { - return cursorStyle.Render(string(line[i])) + return v.styleProvider.RenderCursor(string(line[i])) } - return cursorStyle.Render(" ") + return v.styleProvider.RenderCursor(" ") } if i < lineLen { char := string(line[i]) if shouldHighlight { - return highlightStyle.Render(char) + return v.styleProvider.RenderTextSelection(char) } return char } @@ -526,10 +516,7 @@ func (v *TextSelectionView) renderDisplayLine(lineIdx int, line string, isSelect } if v.visualLineMode { - highlightStyle := lipgloss.NewStyle(). - Background(colors.AccentColor.GetLipglossColor()). - Foreground(lipgloss.Color("#000000")) - return highlightStyle.Render(line) + return v.styleProvider.RenderVisualLineSelection(line) } return v.renderLineWithSelection(lineIdx, line) @@ -576,14 +563,6 @@ func (v *TextSelectionView) shouldHighlightChar(lineIdx, charIdx int, isSelected func (v *TextSelectionView) renderLineWithCursor(lineIdx int, line string, isSelected bool) string { var result strings.Builder - cursorStyle := lipgloss.NewStyle(). - Background(colors.TextSelectionCursor.GetLipglossColor()). - Foreground(colors.TextSelectionForeground.GetLipglossColor()) - - highlightStyle := lipgloss.NewStyle(). - Background(colors.AccentColor.GetLipglossColor()). - Foreground(colors.TextSelectionForeground.GetLipglossColor()) - lineLen := len(line) displayCursorCol := v.cursorCol if displayCursorCol > lineLen { @@ -592,11 +571,9 @@ func (v *TextSelectionView) renderLineWithCursor(lineIdx int, line string, isSel for i := 0; i <= lineLen; i++ { isCursor := i == displayCursorCol - shouldHighlight := false - - shouldHighlight = v.shouldHighlightChar(lineIdx, i, isSelected, lineLen) + shouldHighlight := v.shouldHighlightChar(lineIdx, i, isSelected, lineLen) - charRendered := v.renderCharAtPosition(i, lineLen, line, isCursor, shouldHighlight, cursorStyle, highlightStyle) + charRendered := v.renderCharAtPosition(i, lineLen, line, isCursor, shouldHighlight) if charRendered != "" { result.WriteString(charRendered) } diff --git a/internal/ui/components/theme_selection_view.go b/internal/ui/components/theme_selection_view.go index c8d44d91..703a1907 100644 --- a/internal/ui/components/theme_selection_view.go +++ b/internal/ui/components/theme_selection_view.go @@ -6,7 +6,7 @@ import ( tea "github.com/charmbracelet/bubbletea" domain "github.com/inference-gateway/cli/internal/domain" - colors "github.com/inference-gateway/cli/internal/ui/styles/colors" + styles "github.com/inference-gateway/cli/internal/ui/styles" ) // ThemeSelectorImpl implements theme selection UI @@ -19,12 +19,13 @@ type ThemeSelectorImpl struct { done bool cancelled bool themeService domain.ThemeService + styleProvider *styles.Provider searchQuery string searchMode bool } // NewThemeSelector creates a new theme selector -func NewThemeSelector(themeService domain.ThemeService) *ThemeSelectorImpl { +func NewThemeSelector(themeService domain.ThemeService, styleProvider *styles.Provider) *ThemeSelectorImpl { themes := themeService.ListThemes() m := &ThemeSelectorImpl{ themes: themes, @@ -33,6 +34,7 @@ func NewThemeSelector(themeService domain.ThemeService) *ThemeSelectorImpl { width: 80, height: 24, themeService: themeService, + styleProvider: styleProvider, searchQuery: "", searchMode: false, } @@ -168,25 +170,26 @@ func (m *ThemeSelectorImpl) updateSearch() { func (m *ThemeSelectorImpl) View() string { var b strings.Builder - b.WriteString(fmt.Sprintf("%sSelect a Theme%s\n\n", - m.themeService.GetCurrentTheme().GetAccentColor(), colors.Reset)) + b.WriteString(m.styleProvider.RenderWithColor("Select a Theme", m.styleProvider.GetThemeColor("accent"))) + b.WriteString("\n\n") if m.searchMode { - b.WriteString(fmt.Sprintf("%sSearch: %s%s│%s\n\n", - m.themeService.GetCurrentTheme().GetStatusColor(), m.searchQuery, m.themeService.GetCurrentTheme().GetAccentColor(), colors.Reset)) + b.WriteString(m.styleProvider.RenderWithColor("Search: "+m.searchQuery, m.styleProvider.GetThemeColor("status"))) + b.WriteString(m.styleProvider.RenderWithColor("│", m.styleProvider.GetThemeColor("accent"))) + b.WriteString("\n\n") } else { - b.WriteString(fmt.Sprintf("%sPress / to search • %d themes available%s\n\n", - m.themeService.GetCurrentTheme().GetDimColor(), len(m.themes), colors.Reset)) + helpText := fmt.Sprintf("Press / to search • %d themes available", len(m.themes)) + b.WriteString(m.styleProvider.RenderDimText(helpText)) + b.WriteString("\n\n") } if len(m.filteredThemes) == 0 { if m.searchQuery != "" { - b.WriteString(fmt.Sprintf("%sNo themes match '%s'%s\n", - m.themeService.GetCurrentTheme().GetErrorColor(), m.searchQuery, colors.Reset)) + b.WriteString(m.styleProvider.RenderWithColor(fmt.Sprintf("No themes match '%s'", m.searchQuery), m.styleProvider.GetThemeColor("error"))) } else { - b.WriteString(fmt.Sprintf("%sNo themes available%s\n", - m.themeService.GetCurrentTheme().GetErrorColor(), colors.Reset)) + b.WriteString(m.styleProvider.RenderWithColor("No themes available", m.styleProvider.GetThemeColor("error"))) } + b.WriteString("\n") return b.String() } @@ -205,40 +208,43 @@ func (m *ThemeSelectorImpl) View() string { for i := start; i < start+maxVisible && i < len(m.filteredThemes); i++ { themeName := m.filteredThemes[i] - // Format the theme item based on selection and current theme prefix := " " suffix := "" - color := "" if i == m.selected { prefix = "▶ " - color = m.themeService.GetCurrentTheme().GetAccentColor() } if themeName == currentTheme { suffix = " ✓" - if i != m.selected { - color = m.themeService.GetCurrentTheme().GetStatusColor() - } } - b.WriteString(fmt.Sprintf("%s%s%s%s%s\n", color, prefix, themeName, suffix, colors.Reset)) + line := prefix + themeName + suffix + if i == m.selected { + b.WriteString(m.styleProvider.RenderWithColor(line, m.styleProvider.GetThemeColor("accent"))) + } else if themeName == currentTheme { + b.WriteString(m.styleProvider.RenderWithColor(line, m.styleProvider.GetThemeColor("status"))) + } else { + b.WriteString(line) + } + b.WriteString("\n") } if len(m.filteredThemes) > maxVisible { - b.WriteString(fmt.Sprintf("\n%sShowing %d-%d of %d themes%s\n", - m.themeService.GetCurrentTheme().GetDimColor(), start+1, start+maxVisible, len(m.filteredThemes), colors.Reset)) + paginationText := fmt.Sprintf("Showing %d-%d of %d themes", start+1, start+maxVisible, len(m.filteredThemes)) + b.WriteString("\n") + b.WriteString(m.styleProvider.RenderDimText(paginationText)) + b.WriteString("\n") } b.WriteString("\n") - b.WriteString(colors.CreateSeparator(m.width, "─")) + b.WriteString(strings.Repeat("─", m.width)) b.WriteString("\n") + if m.searchMode { - b.WriteString(fmt.Sprintf("%sType to search, ↑↓ to navigate, Enter to select, Esc to clear search%s", - m.themeService.GetCurrentTheme().GetDimColor(), colors.Reset)) + b.WriteString(m.styleProvider.RenderDimText("Type to search, ↑↓ to navigate, Enter to select, Esc to clear search")) } else { - b.WriteString(fmt.Sprintf("%sUse ↑↓ arrows to navigate, Enter to select, / to search, Esc/Ctrl+C to cancel%s", - m.themeService.GetCurrentTheme().GetDimColor(), colors.Reset)) + b.WriteString(m.styleProvider.RenderDimText("Use ↑↓ arrows to navigate, Enter to select, / to search, Esc/Ctrl+C to cancel")) } return b.String() diff --git a/internal/ui/components/tool_call_renderer.go b/internal/ui/components/tool_call_renderer.go index 82ca40d6..b645726f 100644 --- a/internal/ui/components/tool_call_renderer.go +++ b/internal/ui/components/tool_call_renderer.go @@ -7,10 +7,9 @@ import ( spinner "github.com/charmbracelet/bubbles/spinner" tea "github.com/charmbracelet/bubbletea" - lipgloss "github.com/charmbracelet/lipgloss" constants "github.com/inference-gateway/cli/internal/constants" domain "github.com/inference-gateway/cli/internal/domain" - colors "github.com/inference-gateway/cli/internal/ui/styles/colors" + styles "github.com/inference-gateway/cli/internal/ui/styles" icons "github.com/inference-gateway/cli/internal/ui/styles/icons" sdk "github.com/inference-gateway/sdk" ) @@ -21,7 +20,7 @@ type ToolCallRenderer struct { spinner spinner.Model toolPreviews map[string]*domain.ToolCallPreviewEvent toolPreviewsOrder []string - styles *toolRenderStyles + styleProvider *styles.Provider lastUpdate time.Time parallelTools map[string]*ParallelToolState parallelToolsOrder []string @@ -39,64 +38,20 @@ type ParallelToolState struct { MinShowTime time.Duration } -type toolRenderStyles struct { - statusStreaming lipgloss.Style - statusComplete lipgloss.Style - statusReady lipgloss.Style - statusDefault lipgloss.Style - toolCallMeta lipgloss.Style - toolCallArgs lipgloss.Style - spinner lipgloss.Style - toolName lipgloss.Style - argsContainer lipgloss.Style -} - type ToolInfo struct { Name string Prefix string } -func NewToolCallRenderer() *ToolCallRenderer { +func NewToolCallRenderer(styleProvider *styles.Provider) *ToolCallRenderer { s := spinner.New() s.Spinner = spinner.Dot - styles := &toolRenderStyles{ - statusStreaming: lipgloss.NewStyle(). - Foreground(colors.SpinnerColor.GetLipglossColor()). - Bold(true), - statusComplete: lipgloss.NewStyle(). - Foreground(colors.SuccessColor.GetLipglossColor()). - Bold(true), - statusReady: lipgloss.NewStyle(). - Foreground(colors.WarningColor.GetLipglossColor()). - Bold(true), - statusDefault: lipgloss.NewStyle(). - Foreground(colors.DimColor.GetLipglossColor()), - toolCallMeta: lipgloss.NewStyle(). - Foreground(colors.DimColor.GetLipglossColor()). - Italic(true), - toolCallArgs: lipgloss.NewStyle(). - Foreground(colors.DimColor.GetLipglossColor()). - MarginLeft(2), - spinner: lipgloss.NewStyle(). - Foreground(colors.SpinnerColor.GetLipglossColor()), - toolName: lipgloss.NewStyle(). - Foreground(colors.AccentColor.GetLipglossColor()). - Bold(true), - argsContainer: lipgloss.NewStyle(). - Border(lipgloss.RoundedBorder(), false, false, false, true). - BorderForeground(colors.DimColor.GetLipglossColor()). - PaddingLeft(2). - MarginTop(1), - } - - s.Style = styles.spinner - return &ToolCallRenderer{ spinner: s, toolPreviews: make(map[string]*domain.ToolCallPreviewEvent), parallelTools: make(map[string]*ParallelToolState), - styles: styles, + styleProvider: styleProvider, width: 80, } } @@ -190,8 +145,7 @@ func (r *ToolCallRenderer) SetWidth(width int) { } func (r *ToolCallRenderer) updateArgsContainerWidth() { - r.styles.argsContainer = r.styles.argsContainer.Width(r.width - 6) - r.styles.toolCallArgs = r.styles.toolCallArgs.Width(r.width - 8) + // Width is now handled dynamically by styleProvider methods } func (r *ToolCallRenderer) RenderPreviews() string { @@ -255,42 +209,37 @@ func (r *ToolCallRenderer) shouldShowPreview(*domain.ToolCallPreviewEvent) bool func (r *ToolCallRenderer) renderToolPreview(preview *domain.ToolCallPreviewEvent) string { var statusIcon string var statusText string - var statusStyle lipgloss.Style + var colorName string switch preview.Status { case domain.ToolCallStreamStatusStreaming: statusIcon = icons.GetSpinnerFrame(r.spinnerStep) statusText = "executing" - statusStyle = r.styles.statusStreaming + colorName = "spinner" case domain.ToolCallStreamStatusComplete: statusIcon = icons.CheckMark statusText = "completed" - statusStyle = r.styles.statusComplete + colorName = "success" case domain.ToolCallStreamStatusReady: statusIcon = icons.QueuedIcon statusText = "ready" - statusStyle = r.styles.statusDefault.Foreground(colors.DimColor.GetLipglossColor()) + colorName = "dim" default: statusIcon = icons.BulletIcon statusText = "unknown" - statusStyle = r.styles.statusDefault + colorName = "dim" } toolInfo := r.parseToolName(preview.ToolName) argsPreview := r.formatArgsPreview(preview.Arguments) - header := lipgloss.JoinHorizontal( - lipgloss.Left, - statusStyle.Render(fmt.Sprintf("%s %s:%s", statusIcon, toolInfo.Prefix, toolInfo.Name)), - r.styles.toolCallMeta.Render(fmt.Sprintf(" (%s)", statusText)), - ) + statusPart := r.styleProvider.RenderWithColor(fmt.Sprintf("%s %s:%s", statusIcon, toolInfo.Prefix, toolInfo.Name), r.styleProvider.GetThemeColor(colorName)) + metaPart := r.styleProvider.RenderDimText(fmt.Sprintf(" (%s)", statusText)) + header := statusPart + metaPart if argsPreview != "" { - return lipgloss.JoinVertical( - lipgloss.Left, - header, - r.styles.toolCallArgs.Render(fmt.Sprintf(" %s", argsPreview)), - ) + argsPart := r.styleProvider.RenderDimText(fmt.Sprintf(" %s", argsPreview)) + return r.styleProvider.JoinVertical(header, argsPart) } return header @@ -323,18 +272,17 @@ func (r *ToolCallRenderer) renderToolCallContent(toolInfo ToolInfo, arguments, s statusText = status } + toolNameColor := r.styleProvider.GetThemeColor("accent") var header string if toolInfo.Prefix != "TOOL" { - header = fmt.Sprintf("%s %s:%s (%s)", - statusIcon, - r.styles.toolName.Render(toolInfo.Prefix), - r.styles.toolName.Render(toolInfo.Name), - r.styles.toolCallMeta.Render(statusText)) + prefixPart := r.styleProvider.RenderWithColorAndBold(toolInfo.Prefix, toolNameColor) + namePart := r.styleProvider.RenderWithColorAndBold(toolInfo.Name, toolNameColor) + metaPart := r.styleProvider.RenderDimText(statusText) + header = fmt.Sprintf("%s %s:%s (%s)", statusIcon, prefixPart, namePart, metaPart) } else { - header = fmt.Sprintf("%s %s (%s)", - statusIcon, - r.styles.toolName.Render(toolInfo.Name), - r.styles.toolCallMeta.Render(statusText)) + namePart := r.styleProvider.RenderWithColorAndBold(toolInfo.Name, toolNameColor) + metaPart := r.styleProvider.RenderDimText(statusText) + header = fmt.Sprintf("%s %s (%s)", statusIcon, namePart, metaPart) } if arguments != "" && arguments != "{}" { @@ -343,8 +291,8 @@ func (r *ToolCallRenderer) renderToolCallContent(toolInfo ToolInfo, arguments, s args = args[:197] + "..." } - formattedArgs := r.styles.toolCallArgs.Render(args) - return fmt.Sprintf("%s\n%s", header, r.styles.argsContainer.Render(formattedArgs)) + formattedArgs := r.styleProvider.RenderDimText(args) + return r.styleProvider.JoinVertical(header, formattedArgs) } return header @@ -394,38 +342,35 @@ func (r *ToolCallRenderer) hasActiveParallelTools() bool { func (r *ToolCallRenderer) renderParallelTool(tool *ParallelToolState) string { var statusIcon string var statusText string - var statusStyle lipgloss.Style + var colorName string switch tool.Status { case "queued": statusIcon = icons.QueuedIcon statusText = "queued" - statusStyle = r.styles.statusDefault.Foreground(colors.DimColor.GetLipglossColor()) + colorName = "dim" case "running", "starting", "saving": statusIcon = icons.GetSpinnerFrame(r.spinnerStep) statusText = "executing" - statusStyle = r.styles.statusStreaming + colorName = "spinner" case "complete": statusIcon = icons.CheckMark statusText = "completed" - statusStyle = r.styles.statusComplete + colorName = "success" case "failed": statusIcon = icons.CrossMark statusText = "failed" - statusStyle = r.styles.statusComplete.Foreground(colors.ErrorColor.GetLipglossColor()) + colorName = "error" default: statusIcon = icons.BulletIcon statusText = tool.Status - statusStyle = r.styles.statusDefault + colorName = "dim" } toolInfo := r.parseToolName(tool.ToolName) - header := lipgloss.JoinHorizontal( - lipgloss.Left, - statusStyle.Render(fmt.Sprintf("%s %s:%s", statusIcon, toolInfo.Prefix, toolInfo.Name)), - r.styles.toolCallMeta.Render(fmt.Sprintf(" (%s)", statusText)), - ) + statusPart := r.styleProvider.RenderWithColor(fmt.Sprintf("%s %s:%s", statusIcon, toolInfo.Prefix, toolInfo.Name), r.styleProvider.GetThemeColor(colorName)) + metaPart := r.styleProvider.RenderDimText(fmt.Sprintf(" (%s)", statusText)) - return header + return statusPart + metaPart } diff --git a/internal/ui/components_factory_test.go b/internal/ui/components_factory_test.go index 0f961afd..d6e5d188 100644 --- a/internal/ui/components_factory_test.go +++ b/internal/ui/components_factory_test.go @@ -8,8 +8,43 @@ import ( shortcuts "github.com/inference-gateway/cli/internal/shortcuts" ) +// mockTheme implements domain.Theme for testing +type mockTheme struct{} + +func (m *mockTheme) GetUserColor() string { return "#00FF00" } +func (m *mockTheme) GetAssistantColor() string { return "#0000FF" } +func (m *mockTheme) GetErrorColor() string { return "#FF0000" } +func (m *mockTheme) GetSuccessColor() string { return "#00FF00" } +func (m *mockTheme) GetStatusColor() string { return "#FFFF00" } +func (m *mockTheme) GetAccentColor() string { return "#FF00FF" } +func (m *mockTheme) GetDimColor() string { return "#808080" } +func (m *mockTheme) GetBorderColor() string { return "#FFFFFF" } +func (m *mockTheme) GetDiffAddColor() string { return "#00FF00" } +func (m *mockTheme) GetDiffRemoveColor() string { return "#FF0000" } + +// mockThemeService implements domain.ThemeService for testing +type mockThemeService struct{} + +var _ domain.ThemeService = (*mockThemeService)(nil) + +func (m *mockThemeService) ListThemes() []string { + return []string{"default"} +} + +func (m *mockThemeService) GetCurrentTheme() domain.Theme { + return &mockTheme{} +} + +func (m *mockThemeService) GetCurrentThemeName() string { + return "default" +} + +func (m *mockThemeService) SetTheme(themeName string) error { + return nil +} + func TestCreateConversationView(t *testing.T) { - cv := CreateConversationView(nil) + cv := CreateConversationView(&mockThemeService{}) if cv == nil { t.Fatal("Expected CreateConversationView to return non-nil component") @@ -59,7 +94,7 @@ func TestCreateInputView(t *testing.T) { } func TestCreateStatusView(t *testing.T) { - sv := CreateStatusView(nil) + sv := CreateStatusView(&mockThemeService{}) if sv == nil { t.Fatal("Expected CreateStatusView to return non-nil component") @@ -67,7 +102,7 @@ func TestCreateStatusView(t *testing.T) { } func TestCreateHelpBar(t *testing.T) { - hb := CreateHelpBar(nil) + hb := CreateHelpBar(&mockThemeService{}) if hb == nil { t.Fatal("Expected CreateHelpBar to return non-nil component") diff --git a/internal/ui/components_test.go b/internal/ui/components_test.go index bea7da4a..af7afcdd 100644 --- a/internal/ui/components_test.go +++ b/internal/ui/components_test.go @@ -5,11 +5,12 @@ import ( "testing" domain "github.com/inference-gateway/cli/internal/domain" + components "github.com/inference-gateway/cli/internal/ui/components" sdk "github.com/inference-gateway/sdk" ) func TestConversationViewBasic(t *testing.T) { - cv := CreateConversationView(nil) + cv := CreateConversationView(&mockThemeService{}) cv.SetWidth(80) cv.SetHeight(5) @@ -37,6 +38,9 @@ func TestConversationViewBasic(t *testing.T) { func TestInputViewBasic(t *testing.T) { iv := CreateInputView(nil, nil) + if inputView, ok := iv.(*components.InputView); ok { + inputView.SetThemeService(&mockThemeService{}) + } iv.SetWidth(80) iv.SetHeight(5) @@ -60,7 +64,7 @@ func TestInputViewBasic(t *testing.T) { } func TestStatusViewBasic(t *testing.T) { - sv := CreateStatusView(nil) + sv := CreateStatusView(&mockThemeService{}) sv.SetWidth(80) sv.ShowStatus("Test status") @@ -91,7 +95,7 @@ func TestStatusViewBasic(t *testing.T) { } func TestHelpBarBasic(t *testing.T) { - hb := CreateHelpBar(nil) + hb := CreateHelpBar(&mockThemeService{}) hb.SetWidth(80) shortcuts := []KeyShortcut{ diff --git a/internal/ui/interfaces.go b/internal/ui/interfaces.go index 5c0059ca..3a62ac40 100644 --- a/internal/ui/interfaces.go +++ b/internal/ui/interfaces.go @@ -22,6 +22,7 @@ func NewDefaultTheme() *DefaultTheme { return &DefaultTheme{} } func (t *DefaultTheme) GetUserColor() string { return colors.UserColor.ANSI } func (t *DefaultTheme) GetAssistantColor() string { return colors.AssistantColor.ANSI } func (t *DefaultTheme) GetErrorColor() string { return colors.ErrorColor.ANSI } +func (t *DefaultTheme) GetSuccessColor() string { return colors.SuccessColor.ANSI } func (t *DefaultTheme) GetStatusColor() string { return colors.StatusColor.ANSI } func (t *DefaultTheme) GetAccentColor() string { return colors.AccentColor.ANSI } func (t *DefaultTheme) GetDimColor() string { return colors.DimColor.ANSI } diff --git a/internal/ui/keybinding/actions.go b/internal/ui/keybinding/actions.go index 74793fa5..1d00ed43 100644 --- a/internal/ui/keybinding/actions.go +++ b/internal/ui/keybinding/actions.go @@ -18,8 +18,9 @@ func (r *Registry) registerDefaultBindings() { globalActions := r.createGlobalActions() chatActions := r.createChatActions() scrollActions := r.createScrollActions() + approvalActions := r.createApprovalActions() - r.registerActionsToLayers(globalActions, chatActions, scrollActions) + r.registerActionsToLayers(globalActions, chatActions, scrollActions, approvalActions) } // createGlobalActions creates global key actions available in all views @@ -376,10 +377,65 @@ func (r *Registry) createScrollActions() []*KeyAction { } } +// createApprovalActions creates key actions specific to approval view +func (r *Registry) createApprovalActions() []*KeyAction { + return []*KeyAction{ + { + ID: "approval_left", + Keys: []string{"left", "h"}, + Description: "move selection left", + Category: "approval", + Handler: handleApprovalLeft, + Priority: 150, + Enabled: true, + Context: KeyContext{ + Views: []domain.ViewState{domain.ViewStateToolApproval}, + }, + }, + { + ID: "approval_right", + Keys: []string{"right", "l"}, + Description: "move selection right", + Category: "approval", + Handler: handleApprovalRight, + Priority: 150, + Enabled: true, + Context: KeyContext{ + Views: []domain.ViewState{domain.ViewStateToolApproval}, + }, + }, + { + ID: "approval_approve", + Keys: []string{"enter", "y"}, + Description: "approve tool execution", + Category: "approval", + Handler: handleApprovalApprove, + Priority: 150, + Enabled: true, + Context: KeyContext{ + Views: []domain.ViewState{domain.ViewStateToolApproval}, + }, + }, + { + ID: "approval_reject", + Keys: []string{"n"}, + Description: "reject tool execution", + Category: "approval", + Handler: handleApprovalReject, + Priority: 150, + Enabled: true, + Context: KeyContext{ + Views: []domain.ViewState{domain.ViewStateToolApproval}, + }, + }, + } +} + // registerActionsToLayers registers actions to their appropriate layers -func (r *Registry) registerActionsToLayers(globalActions, chatActions, scrollActions []*KeyAction) { +func (r *Registry) registerActionsToLayers(globalActions, chatActions, scrollActions, approvalActions []*KeyAction) { allActions := append(globalActions, chatActions...) allActions = append(allActions, scrollActions...) + allActions = append(allActions, approvalActions...) for _, action := range allActions { if err := r.Register(action); err != nil { @@ -398,6 +454,10 @@ func (r *Registry) registerActionsToLayers(globalActions, chatActions, scrollAct for _, action := range scrollActions { _ = r.addActionToLayer("chat_view", action) } + + for _, action := range approvalActions { + _ = r.addActionToLayer("approval_view", action) + } } // Handler implementations @@ -412,6 +472,20 @@ func handleCancel(app KeyHandlerContext, keyMsg tea.KeyMsg) tea.Cmd { } stateManager := app.GetStateManager() + + // If we're in approval view, reject the approval and transition back + if stateManager.GetCurrentView() == domain.ViewStateToolApproval { + approvalState := stateManager.GetApprovalUIState() + if approvalState != nil && approvalState.ResponseChan != nil { + return func() tea.Msg { + return domain.ToolApprovalResponseEvent{ + Action: domain.ApprovalReject, + ToolCall: *approvalState.PendingToolCall, + } + } + } + } + if chatSession := stateManager.GetChatSession(); chatSession != nil { agentService := app.GetAgentService() if agentService != nil { @@ -921,3 +995,70 @@ func handlePasteEvent(app KeyHandlerContext, pastedText string) tea.Cmd { return nil } + +// Approval handlers +func handleApprovalLeft(app KeyHandlerContext, keyMsg tea.KeyMsg) tea.Cmd { + stateManager := app.GetStateManager() + approvalState := stateManager.GetApprovalUIState() + if approvalState == nil { + return nil + } + + selectedIndex := approvalState.SelectedIndex + if selectedIndex > int(domain.ApprovalApprove) { + selectedIndex-- + stateManager.SetApprovalSelectedIndex(selectedIndex) + } + return nil +} + +func handleApprovalRight(app KeyHandlerContext, keyMsg tea.KeyMsg) tea.Cmd { + stateManager := app.GetStateManager() + approvalState := stateManager.GetApprovalUIState() + if approvalState == nil { + return nil + } + + selectedIndex := approvalState.SelectedIndex + if selectedIndex < int(domain.ApprovalReject) { + selectedIndex++ + stateManager.SetApprovalSelectedIndex(selectedIndex) + } + return nil +} + +func handleApprovalApprove(app KeyHandlerContext, keyMsg tea.KeyMsg) tea.Cmd { + stateManager := app.GetStateManager() + approvalState := stateManager.GetApprovalUIState() + if approvalState == nil { + return nil + } + + // If user is on "Approve" or presses enter/y, approve the tool + action := domain.ApprovalAction(approvalState.SelectedIndex) + if action == domain.ApprovalApprove || keyMsg.String() == "y" { + action = domain.ApprovalApprove + } + + return func() tea.Msg { + return domain.ToolApprovalResponseEvent{ + Action: action, + ToolCall: *approvalState.PendingToolCall, + } + } +} + +func handleApprovalReject(app KeyHandlerContext, keyMsg tea.KeyMsg) tea.Cmd { + return func() tea.Msg { + stateManager := app.GetStateManager() + approvalState := stateManager.GetApprovalUIState() + if approvalState == nil { + return nil + } + + return domain.ToolApprovalResponseEvent{ + Action: domain.ApprovalReject, + ToolCall: *approvalState.PendingToolCall, + } + } +} diff --git a/internal/ui/keybinding/registry.go b/internal/ui/keybinding/registry.go index a43cdb8f..a7c9846e 100644 --- a/internal/ui/keybinding/registry.go +++ b/internal/ui/keybinding/registry.go @@ -261,6 +261,15 @@ func (r *Registry) initializeLayers() { }, }) + r.AddLayer(&KeyLayer{ + Name: "approval_view", + Priority: 150, + Bindings: make(map[string]*KeyAction), + Matcher: func(app KeyHandlerContext) bool { + return app.GetStateManager().GetCurrentView() == domain.ViewStateToolApproval + }, + }) + r.AddLayer(&KeyLayer{ Name: "global", Priority: 100, diff --git a/internal/ui/shared/interfaces.go b/internal/ui/shared/interfaces.go index 362ebfb7..c5bdaa23 100644 --- a/internal/ui/shared/interfaces.go +++ b/internal/ui/shared/interfaces.go @@ -23,6 +23,7 @@ type Theme interface { GetUserColor() string GetAssistantColor() string GetErrorColor() string + GetSuccessColor() string GetStatusColor() string GetAccentColor() string GetDimColor() string diff --git a/internal/ui/styles/colors/colors.go b/internal/ui/styles/colors/colors.go index 2dda10d3..8224f6d5 100644 --- a/internal/ui/styles/colors/colors.go +++ b/internal/ui/styles/colors/colors.go @@ -90,6 +90,7 @@ var ( GithubUserColor = Color{ANSI: "\033[38;2;3;102;214m", Lipgloss: GithubBlue} GithubAssistantColor = Color{ANSI: "\033[38;2;36;41;46m", Lipgloss: GithubDarkGray} GithubErrorColor = Color{ANSI: "\033[38;2;215;58;73m", Lipgloss: GithubRed} + GithubSuccessColor = Color{ANSI: "\033[38;2;40;167;69m", Lipgloss: GithubGreen} GithubStatusColor = Color{ANSI: "\033[38;2;130;87;223m", Lipgloss: GithubPurple} GithubAccentColor = Color{ANSI: "\033[38;2;3;102;214m", Lipgloss: GithubBlue} GithubDimColor = Color{ANSI: "\033[38;2;88;96;105m", Lipgloss: GithubGray} @@ -101,6 +102,7 @@ var ( DraculaUserColor = Color{ANSI: "\033[38;2;139;233;253m", Lipgloss: DraculaCyan} DraculaAssistantColor = Color{ANSI: "\033[38;2;248;248;242m", Lipgloss: DraculaForeground} DraculaErrorColor = Color{ANSI: "\033[38;2;255;85;85m", Lipgloss: DraculaRed} + DraculaSuccessColor = Color{ANSI: "\033[38;2;80;250;123m", Lipgloss: DraculaGreen} DraculaStatusColor = Color{ANSI: "\033[38;2;189;147;249m", Lipgloss: DraculaPurple} DraculaAccentColor = Color{ANSI: "\033[38;2;255;121;198m", Lipgloss: DraculaPink} DraculaDimColor = Color{ANSI: "\033[38;2;98;114;164m", Lipgloss: DraculaComment} diff --git a/internal/ui/styles/provider.go b/internal/ui/styles/provider.go new file mode 100644 index 00000000..ab4d8fd2 --- /dev/null +++ b/internal/ui/styles/provider.go @@ -0,0 +1,548 @@ +package styles + +import ( + "strings" + + "github.com/charmbracelet/lipgloss" + "github.com/inference-gateway/cli/internal/domain" +) + +// Provider centralizes all styling logic and provides complete abstraction from Lipgloss. +// Components should NEVER import lipgloss directly - they interact with styling through this provider. +type Provider struct { + themeService domain.ThemeService +} + +// NewProvider creates a new style provider +func NewProvider(themeService domain.ThemeService) *Provider { + return &Provider{ + themeService: themeService, + } +} + +// Modal styles + +// RenderModal renders a modal with rounded border +func (p *Provider) RenderModal(content string, width int) string { + theme := p.themeService.GetCurrentTheme() + style := lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color(theme.GetBorderColor())). + Padding(1, 2). + Width(width) + return style.Render(content) +} + +// RenderModalTitle renders a modal title with emphasis +func (p *Provider) RenderModalTitle(title string) string { + theme := p.themeService.GetCurrentTheme() + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetAccentColor())). + Bold(true). + Padding(0, 1) + return style.Render(title) +} + +// List/Selection styles + +// RenderListItem renders a list item (selected or unselected) +func (p *Provider) RenderListItem(content string, selected bool) string { + theme := p.themeService.GetCurrentTheme() + + if selected { + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetAccentColor())). + Bold(true) + return "▶ " + style.Render(content) + } + + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetDimColor())) + return " " + style.Render(content) +} + +// RenderListItemWithDescription renders a list item with a description +func (p *Provider) RenderListItemWithDescription(title, description string, selected bool) string { + theme := p.themeService.GetCurrentTheme() + + var titleStyle, descStyle lipgloss.Style + prefix := " " + + if selected { + titleStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetAccentColor())). + Bold(true) + descStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetDimColor())) + prefix = "▶ " + } else { + titleStyle = lipgloss.NewStyle() + descStyle = lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetDimColor())) + } + + return prefix + titleStyle.Render(title) + "\n " + descStyle.Render(description) +} + +// Input styles + +// RenderInputField renders an input field with border +func (p *Provider) RenderInputField(content string, width int, focused bool) string { + theme := p.themeService.GetCurrentTheme() + + borderColor := theme.GetBorderColor() + if focused { + borderColor = theme.GetAccentColor() + } + + style := lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color(borderColor)). + Padding(0, 1). + Width(width) + + return style.Render(content) +} + +// RenderInputPlaceholder renders placeholder text +func (p *Provider) RenderInputPlaceholder(text string) string { + theme := p.themeService.GetCurrentTheme() + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetDimColor())). + Italic(true) + return style.Render(text) +} + +// Button/Option styles + +// RenderButton renders a button or selectable option +func (p *Provider) RenderButton(text string, selected bool) string { + theme := p.themeService.GetCurrentTheme() + + if selected { + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetAccentColor())). + Background(lipgloss.Color(theme.GetBorderColor())). + Bold(true). + Padding(0, 2) + return style.Render(text) + } + + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetDimColor())). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color(theme.GetBorderColor())). + Padding(0, 2) + return style.Render(text) +} + +// RenderApprovalButton renders an approval-style button with custom colors +func (p *Provider) RenderApprovalButton(text string, selected bool, isApprove bool) string { + theme := p.themeService.GetCurrentTheme() + + borderColor := theme.GetAccentColor() + if !isApprove { + borderColor = theme.GetErrorColor() + } + + if selected { + bgColor := borderColor + fgColor := "#000000" + if !isApprove { + fgColor = "#ffffff" + } + + style := lipgloss.NewStyle(). + Padding(0, 2). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color(borderColor)). + Background(lipgloss.Color(bgColor)). + Foreground(lipgloss.Color(fgColor)). + Bold(true) + return style.Render(text) + } + + style := lipgloss.NewStyle(). + Padding(0, 2). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color(borderColor)) + return style.Render(text) +} + +// Text styles + +// RenderUserText renders text in the user color +func (p *Provider) RenderUserText(text string) string { + theme := p.themeService.GetCurrentTheme() + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetUserColor())) + return style.Render(text) +} + +// RenderAssistantText renders text in the assistant color +func (p *Provider) RenderAssistantText(text string) string { + theme := p.themeService.GetCurrentTheme() + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetAssistantColor())) + return style.Render(text) +} + +// RenderErrorText renders text in the error color +func (p *Provider) RenderErrorText(text string) string { + theme := p.themeService.GetCurrentTheme() + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetErrorColor())). + Bold(true) + return style.Render(text) +} + +// RenderSuccessText renders text in the success color +func (p *Provider) RenderSuccessText(text string) string { + theme := p.themeService.GetCurrentTheme() + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetSuccessColor())). + Bold(true) + return style.Render(text) +} + +// RenderWarningText renders text in the warning/status color +func (p *Provider) RenderWarningText(text string) string { + theme := p.themeService.GetCurrentTheme() + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetStatusColor())) + return style.Render(text) +} + +// RenderDimText renders text in a dimmed style +func (p *Provider) RenderDimText(text string) string { + theme := p.themeService.GetCurrentTheme() + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetDimColor())) + return style.Render(text) +} + +// RenderBoldText renders bold text +func (p *Provider) RenderBoldText(text string) string { + style := lipgloss.NewStyle().Bold(true) + return style.Render(text) +} + +// Layout/Structure styles + +// RenderSeparator renders a horizontal separator line +func (p *Provider) RenderSeparator(width int, char string) string { + theme := p.themeService.GetCurrentTheme() + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetDimColor())) + return style.Render(strings.Repeat(char, width)) +} + +// RenderHeader renders a centered header +func (p *Provider) RenderHeader(text string, width int) string { + theme := p.themeService.GetCurrentTheme() + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetAccentColor())). + Bold(true). + Width(width). + Align(lipgloss.Center) + return style.Render(text) +} + +// RenderBordered renders content with a border +func (p *Provider) RenderBordered(content string, width int) string { + theme := p.themeService.GetCurrentTheme() + style := lipgloss.NewStyle(). + Border(lipgloss.NormalBorder()). + BorderForeground(lipgloss.Color(theme.GetBorderColor())). + Padding(1, 2). + Width(width) + return style.Render(content) +} + +// Diff/Code styles + +// RenderDiffAddition renders a diff addition line +func (p *Provider) RenderDiffAddition(content string) string { + theme := p.themeService.GetCurrentTheme() + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetDiffAddColor())) + return style.Render("+ " + content) +} + +// RenderDiffRemoval renders a diff removal line +func (p *Provider) RenderDiffRemoval(content string) string { + theme := p.themeService.GetCurrentTheme() + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetDiffRemoveColor())) + return style.Render("- " + content) +} + +// RenderCodeBlock renders a code block with subtle background +func (p *Provider) RenderCodeBlock(code string, width int) string { + theme := p.themeService.GetCurrentTheme() + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetAssistantColor())). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color(theme.GetBorderColor())). + Padding(1, 2). + Width(width) + return style.Render(code) +} + +// Status/Badge styles + +// RenderStatusBadge renders a status badge (e.g., "ENABLED", "DISABLED") +func (p *Provider) RenderStatusBadge(text string, positive bool) string { + theme := p.themeService.GetCurrentTheme() + + var color string + if positive { + color = theme.GetSuccessColor() + } else { + color = theme.GetErrorColor() + } + + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(color)). + Bold(true) + return style.Render(text) +} + +// RenderSpinner renders a spinner with status color +func (p *Provider) RenderSpinner(frame string) string { + theme := p.themeService.GetCurrentTheme() + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(theme.GetStatusColor())) + return style.Render(frame) +} + +// Utility methods + +// GetThemeColor returns a theme color for custom styling (use sparingly) +func (p *Provider) GetThemeColor(colorName string) string { + theme := p.themeService.GetCurrentTheme() + + switch colorName { + case "user": + return theme.GetUserColor() + case "assistant": + return theme.GetAssistantColor() + case "error": + return theme.GetErrorColor() + case "success": + return theme.GetSuccessColor() + case "status": + return theme.GetStatusColor() + case "accent": + return theme.GetAccentColor() + case "dim": + return theme.GetDimColor() + case "border": + return theme.GetBorderColor() + case "diffAdd": + return theme.GetDiffAddColor() + case "diffRemove": + return theme.GetDiffRemoveColor() + default: + return theme.GetAssistantColor() + } +} + +// Layout utilities to avoid components depending on lipgloss + +// JoinVertical joins strings vertically +func (p *Provider) JoinVertical(strs ...string) string { + return lipgloss.JoinVertical(lipgloss.Left, strs...) +} + +// JoinHorizontal joins strings horizontally +func (p *Provider) JoinHorizontal(strs ...string) string { + return lipgloss.JoinHorizontal(lipgloss.Top, strs...) +} + +// PlaceCenter places content in the center of the given dimensions +func (p *Provider) PlaceCenter(width, height int, content string) string { + return lipgloss.Place(width, height, lipgloss.Center, lipgloss.Center, content) +} + +// PlaceCenterTop places content in the center-top of the given dimensions +func (p *Provider) PlaceCenterTop(width, height int, content string) string { + return lipgloss.Place(width, height, lipgloss.Center, lipgloss.Top, content) +} + +// GetHeight returns the rendered height of a string +func (p *Provider) GetHeight(s string) int { + return lipgloss.Height(s) +} + +// GetWidth returns the rendered width of a string +func (p *Provider) GetWidth(s string) int { + return lipgloss.Width(s) +} + +// Custom rendering - for complex styling needs + +// RenderTextSelectionCursor renders a cursor character for text selection mode +func (p *Provider) RenderTextSelectionCursor(char string) string { + style := lipgloss.NewStyle(). + Background(lipgloss.Color("#ffffff")). + Foreground(lipgloss.Color("#000000")) + return style.Render(char) +} + +// RenderWithColor renders text with a specific hex color +func (p *Provider) RenderWithColor(text, hexColor string) string { + style := lipgloss.NewStyle().Foreground(lipgloss.Color(hexColor)) + return style.Render(text) +} + +// RenderWithColorAndBold renders text with color and bold +func (p *Provider) RenderWithColorAndBold(text, hexColor string) string { + style := lipgloss.NewStyle(). + Foreground(lipgloss.Color(hexColor)). + Bold(true) + return style.Render(text) +} + +// RenderBold renders text with bold styling +func (p *Provider) RenderBold(text string) string { + return lipgloss.NewStyle().Bold(true).Render(text) +} + +// RenderStyledText renders text with custom Lipgloss-compatible styling +// This is an escape hatch for complex styling not covered by other methods +func (p *Provider) RenderStyledText(text string, opts StyleOptions) string { + style := lipgloss.NewStyle() + + if opts.Foreground != "" { + style = style.Foreground(lipgloss.Color(opts.Foreground)) + } + if opts.Background != "" { + style = style.Background(lipgloss.Color(opts.Background)) + } + if opts.Bold { + style = style.Bold(true) + } + if opts.Italic { + style = style.Italic(true) + } + if opts.Faint { + style = style.Faint(true) + } + if opts.Width > 0 { + style = style.Width(opts.Width) + } + if opts.Padding[0] > 0 || opts.Padding[1] > 0 { + style = style.Padding(opts.Padding[0], opts.Padding[1]) + } + if opts.MarginBottom > 0 { + style = style.MarginBottom(opts.MarginBottom) + } + if opts.MarginTop > 0 { + style = style.MarginTop(opts.MarginTop) + } + + return style.Render(text) +} + +// StyleOptions provides options for custom text styling +type StyleOptions struct { + Foreground string + Background string + Bold bool + Italic bool + Faint bool + Width int + Padding [2]int + MarginBottom int + MarginTop int +} + +// RenderTextSelection renders text with selection highlighting (accent background) +func (p *Provider) RenderTextSelection(text string) string { + theme := p.themeService.GetCurrentTheme() + style := lipgloss.NewStyle(). + Background(lipgloss.Color(theme.GetAccentColor())). + Foreground(lipgloss.Color("#ffffff")) + return style.Render(text) +} + +// RenderCursor renders text with cursor styling +func (p *Provider) RenderCursor(text string) string { + style := lipgloss.NewStyle(). + Background(lipgloss.Color("#00FFFF")). + Foreground(lipgloss.Color("#000000")) + return style.Render(text) +} + +// RenderVisualLineSelection renders a full line with visual line selection styling +func (p *Provider) RenderVisualLineSelection(text string) string { + theme := p.themeService.GetCurrentTheme() + style := lipgloss.NewStyle(). + Background(lipgloss.Color(theme.GetAccentColor())). + Foreground(lipgloss.Color("#000000")) + return style.Render(text) +} + +// RenderBorderedBox renders text inside a rounded border with padding +func (p *Provider) RenderBorderedBox(text, borderColor string, paddingV, paddingH int) string { + style := lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder(), true). + BorderForeground(lipgloss.Color(borderColor)). + Padding(paddingV, paddingH) + return style.Render(text) +} + +// RenderCenteredBoldWithColor renders text centered, bold, and with a specific color +func (p *Provider) RenderCenteredBoldWithColor(text, hexColor string, width int) string { + style := lipgloss.NewStyle(). + Width(width). + Align(lipgloss.Center). + Foreground(lipgloss.Color(hexColor)). + Bold(true). + Padding(0, 1) + return style.Render(text) +} + +// RenderCenteredBorderedBox renders text in a centered bordered box with specified dimensions +func (p *Provider) RenderCenteredBorderedBox(text, borderColor string, width, height, paddingV, paddingH int) string { + style := lipgloss.NewStyle(). + Width(width). + Height(height). + Align(lipgloss.Center, lipgloss.Center). + Border(lipgloss.RoundedBorder(), true). + BorderForeground(lipgloss.Color(borderColor)). + Padding(paddingV, paddingH) + return style.Render(text) +} + +// RenderLeftAlignedBorderedBox renders text in a left-aligned bordered box with specified dimensions +func (p *Provider) RenderLeftAlignedBorderedBox(text, borderColor string, width, height, paddingV, paddingH int) string { + style := lipgloss.NewStyle(). + Width(width). + Height(height). + Align(lipgloss.Left, lipgloss.Center). + Border(lipgloss.RoundedBorder(), true). + BorderForeground(lipgloss.Color(borderColor)). + Padding(paddingV, paddingH) + return style.Render(text) +} + +// RenderTopAlignedBorderedBox renders text in a top-left aligned bordered box with specified dimensions +func (p *Provider) RenderTopAlignedBorderedBox(text, borderColor string, width, height, paddingV, paddingH int) string { + style := lipgloss.NewStyle(). + Width(width). + Height(height). + Align(lipgloss.Left, lipgloss.Top). + Border(lipgloss.RoundedBorder(), true). + BorderForeground(lipgloss.Color(borderColor)). + Padding(paddingV, paddingH) + return style.Render(text) +} + +// GetSpinnerStyle returns a lipgloss.Style for use with third-party components like Bubbles spinner +// This is an exception to complete abstraction, needed for library compatibility +func (p *Provider) GetSpinnerStyle() lipgloss.Style { + theme := p.themeService.GetCurrentTheme() + return lipgloss.NewStyle().Foreground(lipgloss.Color(theme.GetStatusColor())) +} diff --git a/internal/ui/themes.go b/internal/ui/themes.go index 4f39e962..d1e79569 100644 --- a/internal/ui/themes.go +++ b/internal/ui/themes.go @@ -91,15 +91,16 @@ func NewTokyoNightTheme() *TokyoNightTheme { return &TokyoNightTheme{} } -func (t *TokyoNightTheme) GetUserColor() string { return colors.UserColor.ANSI } -func (t *TokyoNightTheme) GetAssistantColor() string { return colors.AssistantColor.ANSI } -func (t *TokyoNightTheme) GetErrorColor() string { return colors.ErrorColor.ANSI } -func (t *TokyoNightTheme) GetStatusColor() string { return colors.StatusColor.ANSI } -func (t *TokyoNightTheme) GetAccentColor() string { return colors.AccentColor.ANSI } -func (t *TokyoNightTheme) GetDimColor() string { return colors.DimColor.ANSI } -func (t *TokyoNightTheme) GetBorderColor() string { return colors.BorderColor.ANSI } -func (t *TokyoNightTheme) GetDiffAddColor() string { return colors.DiffAddColor.ANSI } -func (t *TokyoNightTheme) GetDiffRemoveColor() string { return colors.DiffRemoveColor.ANSI } +func (t *TokyoNightTheme) GetUserColor() string { return colors.UserColor.Lipgloss } +func (t *TokyoNightTheme) GetAssistantColor() string { return colors.AssistantColor.Lipgloss } +func (t *TokyoNightTheme) GetErrorColor() string { return colors.ErrorColor.Lipgloss } +func (t *TokyoNightTheme) GetSuccessColor() string { return colors.SuccessColor.Lipgloss } +func (t *TokyoNightTheme) GetStatusColor() string { return colors.StatusColor.Lipgloss } +func (t *TokyoNightTheme) GetAccentColor() string { return colors.AccentColor.Lipgloss } +func (t *TokyoNightTheme) GetDimColor() string { return colors.DimColor.Lipgloss } +func (t *TokyoNightTheme) GetBorderColor() string { return colors.BorderColor.Lipgloss } +func (t *TokyoNightTheme) GetDiffAddColor() string { return colors.DiffAddColor.Lipgloss } +func (t *TokyoNightTheme) GetDiffRemoveColor() string { return colors.DiffRemoveColor.Lipgloss } // GithubLightTheme provides a light theme similar to GitHub's interface type GithubLightTheme struct{} @@ -111,6 +112,7 @@ func NewGithubLightTheme() *GithubLightTheme { func (t *GithubLightTheme) GetUserColor() string { return colors.GithubUserColor.Lipgloss } func (t *GithubLightTheme) GetAssistantColor() string { return colors.GithubAssistantColor.Lipgloss } func (t *GithubLightTheme) GetErrorColor() string { return colors.GithubErrorColor.Lipgloss } +func (t *GithubLightTheme) GetSuccessColor() string { return colors.GithubSuccessColor.Lipgloss } func (t *GithubLightTheme) GetStatusColor() string { return colors.GithubStatusColor.Lipgloss } func (t *GithubLightTheme) GetAccentColor() string { return colors.GithubAccentColor.Lipgloss } func (t *GithubLightTheme) GetDimColor() string { return colors.GithubDimColor.Lipgloss } @@ -128,6 +130,7 @@ func NewDraculaTheme() *DraculaTheme { func (t *DraculaTheme) GetUserColor() string { return colors.DraculaUserColor.Lipgloss } func (t *DraculaTheme) GetAssistantColor() string { return colors.DraculaAssistantColor.Lipgloss } func (t *DraculaTheme) GetErrorColor() string { return colors.DraculaErrorColor.Lipgloss } +func (t *DraculaTheme) GetSuccessColor() string { return colors.DraculaSuccessColor.Lipgloss } func (t *DraculaTheme) GetStatusColor() string { return colors.DraculaStatusColor.Lipgloss } func (t *DraculaTheme) GetAccentColor() string { return colors.DraculaAccentColor.Lipgloss } func (t *DraculaTheme) GetDimColor() string { return colors.DraculaDimColor.Lipgloss } diff --git a/tests/mocks/generated/fake_config_service.go b/tests/mocks/generated/fake_config_service.go index d669e854..1521f978 100644 --- a/tests/mocks/generated/fake_config_service.go +++ b/tests/mocks/generated/fake_config_service.go @@ -90,6 +90,17 @@ type FakeConfigService struct { isApprovalRequiredReturnsOnCall map[int]struct { result1 bool } + IsBashCommandWhitelistedStub func(string) bool + isBashCommandWhitelistedMutex sync.RWMutex + isBashCommandWhitelistedArgsForCall []struct { + arg1 string + } + isBashCommandWhitelistedReturns struct { + result1 bool + } + isBashCommandWhitelistedReturnsOnCall map[int]struct { + result1 bool + } invocations map[string][][]interface{} invocationsMutex sync.RWMutex } @@ -526,6 +537,67 @@ func (fake *FakeConfigService) IsApprovalRequiredReturnsOnCall(i int, result1 bo }{result1} } +func (fake *FakeConfigService) IsBashCommandWhitelisted(arg1 string) bool { + fake.isBashCommandWhitelistedMutex.Lock() + ret, specificReturn := fake.isBashCommandWhitelistedReturnsOnCall[len(fake.isBashCommandWhitelistedArgsForCall)] + fake.isBashCommandWhitelistedArgsForCall = append(fake.isBashCommandWhitelistedArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.IsBashCommandWhitelistedStub + fakeReturns := fake.isBashCommandWhitelistedReturns + fake.recordInvocation("IsBashCommandWhitelisted", []interface{}{arg1}) + fake.isBashCommandWhitelistedMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeConfigService) IsBashCommandWhitelistedCallCount() int { + fake.isBashCommandWhitelistedMutex.RLock() + defer fake.isBashCommandWhitelistedMutex.RUnlock() + return len(fake.isBashCommandWhitelistedArgsForCall) +} + +func (fake *FakeConfigService) IsBashCommandWhitelistedCalls(stub func(string) bool) { + fake.isBashCommandWhitelistedMutex.Lock() + defer fake.isBashCommandWhitelistedMutex.Unlock() + fake.IsBashCommandWhitelistedStub = stub +} + +func (fake *FakeConfigService) IsBashCommandWhitelistedArgsForCall(i int) string { + fake.isBashCommandWhitelistedMutex.RLock() + defer fake.isBashCommandWhitelistedMutex.RUnlock() + argsForCall := fake.isBashCommandWhitelistedArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeConfigService) IsBashCommandWhitelistedReturns(result1 bool) { + fake.isBashCommandWhitelistedMutex.Lock() + defer fake.isBashCommandWhitelistedMutex.Unlock() + fake.IsBashCommandWhitelistedStub = nil + fake.isBashCommandWhitelistedReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeConfigService) IsBashCommandWhitelistedReturnsOnCall(i int, result1 bool) { + fake.isBashCommandWhitelistedMutex.Lock() + defer fake.isBashCommandWhitelistedMutex.Unlock() + fake.IsBashCommandWhitelistedStub = nil + if fake.isBashCommandWhitelistedReturnsOnCall == nil { + fake.isBashCommandWhitelistedReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isBashCommandWhitelistedReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + func (fake *FakeConfigService) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() diff --git a/tests/mocks/generated/fake_state_manager.go b/tests/mocks/generated/fake_state_manager.go index c5effe32..b99fff70 100644 --- a/tests/mocks/generated/fake_state_manager.go +++ b/tests/mocks/generated/fake_state_manager.go @@ -15,6 +15,10 @@ type FakeStateManager struct { arg1 domain.Message arg2 string } + ClearApprovalUIStateStub func() + clearApprovalUIStateMutex sync.RWMutex + clearApprovalUIStateArgsForCall []struct { + } ClearFileSelectionStateStub func() clearFileSelectionStateMutex sync.RWMutex clearFileSelectionStateArgsForCall []struct { @@ -53,6 +57,16 @@ type FakeStateManager struct { failCurrentToolReturnsOnCall map[int]struct { result1 error } + GetApprovalUIStateStub func() *domain.ApprovalUIState + getApprovalUIStateMutex sync.RWMutex + getApprovalUIStateArgsForCall []struct { + } + getApprovalUIStateReturns struct { + result1 *domain.ApprovalUIState + } + getApprovalUIStateReturnsOnCall map[int]struct { + result1 *domain.ApprovalUIState + } GetChatSessionStub func() *domain.ChatSession getChatSessionMutex sync.RWMutex getChatSessionArgsForCall []struct { @@ -135,6 +149,11 @@ type FakeStateManager struct { popQueuedMessageReturnsOnCall map[int]struct { result1 *domain.QueuedMessage } + SetApprovalSelectedIndexStub func(int) + setApprovalSelectedIndexMutex sync.RWMutex + setApprovalSelectedIndexArgsForCall []struct { + arg1 int + } SetDimensionsStub func(int, int) setDimensionsMutex sync.RWMutex setDimensionsArgsForCall []struct { @@ -146,6 +165,12 @@ type FakeStateManager struct { setFileSelectedIndexArgsForCall []struct { arg1 int } + SetupApprovalUIStateStub func(*sdk.ChatCompletionMessageToolCall, chan domain.ApprovalAction) + setupApprovalUIStateMutex sync.RWMutex + setupApprovalUIStateArgsForCall []struct { + arg1 *sdk.ChatCompletionMessageToolCall + arg2 chan domain.ApprovalAction + } SetupFileSelectionStub func([]string) setupFileSelectionMutex sync.RWMutex setupFileSelectionArgsForCall []struct { @@ -239,6 +264,30 @@ func (fake *FakeStateManager) AddQueuedMessageArgsForCall(i int) (domain.Message return argsForCall.arg1, argsForCall.arg2 } +func (fake *FakeStateManager) ClearApprovalUIState() { + fake.clearApprovalUIStateMutex.Lock() + fake.clearApprovalUIStateArgsForCall = append(fake.clearApprovalUIStateArgsForCall, struct { + }{}) + stub := fake.ClearApprovalUIStateStub + fake.recordInvocation("ClearApprovalUIState", []interface{}{}) + fake.clearApprovalUIStateMutex.Unlock() + if stub != nil { + fake.ClearApprovalUIStateStub() + } +} + +func (fake *FakeStateManager) ClearApprovalUIStateCallCount() int { + fake.clearApprovalUIStateMutex.RLock() + defer fake.clearApprovalUIStateMutex.RUnlock() + return len(fake.clearApprovalUIStateArgsForCall) +} + +func (fake *FakeStateManager) ClearApprovalUIStateCalls(stub func()) { + fake.clearApprovalUIStateMutex.Lock() + defer fake.clearApprovalUIStateMutex.Unlock() + fake.ClearApprovalUIStateStub = stub +} + func (fake *FakeStateManager) ClearFileSelectionState() { fake.clearFileSelectionStateMutex.Lock() fake.clearFileSelectionStateArgsForCall = append(fake.clearFileSelectionStateArgsForCall, struct { @@ -457,6 +506,59 @@ func (fake *FakeStateManager) FailCurrentToolReturnsOnCall(i int, result1 error) }{result1} } +func (fake *FakeStateManager) GetApprovalUIState() *domain.ApprovalUIState { + fake.getApprovalUIStateMutex.Lock() + ret, specificReturn := fake.getApprovalUIStateReturnsOnCall[len(fake.getApprovalUIStateArgsForCall)] + fake.getApprovalUIStateArgsForCall = append(fake.getApprovalUIStateArgsForCall, struct { + }{}) + stub := fake.GetApprovalUIStateStub + fakeReturns := fake.getApprovalUIStateReturns + fake.recordInvocation("GetApprovalUIState", []interface{}{}) + fake.getApprovalUIStateMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeStateManager) GetApprovalUIStateCallCount() int { + fake.getApprovalUIStateMutex.RLock() + defer fake.getApprovalUIStateMutex.RUnlock() + return len(fake.getApprovalUIStateArgsForCall) +} + +func (fake *FakeStateManager) GetApprovalUIStateCalls(stub func() *domain.ApprovalUIState) { + fake.getApprovalUIStateMutex.Lock() + defer fake.getApprovalUIStateMutex.Unlock() + fake.GetApprovalUIStateStub = stub +} + +func (fake *FakeStateManager) GetApprovalUIStateReturns(result1 *domain.ApprovalUIState) { + fake.getApprovalUIStateMutex.Lock() + defer fake.getApprovalUIStateMutex.Unlock() + fake.GetApprovalUIStateStub = nil + fake.getApprovalUIStateReturns = struct { + result1 *domain.ApprovalUIState + }{result1} +} + +func (fake *FakeStateManager) GetApprovalUIStateReturnsOnCall(i int, result1 *domain.ApprovalUIState) { + fake.getApprovalUIStateMutex.Lock() + defer fake.getApprovalUIStateMutex.Unlock() + fake.GetApprovalUIStateStub = nil + if fake.getApprovalUIStateReturnsOnCall == nil { + fake.getApprovalUIStateReturnsOnCall = make(map[int]struct { + result1 *domain.ApprovalUIState + }) + } + fake.getApprovalUIStateReturnsOnCall[i] = struct { + result1 *domain.ApprovalUIState + }{result1} +} + func (fake *FakeStateManager) GetChatSession() *domain.ChatSession { fake.getChatSessionMutex.Lock() ret, specificReturn := fake.getChatSessionReturnsOnCall[len(fake.getChatSessionArgsForCall)] @@ -884,6 +986,38 @@ func (fake *FakeStateManager) PopQueuedMessageReturnsOnCall(i int, result1 *doma }{result1} } +func (fake *FakeStateManager) SetApprovalSelectedIndex(arg1 int) { + fake.setApprovalSelectedIndexMutex.Lock() + fake.setApprovalSelectedIndexArgsForCall = append(fake.setApprovalSelectedIndexArgsForCall, struct { + arg1 int + }{arg1}) + stub := fake.SetApprovalSelectedIndexStub + fake.recordInvocation("SetApprovalSelectedIndex", []interface{}{arg1}) + fake.setApprovalSelectedIndexMutex.Unlock() + if stub != nil { + fake.SetApprovalSelectedIndexStub(arg1) + } +} + +func (fake *FakeStateManager) SetApprovalSelectedIndexCallCount() int { + fake.setApprovalSelectedIndexMutex.RLock() + defer fake.setApprovalSelectedIndexMutex.RUnlock() + return len(fake.setApprovalSelectedIndexArgsForCall) +} + +func (fake *FakeStateManager) SetApprovalSelectedIndexCalls(stub func(int)) { + fake.setApprovalSelectedIndexMutex.Lock() + defer fake.setApprovalSelectedIndexMutex.Unlock() + fake.SetApprovalSelectedIndexStub = stub +} + +func (fake *FakeStateManager) SetApprovalSelectedIndexArgsForCall(i int) int { + fake.setApprovalSelectedIndexMutex.RLock() + defer fake.setApprovalSelectedIndexMutex.RUnlock() + argsForCall := fake.setApprovalSelectedIndexArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeStateManager) SetDimensions(arg1 int, arg2 int) { fake.setDimensionsMutex.Lock() fake.setDimensionsArgsForCall = append(fake.setDimensionsArgsForCall, struct { @@ -949,6 +1083,39 @@ func (fake *FakeStateManager) SetFileSelectedIndexArgsForCall(i int) int { return argsForCall.arg1 } +func (fake *FakeStateManager) SetupApprovalUIState(arg1 *sdk.ChatCompletionMessageToolCall, arg2 chan domain.ApprovalAction) { + fake.setupApprovalUIStateMutex.Lock() + fake.setupApprovalUIStateArgsForCall = append(fake.setupApprovalUIStateArgsForCall, struct { + arg1 *sdk.ChatCompletionMessageToolCall + arg2 chan domain.ApprovalAction + }{arg1, arg2}) + stub := fake.SetupApprovalUIStateStub + fake.recordInvocation("SetupApprovalUIState", []interface{}{arg1, arg2}) + fake.setupApprovalUIStateMutex.Unlock() + if stub != nil { + fake.SetupApprovalUIStateStub(arg1, arg2) + } +} + +func (fake *FakeStateManager) SetupApprovalUIStateCallCount() int { + fake.setupApprovalUIStateMutex.RLock() + defer fake.setupApprovalUIStateMutex.RUnlock() + return len(fake.setupApprovalUIStateArgsForCall) +} + +func (fake *FakeStateManager) SetupApprovalUIStateCalls(stub func(*sdk.ChatCompletionMessageToolCall, chan domain.ApprovalAction)) { + fake.setupApprovalUIStateMutex.Lock() + defer fake.setupApprovalUIStateMutex.Unlock() + fake.SetupApprovalUIStateStub = stub +} + +func (fake *FakeStateManager) SetupApprovalUIStateArgsForCall(i int) (*sdk.ChatCompletionMessageToolCall, chan domain.ApprovalAction) { + fake.setupApprovalUIStateMutex.RLock() + defer fake.setupApprovalUIStateMutex.RUnlock() + argsForCall := fake.setupApprovalUIStateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + func (fake *FakeStateManager) SetupFileSelection(arg1 []string) { var arg1Copy []string if arg1 != nil { diff --git a/tests/mocks/generated/fake_theme.go b/tests/mocks/generated/fake_theme.go index 156ada9a..aa32d467 100644 --- a/tests/mocks/generated/fake_theme.go +++ b/tests/mocks/generated/fake_theme.go @@ -88,6 +88,16 @@ type FakeTheme struct { getStatusColorReturnsOnCall map[int]struct { result1 string } + GetSuccessColorStub func() string + getSuccessColorMutex sync.RWMutex + getSuccessColorArgsForCall []struct { + } + getSuccessColorReturns struct { + result1 string + } + getSuccessColorReturnsOnCall map[int]struct { + result1 string + } GetUserColorStub func() string getUserColorMutex sync.RWMutex getUserColorArgsForCall []struct { @@ -526,6 +536,59 @@ func (fake *FakeTheme) GetStatusColorReturnsOnCall(i int, result1 string) { }{result1} } +func (fake *FakeTheme) GetSuccessColor() string { + fake.getSuccessColorMutex.Lock() + ret, specificReturn := fake.getSuccessColorReturnsOnCall[len(fake.getSuccessColorArgsForCall)] + fake.getSuccessColorArgsForCall = append(fake.getSuccessColorArgsForCall, struct { + }{}) + stub := fake.GetSuccessColorStub + fakeReturns := fake.getSuccessColorReturns + fake.recordInvocation("GetSuccessColor", []interface{}{}) + fake.getSuccessColorMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeTheme) GetSuccessColorCallCount() int { + fake.getSuccessColorMutex.RLock() + defer fake.getSuccessColorMutex.RUnlock() + return len(fake.getSuccessColorArgsForCall) +} + +func (fake *FakeTheme) GetSuccessColorCalls(stub func() string) { + fake.getSuccessColorMutex.Lock() + defer fake.getSuccessColorMutex.Unlock() + fake.GetSuccessColorStub = stub +} + +func (fake *FakeTheme) GetSuccessColorReturns(result1 string) { + fake.getSuccessColorMutex.Lock() + defer fake.getSuccessColorMutex.Unlock() + fake.GetSuccessColorStub = nil + fake.getSuccessColorReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeTheme) GetSuccessColorReturnsOnCall(i int, result1 string) { + fake.getSuccessColorMutex.Lock() + defer fake.getSuccessColorMutex.Unlock() + fake.GetSuccessColorStub = nil + if fake.getSuccessColorReturnsOnCall == nil { + fake.getSuccessColorReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.getSuccessColorReturnsOnCall[i] = struct { + result1 string + }{result1} +} + func (fake *FakeTheme) GetUserColor() string { fake.getUserColorMutex.Lock() ret, specificReturn := fake.getUserColorReturnsOnCall[len(fake.getUserColorArgsForCall)]