diff --git a/cobra.go b/cobra.go index d9cd2414e..4e229bbb5 100644 --- a/cobra.go +++ b/cobra.go @@ -41,6 +41,7 @@ var templateFuncs = template.FuncMap{ var initializers []func() var finalizers []func() +var cmdFinalizers []func(*Command, error) const ( defaultPrefixMatching = false @@ -106,6 +107,10 @@ func OnFinalize(y ...func()) { finalizers = append(finalizers, y...) } +func OnFinalizeCmd(y ...func(cmd *Command, err error)) { + cmdFinalizers = append(cmdFinalizers, y...) +} + // FIXME Gt is unused by cobra and should be removed in a version 2. It exists only for compatibility with users of cobra. // Gt takes two types and checks whether the first type is greater than the second. In case of types Arrays, Chans, diff --git a/command.go b/command.go index 78088db69..fb4ba2e3e 100644 --- a/command.go +++ b/command.go @@ -959,6 +959,8 @@ func (c *Command) execute(a []string) (err error) { c.preRun() defer c.postRun() + var cmdErr error + defer c.postRunCmd(&c, &cmdErr) argWoFlags := c.Flags().Args() if c.DisableFlagParsing { @@ -966,6 +968,7 @@ func (c *Command) execute(a []string) (err error) { } if err := c.ValidateArgs(argWoFlags); err != nil { + cmdErr = err return err } @@ -984,6 +987,7 @@ func (c *Command) execute(a []string) (err error) { for _, p := range parents { if p.PersistentPreRunE != nil { if err := p.PersistentPreRunE(c, argWoFlags); err != nil { + cmdErr = err return err } if !EnableTraverseRunHooks { @@ -998,6 +1002,7 @@ func (c *Command) execute(a []string) (err error) { } if c.PreRunE != nil { if err := c.PreRunE(c, argWoFlags); err != nil { + cmdErr = err return err } } else if c.PreRun != nil { @@ -1005,14 +1010,17 @@ func (c *Command) execute(a []string) (err error) { } if err := c.ValidateRequiredFlags(); err != nil { + cmdErr = err return err } if err := c.ValidateFlagGroups(); err != nil { + cmdErr = err return err } if c.RunE != nil { if err := c.RunE(c, argWoFlags); err != nil { + cmdErr = err return err } } else { @@ -1020,6 +1028,7 @@ func (c *Command) execute(a []string) (err error) { } if c.PostRunE != nil { if err := c.PostRunE(c, argWoFlags); err != nil { + cmdErr = err return err } } else if c.PostRun != nil { @@ -1028,6 +1037,7 @@ func (c *Command) execute(a []string) (err error) { for p := c; p != nil; p = p.Parent() { if p.PersistentPostRunE != nil { if err := p.PersistentPostRunE(c, argWoFlags); err != nil { + cmdErr = err return err } if !EnableTraverseRunHooks { @@ -1056,6 +1066,12 @@ func (c *Command) postRun() { } } +func (c *Command) postRunCmd(cmdRef **Command, errRef *error) { + for _, x := range cmdFinalizers { + x(*cmdRef, *errRef) + } +} + // ExecuteContext is the same as Execute(), but sets the ctx on the command. // Retrieve ctx by calling cmd.Context() inside your *Run lifecycle or ValidArgs // functions. diff --git a/command_test.go b/command_test.go index a86e57f0a..806bf8a1b 100644 --- a/command_test.go +++ b/command_test.go @@ -17,6 +17,7 @@ package cobra import ( "bytes" "context" + "errors" "fmt" "io" "os" @@ -2600,9 +2601,10 @@ func TestFParseErrWhitelistSiblingCommand(t *testing.T) { checkStringContains(t, output, "unknown flag: --unknown") } +const val = "foobar" + func TestSetContext(t *testing.T) { type key struct{} - val := "foobar" root := &Command{ Use: "root", Run: func(cmd *Command, args []string) { @@ -2952,3 +2954,124 @@ func TestHelpFuncExecuted(t *testing.T) { checkStringContains(t, output, helpText) } + +func TestOnFinalizeCmdRunE(t *testing.T) { + t.Cleanup(func() { cmdFinalizers = []func(*Command, error){} }) + type key struct{} + testErr := fmt.Errorf("Test Error") + OnFinalizeCmd(func(cmd *Command, err error) { + key := cmd.Context().Value(key{}) + got, ok := key.(string) + if !ok { + t.Error("key not found in context") + } + if cmd.Name() != "root" { + t.Errorf("Unexpected OnFinalizeCmd: %s", cmd.Name()) + } + + if !errors.Is(err, testErr) { + t.Errorf("Unexpected OnFinalizeCmd error: %v", err) + } + + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } + }) + rootCmd := &Command{ + Use: "root", + RunE: func(cmd *Command, _ []string) error { + cmd.SetContext(context.WithValue(cmd.Context(), key{}, val)) + return testErr + }, + } + + _, err := executeCommand(rootCmd) + if !errors.Is(err, testErr) { + t.Errorf("Unexpected error: %v", err) + } +} + +func TestOnFinalizeCmdPostRunE(t *testing.T) { + t.Cleanup(func() { cmdFinalizers = []func(*Command, error){} }) + type key struct{} + testErr := fmt.Errorf("Test Error") + OnFinalizeCmd(func(cmd *Command, err error) { + key := cmd.Context().Value(key{}) + got, ok := key.(string) + if !ok { + t.Error("key not found in context") + } + if cmd.Name() != "root" { + t.Errorf("unexpected OnFinalizeCmd: %s", cmd.Name()) + } + + if !errors.Is(err, testErr) { + t.Errorf("Unexpected OnFinalizeCmd error: %v", err) + } + + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } + }) + + rootCmd := &Command{ + Use: "root", + Run: func(cmd *Command, args []string) { + cmd.SetContext(context.WithValue(cmd.Context(), key{}, val)) + }, + PostRunE: func(_ *Command, _ []string) error { + return testErr + }, + } + + _, err := executeCommand(rootCmd) + if !errors.Is(err, testErr) { + t.Errorf("Unexpected error: %v", err) + } +} + +func TestOnFinalizeCmdPersistentPostRunE(t *testing.T) { + t.Cleanup(func() { cmdFinalizers = []func(*Command, error){} }) + type key struct{} + testErr := fmt.Errorf("Test Error") + OnFinalizeCmd(func(cmd *Command, err error) { + c := cmd.Context() + key := c.Value(key{}) + got, ok := key.(string) + if !ok { + t.Error("key not found in context") + } + if cmd.Name() != "child" { + t.Errorf("unexpected OnFinalizeCmd: %s", cmd.Name()) + } + + if !errors.Is(err, testErr) { + t.Errorf("Unexpected OnFinalizeCmd error: %v", err) + } + + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } + }) + rootCmd := &Command{ + Use: "root", + Run: func(_ *Command, args []string) {}, + PersistentPostRunE: func(_ *Command, _ []string) error { + return testErr + }, + } + + childCmd := &Command{ + Use: "child", + Run: func(cmd *Command, args []string) { + cmd.SetContext(context.WithValue(cmd.Context(), key{}, val)) + }, + } + + rootCmd.AddCommand(childCmd) + + _, err := executeCommand(rootCmd, childCmd.Use) + if !errors.Is(err, testErr) { + t.Errorf("Unexpected error: %v", err) + } +}