From 60eb554f2752637cded098bb4ccb05a8bc089787 Mon Sep 17 00:00:00 2001 From: Kevin Gentile Date: Tue, 23 Sep 2025 16:49:50 -0400 Subject: [PATCH 1/4] add finalizers with command and error ref --- cobra.go | 5 +++ command.go | 16 +++++++ command_test.go | 109 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 130 insertions(+) diff --git a/cobra.go b/cobra.go index d9cd2414e..4e229bbb5 100644 --- a/cobra.go +++ b/cobra.go @@ -41,6 +41,7 @@ var templateFuncs = template.FuncMap{ var initializers []func() var finalizers []func() +var cmdFinalizers []func(*Command, error) const ( defaultPrefixMatching = false @@ -106,6 +107,10 @@ func OnFinalize(y ...func()) { finalizers = append(finalizers, y...) } +func OnFinalizeCmd(y ...func(cmd *Command, err error)) { + cmdFinalizers = append(cmdFinalizers, y...) +} + // FIXME Gt is unused by cobra and should be removed in a version 2. It exists only for compatibility with users of cobra. // Gt takes two types and checks whether the first type is greater than the second. In case of types Arrays, Chans, diff --git a/command.go b/command.go index 78088db69..fb4ba2e3e 100644 --- a/command.go +++ b/command.go @@ -959,6 +959,8 @@ func (c *Command) execute(a []string) (err error) { c.preRun() defer c.postRun() + var cmdErr error + defer c.postRunCmd(&c, &cmdErr) argWoFlags := c.Flags().Args() if c.DisableFlagParsing { @@ -966,6 +968,7 @@ func (c *Command) execute(a []string) (err error) { } if err := c.ValidateArgs(argWoFlags); err != nil { + cmdErr = err return err } @@ -984,6 +987,7 @@ func (c *Command) execute(a []string) (err error) { for _, p := range parents { if p.PersistentPreRunE != nil { if err := p.PersistentPreRunE(c, argWoFlags); err != nil { + cmdErr = err return err } if !EnableTraverseRunHooks { @@ -998,6 +1002,7 @@ func (c *Command) execute(a []string) (err error) { } if c.PreRunE != nil { if err := c.PreRunE(c, argWoFlags); err != nil { + cmdErr = err return err } } else if c.PreRun != nil { @@ -1005,14 +1010,17 @@ func (c *Command) execute(a []string) (err error) { } if err := c.ValidateRequiredFlags(); err != nil { + cmdErr = err return err } if err := c.ValidateFlagGroups(); err != nil { + cmdErr = err return err } if c.RunE != nil { if err := c.RunE(c, argWoFlags); err != nil { + cmdErr = err return err } } else { @@ -1020,6 +1028,7 @@ func (c *Command) execute(a []string) (err error) { } if c.PostRunE != nil { if err := c.PostRunE(c, argWoFlags); err != nil { + cmdErr = err return err } } else if c.PostRun != nil { @@ -1028,6 +1037,7 @@ func (c *Command) execute(a []string) (err error) { for p := c; p != nil; p = p.Parent() { if p.PersistentPostRunE != nil { if err := p.PersistentPostRunE(c, argWoFlags); err != nil { + cmdErr = err return err } if !EnableTraverseRunHooks { @@ -1056,6 +1066,12 @@ func (c *Command) postRun() { } } +func (c *Command) postRunCmd(cmdRef **Command, errRef *error) { + for _, x := range cmdFinalizers { + x(*cmdRef, *errRef) + } +} + // ExecuteContext is the same as Execute(), but sets the ctx on the command. // Retrieve ctx by calling cmd.Context() inside your *Run lifecycle or ValidArgs // functions. diff --git a/command_test.go b/command_test.go index a86e57f0a..0ccd0a6e9 100644 --- a/command_test.go +++ b/command_test.go @@ -17,6 +17,7 @@ package cobra import ( "bytes" "context" + "errors" "fmt" "io" "os" @@ -2952,3 +2953,111 @@ func TestHelpFuncExecuted(t *testing.T) { checkStringContains(t, output, helpText) } + +func TestOnFinalizeCmdRunE(t *testing.T) { + testErr := fmt.Errorf("Test Error") + finalized := false + var errOut error + var cmdOut *Command + OnFinalizeCmd(func(cmd *Command, err error) { + errOut = err + cmdOut = cmd + finalized = true + }) + rootCmd := &Command{ + Use: "root", + RunE: func(_ *Command, _ []string) error { + return testErr + }, + } + + _, err := executeCommand(rootCmd) + if !errors.Is(err, testErr) { + t.Errorf("Unexpected error: %v", err) + } + if !finalized { + t.Error("OnFinalizedCmd not called") + } + if cmdOut.Name() != rootCmd.Use { + t.Errorf("Unexpected OnFinalizeCmd: %s", cmdOut.Name()) + } + + if !errors.Is(errOut, testErr) { + t.Errorf("Unexpected OnFinalizeCmd error: %v", err) + } +} + +func TestOnFinalizeCmdPostRunE(t *testing.T) { + testErr := fmt.Errorf("Test Error") + finalized := false + var errOut error + var cmdOut *Command + OnFinalizeCmd(func(cmd *Command, err error) { + errOut = err + cmdOut = cmd + finalized = true + }) + rootCmd := &Command{ + Use: "root", + Run: func(cmd *Command, args []string) {}, + PostRunE: func(_ *Command, _ []string) error { + return testErr + }, + } + + _, err := executeCommand(rootCmd) + if !errors.Is(err, testErr) { + t.Errorf("Unexpected error: %v", err) + } + if !finalized { + t.Error("OnFinalizedCmd not called") + } + if cmdOut.Name() != rootCmd.Use { + t.Errorf("unexpected OnFinalizeCmd: %s", cmdOut.Name()) + } + + if !errors.Is(errOut, testErr) { + t.Errorf("Unexpected OnFinalizeCmd error: %v", err) + } +} + +func TestOnFinalizeCmdPersistentPostRunE(t *testing.T) { + testErr := fmt.Errorf("Test Error") + finalized := false + var errOut error + var cmdOut *Command + OnFinalizeCmd(func(cmd *Command, err error) { + errOut = err + cmdOut = cmd + finalized = true + }) + rootCmd := &Command{ + Use: "root", + Run: func(cmd *Command, args []string) {}, + PersistentPostRunE: func(_ *Command, _ []string) error { + return testErr + }, + } + + childCmd := &Command{ + Use: "child", + Run: func(_ *Command, args []string) {}, + } + + rootCmd.AddCommand(childCmd) + + _, err := executeCommand(rootCmd, childCmd.Use) + if !errors.Is(err, testErr) { + t.Errorf("Unexpected error: %v", err) + } + if !finalized { + t.Error("OnFinalizedCmd not called") + } + if cmdOut.Name() != childCmd.Use { + t.Errorf("unexpected OnFinalizeCmd: %s", cmdOut.Name()) + } + + if !errors.Is(errOut, testErr) { + t.Errorf("Unexpected OnFinalizeCmd error: %v", err) + } +} From e6bb1d2cbd3c892b4a137522e80250018ec04fac Mon Sep 17 00:00:00 2001 From: Kevin Gentile Date: Tue, 23 Sep 2025 17:48:04 -0400 Subject: [PATCH 2/4] test context values --- command_test.go | 75 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 56 insertions(+), 19 deletions(-) diff --git a/command_test.go b/command_test.go index 0ccd0a6e9..2915f7606 100644 --- a/command_test.go +++ b/command_test.go @@ -2955,18 +2955,27 @@ func TestHelpFuncExecuted(t *testing.T) { } func TestOnFinalizeCmdRunE(t *testing.T) { + t.Cleanup(func() { cmdFinalizers = []func(*Command, error){} }) + type key struct{} + val := "foobar" testErr := fmt.Errorf("Test Error") - finalized := false var errOut error var cmdOut *Command + var got string OnFinalizeCmd(func(cmd *Command, err error) { errOut = err cmdOut = cmd - finalized = true + key := cmd.Context().Value(key{}) + keyOut, ok := key.(string) + if !ok { + t.Error("key not found in context") + } + got = keyOut }) rootCmd := &Command{ Use: "root", - RunE: func(_ *Command, _ []string) error { + RunE: func(cmd *Command, _ []string) error { + cmd.SetContext(context.WithValue(cmd.Context(), key{}, val)) return testErr }, } @@ -2975,9 +2984,7 @@ func TestOnFinalizeCmdRunE(t *testing.T) { if !errors.Is(err, testErr) { t.Errorf("Unexpected error: %v", err) } - if !finalized { - t.Error("OnFinalizedCmd not called") - } + if cmdOut.Name() != rootCmd.Use { t.Errorf("Unexpected OnFinalizeCmd: %s", cmdOut.Name()) } @@ -2985,21 +2992,36 @@ func TestOnFinalizeCmdRunE(t *testing.T) { if !errors.Is(errOut, testErr) { t.Errorf("Unexpected OnFinalizeCmd error: %v", err) } + + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } } func TestOnFinalizeCmdPostRunE(t *testing.T) { + t.Cleanup(func() { cmdFinalizers = []func(*Command, error){} }) + type key struct{} + val := "foobar" testErr := fmt.Errorf("Test Error") - finalized := false var errOut error var cmdOut *Command + var got string OnFinalizeCmd(func(cmd *Command, err error) { errOut = err cmdOut = cmd - finalized = true + key := cmd.Context().Value(key{}) + keyOut, ok := key.(string) + if !ok { + t.Error("key not found in context") + } + got = keyOut }) + rootCmd := &Command{ Use: "root", - Run: func(cmd *Command, args []string) {}, + Run: func(cmd *Command, args []string) { + cmd.SetContext(context.WithValue(cmd.Context(), key{}, val)) + }, PostRunE: func(_ *Command, _ []string) error { return testErr }, @@ -3009,9 +3031,7 @@ func TestOnFinalizeCmdPostRunE(t *testing.T) { if !errors.Is(err, testErr) { t.Errorf("Unexpected error: %v", err) } - if !finalized { - t.Error("OnFinalizedCmd not called") - } + if cmdOut.Name() != rootCmd.Use { t.Errorf("unexpected OnFinalizeCmd: %s", cmdOut.Name()) } @@ -3019,21 +3039,34 @@ func TestOnFinalizeCmdPostRunE(t *testing.T) { if !errors.Is(errOut, testErr) { t.Errorf("Unexpected OnFinalizeCmd error: %v", err) } + + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } } func TestOnFinalizeCmdPersistentPostRunE(t *testing.T) { + defer func() { cmdFinalizers = []func(*Command, error){} }() + type key struct{} + val := "foobar" testErr := fmt.Errorf("Test Error") - finalized := false var errOut error var cmdOut *Command + var got string OnFinalizeCmd(func(cmd *Command, err error) { errOut = err cmdOut = cmd - finalized = true + c := cmd.Context() + key := c.Value(key{}) + keyOut, ok := key.(string) + if !ok { + t.Error("key not found in context") + } + got = keyOut }) rootCmd := &Command{ Use: "root", - Run: func(cmd *Command, args []string) {}, + Run: func(_ *Command, args []string) {}, PersistentPostRunE: func(_ *Command, _ []string) error { return testErr }, @@ -3041,7 +3074,9 @@ func TestOnFinalizeCmdPersistentPostRunE(t *testing.T) { childCmd := &Command{ Use: "child", - Run: func(_ *Command, args []string) {}, + Run: func(cmd *Command, args []string) { + cmd.SetContext(context.WithValue(cmd.Context(), key{}, val)) + }, } rootCmd.AddCommand(childCmd) @@ -3050,9 +3085,7 @@ func TestOnFinalizeCmdPersistentPostRunE(t *testing.T) { if !errors.Is(err, testErr) { t.Errorf("Unexpected error: %v", err) } - if !finalized { - t.Error("OnFinalizedCmd not called") - } + if cmdOut.Name() != childCmd.Use { t.Errorf("unexpected OnFinalizeCmd: %s", cmdOut.Name()) } @@ -3060,4 +3093,8 @@ func TestOnFinalizeCmdPersistentPostRunE(t *testing.T) { if !errors.Is(errOut, testErr) { t.Errorf("Unexpected OnFinalizeCmd error: %v", err) } + + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } } From e9467ded732ec7be9b21e9e6d31f0f152d647dbf Mon Sep 17 00:00:00 2001 From: Kevin Gentile Date: Tue, 23 Sep 2025 17:55:26 -0400 Subject: [PATCH 3/4] reorganize tests --- command_test.go | 95 +++++++++++++++++++------------------------------ 1 file changed, 37 insertions(+), 58 deletions(-) diff --git a/command_test.go b/command_test.go index 2915f7606..9a7ae02f8 100644 --- a/command_test.go +++ b/command_test.go @@ -2959,18 +2959,23 @@ func TestOnFinalizeCmdRunE(t *testing.T) { type key struct{} val := "foobar" testErr := fmt.Errorf("Test Error") - var errOut error - var cmdOut *Command - var got string OnFinalizeCmd(func(cmd *Command, err error) { - errOut = err - cmdOut = cmd key := cmd.Context().Value(key{}) - keyOut, ok := key.(string) + got, ok := key.(string) if !ok { t.Error("key not found in context") } - got = keyOut + if cmd.Name() != "root" { + t.Errorf("Unexpected OnFinalizeCmd: %s", cmd.Name()) + } + + if !errors.Is(err, testErr) { + t.Errorf("Unexpected OnFinalizeCmd error: %v", err) + } + + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } }) rootCmd := &Command{ Use: "root", @@ -2984,18 +2989,6 @@ func TestOnFinalizeCmdRunE(t *testing.T) { if !errors.Is(err, testErr) { t.Errorf("Unexpected error: %v", err) } - - if cmdOut.Name() != rootCmd.Use { - t.Errorf("Unexpected OnFinalizeCmd: %s", cmdOut.Name()) - } - - if !errors.Is(errOut, testErr) { - t.Errorf("Unexpected OnFinalizeCmd error: %v", err) - } - - if got != val { - t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) - } } func TestOnFinalizeCmdPostRunE(t *testing.T) { @@ -3003,18 +2996,23 @@ func TestOnFinalizeCmdPostRunE(t *testing.T) { type key struct{} val := "foobar" testErr := fmt.Errorf("Test Error") - var errOut error - var cmdOut *Command - var got string OnFinalizeCmd(func(cmd *Command, err error) { - errOut = err - cmdOut = cmd key := cmd.Context().Value(key{}) - keyOut, ok := key.(string) + got, ok := key.(string) if !ok { t.Error("key not found in context") } - got = keyOut + if cmd.Name() != "root" { + t.Errorf("unexpected OnFinalizeCmd: %s", cmd.Name()) + } + + if !errors.Is(err, testErr) { + t.Errorf("Unexpected OnFinalizeCmd error: %v", err) + } + + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } }) rootCmd := &Command{ @@ -3031,38 +3029,31 @@ func TestOnFinalizeCmdPostRunE(t *testing.T) { if !errors.Is(err, testErr) { t.Errorf("Unexpected error: %v", err) } - - if cmdOut.Name() != rootCmd.Use { - t.Errorf("unexpected OnFinalizeCmd: %s", cmdOut.Name()) - } - - if !errors.Is(errOut, testErr) { - t.Errorf("Unexpected OnFinalizeCmd error: %v", err) - } - - if got != val { - t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) - } } func TestOnFinalizeCmdPersistentPostRunE(t *testing.T) { - defer func() { cmdFinalizers = []func(*Command, error){} }() + t.Cleanup(func() { cmdFinalizers = []func(*Command, error){} }) type key struct{} val := "foobar" testErr := fmt.Errorf("Test Error") - var errOut error - var cmdOut *Command - var got string OnFinalizeCmd(func(cmd *Command, err error) { - errOut = err - cmdOut = cmd c := cmd.Context() key := c.Value(key{}) - keyOut, ok := key.(string) + got, ok := key.(string) if !ok { t.Error("key not found in context") } - got = keyOut + if cmd.Name() != "child" { + t.Errorf("unexpected OnFinalizeCmd: %s", cmd.Name()) + } + + if !errors.Is(err, testErr) { + t.Errorf("Unexpected OnFinalizeCmd error: %v", err) + } + + if got != val { + t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) + } }) rootCmd := &Command{ Use: "root", @@ -3085,16 +3076,4 @@ func TestOnFinalizeCmdPersistentPostRunE(t *testing.T) { if !errors.Is(err, testErr) { t.Errorf("Unexpected error: %v", err) } - - if cmdOut.Name() != childCmd.Use { - t.Errorf("unexpected OnFinalizeCmd: %s", cmdOut.Name()) - } - - if !errors.Is(errOut, testErr) { - t.Errorf("Unexpected OnFinalizeCmd error: %v", err) - } - - if got != val { - t.Errorf("Expected value: \n %v\nGot:\n %v\n", val, got) - } } From 514a948ee37b98e019d03b5066c82c462e9f1f4c Mon Sep 17 00:00:00 2001 From: Kevin Gentile Date: Tue, 23 Sep 2025 17:58:05 -0400 Subject: [PATCH 4/4] fix linting error --- command_test.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/command_test.go b/command_test.go index 9a7ae02f8..806bf8a1b 100644 --- a/command_test.go +++ b/command_test.go @@ -2601,9 +2601,10 @@ func TestFParseErrWhitelistSiblingCommand(t *testing.T) { checkStringContains(t, output, "unknown flag: --unknown") } +const val = "foobar" + func TestSetContext(t *testing.T) { type key struct{} - val := "foobar" root := &Command{ Use: "root", Run: func(cmd *Command, args []string) { @@ -2957,7 +2958,6 @@ func TestHelpFuncExecuted(t *testing.T) { func TestOnFinalizeCmdRunE(t *testing.T) { t.Cleanup(func() { cmdFinalizers = []func(*Command, error){} }) type key struct{} - val := "foobar" testErr := fmt.Errorf("Test Error") OnFinalizeCmd(func(cmd *Command, err error) { key := cmd.Context().Value(key{}) @@ -2994,7 +2994,6 @@ func TestOnFinalizeCmdRunE(t *testing.T) { func TestOnFinalizeCmdPostRunE(t *testing.T) { t.Cleanup(func() { cmdFinalizers = []func(*Command, error){} }) type key struct{} - val := "foobar" testErr := fmt.Errorf("Test Error") OnFinalizeCmd(func(cmd *Command, err error) { key := cmd.Context().Value(key{}) @@ -3034,7 +3033,6 @@ func TestOnFinalizeCmdPostRunE(t *testing.T) { func TestOnFinalizeCmdPersistentPostRunE(t *testing.T) { t.Cleanup(func() { cmdFinalizers = []func(*Command, error){} }) type key struct{} - val := "foobar" testErr := fmt.Errorf("Test Error") OnFinalizeCmd(func(cmd *Command, err error) { c := cmd.Context()