Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cobra.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ var templateFuncs = template.FuncMap{

var initializers []func()
var finalizers []func()
var cmdFinalizers []func(*Command, error)

const (
defaultPrefixMatching = false
Expand Down Expand Up @@ -106,6 +107,10 @@ func OnFinalize(y ...func()) {
finalizers = append(finalizers, y...)
}

func OnFinalizeCmd(y ...func(cmd *Command, err error)) {
cmdFinalizers = append(cmdFinalizers, y...)
}

// FIXME Gt is unused by cobra and should be removed in a version 2. It exists only for compatibility with users of cobra.

// Gt takes two types and checks whether the first type is greater than the second. In case of types Arrays, Chans,
Expand Down
16 changes: 16 additions & 0 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -959,13 +959,16 @@ func (c *Command) execute(a []string) (err error) {
c.preRun()

defer c.postRun()
var cmdErr error
defer c.postRunCmd(&c, &cmdErr)

argWoFlags := c.Flags().Args()
if c.DisableFlagParsing {
argWoFlags = a
}

if err := c.ValidateArgs(argWoFlags); err != nil {
cmdErr = err
return err
}

Expand All @@ -984,6 +987,7 @@ func (c *Command) execute(a []string) (err error) {
for _, p := range parents {
if p.PersistentPreRunE != nil {
if err := p.PersistentPreRunE(c, argWoFlags); err != nil {
cmdErr = err
return err
}
if !EnableTraverseRunHooks {
Expand All @@ -998,28 +1002,33 @@ func (c *Command) execute(a []string) (err error) {
}
if c.PreRunE != nil {
if err := c.PreRunE(c, argWoFlags); err != nil {
cmdErr = err
return err
}
} else if c.PreRun != nil {
c.PreRun(c, argWoFlags)
}

if err := c.ValidateRequiredFlags(); err != nil {
cmdErr = err
return err
}
if err := c.ValidateFlagGroups(); err != nil {
cmdErr = err
return err
}

if c.RunE != nil {
if err := c.RunE(c, argWoFlags); err != nil {
cmdErr = err
return err
}
} else {
c.Run(c, argWoFlags)
}
if c.PostRunE != nil {
if err := c.PostRunE(c, argWoFlags); err != nil {
cmdErr = err
return err
}
} else if c.PostRun != nil {
Expand All @@ -1028,6 +1037,7 @@ func (c *Command) execute(a []string) (err error) {
for p := c; p != nil; p = p.Parent() {
if p.PersistentPostRunE != nil {
if err := p.PersistentPostRunE(c, argWoFlags); err != nil {
cmdErr = err
return err
}
if !EnableTraverseRunHooks {
Expand Down Expand Up @@ -1056,6 +1066,12 @@ func (c *Command) postRun() {
}
}

func (c *Command) postRunCmd(cmdRef **Command, errRef *error) {
for _, x := range cmdFinalizers {
x(*cmdRef, *errRef)
}
}

// ExecuteContext is the same as Execute(), but sets the ctx on the command.
// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle or ValidArgs
// functions.
Expand Down
125 changes: 124 additions & 1 deletion command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package cobra
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -2600,9 +2601,10 @@ func TestFParseErrWhitelistSiblingCommand(t *testing.T) {
checkStringContains(t, output, "unknown flag: --unknown")
}

const val = "foobar"

func TestSetContext(t *testing.T) {
type key struct{}
val := "foobar"
root := &Command{
Use: "root",
Run: func(cmd *Command, args []string) {
Expand Down Expand Up @@ -2952,3 +2954,124 @@ func TestHelpFuncExecuted(t *testing.T) {

checkStringContains(t, output, helpText)
}

func TestOnFinalizeCmdRunE(t *testing.T) {
t.Cleanup(func() { cmdFinalizers = []func(*Command, error){} })
type key struct{}
testErr := fmt.Errorf("Test Error")
OnFinalizeCmd(func(cmd *Command, err error) {
key := cmd.Context().Value(key{})
got, ok := key.(string)
if !ok {
t.Error("key not found in context")
}
if cmd.Name() != "root" {
t.Errorf("Unexpected OnFinalizeCmd: %s", cmd.Name())
}

if !errors.Is(err, testErr) {
t.Errorf("Unexpected OnFinalizeCmd error: %v", err)
}

if got != val {
t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got)
}
})
rootCmd := &Command{
Use: "root",
RunE: func(cmd *Command, _ []string) error {
cmd.SetContext(context.WithValue(cmd.Context(), key{}, val))
return testErr
},
}

_, err := executeCommand(rootCmd)
if !errors.Is(err, testErr) {
t.Errorf("Unexpected error: %v", err)
}
}

func TestOnFinalizeCmdPostRunE(t *testing.T) {
t.Cleanup(func() { cmdFinalizers = []func(*Command, error){} })
type key struct{}
testErr := fmt.Errorf("Test Error")
OnFinalizeCmd(func(cmd *Command, err error) {
key := cmd.Context().Value(key{})
got, ok := key.(string)
if !ok {
t.Error("key not found in context")
}
if cmd.Name() != "root" {
t.Errorf("unexpected OnFinalizeCmd: %s", cmd.Name())
}

if !errors.Is(err, testErr) {
t.Errorf("Unexpected OnFinalizeCmd error: %v", err)
}

if got != val {
t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got)
}
})

rootCmd := &Command{
Use: "root",
Run: func(cmd *Command, args []string) {
cmd.SetContext(context.WithValue(cmd.Context(), key{}, val))
},
PostRunE: func(_ *Command, _ []string) error {
return testErr
},
}

_, err := executeCommand(rootCmd)
if !errors.Is(err, testErr) {
t.Errorf("Unexpected error: %v", err)
}
}

func TestOnFinalizeCmdPersistentPostRunE(t *testing.T) {
t.Cleanup(func() { cmdFinalizers = []func(*Command, error){} })
type key struct{}
testErr := fmt.Errorf("Test Error")
OnFinalizeCmd(func(cmd *Command, err error) {
c := cmd.Context()
key := c.Value(key{})
got, ok := key.(string)
if !ok {
t.Error("key not found in context")
}
if cmd.Name() != "child" {
t.Errorf("unexpected OnFinalizeCmd: %s", cmd.Name())
}

if !errors.Is(err, testErr) {
t.Errorf("Unexpected OnFinalizeCmd error: %v", err)
}

if got != val {
t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got)
}
})
rootCmd := &Command{
Use: "root",
Run: func(_ *Command, args []string) {},
PersistentPostRunE: func(_ *Command, _ []string) error {
return testErr
},
}

childCmd := &Command{
Use: "child",
Run: func(cmd *Command, args []string) {
cmd.SetContext(context.WithValue(cmd.Context(), key{}, val))
},
}

rootCmd.AddCommand(childCmd)

_, err := executeCommand(rootCmd, childCmd.Use)
if !errors.Is(err, testErr) {
t.Errorf("Unexpected error: %v", err)
}
}