diff --git a/reexec/reexectest/reexectest.go b/reexec/reexectest/reexectest.go new file mode 100644 index 0000000..8bd32aa --- /dev/null +++ b/reexec/reexectest/reexectest.go @@ -0,0 +1,121 @@ +// Package reexectest provides small helpers for subprocess tests that re-exec +// the current test binary. The child process is selected by setting argv0 to a +// deterministic token derived from (t.Name(), name), while -test.run is used to +// run only the current test/subtest. +// +// Typical usage: +// +// func TestSomething(t *testing.T) { +// if reexectest.Run(t, "child", func(t *testing.T) { +// // child branch +// }) { +// return +// } +// +// // parent branch +// cmd := reexectest.CommandContext(t, t.Context(), "child", "arg1") +// out, err := cmd.CombinedOutput() +// if err != nil { +// t.Fatalf("child failed: %v\n%s", err, out) +// } +// } +package reexectest + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "os" + "os/exec" + "regexp" + "strings" + "testing" +) + +const argv0Prefix = "reexectest-" + +// argv0Token returns a short, portable, deterministic argv0 token. +func argv0Token(t *testing.T, name string) string { + sum := sha256.Sum256([]byte(t.Name() + "\x00" + name)) + // 8 bytes => 16 hex chars; plenty to avoid collisions for test usage. + return argv0Prefix + hex.EncodeToString(sum[:8]) +} + +// Run runs f in the current process iff it is the matching child process for +// (t, name). It returns true if f ran (i.e., we are the child). +// +// When Run returns true, callers should return from the test to avoid running +// the parent branch in the child process. +func Run(t *testing.T, name string, f func(t *testing.T)) bool { + t.Helper() + + if os.Args[0] != argv0Token(t, name) { + return false + } + + // Scrub the "-test.run=" that was injected by CommandContext + origArgs := os.Args + if len(os.Args) > 1 && strings.HasPrefix(os.Args[1], "-test.run=") { + os.Args = append(os.Args[:1], os.Args[2:]...) + defer func() { os.Args = origArgs }() + } + + f(t) + return true +} + +// Command returns an [*exec.Cmd] configured to re-exec the current test binary +// as a subprocess for the given test and name. +// +// The child process is restricted to run only the current test or subtest +// (via -test.run). Its argv[0] is set to a deterministic token derived from +// (t.Name(), name), which is used by [Run] to select the child execution path. +// +// On Linux, the returned command sets [syscall.SysProcAttr.Pdeathsig] to +// SIGTERM, so the child receives SIGTERM if the creating thread dies. +// Callers may modify SysProcAttr before starting the command. +// +// It is analogous to [exec.Command], but targets the current test binary. +func Command(t *testing.T, name string, args ...string) *exec.Cmd { + t.Helper() + + exe, err := os.Executable() + if err != nil { + t.Fatalf("os.Executable(): %v", err) + } + + argv0 := argv0Token(t, name) + pattern := "^" + regexp.QuoteMeta(t.Name()) + "$" + + cmd := exec.Command(exe) + cmd.Path = exe + cmd.Args = append([]string{argv0, "-test.run=" + pattern}, args...) + setPdeathsig(cmd) + return cmd +} + +// CommandContext is like [Command] but includes a context. It uses +// [exec.CommandContext] under the hood. +// +// The provided context controls cancellation of the subprocess in the same +// way as [exec.CommandContext]. +// +// On Linux, the returned command sets [syscall.SysProcAttr.Pdeathsig] to +// SIGTERM. Callers may modify SysProcAttr before starting the command. +func CommandContext(t *testing.T, ctx context.Context, name string, args ...string) *exec.Cmd { + t.Helper() + + exe, err := os.Executable() + if err != nil { + t.Fatalf("os.Executable(): %v", err) + } + + argv0 := argv0Token(t, name) + pattern := "^" + regexp.QuoteMeta(t.Name()) + "$" + + cmd := exec.CommandContext(ctx, exe) + cmd.Path = exe + cmd.Args = append([]string{argv0, "-test.run=" + pattern}, args...) + setPdeathsig(cmd) + return cmd +} diff --git a/reexec/reexectest/reexectest_linux.go b/reexec/reexectest/reexectest_linux.go new file mode 100644 index 0000000..944b088 --- /dev/null +++ b/reexec/reexectest/reexectest_linux.go @@ -0,0 +1,14 @@ +//go:build linux + +package reexectest + +import ( + "os/exec" + "syscall" +) + +func setPdeathsig(cmd *exec.Cmd) { + if cmd.SysProcAttr == nil { + cmd.SysProcAttr = &syscall.SysProcAttr{Pdeathsig: syscall.SIGTERM} + } +} diff --git a/reexec/reexectest/reexectest_other.go b/reexec/reexectest/reexectest_other.go new file mode 100644 index 0000000..6046f89 --- /dev/null +++ b/reexec/reexectest/reexectest_other.go @@ -0,0 +1,7 @@ +//go:build !linux + +package reexectest + +import "os/exec" + +func setPdeathsig(*exec.Cmd) {} diff --git a/reexec/reexectest/reexectest_test.go b/reexec/reexectest/reexectest_test.go new file mode 100644 index 0000000..2206548 --- /dev/null +++ b/reexec/reexectest/reexectest_test.go @@ -0,0 +1,101 @@ +package reexectest_test + +import ( + "errors" + "fmt" + "os" + "os/exec" + "reflect" + "strings" + "testing" + + "github.com/moby/sys/reexec/reexectest" +) + +func TestRun(t *testing.T) { + t.Run("env-and-output", func(t *testing.T) { + const expected = "child-env-and-output-ok" + if reexectest.Run(t, "env-and-output", func(t *testing.T) { + if got := os.Getenv("REEXEC_TEST_HELLO"); got != "world" { + t.Fatalf("env REEXEC_TEST_HELLO: got %q, want %q", got, "world") + } + fmt.Println(expected) + }) { + return + } + + cmd := reexectest.CommandContext(t, t.Context(), "env-and-output") + cmd.Env = append(cmd.Environ(), "REEXEC_TEST_HELLO=world") + + out, err := cmd.CombinedOutput() + if err != nil { + t.Errorf("env-and-output child failed: %v\n%s", err, out) + } + if got := strings.TrimSpace(strings.TrimSuffix(string(out), "PASS\n")); got != expected { + t.Errorf("env-and-output output: got %q, want %q", got, expected) + } + }) + + t.Run("exit-code", func(t *testing.T) { + if reexectest.Run(t, "exit-code", func(t *testing.T) { + os.Exit(23) + }) { + return + } + + cmd := reexectest.CommandContext(t, t.Context(), "exit-code") + err := cmd.Run() + if err == nil { + t.Fatalf("expected non-nil error") + } + + var ee *exec.ExitError + if !errors.As(err, &ee) { + t.Fatalf("got %T, want *exec.ExitError", err) + } + if code := ee.ProcessState.ExitCode(); code != 23 { + t.Fatalf("exit code: got %d, want %d", code, 23) + } + }) + + t.Run("args-passthrough", func(t *testing.T) { + const expected = "child-args-passthrough-ok" + if reexectest.Run(t, "args-passthrough", func(t *testing.T) { + want := []string{"hello", "world"} + got := os.Args[1:] + if !reflect.DeepEqual(got, want) { + t.Fatalf("args: got %q, want %q (full os.Args=%q)", got, want, os.Args) + } + fmt.Println(expected) + }) { + return + } + + cmd := reexectest.CommandContext(t, t.Context(), "args-passthrough", "hello", "world") + out, err := cmd.CombinedOutput() + if err != nil { + t.Errorf("args-passthrough child failed: %v\n%s", err, out) + } + if got := strings.TrimSpace(strings.TrimSuffix(string(out), "PASS\n")); got != expected { + t.Errorf("args-passthrough output: got %q, want %q", got, expected) + } + }) +} + +func TestRunNonSubtest(t *testing.T) { + const expected = "child-non-sub-test-ok" + if reexectest.Run(t, "non-sub-test", func(t *testing.T) { + fmt.Println(expected) + }) { + return + } + + cmd := reexectest.CommandContext(t, t.Context(), "non-sub-test") + out, err := cmd.CombinedOutput() + if err != nil { + t.Errorf("non-sub-test child failed: %v\n%s", err, out) + } + if got := strings.TrimSpace(strings.TrimSuffix(string(out), "PASS\n")); got != expected { + t.Errorf("non-sub-test output: got %q, want %q", got, expected) + } +}