Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package config

import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
Expand Down Expand Up @@ -1060,3 +1061,42 @@ func TestServerConfig_SkipQuarantine_JSONSerialization(t *testing.T) {
func boolPtr(b bool) *bool {
return &b
}

// --- T011: DataDir secret-ref expansion in LoadFromFile ---

// TestLoadConfig_ExpandsDataDir verifies that ${env:...} refs in data_dir are resolved
// before MkdirAll / Validate() run, so the database opens at the resolved path (US3).
func TestLoadConfig_ExpandsDataDir(t *testing.T) {
resolvedDir := t.TempDir()
t.Setenv("TEST_MCPPROXY_EXPAND_DATA_DIR", resolvedDir)

cfgFile := filepath.Join(t.TempDir(), "config.json")
cfgData := `{"data_dir": "${env:TEST_MCPPROXY_EXPAND_DATA_DIR}"}`
require.NoError(t, os.WriteFile(cfgFile, []byte(cfgData), 0600))

cfg, err := LoadFromFile(cfgFile)
require.NoError(t, err)
assert.Equal(t, resolvedDir, cfg.DataDir)
}

// TestLoadConfig_DataDirExpandFailure verifies that when the env var in data_dir is
// missing, LoadFromFile warns and retains the original unresolved reference rather
// than returning an error (US3 robustness requirement).
func TestLoadConfig_DataDirExpandFailure(t *testing.T) {
// Use a unique name that is almost certainly not set in any environment.
const missingVar = "TEST_MCPPROXY_MISSING_DATA_DIR_XYZ_9876"
os.Unsetenv(missingVar) //nolint:errcheck

tmpBase := t.TempDir()
cfgFile := filepath.Join(t.TempDir(), "config.json")
// DataDir contains an unresolvable ref; the literal path lives inside tmpBase
// so any directory MkdirAll creates is cleaned up automatically.
cfgData := fmt.Sprintf(`{"data_dir": "%s/${env:%s}"}`, tmpBase, missingVar)
require.NoError(t, os.WriteFile(cfgFile, []byte(cfgData), 0600))

// LoadFromFile must succeed even when expansion fails — warn + retain original.
cfg, err := LoadFromFile(cfgFile)
require.NoError(t, err)
assert.Contains(t, cfg.DataDir, fmt.Sprintf("${env:%s}", missingVar),
"original unresolved ref should be retained when expansion fails")
}
23 changes: 23 additions & 0 deletions internal/config/loader.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package config

import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
Expand All @@ -10,6 +11,7 @@ import (
"strings"
"time"

"github.com/smart-mcp-proxy/mcpproxy-go/internal/secret"
"github.com/spf13/viper"
)

Expand Down Expand Up @@ -38,6 +40,9 @@ func LoadFromFile(configPath string) (*Config, error) {
cfg.DataDir = filepath.Join(homeDir, DefaultDataDir)
}

// Expand secret/env refs in DataDir before creating it
expandDataDir(cfg)

// Create data directory if it doesn't exist
if err := os.MkdirAll(cfg.DataDir, 0700); err != nil {
return nil, fmt.Errorf("failed to create data directory %s: %w", cfg.DataDir, err)
Expand Down Expand Up @@ -123,6 +128,9 @@ func Load() (*Config, error) {
cfg.DataDir = filepath.Join(homeDir, DefaultDataDir)
}

// Expand secret/env refs in DataDir before creating it
expandDataDir(cfg)

// Create data directory if it doesn't exist
if err := os.MkdirAll(cfg.DataDir, 0700); err != nil {
return nil, fmt.Errorf("failed to create data directory %s: %w", cfg.DataDir, err)
Expand Down Expand Up @@ -460,6 +468,21 @@ func SetRegistriesInitCallback(callback func(*Config)) {
registriesInitCallback = callback
}

// expandDataDir expands secret/env refs in cfg.DataDir in place.
// Failures are logged to stderr and the original value is kept.
func expandDataDir(cfg *Config) {
if cfg.DataDir == "" {
return
}
resolver := secret.NewResolver()
resolved, err := resolver.ExpandSecretRefs(context.Background(), cfg.DataDir)
if err != nil {
fmt.Fprintf(os.Stderr, "WARN: Failed to resolve secret ref in data_dir, using original value: reference=%s err=%v\n", cfg.DataDir, err)
return
}
cfg.DataDir = resolved
}

// applyTLSEnvOverrides applies environment variable overrides for TLS configuration
func applyTLSEnvOverrides(cfg *Config) {
// Ensure TLS config is initialized
Expand Down
8 changes: 4 additions & 4 deletions internal/config/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,14 @@ func MergeServerConfig(base, patch *ServerConfig, opts MergeOptions) (*ServerCon
return nil, nil, fmt.Errorf("%w: both base and patch are nil", ErrInvalidConfig)
}
// Return a copy of patch
merged := copyServerConfig(patch)
merged := CopyServerConfig(patch)
merged.Updated = time.Now()
return merged, NewConfigDiff(), nil
}

if patch == nil {
// If no patch, return a copy of base
merged := copyServerConfig(base)
merged := CopyServerConfig(base)
return merged, nil, nil
}

Expand All @@ -189,7 +189,7 @@ func MergeServerConfig(base, patch *ServerConfig, opts MergeOptions) (*ServerCon
}

// Start with a copy of base
merged := copyServerConfig(base)
merged := CopyServerConfig(base)

// Track changes if requested
var diff *ConfigDiff
Expand Down Expand Up @@ -522,7 +522,7 @@ func MergeOAuthConfig(base, patch *OAuthConfig, removeIfNil bool) *OAuthConfig {

// Helper functions to copy configs (avoiding pointer aliasing)

func copyServerConfig(src *ServerConfig) *ServerConfig {
func CopyServerConfig(src *ServerConfig) *ServerConfig {
if src == nil {
return nil
}
Expand Down
114 changes: 114 additions & 0 deletions internal/secret/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,120 @@ func (r *Resolver) expandValue(ctx context.Context, v reflect.Value) error {
return nil
}

// SecretExpansionError records a failure to resolve a single secret reference during struct expansion.
type SecretExpansionError struct {
FieldPath string // e.g. "WorkingDir", "Isolation.WorkingDir", "Args[0]", "Env[MY_VAR]"
Reference string // the original unresolved reference pattern, e.g. "${env:HOME}"
Err error
}

// ExpandStructSecretsCollectErrors expands secret references in all string fields of v.
// Unlike ExpandStructSecrets, it does not fail fast: it collects all expansion errors and
// continues processing remaining fields. On error, the field retains its original value.
// v must be a non-nil pointer to a struct.
func (r *Resolver) ExpandStructSecretsCollectErrors(ctx context.Context, v interface{}) []SecretExpansionError {
var errs []SecretExpansionError
r.expandValueCollectErrors(ctx, reflect.ValueOf(v), "", &errs)
return errs
}

// expandValueCollectErrors mirrors expandValue but tracks field paths and collects errors
// instead of returning on the first failure. On resolution error the field is left unchanged.
func (r *Resolver) expandValueCollectErrors(ctx context.Context, v reflect.Value, path string, errs *[]SecretExpansionError) {
if !v.IsValid() {
return
}

// Handle pointers
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return
}
r.expandValueCollectErrors(ctx, v.Elem(), path, errs)
return
}

switch v.Kind() {
case reflect.String:
if v.CanSet() {
original := v.String()
if IsSecretRef(original) {
expanded, err := r.ExpandSecretRefs(ctx, original)
if err != nil {
*errs = append(*errs, SecretExpansionError{
FieldPath: path,
Reference: original,
Err: err,
})
// retain original value on failure — do not call SetString
} else {
v.SetString(expanded)
}
}
}

case reflect.Struct:
t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if !field.CanInterface() {
continue
}
fieldType := t.Field(i)
if !fieldType.IsExported() {
continue
}
fieldName := fieldType.Name
newPath := fieldName
if path != "" {
newPath = path + "." + fieldName
}
r.expandValueCollectErrors(ctx, field, newPath, errs)
}

case reflect.Slice, reflect.Array:
for i := 0; i < v.Len(); i++ {
newPath := fmt.Sprintf("%s[%d]", path, i)
r.expandValueCollectErrors(ctx, v.Index(i), newPath, errs)
}

case reflect.Map:
for _, key := range v.MapKeys() {
keyStr := fmt.Sprintf("%v", key.Interface())
newPath := fmt.Sprintf("%s[%s]", path, keyStr)
mapValue := v.MapIndex(key)
if mapValue.Kind() == reflect.String && IsSecretRef(mapValue.String()) {
original := mapValue.String()
expanded, err := r.ExpandSecretRefs(ctx, original)
if err != nil {
*errs = append(*errs, SecretExpansionError{
FieldPath: newPath,
Reference: original,
Err: err,
})
} else {
v.SetMapIndex(key, reflect.ValueOf(expanded))
}
} else if mapValue.Kind() == reflect.Interface {
actualValue := mapValue.Elem()
if actualValue.Kind() == reflect.String && IsSecretRef(actualValue.String()) {
original := actualValue.String()
expanded, err := r.ExpandSecretRefs(ctx, original)
if err != nil {
*errs = append(*errs, SecretExpansionError{
FieldPath: newPath,
Reference: original,
Err: err,
})
} else {
v.SetMapIndex(key, reflect.ValueOf(expanded))
}
}
}
}
}
}

// ExtractConfigSecrets extracts all secret and environment references from a config structure
func (r *Resolver) ExtractConfigSecrets(ctx context.Context, v interface{}) (*ConfigSecretsResponse, error) {
allRefs := []Ref{}
Expand Down
Loading
Loading