diff --git a/.gitignore b/.gitignore index ce3f2ae..4616b30 100644 --- a/.gitignore +++ b/.gitignore @@ -49,3 +49,4 @@ examples/snowflake_trainer/output/ .dsgo **/.agents +examples/yaml_program/yaml_program diff --git a/dsgo.go b/dsgo.go index a1d6ab7..aac4cd2 100644 --- a/dsgo.go +++ b/dsgo.go @@ -131,6 +131,14 @@ type ( MCPSSETransport = mcp.SSETransport // MCPStdioTransport implements Transport over stdio of a subprocess. MCPStdioTransport = mcp.StdioTransport + // MCPLocalTransport implements Transport over an in-process handler. + MCPLocalTransport = mcp.LocalTransport + // MCPLocalHandler routes JSON-RPC requests in-process. + MCPLocalHandler = mcp.LocalHandler + // MCPShellServer is a built-in MCP server exposing shell tools. + MCPShellServer = mcp.ShellServer + // MCPShellServerConfig configures the built-in shell server. + MCPShellServerConfig = mcp.ShellServerConfig ) // Re-export typed generic type @@ -241,10 +249,18 @@ var ( NewMCPFilesystemClient = mcp.NewFilesystemClient // NewMCPHTTPTransport creates a new HTTP transport for MCP communication. NewMCPHTTPTransport = mcp.NewHTTPTransport + // NewMCPHTTPTransportWithTimeout creates a new HTTP transport with a custom timeout. + NewMCPHTTPTransportWithTimeout = mcp.NewHTTPTransportWithTimeout // NewMCPSSETransport creates a new SSE transport for MCP communication. NewMCPSSETransport = mcp.NewSSETransport + // NewMCPSSETransportWithTimeouts creates a new SSE transport with custom timeouts. + NewMCPSSETransportWithTimeouts = mcp.NewSSETransportWithTimeouts // NewMCPStdioTransport creates a new stdio transport for MCP communication. NewMCPStdioTransport = mcp.NewStdioTransport + // NewMCPLocalTransport creates a local MCP transport. + NewMCPLocalTransport = mcp.NewLocalTransport + // NewMCPShellServer creates a built-in shell MCP server. + NewMCPShellServer = mcp.NewShellServer // ConvertMCPToolsToDSGo converts MCP tool schemas to DSGo tools. ConvertMCPToolsToDSGo = mcp.ConvertMCPToolsToDSGo // NewMCPError creates a new MCP error with the given code and message. diff --git a/examples/README.md b/examples/README.md index e9a0e5f..7080f6a 100644 --- a/examples/README.md +++ b/examples/README.md @@ -14,6 +14,7 @@ Runnable examples for DSGo. | `package_analysis/` | Analyze a Go package | `cd examples/package_analysis && go run main.go` | Signatures, parsing | | `security_scan/` | Security scanning workflow | `cd examples/security_scan && go run main.go` | Program patterns | | `snowflake_trainer/` | Trainer experiment | `cd examples/snowflake_trainer && go run main.go` | Trainers | +| `yaml_program/` | Declarative pipeline builder | `cd examples/yaml_program && go run main.go` | YAML config, MCP | Notes: - Use provider-prefixed model IDs (e.g. `openai/gpt-4o-mini`, `openrouter/anthropic/claude-3-opus`). diff --git a/examples/yaml_program/README.md b/examples/yaml_program/README.md new file mode 100644 index 0000000..e731eb0 --- /dev/null +++ b/examples/yaml_program/README.md @@ -0,0 +1,291 @@ +# YAML Pipeline Builder + +Demonstrates how to define DSGo pipelines declaratively using YAML configuration files. + +## Overview + +This example shows how to: +- Define signatures (inputs/outputs) in YAML +- Configure modules (Predict, ChainOfThought, ReAct) declaratively +- Define MCP clients for external tool access (Exa, Jina, Tavily, Filesystem) +- Configure custom function/filesystem tools separately from MCP +- Specify MCP tools per-module with tool filtering +- Compose multi-stage pipelines without writing Go code for the pipeline structure +- Execute pipelines with automatic data flow between stages + +## Files + +| File | Purpose | +|------|---------| +| `pipeline.yaml` | Basic YAML pipeline definition (Predict, ChainOfThought) | +| `pipeline_react.yaml` | ReAct pipeline with filesystem MCP tools | +| `pipeline_functions.yaml` | ReAct with native function tools + filesystem MCP | +| `pipeline_mcp.yaml` | ReAct pipeline with MCP web search tools (requires API key) | +| `pipeline_deep_researcher.yaml` | Multi-stage deep researcher (requires API key) | +| `pipeline_deep_review.yaml` | 2-iteration implement/test/review loop (requires TAVILY_API_KEY) | +| `pipeline_todo.yaml` | Task -> web research -> codebase analysis -> numbered TODO list (requires TAVILY_API_KEY) | +| `builder.go` | YAML parsing and validation | +| `converter.go` | YAML to DSGo Signature conversion | +| `factory.go` | Module creation from specs | +| `tools.go` | MCP client, function, and tool registries | +| `program.go` | Program composition | +| `main.go` | Entry point and execution | + +## YAML Schema + +### Basic Pipeline + +```yaml +name: pipeline_name +description: Pipeline description + +settings: + temperature: 0.7 + max_tokens: 1024 + +signatures: + signature_name: + description: What this signature does + inputs: + - name: field_name + type: string|int|float|bool|json|class|image|datetime + description: Field description + optional: false + outputs: + - name: field_name + type: class + description: Field description + classes: [option1, option2] + +modules: + module_name: + type: Predict|ChainOfThought + signature: signature_name + options: + temperature: 0.3 + max_tokens: 512 + +pipeline: + - module: module_name_1 + - module: module_name_2 +``` + +### Custom Tools (Function and Filesystem) + +Custom tools are defined in the top-level `tools:` section and referenced by modules: + +```yaml +# Custom tools (function/filesystem types only) +tools: + get_datetime: + type: function + name: current_datetime + + calculator: + type: function + name: calculate + + list_files: + type: filesystem + name: list_files + +modules: + assistant: + type: ReAct + signature: assistant + options: + max_iterations: 6 + tools: # Reference custom tools here + - get_datetime + - calculator + - list_files +``` + +### MCP Tools (Per-Module Configuration) + +MCP clients are defined globally, but tool selection is configured per-module: + +```yaml +# Define MCP clients (global registry) +mcp: + tavily: + type: tavily + # api_key: optional, defaults to env var (TAVILY_API_KEY) + filesystem: + type: filesystem + +modules: + researcher: + type: ReAct + signature: research + options: + temperature: 0.5 + max_iterations: 10 + mcp: + tavily: + tools: + - "*" # Use all tools from tavily + filesystem: + tools: + - read_file + - list_directory + + fact_checker: + type: ReAct + signature: fact_check + options: + max_iterations: 6 + mcp: + tavily: + tools: + - tavily_search # Use only specific tool +``` + +### Combining Custom Tools and MCP Tools + +Modules can use both custom tools and MCP tools: + +```yaml +mcp: + filesystem: + type: filesystem + +tools: + get_datetime: + type: function + name: current_datetime + calculator: + type: function + name: calculate + +modules: + assistant: + type: ReAct + signature: assistant + options: + max_iterations: 8 + tools: # Custom tools + - get_datetime + - calculator + mcp: # MCP tools + filesystem: + tools: + - "*" +``` + +## Module Types + +| Type | Description | Options | +|------|-------------|---------| +| `Predict` | Basic prediction | `temperature`, `max_tokens` | +| `ChainOfThought` | Step-by-step reasoning | `temperature`, `max_tokens` | +| `ReAct` | Reasoning + Acting with tools | `temperature`, `max_tokens`, `max_iterations`, `tools` (custom), `mcp` (per-module) | + +## Tool Types (Custom) + +| Type | Description | Fields | +|------|-------------|--------| +| `filesystem` | Built-in file tools | `name`: `list_files`, `read_file`, `search_files` | +| `function` | Native Go function tools | `name`: function name (see below) | + +## Function Tools + +Native Go function tools that work without any API keys: + +| Name | Description | +|------|-------------| +| `current_datetime` | Get current date/time with timezone support | +| `calculate` | Basic arithmetic (add, subtract, multiply, divide) | +| `random_number` | Generate random numbers in a range | +| `string_length` | Get string length and character counts | +| `word_count` | Count words and analyze text statistics | +| `environment_info` | Get environment information | + +## MCP Client Types + +| Type | Env Variable | Description | +|------|--------------|-------------| +| `exa` | `EXA_API_KEY` | Exa search and web content | +| `jina` | `JINA_API_KEY` | Jina URL reading and extraction | +| `tavily` | `TAVILY_API_KEY` | Tavily web search and extraction | +| `filesystem` | - | Local filesystem operations | +| `shell` | - | Shell command execution (apply_patch, shell_run) | +| `custom` | - | Custom MCP server (requires `url` field) | + +## Run + +```bash +cd examples/yaml_program + +# Basic pipeline (Predict, ChainOfThought) +go run . + +# ReAct with filesystem tools +go run . pipeline_react.yaml + +# ReAct with native function tools (no API keys needed!) +go run . pipeline_functions.yaml + +# ReAct with MCP web search (requires TAVILY_API_KEY) +TAVILY_API_KEY=your-key go run . pipeline_mcp.yaml + +# Deep researcher pipeline (requires TAVILY_API_KEY) +TAVILY_API_KEY=your-key go run . pipeline_deep_researcher.yaml + +# Deep review pipeline (requires TAVILY_API_KEY) +TAVILY_API_KEY=your-key go run . pipeline_deep_review.yaml + +# Or with custom YAML file: +go run . path/to/custom.yaml +``` + +## Pipeline Flow + +### Basic Pipeline +``` +┌─────────────────────┐ +│ Input: text │ +└──────────┬──────────┘ + │ + ▼ +┌─────────────────────┐ +│ sentiment_analyzer │ Predict +│ → sentiment │ +│ → confidence │ +└──────────┬──────────┘ + │ + ▼ +┌─────────────────────┐ +│ key_points_extractor│ ChainOfThought +│ → key_points │ +│ → word_count │ +└──────────┬──────────┘ + │ + ▼ +┌─────────────────────┐ +│ summary_generator │ ChainOfThought +│ → summary │ +│ → recommendations │ +└─────────────────────┘ +``` + +### ReAct Pipeline with MCP +``` +┌─────────────────────┐ +│ Input: question │ +└──────────┬──────────┘ + │ + ▼ +┌─────────────────────┐ +│ code_explorer │ ReAct +│ ┌─────────────┐ │ +│ │ MCP: │ │ +│ │ filesystem │ │ ← MCP tools (per-module) +│ │ - "*" │ │ (all filesystem tools) +│ └─────────────┘ │ +│ → answer │ +│ → files_examined │ +└─────────────────────┘ +``` + +Each module receives all outputs from previous modules plus the original inputs. diff --git a/examples/yaml_program/builder.go b/examples/yaml_program/builder.go new file mode 100644 index 0000000..80c5047 --- /dev/null +++ b/examples/yaml_program/builder.go @@ -0,0 +1,437 @@ +package main + +import ( + "fmt" + "os" + "strconv" + "time" + + "gopkg.in/yaml.v3" +) + +// Duration is a YAML-friendly duration type. +// +// It accepts: +// - Go duration strings (e.g., "30s", "5m", "2h") +// - Integers interpreted as seconds +// +// Zero values mean "not set". +type Duration struct { + time.Duration +} + +func (d *Duration) UnmarshalYAML(value *yaml.Node) error { + if value == nil { + return nil + } + if value.Kind != yaml.ScalarNode { + return fmt.Errorf("duration must be a scalar value") + } + if value.Value == "" { + d.Duration = 0 + return nil + } + + // Allow integer values as seconds for convenience. + if value.Tag == "!!int" { + sec, err := strconv.ParseInt(value.Value, 10, 64) + if err != nil { + return fmt.Errorf("invalid duration seconds %q: %w", value.Value, err) + } + if sec < 0 { + return fmt.Errorf("duration must be >= 0") + } + d.Duration = time.Duration(sec) * time.Second + return nil + } + + dur, err := time.ParseDuration(value.Value) + if err != nil { + return fmt.Errorf("invalid duration %q: %w", value.Value, err) + } + if dur < 0 { + return fmt.Errorf("duration must be >= 0") + } + d.Duration = dur + return nil +} + +// TimeoutSettings configures timeouts for the YAML runner. +type TimeoutSettings struct { + // Pipeline is the overall runtime timeout for the whole pipeline. + Pipeline Duration `yaml:"pipeline,omitempty"` + // LMHTTP controls provider HTTP client timeouts (openai/openrouter). + LMHTTP Duration `yaml:"lm_http,omitempty"` + // MCPHTTP controls MCP HTTP transport request timeouts. + MCPHTTP Duration `yaml:"mcp_http,omitempty"` + // MCPSSEPost controls the POST-side timeout for MCP SSE transports. + MCPSSEPost Duration `yaml:"mcp_sse_post,omitempty"` + // MCPSSEWait controls how long to wait for an SSE response. + MCPSSEWait Duration `yaml:"mcp_sse_wait,omitempty"` +} + +// ModelSettings are model generation defaults applied to modules unless overridden. +type ModelSettings struct { + // Name is the default model identifier in "provider/model" form. + Name string `yaml:"name,omitempty"` + Temperature float64 `yaml:"temperature"` + MaxTokens int `yaml:"max_tokens"` +} + +// DSGoSettings are runtime settings for the DSGo pipeline runner. +type DSGoSettings struct { + Timeouts TimeoutSettings `yaml:"timeouts,omitempty"` +} + +// PipelineConfig represents the top-level YAML structure. +// +// Backward compatibility: +// - `settings:` is the legacy key for model defaults, and may also include `timeouts:`. +// - Prefer `model:` for model defaults and `dsgo:` for runtime settings. +type PipelineConfig struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + + // Legacy. + Settings GlobalSettings `yaml:"settings,omitempty"` + + // Preferred. + Model ModelSettings `yaml:"model,omitempty"` + DSGo DSGoSettings `yaml:"dsgo,omitempty"` + + MCP map[string]MCPSpec `yaml:"mcp,omitempty"` + Tools map[string]ToolSpec `yaml:"tools,omitempty"` + Signatures map[string]SignatureSpec `yaml:"signatures"` + Modules map[string]ModuleSpec `yaml:"modules"` + Pipeline []PipelineStep `yaml:"pipeline"` + Inputs map[string]any `yaml:"inputs,omitempty"` +} + +// GlobalSettings represents legacy global pipeline settings. +// Prefer `model:` and `dsgo:` instead. +type GlobalSettings struct { + Temperature float64 `yaml:"temperature"` + MaxTokens int `yaml:"max_tokens"` + Timeouts TimeoutSettings `yaml:"timeouts,omitempty"` +} + +// SignatureSpec represents a signature definition in YAML +type SignatureSpec struct { + Description string `yaml:"description"` + Inputs []FieldSpec `yaml:"inputs"` + Outputs []FieldSpec `yaml:"outputs"` +} + +// FieldSpec represents a field definition in YAML +type FieldSpec struct { + Name string `yaml:"name"` + Type string `yaml:"type"` + Description string `yaml:"description"` + Optional bool `yaml:"optional"` + Classes []string `yaml:"classes,omitempty"` // For class/enum types +} + +// ModuleSpec represents a module definition in YAML +type ModuleSpec struct { + Type string `yaml:"type"` + Model string `yaml:"model,omitempty"` + Signature string `yaml:"signature"` + Options ModuleOptions `yaml:"options"` + MCP map[string]ModuleMCPSpec `yaml:"mcp,omitempty"` // Per-module MCP tool configuration +} + +// ModuleOptions represents module-specific options +type ModuleOptions struct { + Temperature float64 `yaml:"temperature"` + MaxTokens int `yaml:"max_tokens"` + MaxIterations int `yaml:"max_iterations,omitempty"` + Verbose bool `yaml:"verbose,omitempty"` + Tools []string `yaml:"tools,omitempty"` // References to custom tools defined in top-level tools section +} + +// ModuleMCPSpec represents per-module MCP configuration +type ModuleMCPSpec struct { + Tools []string `yaml:"tools"` // "*" for all tools, or list of specific tool names +} + +// MCPSpec represents an MCP client configuration +type MCPSpec struct { + Type string `yaml:"type"` + APIKey string `yaml:"api_key,omitempty"` + URL string `yaml:"url,omitempty"` + AllowedDirs []string `yaml:"allowed_dirs,omitempty"` +} + +// ToolSpec represents a tool definition +type ToolSpec struct { + Type string `yaml:"type"` + Source string `yaml:"source,omitempty"` + Name string `yaml:"name,omitempty"` + Description string `yaml:"description,omitempty"` + Parameters []ToolParamSpec `yaml:"parameters,omitempty"` +} + +// ToolParamSpec represents a tool parameter definition +type ToolParamSpec struct { + Name string `yaml:"name"` + Type string `yaml:"type"` + Description string `yaml:"description"` + Required bool `yaml:"required"` +} + +// PipelineStep represents a step in the pipeline +type PipelineStep struct { + Module string `yaml:"module"` +} + +func (c *PipelineConfig) EffectiveModelSettings() ModelSettings { + settings := ModelSettings{ + Temperature: c.Settings.Temperature, + MaxTokens: c.Settings.MaxTokens, + } + if c.Model.Name != "" { + settings.Name = c.Model.Name + } + if c.Model.Temperature > 0 { + settings.Temperature = c.Model.Temperature + } + if c.Model.MaxTokens > 0 { + settings.MaxTokens = c.Model.MaxTokens + } + return settings +} + +func (c *PipelineConfig) EffectiveTimeouts() TimeoutSettings { + settings := c.Settings.Timeouts + + if c.DSGo.Timeouts.Pipeline.Duration > 0 { + settings.Pipeline = c.DSGo.Timeouts.Pipeline + } + if c.DSGo.Timeouts.LMHTTP.Duration > 0 { + settings.LMHTTP = c.DSGo.Timeouts.LMHTTP + } + if c.DSGo.Timeouts.MCPHTTP.Duration > 0 { + settings.MCPHTTP = c.DSGo.Timeouts.MCPHTTP + } + if c.DSGo.Timeouts.MCPSSEPost.Duration > 0 { + settings.MCPSSEPost = c.DSGo.Timeouts.MCPSSEPost + } + if c.DSGo.Timeouts.MCPSSEWait.Duration > 0 { + settings.MCPSSEWait = c.DSGo.Timeouts.MCPSSEWait + } + + return settings +} + +// LoadPipelineConfig loads a pipeline configuration from a YAML file +func LoadPipelineConfig(filename string) (*PipelineConfig, error) { + data, err := os.ReadFile(filename) + if err != nil { + return nil, fmt.Errorf("failed to read file %s: %w", filename, err) + } + + var config PipelineConfig + if err := yaml.Unmarshal(data, &config); err != nil { + return nil, fmt.Errorf("failed to parse YAML: %w", err) + } + + if err := validateConfig(&config); err != nil { + return nil, fmt.Errorf("configuration validation failed: %w", err) + } + + return &config, nil +} + +// validateConfig validates the pipeline configuration +func validateConfig(config *PipelineConfig) error { + if config.Name == "" { + return fmt.Errorf("pipeline name is required") + } + + if len(config.Signatures) == 0 { + return fmt.Errorf("at least one signature must be defined") + } + + if len(config.Modules) == 0 { + return fmt.Errorf("at least one module must be defined") + } + + if len(config.Pipeline) == 0 { + return fmt.Errorf("pipeline must have at least one step") + } + + // Validate signatures + for name, sig := range config.Signatures { + if err := validateSignature(name, sig); err != nil { + return err + } + } + + // Validate MCP configurations + for name, mcpSpec := range config.MCP { + if err := validateMCP(name, mcpSpec); err != nil { + return err + } + } + + // Validate tool definitions (custom tools only) + for name, toolSpec := range config.Tools { + if err := validateTool(name, toolSpec); err != nil { + return err + } + } + + // Validate modules reference existing signatures and MCP clients + for name, mod := range config.Modules { + if err := validateModule(name, mod, config.Signatures, config.Tools, config.MCP); err != nil { + return err + } + } + + // Validate pipeline steps reference existing modules + for i, step := range config.Pipeline { + if _, exists := config.Modules[step.Module]; !exists { + return fmt.Errorf("pipeline step %d references undefined module: %s", i+1, step.Module) + } + } + + return nil +} + +// validateMCP validates an MCP client configuration +func validateMCP(name string, spec MCPSpec) error { + validTypes := map[string]bool{ + "exa": true, + "jina": true, + "tavily": true, + "custom": true, + "shell": true, + "filesystem": true, + } + + if !validTypes[spec.Type] { + return fmt.Errorf("MCP '%s' has invalid type: %s (valid: exa, jina, tavily, custom, shell, filesystem)", name, spec.Type) + } + + if spec.Type == "custom" && spec.URL == "" { + return fmt.Errorf("MCP '%s' is custom type but has no URL defined", name) + } + + return nil +} + +// validateTool validates a tool definition (custom tools only, not MCP) +func validateTool(name string, spec ToolSpec) error { + validTypes := map[string]bool{ + "filesystem": true, + "function": true, + } + + if !validTypes[spec.Type] { + return fmt.Errorf("tool '%s' has invalid type: %s (valid: filesystem, function)", name, spec.Type) + } + + if spec.Type == "function" { + validFunctions := map[string]bool{ + "current_datetime": true, + "calculate": true, + "random_number": true, + "string_length": true, + "word_count": true, + "environment_info": true, + } + funcName := spec.Name + if funcName == "" { + funcName = name + } + if !validFunctions[funcName] { + return fmt.Errorf("tool '%s' references unknown function: %s (valid: current_datetime, calculate, random_number, string_length, word_count, environment_info)", name, funcName) + } + } + + return nil +} + +// validateSignature validates a signature definition +func validateSignature(name string, sig SignatureSpec) error { + if sig.Description == "" { + return fmt.Errorf("signature '%s' must have a description", name) + } + + if len(sig.Inputs) == 0 { + return fmt.Errorf("signature '%s' must have at least one input", name) + } + + if len(sig.Outputs) == 0 { + return fmt.Errorf("signature '%s' must have at least one output", name) + } + + // Validate field types + validTypes := map[string]bool{ + "string": true, + "int": true, + "float": true, + "bool": true, + "json": true, + "class": true, + "image": true, + "datetime": true, + } + + for _, field := range sig.Inputs { + if !validTypes[field.Type] { + return fmt.Errorf("signature '%s' input '%s' has invalid type: %s", name, field.Name, field.Type) + } + } + + for _, field := range sig.Outputs { + if !validTypes[field.Type] { + return fmt.Errorf("signature '%s' output '%s' has invalid type: %s", name, field.Name, field.Type) + } + if field.Type == "class" && len(field.Classes) == 0 { + return fmt.Errorf("signature '%s' output '%s' is class type but has no classes defined", name, field.Name) + } + } + + return nil +} + +// validateModule validates a module definition +func validateModule(name string, mod ModuleSpec, signatures map[string]SignatureSpec, tools map[string]ToolSpec, mcpClients map[string]MCPSpec) error { + validTypes := map[string]bool{ + "Predict": true, + "ChainOfThought": true, + "ReAct": true, + } + + if !validTypes[mod.Type] { + return fmt.Errorf("module '%s' has invalid type: %s (valid: Predict, ChainOfThought, ReAct)", name, mod.Type) + } + + if _, exists := signatures[mod.Signature]; !exists { + return fmt.Errorf("module '%s' references undefined signature: %s", name, mod.Signature) + } + + // For ReAct, require either tools or mcp configuration + hasMCP := len(mod.MCP) > 0 + hasTools := len(mod.Options.Tools) > 0 + if mod.Type == "ReAct" && !hasMCP && !hasTools { + return fmt.Errorf("module '%s' is ReAct type but has no tools or mcp defined", name) + } + + // Validate custom tool references + for _, toolRef := range mod.Options.Tools { + if _, exists := tools[toolRef]; !exists { + return fmt.Errorf("module '%s' references undefined tool: %s", name, toolRef) + } + } + + // Validate per-module MCP references + for mcpName := range mod.MCP { + if _, exists := mcpClients[mcpName]; !exists { + return fmt.Errorf("module '%s' references undefined MCP client: %s", name, mcpName) + } + } + + return nil +} diff --git a/examples/yaml_program/converter.go b/examples/yaml_program/converter.go new file mode 100644 index 0000000..73ce2d3 --- /dev/null +++ b/examples/yaml_program/converter.go @@ -0,0 +1,99 @@ +package main + +import ( + "fmt" + + "github.com/assagman/dsgo" +) + +// ConvertSignature converts a YAML SignatureSpec to a DSGo Signature +func ConvertSignature(name string, spec SignatureSpec) (*dsgo.Signature, error) { + sig := dsgo.NewSignature(spec.Description) + + // Add input fields + for _, field := range spec.Inputs { + fieldType, err := parseFieldType(field.Type) + if err != nil { + return nil, fmt.Errorf("input field '%s': %w", field.Name, err) + } + + if field.Optional { + sig.AddOptionalInput(field.Name, fieldType, field.Description) + } else { + sig.AddInput(field.Name, fieldType, field.Description) + } + } + + // Add output fields + for _, field := range spec.Outputs { + fieldType, err := parseFieldType(field.Type) + if err != nil { + return nil, fmt.Errorf("output field '%s': %w", field.Name, err) + } + + if field.Type == "class" && len(field.Classes) > 0 { + sig.AddClassOutput(field.Name, field.Classes, field.Description) + } else if field.Optional { + sig.AddOptionalOutput(field.Name, fieldType, field.Description) + } else { + sig.AddOutput(field.Name, fieldType, field.Description) + } + } + + return sig, nil +} + +// parseFieldType converts a string type to DSGo FieldType +func parseFieldType(typeStr string) (dsgo.FieldType, error) { + switch typeStr { + case "string": + return dsgo.FieldTypeString, nil + case "int": + return dsgo.FieldTypeInt, nil + case "float": + return dsgo.FieldTypeFloat, nil + case "bool": + return dsgo.FieldTypeBool, nil + case "json": + return dsgo.FieldTypeJSON, nil + case "class": + return dsgo.FieldTypeClass, nil + case "image": + return dsgo.FieldTypeImage, nil + case "datetime": + return dsgo.FieldTypeDatetime, nil + default: + return "", fmt.Errorf("unknown field type: %s", typeStr) + } +} + +// SignatureRegistry holds converted signatures +type SignatureRegistry struct { + signatures map[string]*dsgo.Signature +} + +// NewSignatureRegistry creates a new registry from YAML specs +func NewSignatureRegistry(specs map[string]SignatureSpec) (*SignatureRegistry, error) { + registry := &SignatureRegistry{ + signatures: make(map[string]*dsgo.Signature), + } + + for name, spec := range specs { + sig, err := ConvertSignature(name, spec) + if err != nil { + return nil, fmt.Errorf("failed to convert signature '%s': %w", name, err) + } + registry.signatures[name] = sig + } + + return registry, nil +} + +// Get returns a signature by name +func (r *SignatureRegistry) Get(name string) (*dsgo.Signature, error) { + sig, exists := r.signatures[name] + if !exists { + return nil, fmt.Errorf("signature not found: %s", name) + } + return sig, nil +} diff --git a/examples/yaml_program/factory.go b/examples/yaml_program/factory.go new file mode 100644 index 0000000..8103f8a --- /dev/null +++ b/examples/yaml_program/factory.go @@ -0,0 +1,192 @@ +package main + +import ( + "context" + "fmt" + + "github.com/assagman/dsgo" +) + +// ModuleFactory creates DSGo modules from YAML specifications. +// +// It supports per-module model overrides by creating (and caching) LM instances +// per requested model. +type ModuleFactory struct { + ctx context.Context + defaultModel string + defaultLM dsgo.LM + lmsByModel map[string]dsgo.LM + + sigRegistry *SignatureRegistry + toolRegistry *ToolRegistry + modelSettings ModelSettings +} + +// NewModuleFactory creates a new module factory. +func NewModuleFactory(ctx context.Context, defaultModel string, defaultLM dsgo.LM, sigRegistry *SignatureRegistry, toolRegistry *ToolRegistry, settings ModelSettings) *ModuleFactory { + f := &ModuleFactory{ + ctx: ctx, + defaultModel: defaultModel, + defaultLM: defaultLM, + lmsByModel: make(map[string]dsgo.LM), + sigRegistry: sigRegistry, + toolRegistry: toolRegistry, + + modelSettings: settings, + } + + if defaultModel != "" && defaultLM != nil { + f.lmsByModel[defaultModel] = defaultLM + } + + return f +} + +func (f *ModuleFactory) getLM(model string) (dsgo.LM, error) { + if model == "" { + return f.defaultLM, nil + } + if f.defaultModel != "" && model == f.defaultModel { + return f.defaultLM, nil + } + if lm, ok := f.lmsByModel[model]; ok { + return lm, nil + } + + lm, err := dsgo.NewLM(f.ctx, model) + if err != nil { + return nil, fmt.Errorf("failed to create LM for model %q: %w", model, err) + } + f.lmsByModel[model] = lm + return lm, nil +} + +// CreateModule creates a DSGo module from a YAML specification +func (f *ModuleFactory) CreateModule(name string, spec ModuleSpec) (dsgo.Module, error) { + lm, err := f.getLM(spec.Model) + if err != nil { + return nil, fmt.Errorf("module '%s': %w", name, err) + } + + sig, err := f.sigRegistry.Get(spec.Signature) + if err != nil { + return nil, fmt.Errorf("module '%s': %w", name, err) + } + + // Build options with defaults from global settings, overridden by module-specific + options := f.buildOptions(spec) + + switch spec.Type { + case "Predict": + return f.createPredict(sig, lm, options), nil + case "ChainOfThought": + return f.createChainOfThought(sig, lm, options), nil + case "ReAct": + return f.createReAct(sig, lm, options, spec) + default: + return nil, fmt.Errorf("unsupported module type: %s", spec.Type) + } +} + +// createPredict creates a Predict module +func (f *ModuleFactory) createPredict(sig *dsgo.Signature, lm dsgo.LM, options *dsgo.GenerateOptions) dsgo.Module { + return dsgo.NewPredict(sig, lm). + WithOptions(options). + WithAdapter(dsgo.NewFallbackAdapter()) +} + +// createChainOfThought creates a ChainOfThought module +func (f *ModuleFactory) createChainOfThought(sig *dsgo.Signature, lm dsgo.LM, options *dsgo.GenerateOptions) dsgo.Module { + return dsgo.NewChainOfThought(sig, lm). + WithOptions(options). + WithAdapter(dsgo.NewFallbackAdapter().WithReasoning(true)) +} + +// createReAct creates a ReAct module with tools +func (f *ModuleFactory) createReAct(sig *dsgo.Signature, lm dsgo.LM, options *dsgo.GenerateOptions, spec ModuleSpec) (dsgo.Module, error) { + var allTools []dsgo.Tool + + // Get custom tools from options.tools + if len(spec.Options.Tools) > 0 { + customTools, err := f.toolRegistry.GetMultiple(spec.Options.Tools) + if err != nil { + return nil, fmt.Errorf("failed to resolve custom tools: %w", err) + } + allTools = append(allTools, customTools...) + } + + // Get MCP tools from per-module mcp configuration + if len(spec.MCP) > 0 { + mcpTools, err := f.toolRegistry.GetAllMCPToolsForModule(spec.MCP) + if err != nil { + return nil, fmt.Errorf("failed to resolve MCP tools: %w", err) + } + allTools = append(allTools, mcpTools...) + } + + react := dsgo.NewReAct(sig, lm, allTools). + WithOptions(options). + WithAdapter(dsgo.NewFallbackAdapter()) + + if spec.Options.MaxIterations > 0 { + react.WithMaxIterations(spec.Options.MaxIterations) + } + + react.WithVerbose(spec.Options.Verbose) + + return react, nil +} + +// buildOptions builds GenerateOptions with proper defaults +func (f *ModuleFactory) buildOptions(spec ModuleSpec) *dsgo.GenerateOptions { + options := dsgo.DefaultGenerateOptions() + + // Apply model defaults first + if f.modelSettings.Temperature > 0 { + options.Temperature = f.modelSettings.Temperature + } + if f.modelSettings.MaxTokens > 0 { + options.MaxTokens = f.modelSettings.MaxTokens + } + + // Apply module-specific options (highest priority) + if spec.Options.Temperature > 0 { + options.Temperature = spec.Options.Temperature + } + if spec.Options.MaxTokens > 0 { + options.MaxTokens = spec.Options.MaxTokens + } + + return options +} + +// ModuleRegistry holds created modules +type ModuleRegistry struct { + modules map[string]dsgo.Module +} + +// NewModuleRegistry creates all modules from YAML specs +func NewModuleRegistry(factory *ModuleFactory, specs map[string]ModuleSpec) (*ModuleRegistry, error) { + registry := &ModuleRegistry{ + modules: make(map[string]dsgo.Module), + } + + for name, spec := range specs { + module, err := factory.CreateModule(name, spec) + if err != nil { + return nil, fmt.Errorf("failed to create module '%s': %w", name, err) + } + registry.modules[name] = module + } + + return registry, nil +} + +// Get returns a module by name +func (r *ModuleRegistry) Get(name string) (dsgo.Module, error) { + mod, exists := r.modules[name] + if !exists { + return nil, fmt.Errorf("module not found: %s", name) + } + return mod, nil +} diff --git a/examples/yaml_program/go.mod b/examples/yaml_program/go.mod new file mode 100644 index 0000000..087d4f7 --- /dev/null +++ b/examples/yaml_program/go.mod @@ -0,0 +1,18 @@ +module github.com/assagman/dsgo/examples/yaml_program + +go 1.25 + +require ( + github.com/assagman/dsgo v0.0.0 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/openai/openai-go/v3 v3.13.0 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect +) + +replace github.com/assagman/dsgo => ../.. diff --git a/examples/yaml_program/go.sum b/examples/yaml_program/go.sum new file mode 100644 index 0000000..7710f4a --- /dev/null +++ b/examples/yaml_program/go.sum @@ -0,0 +1,16 @@ +github.com/openai/openai-go/v3 v3.13.0 h1:arSFmVHcBHNVYG5iqspPJrLoin0Qqn2JcCLWWcTcM1Q= +github.com/openai/openai-go/v3 v3.13.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/examples/yaml_program/main.go b/examples/yaml_program/main.go new file mode 100644 index 0000000..25b8c55 --- /dev/null +++ b/examples/yaml_program/main.go @@ -0,0 +1,185 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "strings" + "time" + + "github.com/assagman/dsgo" +) + +const defaultTimeout = 5 * time.Minute + +func getModelName(config *PipelineConfig) string { + if config.Model.Name != "" { + return config.Model.Name + } + if model := os.Getenv("EXAMPLES_DEFAULT_MODEL"); model != "" { + return model + } + return "openrouter/z-ai/glm-4.6" +} + +func main() { + ctx := context.Background() + + fmt.Println("╔═══════════════════════════════════════════════════════════════╗") + fmt.Println("║ DSGo YAML Pipeline Builder Example ║") + fmt.Println("╚═══════════════════════════════════════════════════════════════╝") + fmt.Println() + + // Load pipeline configuration + configPath := "pipeline.yaml" + if len(os.Args) > 1 { + configPath = os.Args[1] + } + + fmt.Printf("📄 Loading pipeline configuration from: %s\n", configPath) + config, err := LoadPipelineConfig(configPath) + if err != nil { + log.Fatalf("❌ Failed to load pipeline configuration: %v", err) + } + + fmt.Printf("✅ Loaded pipeline: %s\n", config.Name) + fmt.Printf(" Description: %s\n", config.Description) + fmt.Printf(" Signatures: %d\n", len(config.Signatures)) + fmt.Printf(" Modules: %d\n", len(config.Modules)) + fmt.Printf(" Pipeline steps: %d\n", len(config.Pipeline)) + fmt.Println() + + // Apply timeout overrides from YAML (if any) + timeouts := config.EffectiveTimeouts() + if timeouts.LMHTTP.Duration > 0 { + if err := os.Setenv("DSGO_HTTP_TIMEOUT_MS", fmt.Sprintf("%d", timeouts.LMHTTP.Milliseconds())); err != nil { + log.Fatalf("❌ Failed to set DSGO_HTTP_TIMEOUT_MS: %v", err) + } + } + + pipelineTimeout := defaultTimeout + if timeouts.Pipeline.Duration > 0 { + pipelineTimeout = timeouts.Pipeline.Duration + } + fmt.Printf("⏱️ Pipeline timeout: %v\n", pipelineTimeout) + fmt.Println() + + // Initialize LM + modelName := getModelName(config) + fmt.Printf("🤖 Initializing LM: %s\n", modelName) + + lm, err := dsgo.NewLM(ctx, modelName) + if err != nil { + log.Fatalf("❌ Failed to create LM: %v", err) + } + fmt.Println("✅ LM initialized successfully") + fmt.Println() + + // Build program from YAML + fmt.Println("🔧 Building program from YAML configuration...") + builder, err := NewProgramBuilder(ctx, config, modelName, lm) + if err != nil { + log.Fatalf("❌ Failed to create program builder: %v", err) + } + + program, err := builder.Build() + if err != nil { + log.Fatalf("❌ Failed to build program: %v", err) + } + fmt.Printf("✅ Program built successfully with %d modules\n", program.ModuleCount()) + fmt.Println() + + // Display pipeline structure + displayPipelineStructure(config) + + // Get inputs from YAML config + inputs := config.Inputs + if len(inputs) == 0 { + log.Fatalf("❌ No inputs defined in YAML configuration") + } + fmt.Println("📝 Inputs:") + fmt.Println(strings.Repeat("-", 60)) + for k, v := range inputs { + fmt.Printf(" %s: %v\n", k, v) + } + fmt.Println(strings.Repeat("-", 60)) + fmt.Println() + + // Run the pipeline + fmt.Println("🚀 Executing pipeline...") + fmt.Println() + + ctx, cancel := context.WithTimeout(ctx, pipelineTimeout) + defer cancel() + + startTime := time.Now() + prediction, err := program.Forward(ctx, inputs) + elapsed := time.Since(startTime) + + if err != nil { + log.Fatalf("❌ Pipeline execution failed: %v", err) + } + + // Display results + fmt.Println("═══════════════════════════════════════════════════════════════") + fmt.Println(" PIPELINE RESULTS ") + fmt.Println("═══════════════════════════════════════════════════════════════") + fmt.Println() + + displayResults(prediction) + + // Display execution stats + fmt.Println() + fmt.Println("📊 Execution Statistics:") + fmt.Println(strings.Repeat("-", 40)) + fmt.Printf(" Duration: %v\n", elapsed.Round(time.Millisecond)) + fmt.Printf(" Tokens used: %d\n", prediction.Usage.TotalTokens) + fmt.Printf(" Cost: $%.6f\n", prediction.Usage.Cost) + fmt.Println() + + fmt.Println("✅ Pipeline executed successfully!") +} + +func displayPipelineStructure(config *PipelineConfig) { + fmt.Println("📋 Pipeline Structure:") + fmt.Println(strings.Repeat("-", 60)) + + for i, step := range config.Pipeline { + mod := config.Modules[step.Module] + sig := config.Signatures[mod.Signature] + + arrow := " │" + if i == len(config.Pipeline)-1 { + arrow = " └" + } + + fmt.Printf(" %d. %s (%s)\n", i+1, step.Module, mod.Type) + fmt.Printf("%s── Signature: %s\n", arrow, mod.Signature) + fmt.Printf("%s── Inputs: %s\n", arrow, formatFields(sig.Inputs)) + fmt.Printf("%s── Outputs: %s\n", arrow, formatFields(sig.Outputs)) + if mod.Options.Temperature > 0 { + fmt.Printf("%s── Temperature: %.2f\n", arrow, mod.Options.Temperature) + } + fmt.Println() + } +} + +func formatFields(fields []FieldSpec) string { + names := make([]string, len(fields)) + for i, f := range fields { + names[i] = f.Name + } + return strings.Join(names, ", ") +} + +func displayResults(prediction *dsgo.Prediction) { + outputs := prediction.Outputs + + for key, value := range outputs { + fmt.Printf("📌 %s:\n", key) + fmt.Println(strings.Repeat("-", 40)) + fmt.Printf("%v\n", value) + fmt.Println() + } +} diff --git a/examples/yaml_program/pipeline.yaml b/examples/yaml_program/pipeline.yaml new file mode 100644 index 0000000..be80727 --- /dev/null +++ b/examples/yaml_program/pipeline.yaml @@ -0,0 +1,107 @@ +# DSGo YAML Pipeline Definition +# This file defines a multi-stage pipeline for text processing +# demonstrating the YAML-based program builder pattern. + +name: text_analysis_pipeline +description: A pipeline that analyzes text for sentiment and generates a summary + +# Global settings for the pipeline +settings: + temperature: 0.7 + max_tokens: 10240 + +# Define reusable signatures +signatures: + sentiment_analysis: + description: Analyze the sentiment of the provided text + inputs: + - name: text + type: string + description: The text to analyze for sentiment + outputs: + - name: sentiment + type: class + description: The detected sentiment + classes: [positive, negative, neutral] + - name: confidence + type: float + description: Confidence score between 0 and 1 + + key_points_extraction: + description: Extract key points from the analyzed text + inputs: + - name: text + type: string + description: The original text + - name: sentiment + type: string + description: The detected sentiment from previous analysis + outputs: + - name: key_points + type: string + description: Bullet-pointed list of key points + - name: word_count + type: int + description: Number of words in original text + + summary_generation: + description: Generate a comprehensive summary based on sentiment and key points + inputs: + - name: text + type: string + description: The original text + - name: sentiment + type: string + description: The detected sentiment + - name: key_points + type: string + description: The extracted key points + outputs: + - name: summary + type: string + description: A comprehensive summary of the text + - name: recommendations + type: string + description: Actionable recommendations based on the analysis + +# Define modules that use the signatures +modules: + sentiment_analyzer: + type: Predict + signature: sentiment_analysis + options: + temperature: 0.3 + + key_points_extractor: + type: ChainOfThought + signature: key_points_extraction + options: + temperature: 0.5 + max_tokens: 512 + + summary_generator: + type: ChainOfThought + signature: summary_generation + options: + temperature: 0.7 + max_tokens: 1024 + +# Define the pipeline as a sequence of modules +pipeline: + - module: sentiment_analyzer + - module: key_points_extractor + - module: summary_generator + +# Pipeline inputs +inputs: + text: | + The new AI-powered code review system has significantly improved our development workflow. + The automated suggestions are remarkably accurate, catching potential bugs before they reach production. + Team productivity has increased by approximately 30% since implementation. + + However, there are some concerns about the learning curve for junior developers who may become + overly reliant on AI suggestions without fully understanding the underlying code patterns. + Additionally, the system occasionally generates false positives, which can slow down the review process. + + Overall, the benefits outweigh the drawbacks, and we recommend continued adoption with appropriate + training and guidelines for the development team. diff --git a/examples/yaml_program/pipeline_codebase_analysis.yaml b/examples/yaml_program/pipeline_codebase_analysis.yaml new file mode 100644 index 0000000..71d47fa --- /dev/null +++ b/examples/yaml_program/pipeline_codebase_analysis.yaml @@ -0,0 +1,86 @@ +# DSGo YAML Pipeline Definition with ReAct and Tools +# This file demonstrates the ReAct module with filesystem MCP tools +# for code analysis tasks. +# +# Requires: Node.js (npx) or Bun (bunx) for filesystem MCP server + +name: code_analysis_pipeline +description: A pipeline that analyzes codebase using ReAct with filesystem MCP + +# Global settings for the pipeline +settings: + temperature: 0.7 + max_tokens: 10240 + +# Define MCP clients (global registry) +mcp: + filesystem: + type: filesystem + +# Define reusable signatures +signatures: + code_exploration: + description: Explore and understand a codebase structure + inputs: + - name: question + type: string + description: The question about the codebase to answer + outputs: + - name: answer + type: string + description: Detailed answer based on codebase exploration + - name: files_examined + type: string + description: List of files that were examined + + code_review: + description: Review code quality and suggest improvements + inputs: + - name: filepath + type: string + description: Path to the file to review + - name: focus_areas + type: string + description: Specific areas to focus the review on + outputs: + - name: issues + type: string + description: List of identified issues + - name: suggestions + type: string + description: Improvement suggestions + - name: quality_score + type: int + description: Overall quality score from 1-10 + +# Define modules that use the signatures +modules: + code_explorer: + type: ReAct + signature: code_exploration + options: + temperature: 0.5 + max_iterations: 8 + mcp: + filesystem: + tools: + - "*" # Use all filesystem tools + + code_reviewer: + type: ReAct + signature: code_review + options: + temperature: 0.3 + max_iterations: 5 + mcp: + filesystem: + tools: + - read_file + +# Define the pipeline as a sequence of modules +pipeline: + - module: code_explorer + +# Pipeline inputs +inputs: + question: "Analyze codebase carefully and provide project architecture outline" diff --git a/examples/yaml_program/pipeline_deep_researcher.yaml b/examples/yaml_program/pipeline_deep_researcher.yaml new file mode 100644 index 0000000..85a5347 --- /dev/null +++ b/examples/yaml_program/pipeline_deep_researcher.yaml @@ -0,0 +1,177 @@ +# DSGo YAML Pipeline: Deep Researcher +# Multi-stage research workflow: +# topic -> research -> initialResearchResults -> deepSearchTopics -> deepResearchResults -> article +# +# Requires: +# - TAVILY_API_KEY for web search/extraction (MCP) +# +# Run: +# export TAVILY_API_KEY=your-key +# go run . pipeline_deep_researcher.yaml + +name: deep_researcher_pipeline +description: Topic-driven deep research that ends in a clear article + +model: + name: openrouter/z-ai/glm-4.6 + temperature: 0.6 + max_tokens: 10240 + +dsgo: + timeouts: + # Very high timeouts for deep research runs. + pipeline: 2h + lm_http: 2h + mcp_http: 2h + mcp_sse_post: 2h + mcp_sse_wait: 2h + +# Define MCP clients (global registry) +mcp: + tavily: + type: tavily + +signatures: + initial_research: + description: Do broad research and capture initial findings + inputs: + - name: topic + type: string + description: The topic to research + outputs: + - name: research + type: string + description: High-level research summary (2-5 paragraphs) + - name: initialResearchResults + type: string + description: Structured research notes (definitions, key facts, debates, timelines) + - name: sources + type: string + description: Sources consulted (URLs + 1-line relevance note) + + deep_search_planning: + description: Generate deep-level follow-up search topics and queries + inputs: + - name: topic + type: string + description: The topic being researched + - name: initialResearchResults + type: string + description: Initial research notes from the previous stage + - name: research + type: string + description: High-level summary from the previous stage + outputs: + - name: deepSearchTopics + type: string + description: 8-15 targeted deep-dive search queries grouped by theme + + deep_research: + description: Perform deeper research using the planned deepSearchTopics + inputs: + - name: topic + type: string + description: The topic being researched + - name: deepSearchTopics + type: string + description: Deep search queries to investigate + - name: initialResearchResults + type: string + description: Initial research notes for continuity + optional: true + outputs: + - name: deepResearchResults + type: string + description: Deeper findings with citations, contrasts, and edge cases + - name: deepSources + type: string + description: Additional sources consulted (URLs + 1-line relevance note) + + write_article: + description: Write a clear article from research notes and deep research + inputs: + - name: topic + type: string + description: The topic of the article + - name: research + type: string + description: Initial summary + - name: initialResearchResults + type: string + description: Initial research notes + - name: deepResearchResults + type: string + description: Deep research notes + - name: deepSearchTopics + type: string + description: Deep search topics used (for transparency) + optional: true + - name: sources + type: string + description: Sources from initial research + optional: true + - name: deepSources + type: string + description: Sources from deep research + optional: true + outputs: + - name: article + type: string + description: A clear, well-structured article with headings and citations + - name: bibliography + type: string + description: Deduplicated bibliography with URLs + +modules: + researcher: + type: ReAct + signature: initial_research + options: + temperature: 0.4 + max_tokens: 4096 + max_iterations: 10 + verbose: true + mcp: + tavily: + tools: + - "*" + + deep_topic_planner: + type: ChainOfThought + signature: deep_search_planning + options: + temperature: 0.4 + max_tokens: 1536 + + deep_researcher: + type: ReAct + signature: deep_research + options: + temperature: 0.4 + max_tokens: 6144 + max_iterations: 12 + verbose: true + mcp: + tavily: + tools: + - "*" + + article_writer: + type: ChainOfThought + signature: write_article + options: + temperature: 0.5 + max_tokens: 8192 + +pipeline: + - module: researcher + - module: deep_topic_planner + - module: deep_researcher + - module: article_writer + +inputs: + topic: | + Snowflake data loading features, Ingestion of structured, + semi-structured, and unstructured data. Implementation of + stages and file formats. causes of ingestion errors. + resolutions for ingestion errors diff --git a/examples/yaml_program/pipeline_deep_review.yaml b/examples/yaml_program/pipeline_deep_review.yaml new file mode 100644 index 0000000..3850c70 --- /dev/null +++ b/examples/yaml_program/pipeline_deep_review.yaml @@ -0,0 +1,302 @@ +# DSGo YAML Pipeline: Deep Review (2-pass) +# +# Input: task +# Flow: +# task -> enhanced_task -> plan_prompt -> plan +# -> todo_1 -> test_1 -> review_1 +# -> todo_2 -> test_2 -> review_2 +# -> deliver +# +# Requires: +# - TAVILY_API_KEY for web search/extraction +# +# Built-in MCP: +# - shell (whitelisted command runner + apply_patch) +# +# Run: +# export TAVILY_API_KEY=your-key +# go run . pipeline_deep_review.yaml + +name: deep_review_pipeline +description: Two-iteration implement/test/review loop that can patch code and run tests + +model: + name: openrouter/z-ai/glm-4.6 + temperature: 0.5 + max_tokens: 10240 + +dsgo: + timeouts: + pipeline: 2h + lm_http: 2h + mcp_http: 2h + +# Define MCP clients (global registry) +mcp: + tavily: + type: tavily + shell: + type: shell + filesystem: + type: filesystem + +signatures: + enhance_task: + description: Rewrite the user task into an explicit, repo-aware implementation prompt + inputs: + - name: task + type: string + description: Task to perform in a codebase + outputs: + - name: enhanced_task + type: string + description: Enhanced implementation prompt with constraints and acceptance criteria + + plan_prompt: + description: Produce a planning prompt for a two-iteration coding loop + inputs: + - name: enhanced_task + type: string + description: Enhanced implementation prompt + outputs: + - name: plan_prompt + type: string + description: Prompt that asks for a practical plan with steps + + plan: + description: Produce a concrete plan and test strategy + inputs: + - name: enhanced_task + type: string + description: Enhanced implementation prompt + - name: plan_prompt + type: string + description: Planning prompt to follow + outputs: + - name: plan + type: string + description: Concrete implementation plan + - name: test_strategy + type: string + description: Specific commands to run and what success looks like + + todo_iteration: + description: Implement changes using tools; ONLY modify repo via apply_patch (no direct writes) + inputs: + - name: enhanced_task + type: string + description: Enhanced implementation prompt + - name: plan + type: string + description: Implementation plan + - name: review_feedback + type: string + description: Feedback from the previous review (empty for first iteration) + optional: true + outputs: + - name: changes_made + type: string + description: Summary of changes made + - name: patches_applied + type: string + description: What patches were applied (high level) + - name: files_touched + type: string + description: Files changed/added + + test_run: + description: Run tests/build/lint using shell_run (prefer make targets) and report results + inputs: + - name: test_strategy + type: string + description: Commands to run and expected outcomes + - name: changes_made + type: string + description: Context about what changed + outputs: + - name: test_status + type: class + classes: [pass, fail] + description: Whether tests passed + - name: test_output + type: string + description: Combined test output and summary + - name: commands_run + type: string + description: Commands executed + + review: + description: Review changes and test output, then decide pass/fail + inputs: + - name: enhanced_task + type: string + description: Enhanced implementation prompt + - name: plan + type: string + description: Plan that was followed + - name: changes_made + type: string + description: Summary of changes made + - name: test_status + type: string + description: pass/fail + - name: test_output + type: string + description: Test output + outputs: + - name: review_status + type: class + classes: [pass, fail] + description: Whether the review passed + - name: review_feedback + type: string + description: Actionable feedback to address in the next iteration if needed + + deliver: + description: Produce final deliverable; use shell_run to run git diff and validation commands + inputs: + - name: enhanced_task + type: string + description: Enhanced implementation prompt + - name: changes_made + type: string + description: Final summary of changes + - name: review_status + type: string + description: pass/fail + outputs: + - name: final_diff + type: string + description: git diff of final changes + - name: run_instructions + type: string + description: Commands to run to validate the result + +modules: + enhancer: + type: ChainOfThought + signature: enhance_task + options: + temperature: 0.4 + max_tokens: 2048 + + planner_prompt: + type: ChainOfThought + signature: plan_prompt + options: + temperature: 0.4 + max_tokens: 1024 + + planner: + type: ChainOfThought + signature: plan + options: + temperature: 0.4 + max_tokens: 2048 + + todo_1: + type: ReAct + signature: todo_iteration + options: + temperature: 0.3 + max_tokens: 4096 + max_iterations: 12 + verbose: true + mcp: + tavily: + tools: + - "*" + shell: + tools: + - "*" + filesystem: + tools: + - "*" + + test_1: + type: ReAct + signature: test_run + options: + temperature: 0.2 + max_tokens: 4096 + max_iterations: 8 + verbose: true + mcp: + shell: + tools: + - shell_run + + review_1: + type: ChainOfThought + signature: review + options: + temperature: 0.3 + max_tokens: 2048 + + todo_2: + type: ReAct + signature: todo_iteration + options: + temperature: 0.2 + max_tokens: 4096 + max_iterations: 12 + verbose: true + mcp: + tavily: + tools: + - "*" + shell: + tools: + - "*" + filesystem: + tools: + - "*" + + test_2: + type: ReAct + signature: test_run + options: + temperature: 0.2 + max_tokens: 4096 + max_iterations: 8 + verbose: true + mcp: + shell: + tools: + - shell_run + + review_2: + type: ChainOfThought + signature: review + options: + temperature: 0.3 + max_tokens: 2048 + + deliver: + type: ReAct + signature: deliver + options: + temperature: 0.2 + max_tokens: 4096 + max_iterations: 6 + verbose: false + mcp: + shell: + tools: + - shell_run + +pipeline: + - module: enhancer + - module: planner_prompt + - module: planner + - module: todo_1 + - module: test_1 + - module: review_1 + - module: todo_2 + - module: test_2 + - module: review_2 + - module: deliver + +inputs: + task: | + moonshotai provider support in providers package diff --git a/examples/yaml_program/pipeline_exa.yaml b/examples/yaml_program/pipeline_exa.yaml new file mode 100644 index 0000000..74300fd --- /dev/null +++ b/examples/yaml_program/pipeline_exa.yaml @@ -0,0 +1,53 @@ +# DSGo YAML Pipeline Definition with Exa MCP Integration +# Requires: EXA_API_KEY environment variable +# +# To run this example: +# export EXA_API_KEY=your-api-key +# go run . pipeline_exa.yaml + +name: exa_research_pipeline +description: A pipeline that researches topics using Exa web search + +settings: + temperature: 0.7 + max_tokens: 10240 + +# Define MCP clients (global registry) +mcp: + exa: + type: exa + +signatures: + code_research: + description: Research programming topics and find code examples + inputs: + - name: question + type: string + description: Programming question or topic to research + outputs: + - name: answer + type: string + description: Comprehensive answer with code examples + - name: sources + type: string + description: Sources consulted + +modules: + code_researcher: + type: ReAct + signature: code_research + options: + temperature: 0.3 + max_iterations: 8 + mcp: + exa: + tools: + - web_search_exa + - get_code_context_exa + +pipeline: + - module: code_researcher + +# Pipeline inputs +inputs: + question: "How do I implement a ReAct agent in Go using DSGo framework?" diff --git a/examples/yaml_program/pipeline_functions.yaml b/examples/yaml_program/pipeline_functions.yaml new file mode 100644 index 0000000..074bf9b --- /dev/null +++ b/examples/yaml_program/pipeline_functions.yaml @@ -0,0 +1,109 @@ +# DSGo YAML Pipeline Definition with Function Tools +# This file demonstrates ReAct with native Go function tools and filesystem MCP. +# +# Requires: Node.js (npx) or Bun (bunx) for filesystem MCP server + +name: assistant_pipeline +description: A pipeline that answers questions using built-in utility tools + +# Global settings for the pipeline +settings: + temperature: 0.7 + max_tokens: 4096 + +# Define MCP clients (global registry) +mcp: + filesystem: + type: filesystem + +# Define function tools (native Go implementations) +tools: + get_datetime: + type: function + name: current_datetime + + calculator: + type: function + name: calculate + + random: + type: function + name: random_number + + text_analyzer: + type: function + name: word_count + + env_info: + type: function + name: environment_info + +# Define reusable signatures +signatures: + assistant: + description: Answer questions using available tools + inputs: + - name: question + type: string + description: The question to answer + outputs: + - name: answer + type: string + description: The answer to the question + - name: tools_used + type: string + description: Which tools were used to find the answer + + analysis: + description: Analyze text or numbers + inputs: + - name: task + type: string + description: The analysis task to perform + - name: data + type: string + description: The data to analyze + outputs: + - name: result + type: string + description: Analysis result + - name: method + type: string + description: Method used for analysis + +# Define modules that use the signatures +modules: + assistant: + type: ReAct + signature: assistant + options: + temperature: 0.5 + max_iterations: 6 + tools: + - get_datetime + - calculator + - random + - text_analyzer + - env_info + mcp: + filesystem: + tools: + - "*" # Use all filesystem MCP tools + + analyzer: + type: ReAct + signature: analysis + options: + temperature: 0.3 + max_iterations: 5 + tools: + - calculator + - text_analyzer + +# Define the pipeline as a sequence of modules +pipeline: + - module: assistant + +# Pipeline inputs +inputs: + question: "What is the current date and time? Also calculate 25 * 4 + 100." diff --git a/examples/yaml_program/pipeline_idea_to_implementation.yaml b/examples/yaml_program/pipeline_idea_to_implementation.yaml new file mode 100644 index 0000000..bf2d12a --- /dev/null +++ b/examples/yaml_program/pipeline_idea_to_implementation.yaml @@ -0,0 +1,157 @@ +# DSGo YAML Pipeline: Idea to Implementation +# This pipeline takes a high-level idea and progressively breaks it down +# into tasks, plans, subtasks, and implementation details. + +name: idea_to_implementation +description: Progressive breakdown of ideas into actionable implementations + +# Global settings for the pipeline +settings: + temperature: 0.7 + max_tokens: 10240 + +# Define reusable signatures +signatures: + idea_to_task: + description: Convert a high-level idea into a concrete task definition + inputs: + - name: idea + type: string + description: High-level idea or concept to implement + - name: context + type: string + description: Additional context about the project or domain + optional: true + outputs: + - name: task + type: string + description: Clear, actionable task definition with specific goals + - name: success_criteria + type: string + description: Measurable criteria for task completion + - name: scope + type: string + description: What's in scope and what's out of scope + + task_to_plan: + description: Create a strategic plan from a task definition + inputs: + - name: task + type: string + description: The task definition to plan for + - name: success_criteria + type: string + description: Success criteria from previous stage + - name: scope + type: string + description: Scope boundaries from previous stage + outputs: + - name: plan + type: string + description: Strategic plan with approach and methodology + - name: phases + type: string + description: High-level phases or stages + - name: dependencies + type: string + description: Key dependencies and prerequisites + - name: risks + type: string + description: Potential risks and mitigation strategies + + plan_to_subtasks: + description: Break down a plan into specific subtasks + inputs: + - name: plan + type: string + description: The strategic plan + - name: phases + type: string + description: High-level phases from the plan + - name: dependencies + type: string + description: Dependencies to consider + outputs: + - name: subtasks + type: string + description: List of concrete subtasks with priorities + - name: sequence + type: string + description: Recommended execution sequence + - name: effort_estimate + type: string + description: Rough effort estimates for each subtask + + subtasks_to_implementation: + description: Generate detailed implementation guidance for subtasks + inputs: + - name: subtasks + type: string + description: List of subtasks to implement + - name: sequence + type: string + description: Recommended execution sequence + - name: effort_estimate + type: string + description: Effort estimates + - name: task + type: string + description: Original task definition for context + outputs: + - name: implementation + type: string + description: Detailed implementation steps and code structure + - name: technical_approach + type: string + description: Technical approach and architecture decisions + - name: testing_strategy + type: string + description: Testing approach and validation methods + - name: next_steps + type: string + description: Immediate next steps to begin implementation + +# Define modules that use the signatures +modules: + idea_analyzer: + type: ChainOfThought + signature: idea_to_task + options: + temperature: 0.6 + max_tokens: 2048 + + task_planner: + type: ChainOfThought + signature: task_to_plan + options: + temperature: 0.7 + max_tokens: 3072 + + subtask_breaker: + type: ChainOfThought + signature: plan_to_subtasks + options: + temperature: 0.5 + max_tokens: 3072 + + implementation_guide: + type: ChainOfThought + signature: subtasks_to_implementation + options: + temperature: 0.6 + max_tokens: 4096 + +# Define the pipeline as a sequence of modules +pipeline: + - module: idea_analyzer + - module: task_planner + - module: subtask_breaker + - module: implementation_guide + +# Pipeline inputs +inputs: + idea: | + Build a hello world http rest api in golang. + + context: | + experimentation diff --git a/examples/yaml_program/pipeline_mcp.yaml b/examples/yaml_program/pipeline_mcp.yaml new file mode 100644 index 0000000..ab8557c --- /dev/null +++ b/examples/yaml_program/pipeline_mcp.yaml @@ -0,0 +1,96 @@ +# DSGo YAML Pipeline Definition with MCP Integration +# This file demonstrates ReAct with MCP-based web search tools. +# Requires: TAVILY_API_KEY environment variable for web search +# +# To run this example: +# export TAVILY_API_KEY=your-api-key +# go run . pipeline_mcp.yaml +# +# Get a free Tavily API key at: https://tavily.com + +name: dsgo_research_pipeline +description: Deep code analysis for repo https://github.com/assagman/dsgo + +# Global settings for the pipeline +settings: + temperature: 0.7 + max_tokens: 10240 + +# Define MCP clients (global registry) +mcp: + tavily: + type: tavily + # api_key can be set here or via TAVILY_API_KEY env var + +# Define reusable signatures +signatures: + research: + description: Research a topic using web search + inputs: + - name: topic + type: string + description: The topic to research + - name: depth + type: string + description: How deep to research (brief, moderate, comprehensive) + outputs: + - name: summary + type: string + description: Summary of research findings + - name: key_facts + type: string + description: Key facts discovered + - name: sources + type: string + description: Sources consulted + + fact_check: + description: Verify facts using web search + inputs: + - name: claim + type: string + description: The claim to fact-check + outputs: + - name: verdict + type: class + classes: [true, false, partially_true, unverifiable] + description: Verdict on the claim + - name: evidence + type: string + description: Evidence supporting the verdict + - name: confidence + type: float + description: Confidence in the verdict (0-1) + +# Define modules that use the signatures +modules: + researcher: + type: ReAct + signature: research + options: + temperature: 0.5 + max_iterations: 10 + mcp: + tavily: + tools: + - "*" # Use all tools from tavily + + fact_checker: + type: ReAct + signature: fact_check + options: + temperature: 0.3 + max_iterations: 6 + mcp: + tavily: + tools: + - tavily_search # Use only specific tool + +# Define the pipeline as a sequence of modules +pipeline: + - module: researcher + +# Pipeline inputs - these are passed to the first module +inputs: + topic: "DSGo framework architecture and capabilities" + depth: "comprehensive" diff --git a/examples/yaml_program/pipeline_react.yaml b/examples/yaml_program/pipeline_react.yaml new file mode 100644 index 0000000..e8b88ae --- /dev/null +++ b/examples/yaml_program/pipeline_react.yaml @@ -0,0 +1,87 @@ +# DSGo YAML Pipeline Definition with ReAct and Tools +# This file demonstrates the ReAct module with filesystem MCP tools +# for code analysis tasks. +# +# Requires: Node.js (npx) or Bun (bunx) for filesystem MCP server + +name: code_analysis_pipeline +description: A pipeline that analyzes codebase using ReAct with filesystem MCP + +# Global settings for the pipeline +settings: + temperature: 0.7 + max_tokens: 10240 + +# Define MCP clients (global registry) +mcp: + filesystem: + type: filesystem + +# Define reusable signatures +signatures: + code_exploration: + description: Explore and understand a codebase structure + inputs: + - name: question + type: string + description: The question about the codebase to answer + outputs: + - name: answer + type: string + description: Detailed answer based on codebase exploration + - name: files_examined + type: string + description: List of files that were examined + + code_review: + description: Review code quality and suggest improvements + inputs: + - name: filepath + type: string + description: Path to the file to review + - name: focus_areas + type: string + description: Specific areas to focus the review on + outputs: + - name: issues + type: string + description: List of identified issues + - name: suggestions + type: string + description: Improvement suggestions + - name: quality_score + type: int + description: Overall quality score from 1-10 + +# Define modules that use the signatures +modules: + code_explorer: + type: ReAct + signature: code_exploration + options: + temperature: 0.5 + max_iterations: 8 + mcp: + filesystem: + tools: + - "*" # Use all filesystem tools + + code_reviewer: + type: ReAct + signature: code_review + options: + temperature: 0.3 + max_iterations: 5 + mcp: + filesystem: + tools: + - read_file + - get_file_info + +# Define the pipeline as a sequence of modules +pipeline: + - module: code_explorer + +# Pipeline inputs +inputs: + question: "What Go files are in this project? List them with their directory structure." diff --git a/examples/yaml_program/pipeline_todo.yaml b/examples/yaml_program/pipeline_todo.yaml new file mode 100644 index 0000000..63b37ba --- /dev/null +++ b/examples/yaml_program/pipeline_todo.yaml @@ -0,0 +1,142 @@ +# DSGo YAML Pipeline: Task -> Codebase analysis -> Web research -> TODO list +# +# Input: task +# Flow: +# task -> codebase_analysis -> web_research -> todo +# +# Requires: +# - TAVILY_API_KEY for web search/extraction (MCP) +# - Node.js (npx) or Bun (bunx) for filesystem MCP server +# +# Run: +# export TAVILY_API_KEY=your-key +# go run . pipeline_todo.yaml + +name: pipeline_todo +description: Generates a repo-aware numbered TODO list from a task + +model: + name: openrouter/z-ai/glm-4.6 + temperature: 0.4 + max_tokens: 8192 + +dsgo: + timeouts: + pipeline: 45m + lm_http: 15m + mcp_http: 10m + +# Define MCP clients (global registry) +mcp: + tavily: + type: tavily + filesystem: + type: filesystem + allowed_dirs: + - ~/source/me/dsgo/example-generator/ + +signatures: + codebase_analysis: + description: Analyze the current repository to locate files and implementation approach + inputs: + - name: task + type: string + description: Task to perform in this repository + outputs: + - name: codebase_notes + type: string + description: Repo-specific findings (entry points, constraints, suggested approach) + - name: files_examined + type: string + description: Files/dirs examined (paths) + + web_research: + description: Research the task on the web based on codebase analysis findings + inputs: + - name: task + type: string + description: Task to perform in a codebase + - name: codebase_notes + type: string + description: Codebase analysis notes to guide research + optional: true + - name: files_examined + type: string + description: Files examined in the codebase + optional: true + outputs: + - name: research_notes + type: string + description: Practical notes (best practices, pitfalls, relevant APIs), tailored to the task and codebase + - name: sources + type: string + description: 5-12 URLs with a 1-line relevance note each + + write_todo: + description: Produce a flat, actionable TODO list for implementing the task + inputs: + - name: task + type: string + description: The original task + - name: research_notes + type: string + description: Web research notes + optional: true + - name: codebase_notes + type: string + description: Repo analysis notes + outputs: + - name: todo + type: string + description: | + Output ONLY a numbered TODO list. + One item per line, in order, starting with "1.". + No headings, no grouping, no blank lines. + Make items actionable and repo-specific. + Mention file paths and commands when appropriate. + +modules: + researcher: + type: ReAct + signature: web_research + model: openrouter/google/gemini-2.5-flash + options: + temperature: 0.3 + max_tokens: 4096 + max_iterations: 10 + verbose: true + mcp: + tavily: + tools: + - "*" + + analyzer: + type: ReAct + signature: codebase_analysis + model: openrouter/openai/gpt-5.2 + options: + temperature: 0.4 + max_tokens: 4096 + max_iterations: 12 + verbose: true + mcp: + filesystem: + tools: + - "*" + + todo_writer: + type: ChainOfThought + signature: write_todo + model: openrouter/z-ai/glm-4.6 + options: + temperature: 0.2 + max_tokens: 2048 + +pipeline: + - module: analyzer + - module: researcher + - module: todo_writer + +inputs: + task: | + Review cost operations diff --git a/examples/yaml_program/program.go b/examples/yaml_program/program.go new file mode 100644 index 0000000..321b6fd --- /dev/null +++ b/examples/yaml_program/program.go @@ -0,0 +1,84 @@ +package main + +import ( + "context" + "fmt" + + "github.com/assagman/dsgo" +) + +// ProgramBuilder builds a DSGo Program from YAML configuration +type ProgramBuilder struct { + config *PipelineConfig + sigRegistry *SignatureRegistry + toolRegistry *ToolRegistry + moduleRegistry *ModuleRegistry +} + +// NewProgramBuilder creates a new program builder from configuration +func NewProgramBuilder(ctx context.Context, config *PipelineConfig, defaultModel string, lm dsgo.LM) (*ProgramBuilder, error) { + // Create signature registry + sigRegistry, err := NewSignatureRegistry(config.Signatures) + if err != nil { + return nil, fmt.Errorf("failed to create signature registry: %w", err) + } + + // Create MCP client registry (if any MCP configs exist) + var mcpRegistry *MCPClientRegistry + if len(config.MCP) > 0 { + mcpRegistry, err = NewMCPClientRegistry(ctx, config.MCP, config.EffectiveTimeouts()) + if err != nil { + return nil, fmt.Errorf("failed to create MCP client registry: %w", err) + } + } + + // Create tool registry + toolRegistry, err := NewToolRegistry(ctx, config.Tools, mcpRegistry) + if err != nil { + return nil, fmt.Errorf("failed to create tool registry: %w", err) + } + + // Create module factory and registry + factory := NewModuleFactory(ctx, defaultModel, lm, sigRegistry, toolRegistry, config.EffectiveModelSettings()) + moduleRegistry, err := NewModuleRegistry(factory, config.Modules) + if err != nil { + return nil, fmt.Errorf("failed to create module registry: %w", err) + } + + return &ProgramBuilder{ + config: config, + sigRegistry: sigRegistry, + toolRegistry: toolRegistry, + moduleRegistry: moduleRegistry, + }, nil +} + +// Build creates the DSGo Program from the pipeline definition +func (pb *ProgramBuilder) Build() (*dsgo.Program, error) { + program := dsgo.NewProgram(pb.config.Name) + + for i, step := range pb.config.Pipeline { + module, err := pb.moduleRegistry.Get(step.Module) + if err != nil { + return nil, fmt.Errorf("pipeline step %d: %w", i+1, err) + } + program.AddModule(module) + } + + return program, nil +} + +// GetConfig returns the pipeline configuration +func (pb *ProgramBuilder) GetConfig() *PipelineConfig { + return pb.config +} + +// GetSignatureRegistry returns the signature registry +func (pb *ProgramBuilder) GetSignatureRegistry() *SignatureRegistry { + return pb.sigRegistry +} + +// GetModuleRegistry returns the module registry +func (pb *ProgramBuilder) GetModuleRegistry() *ModuleRegistry { + return pb.moduleRegistry +} diff --git a/examples/yaml_program/tools.go b/examples/yaml_program/tools.go new file mode 100644 index 0000000..75cd4cd --- /dev/null +++ b/examples/yaml_program/tools.go @@ -0,0 +1,706 @@ +package main + +import ( + "context" + "fmt" + "io/fs" + "math/rand" + "net/url" + "os" + "path/filepath" + "runtime" + "strings" + "time" + + "github.com/assagman/dsgo" +) + +// MCPClientRegistry holds initialized MCP clients +type MCPClientRegistry struct { + clients map[string]*dsgo.MCPClient + specs map[string]MCPSpec +} + +// NewMCPClientRegistry creates MCP clients from YAML specs +func NewMCPClientRegistry(ctx context.Context, specs map[string]MCPSpec, timeouts TimeoutSettings) (*MCPClientRegistry, error) { + registry := &MCPClientRegistry{ + clients: make(map[string]*dsgo.MCPClient), + specs: specs, + } + + for name, spec := range specs { + client, err := createMCPClient(ctx, name, spec, timeouts) + if err != nil { + return nil, fmt.Errorf("failed to create MCP client '%s': %w", name, err) + } + registry.clients[name] = client + } + + return registry, nil +} + +// GetSpec returns the spec for an MCP client +func (r *MCPClientRegistry) GetSpec(name string) (MCPSpec, bool) { + spec, exists := r.specs[name] + return spec, exists +} + +// Clients returns all client names +func (r *MCPClientRegistry) Clients() []string { + names := make([]string, 0, len(r.clients)) + for name := range r.clients { + names = append(names, name) + } + return names +} + +// Get returns an MCP client by name +func (r *MCPClientRegistry) Get(name string) (*dsgo.MCPClient, error) { + client, exists := r.clients[name] + if !exists { + return nil, fmt.Errorf("MCP client not found: %s", name) + } + return client, nil +} + +// createMCPClient creates an MCP client from a spec. +// +// We build clients via explicit transports so the YAML runner can override +// timeouts without relying on environment variables. +func createMCPClient(ctx context.Context, name string, spec MCPSpec, timeouts TimeoutSettings) (*dsgo.MCPClient, error) { + apiKey := spec.APIKey + if apiKey == "" { + apiKey = getAPIKeyFromEnv(spec.Type) + } + + httpTimeout := timeouts.MCPHTTP.Duration + postTimeout := timeouts.MCPSSEPost.Duration + waitTimeout := timeouts.MCPSSEWait.Duration + + var transport dsgo.MCPTransport + switch spec.Type { + case "exa": + transport = dsgo.NewMCPHTTPTransportWithTimeout("https://mcp.exa.ai/mcp", apiKey, httpTimeout) + case "jina": + transport = dsgo.NewMCPSSETransportWithTimeouts("https://mcp.jina.ai/sse", apiKey, postTimeout, waitTimeout) + case "tavily": + baseURL, err := url.Parse("https://mcp.tavily.com/mcp") + if err != nil { + return nil, fmt.Errorf("failed to parse Tavily MCP URL: %w", err) + } + q := baseURL.Query() + q.Set("tavilyApiKey", apiKey) + baseURL.RawQuery = q.Encode() + + // Tavily uses the query param for auth, not headers. + transport = dsgo.NewMCPHTTPTransportWithTimeout(baseURL.String(), "", httpTimeout) + case "shell": + projectRoot, err := findProjectRoot() + if err != nil { + return nil, fmt.Errorf("failed to find project root: %w", err) + } + shellServer, err := dsgo.NewMCPShellServer(dsgo.MCPShellServerConfig{RootDir: projectRoot}) + if err != nil { + return nil, fmt.Errorf("failed to create shell MCP server: %w", err) + } + transport, err = dsgo.NewMCPLocalTransport(shellServer) + if err != nil { + return nil, fmt.Errorf("failed to create local transport: %w", err) + } + case "filesystem": + projectRoot, err := findProjectRoot() + if err != nil { + return nil, fmt.Errorf("failed to find project root: %w", err) + } + allowedDirs := []string{projectRoot} + if len(spec.AllowedDirs) > 0 { + allowedDirs = spec.AllowedDirs + } + client, err := dsgo.NewMCPFilesystemClient(allowedDirs...) + if err != nil { + return nil, fmt.Errorf("failed to create filesystem MCP client: %w", err) + } + if err := client.Initialize(ctx); err != nil { + return nil, fmt.Errorf("failed to initialize filesystem MCP client: %w", err) + } + return client, nil + case "custom": + transport = dsgo.NewMCPHTTPTransportWithTimeout(spec.URL, apiKey, httpTimeout) + default: + return nil, fmt.Errorf("unsupported MCP type: %s", spec.Type) + } + + client, err := dsgo.NewMCPClient(dsgo.MCPClientConfig{Transport: transport}) + if err != nil { + return nil, err + } + + if err := client.Initialize(ctx); err != nil { + return nil, fmt.Errorf("failed to initialize MCP client: %w", err) + } + + return client, nil +} + +// getAPIKeyFromEnv returns the API key from environment variable based on MCP type +func getAPIKeyFromEnv(mcpType string) string { + envVars := map[string]string{ + "exa": "EXA_API_KEY", + "jina": "JINA_API_KEY", + "tavily": "TAVILY_API_KEY", + } + + if envVar, exists := envVars[mcpType]; exists { + return os.Getenv(envVar) + } + return "" +} + +// ToolRegistry holds resolved DSGo tools (custom tools only, not MCP) +type ToolRegistry struct { + tools map[string]dsgo.Tool + mcpRegistry *MCPClientRegistry +} + +// NewToolRegistry creates tools from YAML specs (custom tools only) +func NewToolRegistry(ctx context.Context, specs map[string]ToolSpec, mcpRegistry *MCPClientRegistry) (*ToolRegistry, error) { + registry := &ToolRegistry{ + tools: make(map[string]dsgo.Tool), + mcpRegistry: mcpRegistry, + } + + for name, spec := range specs { + tool, err := createTool(name, spec) + if err != nil { + return nil, fmt.Errorf("failed to create tool '%s': %w", name, err) + } + registry.tools[name] = tool + } + + return registry, nil +} + +// Get returns a tool by name +func (r *ToolRegistry) Get(name string) (dsgo.Tool, error) { + tool, exists := r.tools[name] + if !exists { + return dsgo.Tool{}, fmt.Errorf("tool not found: %s", name) + } + return tool, nil +} + +// GetMultiple returns multiple tools by name +func (r *ToolRegistry) GetMultiple(names []string) ([]dsgo.Tool, error) { + result := make([]dsgo.Tool, 0, len(names)) + for _, name := range names { + tool, err := r.Get(name) + if err != nil { + return nil, err + } + result = append(result, tool) + } + return result, nil +} + +// GetMCPTools returns tools from an MCP client, filtered by the tool list. +// If toolFilters contains "*", all tools from that MCP client are returned. +func (r *ToolRegistry) GetMCPTools(mcpName string, toolFilters []string) ([]dsgo.Tool, error) { + if r.mcpRegistry == nil { + return nil, fmt.Errorf("no MCP registry available") + } + + client, err := r.mcpRegistry.Get(mcpName) + if err != nil { + return nil, err + } + + allTools := client.GetTools() + + // Check for wildcard + for _, filter := range toolFilters { + if filter == "*" { + return allTools, nil + } + } + + // Filter to specific tools + filterSet := make(map[string]bool) + for _, name := range toolFilters { + filterSet[name] = true + } + + result := make([]dsgo.Tool, 0, len(toolFilters)) + for _, tool := range allTools { + if filterSet[tool.Name] { + result = append(result, tool) + } + } + + return result, nil +} + +// GetAllMCPToolsForModule resolves all MCP tools configured for a module +func (r *ToolRegistry) GetAllMCPToolsForModule(mcpConfigs map[string]ModuleMCPSpec) ([]dsgo.Tool, error) { + var result []dsgo.Tool + for mcpName, mcpSpec := range mcpConfigs { + tools, err := r.GetMCPTools(mcpName, mcpSpec.Tools) + if err != nil { + return nil, fmt.Errorf("failed to get tools from MCP '%s': %w", mcpName, err) + } + result = append(result, tools...) + } + return result, nil +} + +// createTool creates a DSGo tool from a spec (custom tools only, not MCP) +func createTool(name string, spec ToolSpec) (dsgo.Tool, error) { + switch spec.Type { + case "filesystem": + return createFilesystemTool(name, spec) + case "function": + return createFunctionTool(name, spec) + default: + return dsgo.Tool{}, fmt.Errorf("unsupported tool type: %s (valid: filesystem, function)", spec.Type) + } +} + +// createFilesystemTool creates a filesystem tool +func createFilesystemTool(name string, spec ToolSpec) (dsgo.Tool, error) { + toolName := spec.Name + if toolName == "" { + toolName = name + } + + switch toolName { + case "list_files": + return newListFilesTool(), nil + case "read_file": + return newReadFileTool(), nil + case "search_files": + return newSearchFilesTool(), nil + default: + return dsgo.Tool{}, fmt.Errorf("unknown filesystem tool: %s", toolName) + } +} + +// findProjectRoot finds the project root by looking for go.mod +func findProjectRoot() (string, error) { + dir, err := os.Getwd() + if err != nil { + return "", err + } + + for { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir, nil + } + parent := filepath.Dir(dir) + if parent == dir { + return os.Getwd() + } + dir = parent + } +} + +// validatePath ensures path is within project root +func validatePath(projectRoot, path string) (string, error) { + if !filepath.IsAbs(path) { + path = filepath.Join(projectRoot, path) + } + path = filepath.Clean(path) + + if path != projectRoot && !strings.HasPrefix(path, projectRoot+string(os.PathSeparator)) { + return "", fmt.Errorf("path %s is outside project root", path) + } + return path, nil +} + +// newListFilesTool creates a tool for listing files +func newListFilesTool() dsgo.Tool { + listFiles := func(ctx context.Context, args map[string]any) (any, error) { + projectRoot, err := findProjectRoot() + if err != nil { + return nil, fmt.Errorf("failed to find project root: %w", err) + } + + directory := projectRoot + if dirArg, ok := args["directory"].(string); ok && dirArg != "" { + directory, err = validatePath(projectRoot, dirArg) + if err != nil { + return nil, err + } + } + + depth := 3 + if depthVal, ok := args["depth"].(float64); ok { + depth = int(depthVal) + } + + var files []string + err = filepath.WalkDir(directory, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + relPath, err := filepath.Rel(projectRoot, path) + if err != nil { + return err + } + currentDepth := len(strings.Split(relPath, string(os.PathSeparator))) + + if currentDepth > depth { + if d.IsDir() { + return filepath.SkipDir + } + return nil + } + + if d.IsDir() { + files = append(files, relPath+"/") + } else { + files = append(files, relPath) + } + return nil + }) + + if err != nil { + return nil, fmt.Errorf("error walking directory: %w", err) + } + + return map[string]any{ + "files": files, + "directory": directory, + }, nil + } + + return *dsgo.NewTool("list_files", "List files and directories in a given path up to a specified depth", listFiles). + AddParameter("directory", "string", "The directory path to list (relative to project root)", false). + AddParameter("depth", "int", "Maximum depth to traverse (default: 3)", false) +} + +// newReadFileTool creates a tool for reading files +func newReadFileTool() dsgo.Tool { + readFile := func(ctx context.Context, args map[string]any) (any, error) { + projectRoot, err := findProjectRoot() + if err != nil { + return nil, fmt.Errorf("failed to find project root: %w", err) + } + + filepathArg, ok := args["filepath"].(string) + if !ok || filepathArg == "" { + return nil, fmt.Errorf("filepath parameter is required") + } + + filepathArg, err = validatePath(projectRoot, filepathArg) + if err != nil { + return nil, err + } + + content, err := os.ReadFile(filepathArg) + if err != nil { + return nil, fmt.Errorf("error reading file: %w", err) + } + + totalBytes := len(content) + totalLines := strings.Count(string(content), "\n") + if totalBytes > 0 && content[totalBytes-1] != '\n' { + totalLines++ + } + + contentStr := string(content) + truncated := false + if len(contentStr) > 10000 { + contentStr = contentStr[:10000] + truncated = true + truncatedLines := strings.Count(contentStr, "\n") + contentStr += fmt.Sprintf("\n... [truncated: showing %d/%d bytes, ~%d/%d lines]", + 10000, totalBytes, truncatedLines, totalLines) + } + + return map[string]any{ + "content": contentStr, + "filepath": filepathArg, + "size_bytes": totalBytes, + "total_lines": totalLines, + "truncated": truncated, + }, nil + } + + return *dsgo.NewTool("read_file", "Read the content of a specific file", readFile). + AddParameter("filepath", "string", "The path to the file to read (relative to project root)", true) +} + +// newSearchFilesTool creates a tool for searching files by glob pattern +func newSearchFilesTool() dsgo.Tool { + searchFiles := func(ctx context.Context, args map[string]any) (any, error) { + projectRoot, err := findProjectRoot() + if err != nil { + return nil, fmt.Errorf("failed to find project root: %w", err) + } + + directory := projectRoot + if dirArg, ok := args["directory"].(string); ok && dirArg != "" { + directory, err = validatePath(projectRoot, dirArg) + if err != nil { + return nil, err + } + } + + pattern, ok := args["pattern"].(string) + if !ok || pattern == "" { + return nil, fmt.Errorf("pattern parameter is required") + } + + fullPattern := filepath.Join(directory, pattern) + matches, err := filepath.Glob(fullPattern) + if err != nil { + return nil, fmt.Errorf("error searching files: %w", err) + } + + relativeMatches := make([]string, len(matches)) + for i, match := range matches { + relativePath, err := filepath.Rel(projectRoot, match) + if err != nil { + relativeMatches[i] = match + } else { + relativeMatches[i] = relativePath + } + } + + return map[string]any{ + "files": relativeMatches, + "directory": directory, + "pattern": pattern, + }, nil + } + + return *dsgo.NewTool("search_files", "Search for files matching a glob pattern (standard Go match, no ** recursion)", searchFiles). + AddParameter("directory", "string", "The directory to search in (relative to project root)", false). + AddParameter("pattern", "string", "Glob pattern to match (e.g., *.go)", true) +} + +// createFunctionTool creates a native Go function tool +func createFunctionTool(name string, spec ToolSpec) (dsgo.Tool, error) { + toolName := spec.Name + if toolName == "" { + toolName = name + } + + switch toolName { + case "current_datetime": + return newCurrentDateTimeTool(), nil + case "calculate": + return newCalculateTool(), nil + case "random_number": + return newRandomNumberTool(), nil + case "string_length": + return newStringLengthTool(), nil + case "word_count": + return newWordCountTool(), nil + case "environment_info": + return newEnvironmentInfoTool(), nil + default: + return dsgo.Tool{}, fmt.Errorf("unknown function tool: %s", toolName) + } +} + +// newCurrentDateTimeTool creates a tool that returns current date and time +func newCurrentDateTimeTool() dsgo.Tool { + getCurrentDateTime := func(ctx context.Context, args map[string]any) (any, error) { + now := time.Now() + + format := "2006-01-02 15:04:05" + if formatArg, ok := args["format"].(string); ok && formatArg != "" { + format = formatArg + } + + timezone := "Local" + if tzArg, ok := args["timezone"].(string); ok && tzArg != "" { + loc, err := time.LoadLocation(tzArg) + if err != nil { + return nil, fmt.Errorf("invalid timezone: %w", err) + } + now = now.In(loc) + timezone = tzArg + } + + return map[string]any{ + "datetime": now.Format(format), + "unix": now.Unix(), + "timezone": timezone, + "day_of_week": now.Weekday().String(), + "iso8601": now.Format(time.RFC3339), + }, nil + } + + return *dsgo.NewTool("current_datetime", "Get the current date and time", getCurrentDateTime). + AddParameter("format", "string", "Date format (Go style, e.g., '2006-01-02')", false). + AddParameter("timezone", "string", "Timezone (e.g., 'America/New_York', 'UTC')", false) +} + +// newCalculateTool creates a simple calculator tool +func newCalculateTool() dsgo.Tool { + calculate := func(ctx context.Context, args map[string]any) (any, error) { + a, ok := args["a"].(float64) + if !ok { + return nil, fmt.Errorf("parameter 'a' is required and must be a number") + } + + b, ok := args["b"].(float64) + if !ok { + return nil, fmt.Errorf("parameter 'b' is required and must be a number") + } + + op, ok := args["operation"].(string) + if !ok { + op = "add" + } + + var result float64 + switch op { + case "add", "+": + result = a + b + case "subtract", "-": + result = a - b + case "multiply", "*": + result = a * b + case "divide", "/": + if b == 0 { + return nil, fmt.Errorf("division by zero") + } + result = a / b + default: + return nil, fmt.Errorf("unknown operation: %s (valid: add, subtract, multiply, divide)", op) + } + + return map[string]any{ + "result": result, + "operation": op, + "a": a, + "b": b, + }, nil + } + + return *dsgo.NewTool("calculate", "Perform basic arithmetic calculations", calculate). + AddParameter("a", "number", "First operand", true). + AddParameter("b", "number", "Second operand", true). + AddParameter("operation", "string", "Operation: add, subtract, multiply, divide", false) +} + +// newRandomNumberTool creates a random number generator tool +func newRandomNumberTool() dsgo.Tool { + randomNumber := func(ctx context.Context, args map[string]any) (any, error) { + minVal := 1.0 + maxVal := 100.0 + + if min, ok := args["min"].(float64); ok { + minVal = min + } + if max, ok := args["max"].(float64); ok { + maxVal = max + } + + if minVal >= maxVal { + return nil, fmt.Errorf("min must be less than max") + } + + result := minVal + rand.Float64()*(maxVal-minVal) + + return map[string]any{ + "number": int(result), + "min": int(minVal), + "max": int(maxVal), + }, nil + } + + return *dsgo.NewTool("random_number", "Generate a random number within a range", randomNumber). + AddParameter("min", "number", "Minimum value (default: 1)", false). + AddParameter("max", "number", "Maximum value (default: 100)", false) +} + +// newStringLengthTool creates a tool that returns string length +func newStringLengthTool() dsgo.Tool { + stringLength := func(ctx context.Context, args map[string]any) (any, error) { + text, ok := args["text"].(string) + if !ok { + return nil, fmt.Errorf("parameter 'text' is required") + } + + return map[string]any{ + "length": len(text), + "rune_count": len([]rune(text)), + "word_count": len(strings.Fields(text)), + "line_count": strings.Count(text, "\n") + 1, + "is_empty": len(strings.TrimSpace(text)) == 0, + }, nil + } + + return *dsgo.NewTool("string_length", "Get length and character count of a string", stringLength). + AddParameter("text", "string", "The text to analyze", true) +} + +// newWordCountTool creates a tool that counts words +func newWordCountTool() dsgo.Tool { + wordCount := func(ctx context.Context, args map[string]any) (any, error) { + text, ok := args["text"].(string) + if !ok { + return nil, fmt.Errorf("parameter 'text' is required") + } + + words := strings.Fields(text) + wordFreq := make(map[string]int) + for _, word := range words { + word = strings.ToLower(strings.Trim(word, ".,!?;:\"'()[]{}")) + if word != "" { + wordFreq[word]++ + } + } + + return map[string]any{ + "total_words": len(words), + "unique_words": len(wordFreq), + "characters": len(text), + "average_word_length": func() float64 { + if len(words) == 0 { + return 0 + } + total := 0 + for _, w := range words { + total += len(w) + } + return float64(total) / float64(len(words)) + }(), + }, nil + } + + return *dsgo.NewTool("word_count", "Count words and analyze text statistics", wordCount). + AddParameter("text", "string", "The text to analyze", true) +} + +// newEnvironmentInfoTool creates a tool that returns environment information +func newEnvironmentInfoTool() dsgo.Tool { + environmentInfo := func(ctx context.Context, args map[string]any) (any, error) { + cwd, _ := os.Getwd() + hostname, _ := os.Hostname() + + info := map[string]any{ + "os": runtime.GOOS, + "arch": runtime.GOARCH, + "hostname": hostname, + "working_dir": cwd, + "go_version": runtime.Version(), + "num_cpu": runtime.NumCPU(), + "user": os.Getenv("USER"), + "home": os.Getenv("HOME"), + } + + if varName, ok := args["var"].(string); ok && varName != "" { + info["requested_var"] = os.Getenv(varName) + } + + return info, nil + } + + return *dsgo.NewTool("environment_info", "Get information about the current environment", environmentInfo). + AddParameter("var", "string", "Optional: specific environment variable to retrieve", false) +} diff --git a/internal/mcp/client.go b/internal/mcp/client.go index d084f54..126e6f6 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -200,7 +200,21 @@ func (c *Client) CallTool(ctx context.Context, name string, args map[string]any) } if result.IsError { - return "", fmt.Errorf("tool returned error status") + // Many MCP servers return error details in the content array. + // Preserve that information so callers can see the real failure cause. + var errText string + for _, content := range result.Content { + if content.Type == "text" && content.Text != "" { + if errText != "" { + errText += "\n" + } + errText += content.Text + } + } + if errText == "" { + errText = "tool returned error status" + } + return "", fmt.Errorf("tool returned error status: %s", errText) } // Combine text content diff --git a/internal/mcp/local_transport.go b/internal/mcp/local_transport.go new file mode 100644 index 0000000..d05aca8 --- /dev/null +++ b/internal/mcp/local_transport.go @@ -0,0 +1,39 @@ +package mcp + +import ( + "context" + "fmt" +) + +// LocalHandler handles MCP JSON-RPC requests in-process. +// +// It enables building built-in MCP servers without HTTP/SSE/Stdio transports. +// LocalHandler implementations must be thread-safe. +type LocalHandler interface { + Handle(ctx context.Context, req *JSONRPCRequest) (*JSONRPCResponse, error) +} + +// LocalTransport is an MCP transport that routes requests to an in-process handler. +// +// It is useful for built-in tools like a safe shell runner. +type LocalTransport struct { + handler LocalHandler +} + +// NewLocalTransport creates a new LocalTransport. +// +// Returns an error if handler is nil. +func NewLocalTransport(handler LocalHandler) (*LocalTransport, error) { + if handler == nil { + return nil, fmt.Errorf("handler cannot be nil") + } + return &LocalTransport{handler: handler}, nil +} + +// Send routes the request to the local handler. +func (t *LocalTransport) Send(ctx context.Context, request *JSONRPCRequest) (*JSONRPCResponse, error) { + return t.handler.Handle(ctx, request) +} + +// Close closes the transport. +func (t *LocalTransport) Close() error { return nil } diff --git a/internal/mcp/local_transport_test.go b/internal/mcp/local_transport_test.go new file mode 100644 index 0000000..41668d9 --- /dev/null +++ b/internal/mcp/local_transport_test.go @@ -0,0 +1,35 @@ +package mcp + +import ( + "context" + "encoding/json" + "testing" +) + +type testHandler struct{} + +func (h testHandler) Handle(ctx context.Context, req *JSONRPCRequest) (*JSONRPCResponse, error) { + payload, _ := json.Marshal(map[string]any{"ok": true, "method": req.Method}) + return &JSONRPCResponse{JSONRPC: "2.0", ID: req.ID, Result: payload}, nil +} + +func TestLocalTransport_Send(t *testing.T) { + tr, err := NewLocalTransport(testHandler{}) + if err != nil { + t.Fatalf("NewLocalTransport() error: %v", err) + } + resp, err := tr.Send(context.Background(), &JSONRPCRequest{JSONRPC: "2.0", ID: "1", Method: "tools/list"}) + if err != nil { + t.Fatalf("Send() error: %v", err) + } + if resp == nil || resp.Result == nil { + t.Fatalf("expected response result") + } +} + +func TestLocalTransport_NilHandler(t *testing.T) { + _, err := NewLocalTransport(nil) + if err == nil { + t.Fatal("expected error for nil handler") + } +} diff --git a/internal/mcp/shell_server.go b/internal/mcp/shell_server.go new file mode 100644 index 0000000..2f64c70 --- /dev/null +++ b/internal/mcp/shell_server.go @@ -0,0 +1,658 @@ +package mcp + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + "strconv" + "strings" + "time" +) + +const ( + defaultShellTimeout = 10 * time.Minute + defaultShellMaxOutput = 256 * 1024 + defaultShellMaxPatch = 512 * 1024 +) + +// ShellServerConfig configures the built-in MCP shell server. +// +// RootDir is used to constrain working directories and patch application. +// Note: Symlinks within RootDir may point outside it; this is not a security boundary. +// DefaultTimeout is used when callers do not provide timeout_seconds. +// MaxOutputBytes bounds stdout/stderr capture for shell_run. +// MaxPatchBytes bounds the size of patches accepted by apply_patch. +type ShellServerConfig struct { + RootDir string + DefaultTimeout time.Duration + MaxOutputBytes int + MaxPatchBytes int +} + +// ShellServer is a built-in MCP server that exposes safe tools: +// - shell_run (whitelisted command runner) +// - apply_patch (git apply for unified diffs) +// +// SECURITY NOTE: ShellServer is designed for trusted repositories only. +// It is NOT a security sandbox. Commands like make/go test can run arbitrary +// code defined in the repo, and symlinks in the repo may point outside RootDir. +// +// It implements LocalHandler. +type ShellServer struct { + rootDir string + defaultTimeout time.Duration + maxOutputBytes int + maxPatchBytes int +} + +// NewShellServer creates a new ShellServer. +func NewShellServer(cfg ShellServerConfig) (*ShellServer, error) { + root := cfg.RootDir + if root == "" { + var err error + root, err = findRepoRoot() + if err != nil { + return nil, err + } + } + root, err := filepath.Abs(root) + if err != nil { + return nil, fmt.Errorf("failed to resolve root dir: %w", err) + } + + if cfg.DefaultTimeout <= 0 { + cfg.DefaultTimeout = defaultShellTimeout + } + if cfg.MaxOutputBytes <= 0 { + cfg.MaxOutputBytes = defaultShellMaxOutput + } + if cfg.MaxPatchBytes <= 0 { + cfg.MaxPatchBytes = defaultShellMaxPatch + } + + return &ShellServer{ + rootDir: root, + defaultTimeout: cfg.DefaultTimeout, + maxOutputBytes: cfg.MaxOutputBytes, + maxPatchBytes: cfg.MaxPatchBytes, + }, nil +} + +func (s *ShellServer) Handle(ctx context.Context, req *JSONRPCRequest) (*JSONRPCResponse, error) { + switch req.Method { + case "initialize": + return s.handleInitialize(req.ID) + case "tools/list": + return s.handleToolsList(req.ID) + case "tools/call": + return s.handleToolsCall(ctx, req.ID, req.Params) + default: + return &JSONRPCResponse{ + JSONRPC: "2.0", + ID: req.ID, + Error: &JSONRPCError{ + Code: ErrCodeMethodNotFound, + Message: "method not found", + }, + }, nil + } +} + +func (s *ShellServer) handleInitialize(id any) (*JSONRPCResponse, error) { + result := MCPInitializeResult{ + ProtocolVersion: "2024-11-05", + Capabilities: map[string]any{}, + ServerInfo: map[string]string{ + "name": "dsgo-shell", + "version": "0.1.0", + }, + } + b, err := json.Marshal(result) + if err != nil { + return nil, err + } + return &JSONRPCResponse{JSONRPC: "2.0", ID: id, Result: b}, nil +} + +func (s *ShellServer) handleToolsList(id any) (*JSONRPCResponse, error) { + schemas := []MCPToolSchema{ + { + Name: "shell_run", + Description: "Run a whitelisted shell command in the repository (make/go test/git read-only)", + InputSchema: MCPInputSchema{ + Type: "object", + Properties: map[string]any{ + "command": map[string]any{ + "type": "string", + "description": "Command to run (e.g. make, go, git)", + }, + "args": map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "string", + }, + "description": "Arguments for the command (e.g. [\"test\", \"./...\"]) ", + }, + "dir": map[string]any{ + "type": "string", + "description": "Working directory relative to repo root", + }, + "timeout_seconds": map[string]any{ + "type": "integer", + "description": "Optional timeout (seconds)", + }, + }, + Required: []string{"command"}, + }, + }, + { + Name: "apply_patch", + Description: "Apply a unified diff patch to the repository using git apply", + InputSchema: MCPInputSchema{ + Type: "object", + Properties: map[string]any{ + "patch": map[string]any{ + "type": "string", + "description": "Unified diff (git apply compatible)", + }, + }, + Required: []string{"patch"}, + }, + }, + } + + payload := MCPListToolsResult{Tools: schemas} + b, err := json.Marshal(payload) + if err != nil { + return nil, err + } + return &JSONRPCResponse{JSONRPC: "2.0", ID: id, Result: b}, nil +} + +func (s *ShellServer) handleToolsCall(ctx context.Context, id any, rawParams json.RawMessage) (*JSONRPCResponse, error) { + var params struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + } + if err := json.Unmarshal(rawParams, ¶ms); err != nil { + return s.toolResultError(id, fmt.Sprintf("invalid params: %v", err)) + } + + switch params.Name { + case "shell_run": + return s.handleShellRun(ctx, id, params.Arguments) + case "apply_patch": + return s.handleApplyPatch(ctx, id, params.Arguments) + default: + return s.toolResultError(id, fmt.Sprintf("unknown tool: %s", params.Name)) + } +} + +func (s *ShellServer) handleShellRun(ctx context.Context, id any, args map[string]any) (*JSONRPCResponse, error) { + cmdName, _ := args["command"].(string) + if strings.TrimSpace(cmdName) == "" { + return s.toolResultError(id, "command is required") + } + + cmdArgs, ok := decodeStringArray(args["args"]) + if !ok { + return s.toolResultError(id, "args must be an array of strings") + } + if strings.Contains(cmdName, " ") { + parts := strings.Fields(cmdName) + cmdName = parts[0] + cmdArgs = append(parts[1:], cmdArgs...) + } + + dirArg, _ := args["dir"].(string) + workDir, err := s.resolveWorkDir(dirArg) + if err != nil { + return s.toolResultError(id, err.Error()) + } + + timeout := s.defaultTimeout + if raw, ok := args["timeout_seconds"]; ok { + sec, err := toInt(raw) + if err != nil { + return s.toolResultError(id, "timeout_seconds must be an integer") + } + if sec > 0 { + timeout = time.Duration(sec) * time.Second + } + } + + cmdName, cmdArgs, err = normalizeAndValidateCommand(cmdName, cmdArgs) + if err != nil { + return s.toolResultError(id, err.Error()) + } + + start := time.Now() + runCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + cmd := exec.CommandContext(runCtx, cmdName, cmdArgs...) + cmd.Dir = workDir + + var stdoutBuf, stderrBuf limitedBuffer + stdoutBuf.MaxBytes = s.maxOutputBytes + stderrBuf.MaxBytes = s.maxOutputBytes + cmd.Stdout = &stdoutBuf + cmd.Stderr = &stderrBuf + + err = cmd.Run() + duration := time.Since(start) + + exitCode := 0 + errorText := "" + timedOut := false + if err != nil { + exitCode = exitCodeFromErr(err) + errorText = err.Error() + timedOut = errors.Is(err, context.DeadlineExceeded) || errors.Is(runCtx.Err(), context.DeadlineExceeded) + } + + result := map[string]any{ + "command": cmdName, + "args": cmdArgs, + "dir": workDir, + "exit_code": exitCode, + "stdout": stdoutBuf.String(), + "stderr": stderrBuf.String(), + "stdout_trunc": stdoutBuf.Truncated, + "stderr_trunc": stderrBuf.Truncated, + "duration_ms": duration.Milliseconds(), + "timeout_seconds": int(timeout.Seconds()), + "timed_out": timedOut, + "error": errorText, + "goos": runtime.GOOS, + "goarch": runtime.GOARCH, + } + b, _ := json.MarshalIndent(result, "", " ") + return s.toolResultText(id, string(b)) +} + +func (s *ShellServer) handleApplyPatch(ctx context.Context, id any, args map[string]any) (*JSONRPCResponse, error) { + patch, _ := args["patch"].(string) + if len(patch) > s.maxPatchBytes { + return s.toolResultError(id, "patch too large") + } + if strings.TrimSpace(patch) == "" { + return s.toolResultError(id, "patch is required") + } + + if err := validateUnifiedDiffPaths(patch); err != nil { + return s.toolResultError(id, err.Error()) + } + + cmd := exec.CommandContext(ctx, "git", "apply", "--whitespace=nowarn", "--") + cmd.Dir = s.rootDir + cmd.Stdin = strings.NewReader(patch) + + var stderr limitedBuffer + stderr.MaxBytes = s.maxOutputBytes + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + msg := "git apply failed" + if stderr.String() != "" { + msg += ": " + stderr.String() + } + return s.toolResultError(id, msg) + } + + return s.toolResultText(id, "patch applied") +} + +func (s *ShellServer) toolResultText(id any, text string) (*JSONRPCResponse, error) { + result := MCPCallToolResult{Content: []MCPContent{{Type: "text", Text: text}}} + b, err := json.Marshal(result) + if err != nil { + return nil, err + } + return &JSONRPCResponse{JSONRPC: "2.0", ID: id, Result: b}, nil +} + +func (s *ShellServer) toolResultError(id any, msg string) (*JSONRPCResponse, error) { + result := MCPCallToolResult{Content: []MCPContent{{Type: "text", Text: msg}}, IsError: true} + b, err := json.Marshal(result) + if err != nil { + return nil, err + } + return &JSONRPCResponse{JSONRPC: "2.0", ID: id, Result: b}, nil +} + +func (s *ShellServer) resolveWorkDir(dirArg string) (string, error) { + if strings.TrimSpace(dirArg) == "" { + return s.rootDir, nil + } + candidate := dirArg + if !filepath.IsAbs(candidate) { + candidate = filepath.Join(s.rootDir, candidate) + } + candidate = filepath.Clean(candidate) + candidate, err := filepath.Abs(candidate) + if err != nil { + return "", fmt.Errorf("invalid dir: %w", err) + } + + if !strings.HasPrefix(candidate, s.rootDir+string(filepath.Separator)) && candidate != s.rootDir { + return "", fmt.Errorf("dir %q is outside repo root", dirArg) + } + return candidate, nil +} + +func findRepoRoot() (string, error) { + cwd, err := os.Getwd() + if err != nil { + return "", err + } + dir := cwd + for { + if st, err := os.Stat(filepath.Join(dir, ".git")); err == nil && st.IsDir() { + if isValidRepoRoot(dir) { + return dir, nil + } + } + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + if isValidRepoRoot(dir) { + return dir, nil + } + } + parent := filepath.Dir(dir) + if parent == dir { + return "", fmt.Errorf("failed to find repo root from %q", cwd) + } + dir = parent + } +} + +func isValidRepoRoot(dir string) bool { + if dir == "/" || dir == "" { + return false + } + if runtime.GOOS == "windows" { + if len(dir) <= 3 { + return false + } + } else { + parts := strings.Split(strings.Trim(dir, "/"), "/") + if len(parts) < 2 { + return false + } + } + systemDirs := []string{"/bin", "/etc", "/lib", "/sbin", "/usr", "/var", "/tmp", "/dev", "/proc", "/sys"} + for _, sys := range systemDirs { + if dir == sys || strings.HasPrefix(dir, sys+"/") { + return false + } + } + return true +} + +type limitedBuffer struct { + MaxBytes int + buf bytes.Buffer + Truncated bool +} + +func (b *limitedBuffer) Write(p []byte) (int, error) { + if b.MaxBytes <= 0 { + return b.buf.Write(p) + } + remaining := b.MaxBytes - b.buf.Len() + if remaining <= 0 { + b.Truncated = true + return len(p), nil + } + if len(p) > remaining { + _, _ = b.buf.Write(p[:remaining]) + b.Truncated = true + return len(p), nil + } + return b.buf.Write(p) +} + +func (b *limitedBuffer) String() string { return b.buf.String() } + +func decodeStringArray(v any) ([]string, bool) { + if v == nil { + return nil, true + } + switch arr := v.(type) { + case []any: + out := make([]string, 0, len(arr)) + for _, item := range arr { + s, ok := item.(string) + if !ok { + return nil, false + } + out = append(out, s) + } + return out, true + case []string: + return append([]string(nil), arr...), true + default: + return nil, false + } +} + +func toInt(v any) (int, error) { + switch t := v.(type) { + case int: + return t, nil + case int64: + return int(t), nil + case float64: + return int(t), nil + case string: + i, err := strconv.Atoi(t) + if err != nil { + return 0, err + } + return i, nil + default: + return 0, errors.New("not an int") + } +} + +func exitCodeFromErr(err error) int { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + return exitErr.ExitCode() + } + if errors.Is(err, context.DeadlineExceeded) { + return 124 + } + return 1 +} + +func normalizeAndValidateCommand(command string, args []string) (string, []string, error) { + command = strings.TrimSpace(command) + if command == "" { + return "", nil, fmt.Errorf("command is empty") + } + + switch command { + case "make": + if err := validateMakeArgs(args); err != nil { + return "", nil, err + } + return command, args, nil + case "go": + if err := validateGoArgs(args); err != nil { + return "", nil, err + } + return command, args, nil + case "git": + if err := validateGitArgs(args); err != nil { + return "", nil, err + } + return command, args, nil + default: + return "", nil, fmt.Errorf("command not allowed: %s", command) + } +} + +func validateMakeArgs(args []string) error { + for i := 0; i < len(args); i++ { + arg := args[i] + if arg == "-f" || strings.HasPrefix(arg, "-f") || arg == "--file" || strings.HasPrefix(arg, "--file") { + return fmt.Errorf("make file flags are not allowed") + } + if arg == "-C" || strings.HasPrefix(arg, "-C") || arg == "--directory" || strings.HasPrefix(arg, "--directory") { + return fmt.Errorf("make directory flags are not allowed") + } + if strings.Contains(arg, "=") { + return fmt.Errorf("make variable assignments are not allowed") + } + } + return nil +} + +func validateGoArgs(args []string) error { + if len(args) == 0 { + return fmt.Errorf("go subcommand is required") + } + if args[0] != "test" { + return fmt.Errorf("go only allows 'test'") + } + for i := 1; i < len(args); i++ { + arg := args[i] + switch { + case arg == "-c" || strings.HasPrefix(arg, "-c="): + return fmt.Errorf("go test -c is not allowed") + case arg == "-o" || strings.HasPrefix(arg, "-o="): + return fmt.Errorf("go test output flags are not allowed") + case arg == "-coverprofile" || strings.HasPrefix(arg, "-coverprofile="): + return fmt.Errorf("go test coverprofile is not allowed") + case arg == "-cpuprofile" || strings.HasPrefix(arg, "-cpuprofile="): + return fmt.Errorf("go test cpuprofile is not allowed") + case arg == "-memprofile" || strings.HasPrefix(arg, "-memprofile="): + return fmt.Errorf("go test memprofile is not allowed") + case arg == "-trace" || strings.HasPrefix(arg, "-trace="): + return fmt.Errorf("go test trace is not allowed") + } + } + return nil +} + +func validateGitArgs(args []string) error { + if len(args) == 0 { + return fmt.Errorf("git subcommand is required") + } + + // Allow a small set of safe global options, then require a read-only subcommand. + for i := 0; i < len(args); i++ { + arg := args[i] + + if arg == "--" { + args = args[i+1:] + break + } + + if !strings.HasPrefix(arg, "-") { + args = args[i:] + break + } + + switch { + case arg == "--no-pager": + continue + case arg == "-c" || strings.HasPrefix(arg, "-c") || arg == "--config": + return fmt.Errorf("git config options are not allowed") + case arg == "-C" || strings.HasPrefix(arg, "-C"): + return fmt.Errorf("git -C is not allowed") + case arg == "--git-dir" || strings.HasPrefix(arg, "--git-dir"): + return fmt.Errorf("git-dir overrides are not allowed") + case arg == "--work-tree" || strings.HasPrefix(arg, "--work-tree"): + return fmt.Errorf("work-tree overrides are not allowed") + case arg == "-o" || strings.HasPrefix(arg, "-o") || arg == "--output" || strings.HasPrefix(arg, "--output"): + return fmt.Errorf("git output flags are not allowed") + default: + return fmt.Errorf("git global option not allowed: %s", arg) + } + } + + if len(args) == 0 { + return fmt.Errorf("git subcommand missing") + } + + sub := args[0] + allowed := map[string]bool{ + "status": true, + "diff": true, + "log": true, + "show": true, + "ls-files": true, + "rev-parse": true, + } + if !allowed[sub] { + return fmt.Errorf("git subcommand not allowed: %s", sub) + } + return nil +} + +func validateUnifiedDiffPaths(patch string) error { + // Validate common diff headers. + // Allow /dev/null and a/ b/ prefixes. + lines := strings.Split(patch, "\n") + for _, line := range lines { + var path string + switch { + case strings.HasPrefix(line, "diff --git "): + parts := strings.Fields(line) + if len(parts) >= 4 { + // diff --git a/foo b/foo + if err := validateDiffPath(parts[2]); err != nil { + return err + } + if err := validateDiffPath(parts[3]); err != nil { + return err + } + } + continue + case strings.HasPrefix(line, "--- "): + path = strings.TrimSpace(strings.TrimPrefix(line, "--- ")) + case strings.HasPrefix(line, "+++ "): + path = strings.TrimSpace(strings.TrimPrefix(line, "+++ ")) + default: + continue + } + if path == "/dev/null" { + continue + } + if err := validateDiffPath(path); err != nil { + return err + } + } + return nil +} + +func validateDiffPath(p string) error { + if p == "" { + return nil + } + if strings.HasPrefix(p, "a/") || strings.HasPrefix(p, "b/") { + p = p[2:] + } + p = filepath.Clean(p) + if filepath.IsAbs(p) { + return fmt.Errorf("absolute paths are not allowed in patches") + } + if strings.HasPrefix(p, ".."+string(filepath.Separator)) || p == ".." { + return fmt.Errorf("parent traversal paths are not allowed in patches") + } + if strings.Contains(p, ".."+string(filepath.Separator)) { + return fmt.Errorf("parent traversal paths are not allowed in patches") + } + if strings.HasPrefix(p, ".git"+string(filepath.Separator)) || p == ".git" { + return fmt.Errorf("modifying .git directory is not allowed") + } + return nil +} diff --git a/internal/mcp/shell_server_test.go b/internal/mcp/shell_server_test.go new file mode 100644 index 0000000..f945f2b --- /dev/null +++ b/internal/mcp/shell_server_test.go @@ -0,0 +1,137 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + "testing" +) + +func TestShellServer_ToolsList(t *testing.T) { + server, err := NewShellServer(ShellServerConfig{RootDir: t.TempDir()}) + if err != nil { + t.Fatalf("NewShellServer() error: %v", err) + } + + resp, err := server.Handle(context.Background(), &JSONRPCRequest{JSONRPC: "2.0", ID: "1", Method: "tools/list"}) + if err != nil { + t.Fatalf("Handle() error: %v", err) + } + if resp == nil || resp.Result == nil { + t.Fatalf("expected result") + } + var list MCPListToolsResult + if err := json.Unmarshal(resp.Result, &list); err != nil { + t.Fatalf("unmarshal tools/list: %v", err) + } + if len(list.Tools) < 2 { + t.Fatalf("expected at least 2 tools, got %d", len(list.Tools)) + } +} + +func TestShellServer_ShellRun_DeniesMakeFileFlag(t *testing.T) { + server, err := NewShellServer(ShellServerConfig{RootDir: t.TempDir()}) + if err != nil { + t.Fatalf("NewShellServer() error: %v", err) + } + + params, _ := json.Marshal(map[string]any{ + "name": "shell_run", + "arguments": map[string]any{ + "command": "make", + "args": []any{"-f", "Makefile"}, + }, + }) + resp, err := server.Handle(context.Background(), &JSONRPCRequest{JSONRPC: "2.0", ID: "1", Method: "tools/call", Params: params}) + if err != nil { + t.Fatalf("Handle() error: %v", err) + } + var result MCPCallToolResult + if err := json.Unmarshal(resp.Result, &result); err != nil { + t.Fatalf("unmarshal result: %v", err) + } + if !result.IsError { + t.Fatalf("expected isError") + } +} + +func TestShellServer_ApplyPatch_RejectsTraversal(t *testing.T) { + server, err := NewShellServer(ShellServerConfig{RootDir: t.TempDir()}) + if err != nil { + t.Fatalf("NewShellServer() error: %v", err) + } + + patch := "diff --git a/../x b/../x\nnew file mode 100644\nindex 0000000..e69de29\n--- /dev/null\n+++ b/../x\n@@ -0,0 +1 @@\n+hi\n" + params, _ := json.Marshal(map[string]any{ + "name": "apply_patch", + "arguments": map[string]any{ + "patch": patch, + }, + }) + resp, err := server.Handle(context.Background(), &JSONRPCRequest{JSONRPC: "2.0", ID: "1", Method: "tools/call", Params: params}) + if err != nil { + t.Fatalf("Handle() error: %v", err) + } + var result MCPCallToolResult + if err := json.Unmarshal(resp.Result, &result); err != nil { + t.Fatalf("unmarshal result: %v", err) + } + if !result.IsError { + t.Fatalf("expected isError") + } +} + +func TestShellServer_ApplyPatch_AllowsNewFileInGitRepo(t *testing.T) { + root := t.TempDir() + + if err := runGit(t, root, "init"); err != nil { + t.Fatalf("git init: %v", err) + } + + server, err := NewShellServer(ShellServerConfig{RootDir: root}) + if err != nil { + t.Fatalf("NewShellServer() error: %v", err) + } + + patch := "diff --git a/newfile.txt b/newfile.txt\nnew file mode 100644\nindex 0000000..9daeafb\n--- /dev/null\n+++ b/newfile.txt\n@@ -0,0 +1 @@\n+hello\n" + params, _ := json.Marshal(map[string]any{ + "name": "apply_patch", + "arguments": map[string]any{ + "patch": patch, + }, + }) + resp, err := server.Handle(context.Background(), &JSONRPCRequest{JSONRPC: "2.0", ID: "1", Method: "tools/call", Params: params}) + if err != nil { + t.Fatalf("Handle() error: %v", err) + } + var result MCPCallToolResult + if err := json.Unmarshal(resp.Result, &result); err != nil { + t.Fatalf("unmarshal result: %v", err) + } + if result.IsError { + t.Fatalf("expected success, got error: %s", result.Content[0].Text) + } + + if _, err := os.Stat(filepath.Join(root, "newfile.txt")); err != nil { + t.Fatalf("expected newfile.txt to exist: %v", err) + } +} + +func runGit(t *testing.T, dir string, args ...string) error { + t.Helper() + cmd := execCommand("git", args...) + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("git %v failed: %v: %s", args, err, string(out)) + } + return nil +} + +// execCommand exists to avoid linter complaints about shadowing. +func execCommand(name string, args ...string) *exec.Cmd { + return exec.Command(name, args...) +} diff --git a/internal/mcp/transport.go b/internal/mcp/transport.go index a639cb6..093386b 100644 --- a/internal/mcp/transport.go +++ b/internal/mcp/transport.go @@ -11,6 +11,7 @@ import ( "net" "net/http" "net/url" + "os" "os/exec" "strings" "sync" @@ -33,11 +34,22 @@ type HTTPTransport struct { sessionID string } +const defaultHTTPTimeout = 300 * time.Second + // NewHTTPTransport creates a new HTTPTransport. func NewHTTPTransport(url string, apiKey string) *HTTPTransport { + return NewHTTPTransportWithTimeout(url, apiKey, defaultHTTPTimeout) +} + +// NewHTTPTransportWithTimeout creates a new HTTPTransport with a custom client timeout. +// If timeout <= 0, the default timeout is used. +func NewHTTPTransportWithTimeout(url string, apiKey string, timeout time.Duration) *HTTPTransport { + if timeout <= 0 { + timeout = defaultHTTPTimeout + } return &HTTPTransport{ url: url, - client: &http.Client{Timeout: 300 * time.Second}, + client: &http.Client{Timeout: timeout}, apiKey: apiKey, sessionID: fmt.Sprintf("sess_%d", time.Now().UnixNano()), } @@ -55,6 +67,11 @@ func (t *HTTPTransport) Send(ctx context.Context, request *JSONRPCRequest) (*JSO var lastErr error for attempt := 1; attempt <= maxAttempts; attempt++ { + // Short-circuit if context is already cancelled + if ctx.Err() != nil { + return nil, ctx.Err() + } + req, err := http.NewRequestWithContext(ctx, "POST", t.url, bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) @@ -172,10 +189,16 @@ func decodeJSONRPCResponse(contentType string, body io.Reader) (*JSONRPCResponse // SSETransport implements Transport over SSE. type SSETransport struct { - sseURL string - postURL string - client *http.Client - apiKey string + sseURL string + postURL string + client *http.Client + apiKey string + + // Timeouts for the POST side of SSE MCP. + // These are intentionally separate from caller context since SSE streams can be long-lived. + postTimeout time.Duration + responseTimeout time.Duration + running atomic.Bool mu sync.Mutex pending map[any]chan *JSONRPCResponse @@ -186,13 +209,28 @@ type SSETransport struct { // NewSSETransport creates a new SSETransport. func NewSSETransport(url string, apiKey string) *SSETransport { + return NewSSETransportWithTimeouts(url, apiKey, defaultHTTPTimeout, defaultHTTPTimeout) +} + +// NewSSETransportWithTimeouts creates a new SSETransport with custom POST and response timeouts. +// If either timeout <= 0, the default is used. +func NewSSETransportWithTimeouts(url string, apiKey string, postTimeout time.Duration, responseTimeout time.Duration) *SSETransport { + if postTimeout <= 0 { + postTimeout = defaultHTTPTimeout + } + if responseTimeout <= 0 { + responseTimeout = defaultHTTPTimeout + } + t := &SSETransport{ - sseURL: url, - client: &http.Client{Timeout: 0}, // No timeout for SSE stream - apiKey: apiKey, - pending: make(map[any]chan *JSONRPCResponse), - stopCh: make(chan struct{}), - initCh: make(chan struct{}), + sseURL: url, + client: &http.Client{Timeout: 0}, // No timeout for SSE stream + apiKey: apiKey, + postTimeout: postTimeout, + responseTimeout: responseTimeout, + pending: make(map[any]chan *JSONRPCResponse), + stopCh: make(chan struct{}), + initCh: make(chan struct{}), } t.running.Store(true) return t @@ -282,6 +320,7 @@ func (t *SSETransport) Send(ctx context.Context, request *JSONRPCRequest) (*JSON t.pending[request.ID] = respCh t.mu.Unlock() + // Ensure cleanup of pending map entry on all return paths (success, error, timeout) defer func() { t.mu.Lock() delete(t.pending, request.ID) @@ -293,9 +332,13 @@ func (t *SSETransport) Send(ctx context.Context, request *JSONRPCRequest) (*JSON return nil, fmt.Errorf("failed to marshal request: %w", err) } - // Use a separate timeout for the POST request (300 seconds) - // Don't use the parent context directly as it may have a shorter timeout - postCtx, postCancel := context.WithTimeout(context.Background(), 300*time.Second) + // Use a separate timeout for the POST request. + // Don't use the parent context directly as it may have a shorter timeout. + postTimeout := t.postTimeout + if postTimeout <= 0 { + postTimeout = defaultHTTPTimeout + } + postCtx, postCancel := context.WithTimeout(context.Background(), postTimeout) defer postCancel() req, err := http.NewRequestWithContext(postCtx, "POST", t.postURL, bytes.NewReader(body)) @@ -331,7 +374,12 @@ func (t *SSETransport) Send(ctx context.Context, request *JSONRPCRequest) (*JSON return nil, fmt.Errorf("received nil response - transport may have closed") } return resp, nil - case <-time.After(300 * time.Second): + case <-time.After(func() time.Duration { + if t.responseTimeout > 0 { + return t.responseTimeout + } + return defaultHTTPTimeout + }()): return nil, fmt.Errorf("timeout waiting for response to request %v", request.ID) } } @@ -382,7 +430,7 @@ func (t *SSETransport) readLoop(body io.ReadCloser) { u, err := url.Parse(t.sseURL) if err != nil { // Fallback to simple string manipulation if parsing fails - fmt.Printf("Warning: failed to parse SSE URL %q: %v\n", t.sseURL, err) + fmt.Fprintf(os.Stderr, "dsgo: warning: failed to parse SSE URL %q: %v\n", t.sseURL, err) } else { // Construct new URL t.postURL = u.Scheme + "://" + u.Host + endpoint diff --git a/internal/retry/retry.go b/internal/retry/retry.go index e67175f..9b570d8 100644 --- a/internal/retry/retry.go +++ b/internal/retry/retry.go @@ -58,6 +58,27 @@ func NewOptions(maxRetries int, initialBackoff, maxBackoff time.Duration, jitter } } +// MergeFrom applies non-zero values from the provided overrides to this Options. +// This allows partial configuration where callers only override specific fields. +// +// NOTE: This method mutates the receiver. It is NOT safe for concurrent use on +// a shared *Options. Callers should use it only on freshly created instances +// (e.g., from DefaultOptions() or Copy()) before sharing. +func (o *Options) MergeFrom(maxRetries int, initialBackoff, maxBackoff time.Duration, jitterFactor float64) { + if maxRetries > 0 { + o.MaxRetries = maxRetries + } + if initialBackoff > 0 { + o.InitialBackoff = initialBackoff + } + if maxBackoff > 0 { + o.MaxBackoff = maxBackoff + } + if jitterFactor > 0 { + o.JitterFactor = jitterFactor + } +} + func IsRetryable(statusCode int) bool { return statusCode == http.StatusTooManyRequests || // 429 statusCode == http.StatusInternalServerError || // 500 @@ -138,17 +159,16 @@ func WithExponentialBackoffOpts(ctx context.Context, fn HTTPFunc, opts *Options) return resp, nil } -func calculateBackoff(attempt int) time.Duration { - return calculateBackoffWithOpts(attempt, nil) +func calculateBackoffWithOpts(attempt int, opts *Options) time.Duration { + return calculateBackoff(attempt, opts) } -func calculateBackoffWithOpts(attempt int, opts *Options) time.Duration { +func calculateBackoff(attempt int, opts *Options) time.Duration { if opts == nil { opts = DefaultOptions() } backoff := float64(opts.InitialBackoff) * math.Pow(2, float64(attempt)) - if backoff > float64(opts.MaxBackoff) { backoff = float64(opts.MaxBackoff) } diff --git a/internal/retry/retry_test.go b/internal/retry/retry_test.go index 7b43ce2..20f87b8 100644 --- a/internal/retry/retry_test.go +++ b/internal/retry/retry_test.go @@ -222,7 +222,7 @@ func TestCalculateBackoff(t *testing.T) { tt := tt t.Run("", func(t *testing.T) { t.Parallel() - backoff := calculateBackoff(tt.attempt) + backoff := calculateBackoff(tt.attempt, nil) if backoff < tt.minExpected || backoff > tt.maxExpected { t.Errorf("calculateBackoff(%d) = %v, want between %v and %v", tt.attempt, backoff, tt.minExpected, tt.maxExpected)