Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@ on:
types: [opened, synchronize, reopened]

jobs:
lint:
name: Lint
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: stable
- name: Run golangci-lint
uses: golangci/golangci-lint-action@v7

build:
name: Build and test
runs-on: ubuntu-latest
Expand Down
11 changes: 11 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
version: "2"

linters:
default: standard
exclusions:
paths:
- examples

formatters:
enable:
- goimports
17 changes: 17 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
.PHONY: build test lint format ci-test

build:
go build -v .

test:
go test $$(go list ./... | grep -v 'examples') -count=1 -v

lint:
golangci-lint run ./...

format:
goimports -w $$(find . -name '*.go' -not -path './examples/*')

ci-test:
go test $$(go list ./... | grep -v 'examples') -count=1 -v -json -cover \
| tparse -all -follow -sort=elapsed -trimpath=auto
12 changes: 6 additions & 6 deletions graceful/graceful.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func Run(run func(context.Context) error, opts ...Option) {
if cfg.logger != nil {
cfg.logger.Error("function error", slog.Any("error", err))
} else {
fmt.Fprintln(cfg.stderr, err)
_, _ = fmt.Fprintln(cfg.stderr, err)
}
exit(1)
}
Expand All @@ -113,7 +113,7 @@ func Run(run func(context.Context) error, opts ...Option) {
if cfg.logger != nil {
cfg.logger.Warn(msg)
} else {
fmt.Fprintln(cfg.stderr, msg)
_, _ = fmt.Fprintln(cfg.stderr, msg)
}
exit(130)
}
Expand All @@ -127,7 +127,7 @@ func Run(run func(context.Context) error, opts ...Option) {
if cfg.logger != nil {
cfg.logger.Info(msg)
} else {
fmt.Fprintln(cfg.stderr, msg)
_, _ = fmt.Fprintln(cfg.stderr, msg)
}

// Set up shutdown timeout if configured
Expand All @@ -145,7 +145,7 @@ func Run(run func(context.Context) error, opts ...Option) {
if cfg.logger != nil {
cfg.logger.Error("function error", "error", err)
} else {
fmt.Fprintln(cfg.stderr, err)
_, _ = fmt.Fprintln(cfg.stderr, err)
}
exit(1)
}
Expand All @@ -157,7 +157,7 @@ func Run(run func(context.Context) error, opts ...Option) {
if cfg.logger != nil {
cfg.logger.Warn(msg)
} else {
fmt.Fprintln(cfg.stderr, msg)
_, _ = fmt.Fprintln(cfg.stderr, msg)
}
exit(130)

Expand All @@ -167,7 +167,7 @@ func Run(run func(context.Context) error, opts ...Option) {
if cfg.logger != nil {
cfg.logger.Error(msg)
} else {
fmt.Fprintln(cfg.stderr, msg)
_, _ = fmt.Fprintln(cfg.stderr, msg)
}
exit(124)
}
Expand Down
2 changes: 1 addition & 1 deletion graceful/graceful_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func sendSignal(trigger <-chan struct{}, delay time.Duration) {
if delay > 0 {
time.Sleep(delay)
}
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
_ = syscall.Kill(syscall.Getpid(), syscall.SIGINT)
}

func TestRun_Success(t *testing.T) {
Expand Down
207 changes: 112 additions & 95 deletions parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,32 +36,62 @@ func Parse(root *Command, args []string) error {
// Reset command path but preserve other state
root.state.path = []*Command{root}
}
// First split args at the -- delimiter if present
var argsToParse []string
var remainingArgs []string

argsToParse, remainingArgs := splitAtDelimiter(args)

current, err := resolveCommandPath(root, argsToParse)
if err != nil {
return err
}
current.Flags.Usage = func() { /* suppress default usage */ }

// Check for help flags after resolving the correct command
for _, arg := range argsToParse {
if arg == "-h" || arg == "--h" || arg == "-help" || arg == "--help" {
// Combine flags first so the help message includes all inherited flags
combineFlags(root.state.path)
return flag.ErrHelp
}
}

combinedFlags := combineFlags(root.state.path)

// Let ParseToEnd handle the flag parsing
if err := xflag.ParseToEnd(combinedFlags, argsToParse); err != nil {
return fmt.Errorf("command %q: %w", getCommandPath(root.state.path), err)
}

if err := checkRequiredFlags(root.state.path, combinedFlags); err != nil {
return err
}

root.state.Args = collectArgs(root.state.path, combinedFlags.Args(), remainingArgs)

if current.Exec == nil {
return fmt.Errorf("command %q: no exec function defined", getCommandPath(root.state.path))
}
return nil
}

// splitAtDelimiter splits args at the first "--" delimiter. Returns the args before the delimiter
// and any args after it.
func splitAtDelimiter(args []string) (argsToParse, remaining []string) {
for i, arg := range args {
if arg == "--" {
argsToParse = args[:i]
remainingArgs = args[i+1:]
break
return args[:i], args[i+1:]
}
}
if argsToParse == nil {
argsToParse = args
}
return args, nil
}

// resolveCommandPath walks argsToParse to resolve the subcommand chain, building root.state.path
// and initializing flag sets along the way. Returns the terminal (deepest) command.
func resolveCommandPath(root *Command, argsToParse []string) (*Command, error) {
current := root
if current.Flags == nil {
current.Flags = flag.NewFlagSet(root.Name, flag.ContinueOnError)
}
var commandChain []*Command
commandChain = append(commandChain, root)

// Create combined flags with all parent flags
combinedFlags := flag.NewFlagSet(root.Name, flag.ContinueOnError)
combinedFlags.SetOutput(io.Discard)

// First pass: process commands and build the flag set
i := 0
for i < len(argsToParse) {
arg := argsToParse[i]
Expand All @@ -74,15 +104,24 @@ func Parse(root *Command, args []string) error {
continue
}

// Check if this flag expects a value
// Check if this flag expects a value across all commands in the chain (not just the
// current command), since flags from ancestor commands are inherited and can appear
// anywhere.
name := strings.TrimLeft(arg, "-")
if f := current.Flags.Lookup(name); f != nil {
if _, isBool := f.Value.(interface{ IsBoolFlag() bool }); !isBool {
// Skip both flag and its value
i += 2
continue
skipValue := false
for _, cmd := range root.state.path {
if f := cmd.Flags.Lookup(name); f != nil {
if _, isBool := f.Value.(interface{ IsBoolFlag() bool }); !isBool {
skipValue = true
}
break
}
}
if skipValue {
// Skip both flag and its value
i += 2
continue
}
i++
continue
}
Expand All @@ -95,73 +134,55 @@ func Parse(root *Command, args []string) error {
sub.Flags = flag.NewFlagSet(sub.Name, flag.ContinueOnError)
}
current = sub
commandChain = append(commandChain, sub)
i++
continue
}
return current.formatUnknownCommandError(arg)
return nil, current.formatUnknownCommandError(arg)
}
break
}
current.Flags.Usage = func() { /* suppress default usage */ }

// Add the help check here, after we've found the correct command
hasHelp := false
for _, arg := range argsToParse {
if arg == "-h" || arg == "--h" || arg == "-help" || arg == "--help" {
hasHelp = true
break
}
}
return current, nil
}

// Add flags in reverse order for proper precedence
for i := len(commandChain) - 1; i >= 0; i-- {
cmd := commandChain[i]
// combineFlags merges flags from the command path into a single FlagSet. Flags are added in reverse
// order (deepest command first) so that child flags take precedence over parent flags.
func combineFlags(path []*Command) *flag.FlagSet {
combined := flag.NewFlagSet(path[0].Name, flag.ContinueOnError)
combined.SetOutput(io.Discard)
for i := len(path) - 1; i >= 0; i-- {
cmd := path[i]
if cmd.Flags != nil {
cmd.Flags.VisitAll(func(f *flag.Flag) {
if combinedFlags.Lookup(f.Name) == nil {
combinedFlags.Var(f.Value, f.Name, f.Usage)
if combined.Lookup(f.Name) == nil {
combined.Var(f.Value, f.Name, f.Usage)
}
})
}
}
// Make sure to return help only after combining all flags, this way we get the full list of
// flags in the help message!
if hasHelp {
return flag.ErrHelp
}
return combined
}

// Let ParseToEnd handle the flag parsing
if err := xflag.ParseToEnd(combinedFlags, argsToParse); err != nil {
return fmt.Errorf("command %q: %w", getCommandPath(root.state.path), err)
}
// checkRequiredFlags verifies that all flags marked as required in FlagsMetadata were explicitly
// set during parsing.
func checkRequiredFlags(path []*Command, combined *flag.FlagSet) error {
// Build a set of flags that were explicitly set during parsing. Visit (unlike VisitAll) only
// iterates over flags that were actually provided by the user, regardless of their value.
setFlags := make(map[string]struct{})
combined.Visit(func(f *flag.Flag) {
setFlags[f.Name] = struct{}{}
})

// Check required flags
var missingFlags []string
for _, cmd := range commandChain {
if len(cmd.FlagsMetadata) > 0 {
for _, flagMetadata := range cmd.FlagsMetadata {
if !flagMetadata.Required {
continue
}
flag := combinedFlags.Lookup(flagMetadata.Name)
if flag == nil {
return fmt.Errorf("command %q: internal error: required flag %s not found in flag set", getCommandPath(root.state.path), formatFlagName(flagMetadata.Name))
}
if _, isBool := flag.Value.(interface{ IsBoolFlag() bool }); isBool {
isSet := false
for _, arg := range argsToParse {
if strings.HasPrefix(arg, "-"+flagMetadata.Name) || strings.HasPrefix(arg, "--"+flagMetadata.Name) {
isSet = true
break
}
}
if !isSet {
missingFlags = append(missingFlags, formatFlagName(flagMetadata.Name))
}
} else if flag.Value.String() == flag.DefValue {
missingFlags = append(missingFlags, formatFlagName(flagMetadata.Name))
}
for _, cmd := range path {
for _, flagMetadata := range cmd.FlagsMetadata {
if !flagMetadata.Required {
continue
}
if combined.Lookup(flagMetadata.Name) == nil {
return fmt.Errorf("command %q: internal error: required flag %s not found in flag set", getCommandPath(path), formatFlagName(flagMetadata.Name))
}
if _, ok := setFlags[flagMetadata.Name]; !ok {
missingFlags = append(missingFlags, formatFlagName(flagMetadata.Name))
}
}
}
Expand All @@ -170,40 +191,36 @@ func Parse(root *Command, args []string) error {
if len(missingFlags) > 1 {
msg += "s"
}
return fmt.Errorf("command %q: %s %q not set", getCommandPath(root.state.path), msg, strings.Join(missingFlags, ", "))
return fmt.Errorf("command %q: %s %q not set", getCommandPath(path), msg, strings.Join(missingFlags, ", "))
}
return nil
}

// Skip past command names in remaining args
parsed := combinedFlags.Args()
// collectArgs strips resolved command names from the parsed positional args and appends any args
// that appeared after the "--" delimiter.
func collectArgs(path []*Command, parsed, remaining []string) []string {
// Skip past command names in remaining args. Only strip the exact command names that were
// resolved during traversal (path[1:], since root never appears in user args), in order and
// only once each.
startIdx := 0
for _, arg := range parsed {
isCommand := false
for _, cmd := range commandChain {
if arg == cmd.Name {
startIdx++
isCommand = true
break
}
}
if !isCommand {
chainIdx := 1 // Skip root
for startIdx < len(parsed) && chainIdx < len(path) {
if strings.EqualFold(parsed[startIdx], path[chainIdx].Name) {
startIdx++
chainIdx++
} else {
break
}
}

// Combine remaining parsed args and everything after delimiter
var finalArgs []string
if startIdx < len(parsed) {
finalArgs = append(finalArgs, parsed[startIdx:]...)
}
if len(remainingArgs) > 0 {
finalArgs = append(finalArgs, remainingArgs...)
if len(remaining) > 0 {
finalArgs = append(finalArgs, remaining...)
}
root.state.Args = finalArgs

if current.Exec == nil {
return fmt.Errorf("command %q: no exec function defined", getCommandPath(root.state.path))
}
return nil
return finalArgs
}

var validNameRegex = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_-]*$`)
Expand Down
Loading