Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions reexec/example_multicall_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package reexec_test

import (
"fmt"
"os"

"github.com/moby/sys/reexec"
)

func init() {
reexec.Register("example-foo", func() {
fmt.Println("Hello from entrypoint example-foo")
})
reexec.Register("example-bar", func() {
fmt.Println("Hello from entrypoint example-bar")
})
}

// Example_multicall demonstrates a BusyBox-style multi-call binary.
//
// In a real multi-call binary:
//
// go build -o example .
// ln -s example example-foo
// ln -s example example-bar
//
// ./example-foo # runs entrypoint "example-foo"
// ./example-bar # runs entrypoint "example-bar"
//
// At process startup, main would call [reexec.Init] and return if it
// matches an entrypoint. This example first shows that call, then emulates
// different invocation names by modifying os.Args[0].
func Example_multicall() {
// What main would normally do:
if reexec.Init() {
// Matched a reexec entrypoint; stop normal main execution.
return
}
reset := os.Args[0]

// Emulate running as "example-foo".
os.Args[0] = "example-foo"
_ = reexec.Init()

// Emulate running as "example-bar".
os.Args[0] = "example-bar"
_ = reexec.Init()

// Emulate running under the default binary name (no match).
os.Args[0] = reset
if !reexec.Init() {
fmt.Println("Hello main")
}

// Output:
// Hello from entrypoint example-foo
// Hello from entrypoint example-bar
// Hello main
}
41 changes: 41 additions & 0 deletions reexec/example_programmatic_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package reexec_test

import (
"context"
"fmt"
"os"
"time"

"github.com/moby/sys/reexec"
)

func init() {
reexec.Register("example-child", func() {
fmt.Println("Hello from example-child entrypoint")
})
}

// Example_programmatic demonstrates using reexec to programmatically
// re-execute the current binary.
func Example_programmatic() {
if reexec.Init() {
// Matched a reexec entrypoint; stop normal main execution.
return
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

cmd := reexec.CommandContext(ctx, "example-child")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
fmt.Println("reexec error:", err)
return
}

fmt.Println("Back in parent process")
// Output:
// Hello from example-child entrypoint
// Back in parent process
}
45 changes: 45 additions & 0 deletions reexec/internal/reexecoverride/reexecoverride.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Package reexecoverride provides test utilities for overriding argv0 as
// observed by reexec.Self within the current process.

package reexecoverride

import "sync/atomic"

// argv0Override holds an optional override for os.Args[0] used by reexec.Self.
var argv0Override atomic.Pointer[string]

// Argv0 returns the overridden argv0 if set.
func Argv0() (string, bool) {
p := argv0Override.Load()
if p == nil {
return "", false
}
return *p, true
}

// TestingTB is the minimal subset of [testing.TB] used by this package.
type TestingTB interface {
Helper()
Cleanup(func())
}

// OverrideArgv0 overrides the argv0 value observed by reexec.Self for the
// lifetime of the calling test and restores it via [testing.TB.Cleanup].
//
// The override is process-global. Tests using OverrideArgv0 must not run in
// parallel with other tests that call reexec.Self. OverrideArgv0 panics if an
// override is already active.
func OverrideArgv0(t TestingTB, argv0 string) {
t.Helper()

s := argv0
if !argv0Override.CompareAndSwap(nil, &s) {
panic("testing: test using reexecoverride.OverrideArgv0 cannot use t.Parallel")
}

t.Cleanup(func() {
if !argv0Override.CompareAndSwap(&s, nil) {
panic("testing: cleanup for reexecoverride.OverrideArgv0 detected parallel use of reexec.Self")
}
})
}
12 changes: 9 additions & 3 deletions reexec/reexec.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"os/exec"
"path/filepath"
"runtime"

"github.com/moby/sys/reexec/internal/reexecoverride"
)

