diff --git a/reexec/example_multicall_test.go b/reexec/example_multicall_test.go new file mode 100644 index 0000000..a049391 --- /dev/null +++ b/reexec/example_multicall_test.go @@ -0,0 +1,59 @@ +package reexec_test + +import ( + "fmt" + "os" + + "github.com/moby/sys/reexec" +) + +func init() { + reexec.Register("example-foo", func() { + fmt.Println("Hello from entrypoint example-foo") + }) + reexec.Register("example-bar", func() { + fmt.Println("Hello from entrypoint example-bar") + }) +} + +// Example_multicall demonstrates a BusyBox-style multi-call binary. +// +// In a real multi-call binary: +// +// go build -o example . +// ln -s example example-foo +// ln -s example example-bar +// +// ./example-foo # runs entrypoint "example-foo" +// ./example-bar # runs entrypoint "example-bar" +// +// At process startup, main would call [reexec.Init] and return if it +// matches an entrypoint. This example first shows that call, then emulates +// different invocation names by modifying os.Args[0]. +func Example_multicall() { + // What main would normally do: + if reexec.Init() { + // Matched a reexec entrypoint; stop normal main execution. + return + } + reset := os.Args[0] + + // Emulate running as "example-foo". + os.Args[0] = "example-foo" + _ = reexec.Init() + + // Emulate running as "example-bar". + os.Args[0] = "example-bar" + _ = reexec.Init() + + // Emulate running under the default binary name (no match). + os.Args[0] = reset + if !reexec.Init() { + fmt.Println("Hello main") + } + + // Output: + // Hello from entrypoint example-foo + // Hello from entrypoint example-bar + // Hello main +} diff --git a/reexec/example_programmatic_test.go b/reexec/example_programmatic_test.go new file mode 100644 index 0000000..54bf3c4 --- /dev/null +++ b/reexec/example_programmatic_test.go @@ -0,0 +1,41 @@ +package reexec_test + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/moby/sys/reexec" +) + +func init() { + reexec.Register("example-child", func() { + fmt.Println("Hello from example-child entrypoint") + }) +} + +// Example_programmatic demonstrates using reexec to programmatically +// re-execute the current binary. +func Example_programmatic() { + if reexec.Init() { + // Matched a reexec entrypoint; stop normal main execution. + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + cmd := reexec.CommandContext(ctx, "example-child") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + fmt.Println("reexec error:", err) + return + } + + fmt.Println("Back in parent process") + // Output: + // Hello from example-child entrypoint + // Back in parent process +} diff --git a/reexec/internal/reexecoverride/reexecoverride.go b/reexec/internal/reexecoverride/reexecoverride.go new file mode 100644 index 0000000..bd37a2b --- /dev/null +++ b/reexec/internal/reexecoverride/reexecoverride.go @@ -0,0 +1,45 @@ +// Package reexecoverride provides test utilities for overriding argv0 as +// observed by reexec.Self within the current process. + +package reexecoverride + +import "sync/atomic" + +// argv0Override holds an optional override for os.Args[0] used by reexec.Self. +var argv0Override atomic.Pointer[string] + +// Argv0 returns the overridden argv0 if set. +func Argv0() (string, bool) { + p := argv0Override.Load() + if p == nil { + return "", false + } + return *p, true +} + +// TestingTB is the minimal subset of [testing.TB] used by this package. +type TestingTB interface { + Helper() + Cleanup(func()) +} + +// OverrideArgv0 overrides the argv0 value observed by reexec.Self for the +// lifetime of the calling test and restores it via [testing.TB.Cleanup]. +// +// The override is process-global. Tests using OverrideArgv0 must not run in +// parallel with other tests that call reexec.Self. OverrideArgv0 panics if an +// override is already active. +func OverrideArgv0(t TestingTB, argv0 string) { + t.Helper() + + s := argv0 + if !argv0Override.CompareAndSwap(nil, &s) { + panic("testing: test using reexecoverride.OverrideArgv0 cannot use t.Parallel") + } + + t.Cleanup(func() { + if !argv0Override.CompareAndSwap(&s, nil) { + panic("testing: cleanup for reexecoverride.OverrideArgv0 detected parallel use of reexec.Self") + } + }) +} diff --git a/reexec/reexec.go b/reexec/reexec.go index b97a0aa..4b2e093 100644 --- a/reexec/reexec.go +++ b/reexec/reexec.go @@ -13,6 +13,8 @@ import ( "os/exec" "path/filepath" "runtime" + + "github.com/moby/sys/reexec/internal/reexecoverride" ) var registeredInitializers = make(map[string]func()) @@ -78,14 +80,18 @@ func CommandContext(ctx context.Context, args ...string) *exec.Cmd { // "my-binary" at "/usr/bin/" (or "my-binary.exe" at "C:\" on Windows), // then it returns "/usr/bin/my-binary" and "C:\my-binary.exe" respectively. func Self() string { + if argv0, ok := reexecoverride.Argv0(); ok { + return naiveSelf(argv0) + } if runtime.GOOS == "linux" { return "/proc/self/exe" } - return naiveSelf() + return naiveSelf(os.Args[0]) } -func naiveSelf() string { - name := os.Args[0] +// naiveSelf is a separate function to allow testing in isolation on Linux. +func naiveSelf(argv0 string) string { + name := argv0 if filepath.Base(name) == name { if lp, err := exec.LookPath(name); err == nil { return lp diff --git a/reexec/reexec_test.go b/reexec/reexec_test.go index db3067f..36c9ee4 100644 --- a/reexec/reexec_test.go +++ b/reexec/reexec_test.go @@ -1,16 +1,19 @@ -package reexec +package reexec_test import ( "context" "errors" "fmt" "os" - "os/exec" "path/filepath" "reflect" + "runtime" "strings" "testing" "time" + + "github.com/moby/sys/reexec" + "github.com/moby/sys/reexec/internal/reexecoverride" ) const ( @@ -20,10 +23,10 @@ const ( ) func init() { - Register(testReExec, func() { + reexec.Register(testReExec, func() { panic("Return Error") }) - Register(testReExec2, func() { + reexec.Register(testReExec2, func() { var args string if len(os.Args) > 1 { args = fmt.Sprintf("(args: %#v)", os.Args[1:]) @@ -31,12 +34,15 @@ func init() { fmt.Println("Hello", testReExec2, args) os.Exit(0) }) - Register(testReExec3, func() { + reexec.Register(testReExec3, func() { fmt.Println("Hello " + testReExec3) time.Sleep(1 * time.Second) os.Exit(0) }) - Init() + if reexec.Init() { + // Make sure we exit in case re-exec didn't os.Exit on its own. + os.Exit(0) + } } func TestRegister(t *testing.T) { @@ -69,7 +75,7 @@ func TestRegister(t *testing.T) { t.Errorf("got %q, want %q", r, tc.expectedErr) } }() - Register(tc.name, func() {}) + reexec.Register(tc.name, func() {}) }) } } @@ -98,7 +104,7 @@ func TestCommand(t *testing.T) { } for _, tc := range tests { t.Run(tc.doc, func(t *testing.T) { - cmd := Command(tc.cmdAndArgs...) + cmd := reexec.Command(tc.cmdAndArgs...) if !reflect.DeepEqual(cmd.Args, tc.cmdAndArgs) { t.Fatalf("got %+v, want %+v", cmd.Args, tc.cmdAndArgs) } @@ -165,7 +171,7 @@ func TestCommandContext(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - cmd := CommandContext(ctx, tc.cmdAndArgs...) + cmd := reexec.CommandContext(ctx, tc.cmdAndArgs...) if !reflect.DeepEqual(cmd.Args, tc.cmdAndArgs) { t.Fatalf("got %+v, want %+v", cmd.Args, tc.cmdAndArgs) } @@ -194,30 +200,82 @@ func TestCommandContext(t *testing.T) { } } -func TestNaiveSelf(t *testing.T) { - if os.Getenv("TEST_CHECK") == "1" { - os.Exit(2) - } - cmd := exec.Command(naiveSelf(), "-test.run=TestNaiveSelf") - cmd.Env = append(os.Environ(), "TEST_CHECK=1") - err := cmd.Start() +// TestRunNaiveSelf verifies that reexec.Self() (and thus CommandContext) +// can resolve a path that can be used to re-execute the current test binary +// when it falls back to the argv[0]-based implementation. +// +// It forces Self() to bypass the Linux /proc/self/exe fast-path via +// [reexecoverride.OverrideArgv0] so that the fallback logic is exercised +// consistently across platforms. +func TestRunNaiveSelf(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // Force Self() to use naiveSelf(os.Args[0]), instead of "/proc/self/exe" on Linux. + reexecoverride.OverrideArgv0(t, os.Args[0]) + + cmd := reexec.CommandContext(ctx, testReExec2) + out, err := cmd.CombinedOutput() if err != nil { t.Fatalf("Unable to start command: %v", err) } - err = cmd.Wait() - var expError *exec.ExitError - if !errors.As(err, &expError) { - t.Fatalf("got %T, want %T", err, expError) - } - - const expected = "exit status 2" - if err.Error() != expected { - t.Fatalf("got %v, want %v", err, expected) + expOut := "Hello test-reexec2" + actual := strings.TrimSpace(string(out)) + if actual != expOut { + t.Errorf("got %v, want %v", actual, expOut) } +} - os.Args[0] = "mkdir" - if naiveSelf() == os.Args[0] { - t.Fatalf("Expected naiveSelf to resolve the location of mkdir") - } +func TestNaiveSelfResolve(t *testing.T) { + t.Run("fast path on Linux", func(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("only supported on Linux") + } + resolved := reexec.Self() + expected := "/proc/self/exe" + if resolved != expected { + t.Errorf("got %v, want %v", resolved, expected) + } + }) + t.Run("resolve in PATH", func(t *testing.T) { + executable := "sh" + if runtime.GOOS == "windows" { + executable = "cmd" + } + reexecoverride.OverrideArgv0(t, executable) + resolved := reexec.Self() + if resolved == executable { + t.Errorf("did not resolve via PATH; got %q", resolved) + } + if !filepath.IsAbs(resolved) { + t.Errorf("expected absolute path; got %q", resolved) + } + }) + t.Run("not in PATH", func(t *testing.T) { + const executable = "some-nonexistent-executable" + reexecoverride.OverrideArgv0(t, executable) + resolved := reexec.Self() + want, _ := filepath.Abs(executable) + if resolved != want { + t.Errorf("expected absolute path; got %q, want %q", resolved, want) + } + }) + t.Run("relative path", func(t *testing.T) { + executable := filepath.Join(".", "some-executable") + reexecoverride.OverrideArgv0(t, executable) + resolved := reexec.Self() + want, _ := filepath.Abs(executable) + if resolved != want { + t.Errorf("expected absolute path; got %q, want %q", resolved, want) + } + }) + t.Run("absolute path unchanged", func(t *testing.T) { + executable := filepath.Join(os.TempDir(), "some-executable") + reexecoverride.OverrideArgv0(t, executable) + resolved := reexec.Self() + if resolved != executable { + t.Errorf("should not modify absolute paths; got %q, want %q", resolved, executable) + } + }) }