diff --git a/cmd/proxsave/upgrade.go b/cmd/proxsave/upgrade.go index 6a35f7b..0cc886f 100644 --- a/cmd/proxsave/upgrade.go +++ b/cmd/proxsave/upgrade.go @@ -159,6 +159,9 @@ func runUpgrade(ctx context.Context, args *cli.Args, bootstrap *logging.Bootstra bootstrap.Warning("Upgrade: configuration upgrade failed: %v", cfgUpgradeErr) } } + if sessionLogger != nil && cfgUpgradeResult != nil && len(cfgUpgradeResult.MissingKeys) > 0 { + sessionLogger.Info("Upgrade: configuration updated with %d missing key(s): %s", len(cfgUpgradeResult.MissingKeys), strings.Join(cfgUpgradeResult.MissingKeys, ", ")) + } // Refresh docs/symlinks/cron/identity (configuration upgrade is handled separately) logging.DebugStepBootstrap(bootstrap, "upgrade workflow", "refreshing docs and symlinks") diff --git a/internal/config/upgrade.go b/internal/config/upgrade.go index c2779c8..1b31963 100644 --- a/internal/config/upgrade.go +++ b/internal/config/upgrade.go @@ -3,6 +3,7 @@ package config import ( "fmt" "os" + "sort" "strings" "time" @@ -23,6 +24,11 @@ type envValue struct { comment string } +type keyRange struct { + start int + end int +} + // UpgradeResult describes the outcome of a configuration upgrade. type UpgradeResult struct { // BackupPath is the path of the backup created from the previous config. @@ -158,37 +164,38 @@ func computeConfigUpgrade(configPath string) (*UpgradeResult, string, []byte, er originalLines := strings.Split(normalizedOriginal, "\n") // 1. Collect user values: for each KEY we store all VALUE entries in order. - userValues, userKeyOrder, caseMap, caseConflicts, warnings, err := parseEnvValues(originalLines) + userValues, userKeyOrder, caseMap, caseConflicts, warnings, userRanges, err := parseEnvValues(originalLines) if err != nil { return result, "", originalContent, fmt.Errorf("failed to parse config %s: %w", configPath, err) } - // 2. Walk the template line-by-line, merging values. + // 2. Walk the template line-by-line and collect template entries. template := DefaultEnvTemplate() normalizedTemplate := strings.ReplaceAll(template, "\r\n", "\n") templateLines := strings.Split(normalizedTemplate, "\n") - templateKeys := make(map[string]bool) + type templateEntry struct { + key string + upper string + lines []string + index int + } + + templateEntries := make([]templateEntry, 0) templateKeyByUpper := make(map[string]string) - missingKeys := make([]string, 0) - newLines := make([]string, 0, len(templateLines)+len(userValues)) - processedUserKeys := make(map[string]bool) // Track which user keys (original case) have been used for i := 0; i < len(templateLines); i++ { line := templateLines[i] trimmed := strings.TrimSpace(line) if utils.IsComment(trimmed) { - newLines = append(newLines, line) continue } key, _, _, ok := splitKeyValueRaw(line) if !ok || key == "" { - newLines = append(newLines, line) continue } - templateKeys[key] = true upperKey := strings.ToUpper(key) if existing, ok := templateKeyByUpper[upperKey]; ok { if existing != key { @@ -198,120 +205,186 @@ func computeConfigUpgrade(configPath string) (*UpgradeResult, string, []byte, er templateKeyByUpper[upperKey] = key } - // Logic to find the user's values for this key. - // 1. Try exact match - targetUserKey := key - if _, ok := userValues[key]; !ok { - // 2. Try case-insensitive match - if mappedKey, ok := caseMap[strings.ToUpper(key)]; ok { - targetUserKey = mappedKey - } - } - - // Handle block values if blockValueKeys[upperKey] && trimmed == fmt.Sprintf("%s=\"", key) { blockEnd, err := findClosingQuoteLine(templateLines, i+1) if err != nil { return result, "", originalContent, fmt.Errorf("template %s block invalid: %w", key, err) } - - if values, ok := userValues[targetUserKey]; ok && len(values) > 0 { - processedUserKeys[targetUserKey] = true - for _, v := range values { - // Use TEMPLATE Key casing to enforce consistency - newLines = append(newLines, renderEnvValue(key, v)...) - } - } else { - missingKeys = append(missingKeys, key) - newLines = append(newLines, templateLines[i:blockEnd+1]...) - } - + templateEntries = append(templateEntries, templateEntry{ + key: key, + upper: upperKey, + lines: templateLines[i : blockEnd+1], + index: len(templateEntries), + }) i = blockEnd continue } + templateEntries = append(templateEntries, templateEntry{ + key: key, + upper: upperKey, + lines: []string{line}, + index: len(templateEntries), + }) + } - if values, ok := userValues[targetUserKey]; ok && len(values) > 0 { - processedUserKeys[targetUserKey] = true - for _, v := range values { - // Use TEMPLATE Key casing to enforce consistency - newLines = append(newLines, renderEnvValue(key, v)...) + // 3. Compute missing and extra keys. + missingKeys := make([]string, 0) + missingEntries := make([]templateEntry, 0) + for _, entry := range templateEntries { + targetUserKey := entry.key + if _, ok := userValues[entry.key]; !ok { + if mappedKey, ok := caseMap[entry.upper]; ok { + targetUserKey = mappedKey } - } else { - // Key missing in user config: keep template default and record it. - missingKeys = append(missingKeys, key) - newLines = append(newLines, line) } + if values, ok := userValues[targetUserKey]; ok && len(values) > 0 { + continue + } + missingKeys = append(missingKeys, entry.key) + missingEntries = append(missingEntries, entry) } - // 3. Append extra keys (present only in user config) in a dedicated section. extraKeys := make([]string, 0) - extraLines := make([]string, 0) - for _, key := range userKeyOrder { - if processedUserKeys[key] { + upperKey := strings.ToUpper(key) + if _, ok := templateKeyByUpper[upperKey]; !ok { + extraKeys = append(extraKeys, key) continue } - // If exact match was in template keys (should have been processed above), skip - if templateKeys[key] { + if caseConflicts[upperKey] && caseMap[upperKey] != key { + extraKeys = append(extraKeys, key) continue } - // If case-insensitive match was in template keys (should have been processed), skip - // Check by upper casing the user key and seeing if it exists in templateKeys? - // But wait, templateKeys stores exact keys. - // If user has "Backup_Enabled", and template has "BACKUP_ENABLED". - // We processed "BACKUP_ENABLED", found "Backup_Enabled" via caseMap, and marked "Backup_Enabled" as processed. - // So `processedUserKeys["Backup_Enabled"]` is true. We skip. - // Correct. + if templateKey, ok := templateKeyByUpper[upperKey]; ok && templateKey != key && !caseConflicts[upperKey] { + warnings = append(warnings, fmt.Sprintf("Key %q differs only by case from template key %q; preserved with original casing", key, templateKey)) + } + } - upperKey := strings.ToUpper(key) - if templateKey, ok := templateKeyByUpper[upperKey]; ok && templateKey != key { - if !caseConflicts[upperKey] { - warnings = append(warnings, fmt.Sprintf("Key %q differs only by case from template key %q; preserved as custom entry", key, templateKey)) - } + // Count preserved values + preserved := 0 + for key, values := range userValues { + if _, ok := templateKeyByUpper[strings.ToUpper(key)]; ok { + preserved += len(values) } + } - values := userValues[key] - if len(values) == 0 { - continue + // If nothing is missing, do not rewrite the file. + if len(missingKeys) == 0 { + result.Changed = false + result.Warnings = warnings + result.ExtraKeys = extraKeys + result.PreservedValues = preserved + return result, "", originalContent, nil + } + + type insertOp struct { + index int + lines []string + order int + } + + hasTrailingNewline := strings.HasSuffix(normalizedOriginal, "\n") + appendIndex := len(originalLines) + if hasTrailingNewline && len(originalLines) > 0 && originalLines[len(originalLines)-1] == "" { + appendIndex = len(originalLines) - 1 + } + normalizeInsertIndex := func(idx int) int { + if idx < 0 { + return 0 + } + if idx > appendIndex { + return appendIndex } - extraKeys = append(extraKeys, key) - for _, v := range values { - // Preserve USER's original key casing for extras - extraLines = append(extraLines, renderEnvValue(key, v)...) + if hasTrailingNewline && idx == len(originalLines) { + return appendIndex } + return idx } - if len(extraLines) > 0 { - newLines = append(newLines, - "", - "# ----------------------------------------------------------------------", - "# Custom keys preserved from previous configuration (not present in template)", - "# ----------------------------------------------------------------------", - ) - newLines = append(newLines, extraLines...) + resolveUserKey := func(entry templateEntry) (string, bool) { + if values, ok := userValues[entry.key]; ok && len(values) > 0 { + return entry.key, true + } + if mappedKey, ok := caseMap[entry.upper]; ok { + if values, ok := userValues[mappedKey]; ok && len(values) > 0 { + return mappedKey, true + } + } + return "", false } - // Count preserved values - preserved := 0 - for key := range processedUserKeys { - preserved += len(userValues[key]) + findPrevAnchor := func(entryIndex int) (int, bool) { + for i := entryIndex - 1; i >= 0; i-- { + if userKey, ok := resolveUserKey(templateEntries[i]); ok { + ranges := userRanges[userKey] + if len(ranges) == 0 { + continue + } + return ranges[len(ranges)-1].end + 1, true + } + } + return 0, false } - // If nothing changed (no missing keys and no extras), we can return early. - // BUT checking "nothing changed" is harder now because we might have renamed keys. - // If we renamed a key, the content CHANGED. - // So we should compare normalized content? - // Or just assume if we parsed everything and re-rendered, and it matches original string... + ops := make([]insertOp, 0, len(missingEntries)) + unanchored := make([]templateEntry, 0) + for _, entry := range missingEntries { + insertIndex := appendIndex + if prev, ok := findPrevAnchor(entry.index); ok { + insertIndex = prev + } else { + unanchored = append(unanchored, entry) + continue + } + insertIndex = normalizeInsertIndex(insertIndex) + ops = append(ops, insertOp{ + index: insertIndex, + lines: entry.lines, + order: entry.index, + }) + } - newContent := strings.Join(newLines, lineEnding) - // Preserve trailing newline if template had one. - if strings.HasSuffix(normalizedTemplate, "\n") && !strings.HasSuffix(newContent, lineEnding) { - newContent += lineEnding + if len(unanchored) > 0 { + section := []string{"# Added by upgrade"} + if appendIndex > 0 && strings.TrimSpace(originalLines[appendIndex-1]) != "" { + section = append([]string{""}, section...) + } + for _, entry := range unanchored { + section = append(section, entry.lines...) + } + ops = append(ops, insertOp{ + index: normalizeInsertIndex(appendIndex), + lines: section, + order: len(templateEntries), + }) } + sort.SliceStable(ops, func(i, j int) bool { + if ops[i].index != ops[j].index { + return ops[i].index < ops[j].index + } + return ops[i].order < ops[j].order + }) + + newLines := make([]string, 0, len(originalLines)+len(ops)) + opIdx := 0 + for i := 0; i < len(originalLines); i++ { + for opIdx < len(ops) && ops[opIdx].index == i { + newLines = append(newLines, ops[opIdx].lines...) + opIdx++ + } + newLines = append(newLines, originalLines[i]) + } + for opIdx < len(ops) { + newLines = append(newLines, ops[opIdx].lines...) + opIdx++ + } + + newContent := strings.Join(newLines, lineEnding) if newContent == string(originalContent) { result.Changed = false result.Warnings = warnings + result.ExtraKeys = extraKeys result.PreservedValues = preserved return result, "", originalContent, nil } @@ -324,12 +397,13 @@ func computeConfigUpgrade(configPath string) (*UpgradeResult, string, []byte, er return result, newContent, originalContent, nil } -func parseEnvValues(lines []string) (map[string][]envValue, []string, map[string]string, map[string]bool, []string, error) { +func parseEnvValues(lines []string) (map[string][]envValue, []string, map[string]string, map[string]bool, []string, map[string][]keyRange, error) { userValues := make(map[string][]envValue) userKeyOrder := make([]string, 0) caseMap := make(map[string]string) // UPPER -> original caseConflicts := make(map[string]bool) warnings := make([]string, 0) + userRanges := make(map[string][]keyRange) for i := 0; i < len(lines); i++ { line := lines[i] @@ -358,7 +432,7 @@ func parseEnvValues(lines []string) (map[string][]envValue, []string, map[string blockLines := make([]string, 0) blockEnd, err := findClosingQuoteLine(lines, i+1) if err != nil { - return nil, nil, nil, nil, nil, fmt.Errorf("unterminated multi-line value for %s starting at line %d", key, i+1) + return nil, nil, nil, nil, nil, nil, fmt.Errorf("unterminated multi-line value for %s starting at line %d", key, i+1) } blockLines = append(blockLines, lines[i+1:blockEnd]...) @@ -366,6 +440,7 @@ func parseEnvValues(lines []string) (map[string][]envValue, []string, map[string userKeyOrder = append(userKeyOrder, key) } userValues[key] = append(userValues[key], envValue{kind: envValueKindBlock, blockLines: blockLines}) + userRanges[key] = append(userRanges[key], keyRange{start: i, end: blockEnd}) i = blockEnd continue @@ -375,9 +450,10 @@ func parseEnvValues(lines []string) (map[string][]envValue, []string, map[string userKeyOrder = append(userKeyOrder, key) } userValues[key] = append(userValues[key], envValue{kind: envValueKindLine, rawValue: rawValue, comment: comment}) + userRanges[key] = append(userRanges[key], keyRange{start: i, end: i}) } - return userValues, userKeyOrder, caseMap, caseConflicts, warnings, nil + return userValues, userKeyOrder, caseMap, caseConflicts, warnings, userRanges, nil } func splitKeyValueRaw(line string) (string, string, string, bool) { diff --git a/internal/config/upgrade_test.go b/internal/config/upgrade_test.go index eb9bfa9..1045e66 100644 --- a/internal/config/upgrade_test.go +++ b/internal/config/upgrade_test.go @@ -79,7 +79,7 @@ func TestPlanUpgradeTracksExtraKeys(t *testing.T) { }) } -func TestUpgradeConfigCreatesBackupAndCustomSection(t *testing.T) { +func TestUpgradeConfigCreatesBackupAndPreservesExtraKeys(t *testing.T) { withTemplate(t, upgradeTemplate, func() { tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, "backup.env") @@ -108,9 +108,6 @@ func TestUpgradeConfigCreatesBackupAndCustomSection(t *testing.T) { t.Fatalf("failed to read upgraded config: %v", err) } content := string(updated) - if !strings.Contains(content, "Custom keys preserved") { - t.Fatalf("expected custom section header, got: %s", content) - } if !strings.Contains(content, "EXTRA_KEY=value") { t.Fatalf("expected EXTRA_KEY preserved, got: %s", content) } @@ -415,9 +412,41 @@ Custom_Backup_Paths=" } content := strings.ReplaceAll(string(data), "\r\n", "\n") - expectedBlock := "CUSTOM_BACKUP_PATHS=\"\n/etc/custom.conf\n\"\n" + expectedBlock := "Custom_Backup_Paths=\"\n/etc/custom.conf\n\"\n" if !strings.Contains(content, expectedBlock) { - t.Fatalf("upgraded config missing preserved block with fixed casing:\nGot:\n%s\nWant contains:\n%s", content, expectedBlock) + t.Fatalf("upgraded config missing preserved block with original casing:\nGot:\n%s\nWant contains:\n%s", content, expectedBlock) + } + }) +} + +func TestUpgradeConfigAddsMissingKeysUnderUpgradeSectionWhenNoAnchor(t *testing.T) { + template := "KEY1=default\nKEY2=default\n" + withTemplate(t, template, func() { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "backup.env") + legacy := "EXTRA=value\n" + if err := os.WriteFile(configPath, []byte(legacy), 0600); err != nil { + t.Fatalf("failed to write legacy config: %v", err) + } + + result, err := UpgradeConfigFile(configPath) + if err != nil { + t.Fatalf("UpgradeConfigFile returned error: %v", err) + } + if !result.Changed { + t.Fatal("expected result.Changed=true when keys are missing") + } + + data, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("failed to read upgraded config: %v", err) + } + content := strings.ReplaceAll(string(data), "\r\n", "\n") + if !strings.Contains(content, "EXTRA=value") { + t.Fatalf("expected EXTRA to remain, got:\n%s", content) + } + if !strings.Contains(content, "# Added by upgrade\nKEY1=default\nKEY2=default\n") { + t.Fatalf("expected missing keys under upgrade section, got:\n%s", content) } }) }