var registeredInitializers = make(map[string]func())
Expand Down Expand Up @@ -78,14 +80,18 @@ func CommandContext(ctx context.Context, args ...string) *exec.Cmd {
// "my-binary" at "/usr/bin/" (or "my-binary.exe" at "C:\" on Windows),
// then it returns "/usr/bin/my-binary" and "C:\my-binary.exe" respectively.
func Self() string {
if argv0, ok := reexecoverride.Argv0(); ok {
return naiveSelf(argv0)
}
if runtime.GOOS == "linux" {
return "/proc/self/exe"
}
return naiveSelf()
return naiveSelf(os.Args[0])
}

func naiveSelf() string {
name := os.Args[0]
// naiveSelf is a separate function to allow testing in isolation on Linux.
func naiveSelf(argv0 string) string {
name := argv0
if filepath.Base(name) == name {
if lp, err := exec.LookPath(name); err == nil {
return lp
Expand Down
116 changes: 87 additions & 29 deletions reexec/reexec_test.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
package reexec
package reexec_test

import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"reflect"
"runtime"
"strings"
"testing"
"time"

"github.com/moby/sys/reexec"
"github.com/moby/sys/reexec/internal/reexecoverride"
)

const (
Expand All @@ -20,23 +23,26 @@ const (
)

func init() {
Register(testReExec, func() {
reexec.Register(testReExec, func() {
panic("Return Error")
})
Register(testReExec2, func() {
reexec.Register(testReExec2, func() {
var args string
if len(os.Args) > 1 {
args = fmt.Sprintf("(args: %#v)", os.Args[1:])
}
fmt.Println("Hello", testReExec2, args)
os.Exit(0)
})
Register(testReExec3, func() {
reexec.Register(testReExec3, func() {
fmt.Println("Hello " + testReExec3)
time.Sleep(1 * time.Second)
os.Exit(0)
})
Init()
if reexec.Init() {
// Make sure we exit in case re-exec didn't os.Exit on its own.
os.Exit(0)
}
}

func TestRegister(t *testing.T) {
Expand Down Expand Up @@ -69,7 +75,7 @@ func TestRegister(t *testing.T) {
t.Errorf("got %q, want %q", r, tc.expectedErr)
}
}()
Register(tc.name, func() {})
reexec.Register(tc.name, func() {})
})
}
}
Expand Down Expand Up @@ -98,7 +104,7 @@ func TestCommand(t *testing.T) {
}
for _, tc := range tests {
t.Run(tc.doc, func(t *testing.T) {
cmd := Command(tc.cmdAndArgs...)
cmd := reexec.Command(tc.cmdAndArgs...)
if !reflect.DeepEqual(cmd.Args, tc.cmdAndArgs) {
t.Fatalf("got %+v, want %+v", cmd.Args, tc.cmdAndArgs)
}
Expand Down Expand Up @@ -165,7 +171,7 @@ func TestCommandContext(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

cmd := CommandContext(ctx, tc.cmdAndArgs...)
cmd := reexec.CommandContext(ctx, tc.cmdAndArgs...)
if !reflect.DeepEqual(cmd.Args, tc.cmdAndArgs) {
t.Fatalf("got %+v, want %+v", cmd.Args, tc.cmdAndArgs)
}
Expand Down Expand Up @@ -194,30 +200,82 @@ func TestCommandContext(t *testing.T) {
}
}

func TestNaiveSelf(t *testing.T) {
if os.Getenv("TEST_CHECK") == "1" {
os.Exit(2)
}
cmd := exec.Command(naiveSelf(), "-test.run=TestNaiveSelf")
cmd.Env = append(os.Environ(), "TEST_CHECK=1")
err := cmd.Start()
// TestRunNaiveSelf verifies that reexec.Self() (and thus CommandContext)
// can resolve a path that can be used to re-execute the current test binary
// when it falls back to the argv[0]-based implementation.
//
// It forces Self() to bypass the Linux /proc/self/exe fast-path via
// [reexecoverride.OverrideArgv0] so that the fallback logic is exercised
// consistently across platforms.
func TestRunNaiveSelf(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

// Force Self() to use naiveSelf(os.Args[0]), instead of "/proc/self/exe" on Linux.
reexecoverride.OverrideArgv0(t, os.Args[0])

cmd := reexec.CommandContext(ctx, testReExec2)
out, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("Unable to start command: %v", err)
}
err = cmd.Wait()

var expError *exec.ExitError
if !errors.As(err, &expError) {
t.Fatalf("got %T, want %T", err, expError)
}

const expected = "exit status 2"
if err.Error() != expected {
t.Fatalf("got %v, want %v", err, expected)
expOut := "Hello test-reexec2"
actual := strings.TrimSpace(string(out))
if actual != expOut {
t.Errorf("got %v, want %v", actual, expOut)
}
}

os.Args[0] = "mkdir"
if naiveSelf() == os.Args[0] {
t.Fatalf("Expected naiveSelf to resolve the location of mkdir")
}
func TestNaiveSelfResolve(t *testing.T) {
t.Run("fast path on Linux", func(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("only supported on Linux")
}
resolved := reexec.Self()
expected := "/proc/self/exe"
if resolved != expected {
t.Errorf("got %v, want %v", resolved, expected)
}
})
t.Run("resolve in PATH", func(t *testing.T) {
executable := "sh"
if runtime.GOOS == "windows" {
executable = "cmd"
}
reexecoverride.OverrideArgv0(t, executable)
resolved := reexec.Self()
if resolved == executable {
t.Errorf("did not resolve via PATH; got %q", resolved)
}
if !filepath.IsAbs(resolved) {
t.Errorf("expected absolute path; got %q", resolved)
}
})
t.Run("not in PATH", func(t *testing.T) {
const executable = "some-nonexistent-executable"
reexecoverride.OverrideArgv0(t, executable)
resolved := reexec.Self()
want, _ := filepath.Abs(executable)
if resolved != want {
t.Errorf("expected absolute path; got %q, want %q", resolved, want)
}
})
t.Run("relative path", func(t *testing.T) {
executable := filepath.Join(".", "some-executable")
reexecoverride.OverrideArgv0(t, executable)
resolved := reexec.Self()
want, _ := filepath.Abs(executable)
if resolved != want {
t.Errorf("expected absolute path; got %q, want %q", resolved, want)
}
})
t.Run("absolute path unchanged", func(t *testing.T) {
executable := filepath.Join(os.TempDir(), "some-executable")
reexecoverride.OverrideArgv0(t, executable)
resolved := reexec.Self()
if resolved != executable {
t.Errorf("should not modify absolute paths; got %q, want %q", resolved, executable)
}
})
}
Loading