Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
"github.com/smart-mcp-proxy/mcpproxy-go/internal/secureenv"
"os"
"strings"
"time"

"github.com/smart-mcp-proxy/mcpproxy-go/internal/secureenv"
)

const (
Expand Down Expand Up @@ -942,8 +944,10 @@ func (c *Config) ValidateDetailed() []ValidationError {
// ClientSecret can be a secret reference, so we don't validate it as empty.
}

// Validate DataDir exists (if specified and not empty)
if c.DataDir != "" {
// Validate DataDir exists (if specified and not empty).
// Skip validation if the path still contains unresolved ${...} refs —
// it will be resolved at a later point or the user will fix the env var.
if c.DataDir != "" && !strings.Contains(c.DataDir, "${") {
if _, err := os.Stat(c.DataDir); os.IsNotExist(err) {
errors = append(errors, ValidationError{
Field: "data_dir",
Expand Down
2 changes: 1 addition & 1 deletion internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,7 @@ func TestLoadConfig_DataDirExpandFailure(t *testing.T) {
cfgFile := filepath.Join(t.TempDir(), "config.json")
// DataDir contains an unresolvable ref; the literal path lives inside tmpBase
// so any directory MkdirAll creates is cleaned up automatically.
cfgData := fmt.Sprintf(`{"data_dir": "%s/${env:%s}"}`, tmpBase, missingVar)
cfgData := fmt.Sprintf(`{"data_dir": "%s/${env:%s}"}`, filepath.ToSlash(tmpBase), missingVar)
require.NoError(t, os.WriteFile(cfgFile, []byte(cfgData), 0600))

// LoadFromFile must succeed even when expansion fails — warn + retain original.
Expand Down
20 changes: 14 additions & 6 deletions internal/config/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@ func LoadFromFile(configPath string) (*Config, error) {
// Expand secret/env refs in DataDir before creating it
expandDataDir(cfg)

// Create data directory if it doesn't exist
if err := os.MkdirAll(cfg.DataDir, 0700); err != nil {
return nil, fmt.Errorf("failed to create data directory %s: %w", cfg.DataDir, err)
// Create data directory if it doesn't exist.
// Skip if the path still contains unresolved ${...} refs (e.g., missing env var) —
// these are invalid path characters on Windows and the directory can't be created anyway.
if !strings.Contains(cfg.DataDir, "${") {
if err := os.MkdirAll(cfg.DataDir, 0700); err != nil {
return nil, fmt.Errorf("failed to create data directory %s: %w", cfg.DataDir, err)
}
}

// Apply environment variable overrides for TLS configuration
Expand Down Expand Up @@ -131,9 +135,13 @@ func Load() (*Config, error) {
// Expand secret/env refs in DataDir before creating it
expandDataDir(cfg)

// Create data directory if it doesn't exist
if err := os.MkdirAll(cfg.DataDir, 0700); err != nil {
return nil, fmt.Errorf("failed to create data directory %s: %w", cfg.DataDir, err)
// Create data directory if it doesn't exist.
// Skip if the path still contains unresolved ${...} refs (e.g., missing env var) —
// these are invalid path characters on Windows and the directory can't be created anyway.
if !strings.Contains(cfg.DataDir, "${") {
if err := os.MkdirAll(cfg.DataDir, 0700); err != nil {
return nil, fmt.Errorf("failed to create data directory %s: %w", cfg.DataDir, err)
}
}

// Parse upstream servers from CLI
Expand Down
43 changes: 35 additions & 8 deletions internal/server/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -844,8 +844,23 @@ func (p *MCPProxyServer) handleListRegistries(ctx context.Context, _ mcp.CallToo
return mcp.NewToolResultText(string(jsonResult)), nil
}

// handleRetrieveTools implements the retrieve_tools functionality
// handleRetrieveToolsForMode returns a handler closure with the routing mode baked in.
// This allows the retrieve_tools handler to adapt its usage_instructions based on
// whether it's being used in code_execution mode or retrieve_tools (call_tool) mode.
func (p *MCPProxyServer) handleRetrieveToolsForMode(routingMode string) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return p.handleRetrieveToolsWithMode(ctx, request, routingMode)
}
}

// handleRetrieveTools implements the retrieve_tools functionality with default (retrieve_tools) mode.
// This is a backward-compatible wrapper for callers that don't need mode-specific instructions.
func (p *MCPProxyServer) handleRetrieveTools(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return p.handleRetrieveToolsWithMode(ctx, request, "")
}

// handleRetrieveToolsWithMode implements the retrieve_tools functionality with mode-aware instructions.
func (p *MCPProxyServer) handleRetrieveToolsWithMode(ctx context.Context, request mcp.CallToolRequest, routingMode string) (*mcp.CallToolResult, error) {
startTime := time.Now()

// Extract session info for activity logging (Spec 024)
Expand Down Expand Up @@ -980,17 +995,29 @@ func (p *MCPProxyServer) handleRetrieveTools(ctx context.Context, request mcp.Ca
mcpTools = append(mcpTools, mcpTool)
}

response := map[string]interface{}{
"tools": mcpTools,
"query": query,
"total": len(results),
// Add usage instructions for intent-based tool calling (Spec 018)
"usage_instructions": "TOOL SELECTION GUIDE: Check the 'call_with' field for each tool, then use the matching tool variant. " +
// Build mode-aware usage instructions
var usageInstructions string
switch routingMode {
case config.RoutingModeCodeExecution:
usageInstructions = "TOOL CALLING GUIDE: Use the `code_execution` tool to call any discovered tool via JavaScript: " +
"call_tool(serverName, toolName, args). Example: call_tool('github', 'create_issue', {title: 'Bug fix'}). " +
"The 'call_with' field indicates each tool's permission tier (read/write/destructive) for your reference. " +
"Do NOT use call_tool_read, call_tool_write, or call_tool_destructive — they are not available in code execution mode."
default:
// Default instructions for retrieve_tools mode and backward-compatible callers
usageInstructions = "TOOL SELECTION GUIDE: Check the 'call_with' field for each tool, then use the matching tool variant. " +
"DECISION RULES BY TOOL NAME: " +
"(1) READ (call_tool_read): search, query, list, get, fetch, find, check, view, read, show, describe, lookup, retrieve, browse, explore, discover, scan, inspect, analyze, examine, validate, verify. DEFAULT choice when unsure. " +
"(2) WRITE (call_tool_write): create, update, modify, add, set, send, edit, change, write, post, put, patch, insert, upload, submit, assign, configure, enable, register, subscribe, publish, move, copy, rename, merge. " +
"(3) DESTRUCTIVE (call_tool_destructive): delete, remove, drop, revoke, disable, destroy, purge, reset, clear, unsubscribe, cancel, terminate, close, archive, ban, block, disconnect, kill, wipe, truncate, force, hard. " +
"INTENT TRACKING: Always provide intent_reason (why you're calling this tool) and intent_data_sensitivity (public/internal/private/unknown) to enable activity auditing.",
"INTENT TRACKING: Always provide intent_reason (why you're calling this tool) and intent_data_sensitivity (public/internal/private/unknown) to enable activity auditing."
}

response := map[string]interface{}{
"tools": mcpTools,
"query": query,
"total": len(results),
"usage_instructions": usageInstructions,
}

// Add debug information if requested
Expand Down
4 changes: 2 additions & 2 deletions internal/server/mcp_routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ func (p *MCPProxyServer) buildCodeExecModeTools() []mcpserver.ServerTool {
)
tools = append(tools, mcpserver.ServerTool{
Tool: retrieveToolsTool,
Handler: p.handleRetrieveTools,
Handler: p.handleRetrieveToolsForMode(config.RoutingModeCodeExecution),
})

// Add management tools (upstream_servers, quarantine, registries)
Expand Down Expand Up @@ -273,7 +273,7 @@ func (p *MCPProxyServer) buildCallToolModeTools() []mcpserver.ServerTool {
)
tools = append(tools, mcpserver.ServerTool{
Tool: retrieveToolsTool,
Handler: p.handleRetrieveTools,
Handler: p.handleRetrieveToolsForMode(config.RoutingModeRetrieveTools),
})

// call_tool_read
Expand Down
154 changes: 154 additions & 0 deletions internal/server/mcp_routing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package server

import (
"context"
"encoding/json"
"testing"

"github.com/mark3labs/mcp-go/mcp"
Expand Down Expand Up @@ -390,3 +391,156 @@ func TestDirectModeHandler_NoAuthContext(t *testing.T) {
handler(context.Background(), request)
}()
}

// TestRetrieveToolsInstructions_CodeExecutionMode verifies that handleRetrieveToolsWithMode
// returns code_execution-specific usage_instructions when called with RoutingModeCodeExecution.
func TestRetrieveToolsInstructions_CodeExecutionMode(t *testing.T) {
proxy := createTestMCPProxyServer(t)

request := mcp.CallToolRequest{}
request.Params.Name = "retrieve_tools"
request.Params.Arguments = map[string]interface{}{
"query": "test query",
}

result, err := proxy.handleRetrieveToolsWithMode(context.Background(), request, config.RoutingModeCodeExecution)
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.IsError)

// Parse the response JSON to extract usage_instructions
responseText := result.Content[0].(mcp.TextContent).Text
var response map[string]interface{}
err = json.Unmarshal([]byte(responseText), &response)
require.NoError(t, err, "response should be valid JSON")

instructions, ok := response["usage_instructions"].(string)
require.True(t, ok, "usage_instructions should be a string")

// Code execution mode: should mention code_execution and call_tool()
assert.Contains(t, instructions, "code_execution",
"code_execution mode should mention 'code_execution' tool")
assert.Contains(t, instructions, "call_tool(",
"code_execution mode should mention call_tool() JavaScript function")

// Code execution mode: should NOT recommend call_tool_read/write/destructive as tools to use.
// Note: the instructions may mention them in a "Do NOT use" warning, which is acceptable.
// What they must NOT contain is the retrieve_tools-mode decision rules that tell the LLM
// to use these as tool variants.
assert.NotContains(t, instructions, "DECISION RULES BY TOOL NAME",
"code_execution mode should NOT contain call_tool variant decision rules")
assert.NotContains(t, instructions, "(1) READ (call_tool_read)",
"code_execution mode should NOT recommend call_tool_read as a variant")
assert.NotContains(t, instructions, "(2) WRITE (call_tool_write)",
"code_execution mode should NOT recommend call_tool_write as a variant")
assert.NotContains(t, instructions, "(3) DESTRUCTIVE (call_tool_destructive)",
"code_execution mode should NOT recommend call_tool_destructive as a variant")
}

// TestRetrieveToolsInstructions_RetrieveToolsMode verifies that handleRetrieveToolsWithMode
// returns call_tool_*-specific usage_instructions when called with RoutingModeRetrieveTools.
func TestRetrieveToolsInstructions_RetrieveToolsMode(t *testing.T) {
proxy := createTestMCPProxyServer(t)

request := mcp.CallToolRequest{}
request.Params.Name = "retrieve_tools"
request.Params.Arguments = map[string]interface{}{
"query": "test query",
}

result, err := proxy.handleRetrieveToolsWithMode(context.Background(), request, config.RoutingModeRetrieveTools)
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.IsError)

// Parse the response JSON to extract usage_instructions
responseText := result.Content[0].(mcp.TextContent).Text
var response map[string]interface{}
err = json.Unmarshal([]byte(responseText), &response)
require.NoError(t, err, "response should be valid JSON")

instructions, ok := response["usage_instructions"].(string)
require.True(t, ok, "usage_instructions should be a string")

// Retrieve tools mode: should mention call_tool_read/write/destructive
assert.Contains(t, instructions, "call_tool_read",
"retrieve_tools mode should mention call_tool_read")
assert.Contains(t, instructions, "call_tool_write",
"retrieve_tools mode should mention call_tool_write")
assert.Contains(t, instructions, "call_tool_destructive",
"retrieve_tools mode should mention call_tool_destructive")
assert.Contains(t, instructions, "INTENT TRACKING",
"retrieve_tools mode should mention intent tracking")
}

// TestRetrieveToolsInstructions_DefaultMode verifies that handleRetrieveToolsWithMode
// with empty routingMode returns the same instructions as retrieve_tools mode.
func TestRetrieveToolsInstructions_DefaultMode(t *testing.T) {
proxy := createTestMCPProxyServer(t)

request := mcp.CallToolRequest{}
request.Params.Name = "retrieve_tools"
request.Params.Arguments = map[string]interface{}{
"query": "test query",
}

result, err := proxy.handleRetrieveToolsWithMode(context.Background(), request, "")
require.NoError(t, err)
require.NotNil(t, result)
require.False(t, result.IsError)

// Parse the response JSON to extract usage_instructions
responseText := result.Content[0].(mcp.TextContent).Text
var response map[string]interface{}
err = json.Unmarshal([]byte(responseText), &response)
require.NoError(t, err, "response should be valid JSON")

instructions, ok := response["usage_instructions"].(string)
require.True(t, ok, "usage_instructions should be a string")

// Default mode should use the same instructions as retrieve_tools mode
assert.Contains(t, instructions, "call_tool_read",
"default mode should contain call_tool_read instructions")
assert.Contains(t, instructions, "call_tool_write",
"default mode should contain call_tool_write instructions")
}

// TestHandleRetrieveToolsForMode_ClosureReturnsDifferentInstructions verifies that
// handleRetrieveToolsForMode creates closures that produce different instructions per mode.
func TestHandleRetrieveToolsForMode_ClosureReturnsDifferentInstructions(t *testing.T) {
proxy := createTestMCPProxyServer(t)

request := mcp.CallToolRequest{}
request.Params.Name = "retrieve_tools"
request.Params.Arguments = map[string]interface{}{
"query": "search for tools",
}

// Get handler for code_execution mode
codeExecHandler := proxy.handleRetrieveToolsForMode(config.RoutingModeCodeExecution)
codeExecResult, err := codeExecHandler(context.Background(), request)
require.NoError(t, err)

// Get handler for retrieve_tools mode
retrieveHandler := proxy.handleRetrieveToolsForMode(config.RoutingModeRetrieveTools)
retrieveResult, err := retrieveHandler(context.Background(), request)
require.NoError(t, err)

// Parse both results
var codeExecResponse, retrieveResponse map[string]interface{}
err = json.Unmarshal([]byte(codeExecResult.Content[0].(mcp.TextContent).Text), &codeExecResponse)
require.NoError(t, err)
err = json.Unmarshal([]byte(retrieveResult.Content[0].(mcp.TextContent).Text), &retrieveResponse)
require.NoError(t, err)

codeExecInstructions := codeExecResponse["usage_instructions"].(string)
retrieveInstructions := retrieveResponse["usage_instructions"].(string)

// They should be different
assert.NotEqual(t, codeExecInstructions, retrieveInstructions,
"code_execution and retrieve_tools modes should produce different usage_instructions")

// Code exec should mention code_execution, retrieve should mention call_tool_read
assert.Contains(t, codeExecInstructions, "code_execution")
assert.Contains(t, retrieveInstructions, "call_tool_read")
}
Loading