Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions reexec/reexectest/reexectest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Package reexectest provides small helpers for subprocess tests that re-exec
// the current test binary. The child process is selected by setting argv0 to a
// deterministic token derived from (t.Name(), name), while -test.run is used to
// run only the current test/subtest.
//
// Typical usage:
//
// func TestSomething(t *testing.T) {
// if reexectest.Run(t, "child", func(t *testing.T) {
// // child branch
// }) {
// return
// }
//
// // parent branch
// cmd := reexectest.CommandContext(t, t.Context(), "child", "arg1")
// out, err := cmd.CombinedOutput()
// if err != nil {
// t.Fatalf("child failed: %v\n%s", err, out)
// }
// }
package reexectest

import (
"context"
"crypto/sha256"
"encoding/hex"
"os"
"os/exec"
"regexp"
"strings"
"testing"
)

const argv0Prefix = "reexectest-"

// argv0Token returns a short, portable, deterministic argv0 token.
func argv0Token(t *testing.T, name string) string {
sum := sha256.Sum256([]byte(t.Name() + "\x00" + name))
// 8 bytes => 16 hex chars; plenty to avoid collisions for test usage.
return argv0Prefix + hex.EncodeToString(sum[:8])
}

// Run runs f in the current process iff it is the matching child process for
// (t, name). It returns true if f ran (i.e., we are the child).
//
// When Run returns true, callers should return from the test to avoid running
// the parent branch in the child process.
func Run(t *testing.T, name string, f func(t *testing.T)) bool {
t.Helper()

if os.Args[0] != argv0Token(t, name) {
return false
}

// Scrub the "-test.run=<pattern>" that was injected by CommandContext
origArgs := os.Args
if len(os.Args) > 1 && strings.HasPrefix(os.Args[1], "-test.run=") {
os.Args = append(os.Args[:1], os.Args[2:]...)
defer func() { os.Args = origArgs }()
}

f(t)
return true
}

// Command returns an [*exec.Cmd] configured to re-exec the current test binary
// as a subprocess for the given test and name.
//
// The child process is restricted to run only the current test or subtest
// (via -test.run). Its argv[0] is set to a deterministic token derived from
// (t.Name(), name), which is used by [Run] to select the child execution path.
//
// On Linux, the returned command sets [syscall.SysProcAttr.Pdeathsig] to
// SIGTERM, so the child receives SIGTERM if the creating thread dies.
// Callers may modify SysProcAttr before starting the command.
//
// It is analogous to [exec.Command], but targets the current test binary.
func Command(t *testing.T, name string, args ...string) *exec.Cmd {
t.Helper()

exe, err := os.Executable()
if err != nil {
t.Fatalf("os.Executable(): %v", err)
}

argv0 := argv0Token(t, name)
pattern := "^" + regexp.QuoteMeta(t.Name()) + "$"

cmd := exec.Command(exe)
cmd.Path = exe
cmd.Args = append([]string{argv0, "-test.run=" + pattern}, args...)
setPdeathsig(cmd)
return cmd
}

// CommandContext is like [Command] but includes a context. It uses
// [exec.CommandContext] under the hood.
//
// The provided context controls cancellation of the subprocess in the same
// way as [exec.CommandContext].
//
// On Linux, the returned command sets [syscall.SysProcAttr.Pdeathsig] to
// SIGTERM. Callers may modify SysProcAttr before starting the command.
func CommandContext(t *testing.T, ctx context.Context, name string, args ...string) *exec.Cmd {
t.Helper()

exe, err := os.Executable()
if err != nil {
t.Fatalf("os.Executable(): %v", err)
}

argv0 := argv0Token(t, name)
pattern := "^" + regexp.QuoteMeta(t.Name()) + "$"

cmd := exec.CommandContext(ctx, exe)
cmd.Path = exe
cmd.Args = append([]string{argv0, "-test.run=" + pattern}, args...)
setPdeathsig(cmd)
return cmd
}
14 changes: 14 additions & 0 deletions reexec/reexectest/reexectest_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//go:build linux

package reexectest

import (
"os/exec"
"syscall"
)

func setPdeathsig(cmd *exec.Cmd) {
if cmd.SysProcAttr == nil {
cmd.SysProcAttr = &syscall.SysProcAttr{Pdeathsig: syscall.SIGTERM}
}
}
7 changes: 7 additions & 0 deletions reexec/reexectest/reexectest_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
//go:build !linux

package reexectest

import "os/exec"

func setPdeathsig(*exec.Cmd) {}
101 changes: 101 additions & 0 deletions reexec/reexectest/reexectest_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package reexectest_test

import (
"errors"
"fmt"
"os"
"os/exec"
"reflect"
"strings"
"testing"

"github.com/moby/sys/reexec/reexectest"
)

func TestRun(t *testing.T) {
t.Run("env-and-output", func(t *testing.T) {
const expected = "child-env-and-output-ok"
if reexectest.Run(t, "env-and-output", func(t *testing.T) {
if got := os.Getenv("REEXEC_TEST_HELLO"); got != "world" {
t.Fatalf("env REEXEC_TEST_HELLO: got %q, want %q", got, "world")
}
fmt.Println(expected)
}) {
return
}

cmd := reexectest.CommandContext(t, t.Context(), "env-and-output")
cmd.Env = append(cmd.Environ(), "REEXEC_TEST_HELLO=world")

out, err := cmd.CombinedOutput()
if err != nil {
t.Errorf("env-and-output child failed: %v\n%s", err, out)
}
if got := strings.TrimSpace(strings.TrimSuffix(string(out), "PASS\n")); got != expected {
t.Errorf("env-and-output output: got %q, want %q", got, expected)
}
})

t.Run("exit-code", func(t *testing.T) {
if reexectest.Run(t, "exit-code", func(t *testing.T) {
os.Exit(23)
}) {
return
}

cmd := reexectest.CommandContext(t, t.Context(), "exit-code")
err := cmd.Run()
if err == nil {
t.Fatalf("expected non-nil error")
}

var ee *exec.ExitError
if !errors.As(err, &ee) {
t.Fatalf("got %T, want *exec.ExitError", err)
}
if code := ee.ProcessState.ExitCode(); code != 23 {
t.Fatalf("exit code: got %d, want %d", code, 23)
}
})

t.Run("args-passthrough", func(t *testing.T) {
const expected = "child-args-passthrough-ok"
if reexectest.Run(t, "args-passthrough", func(t *testing.T) {
want := []string{"hello", "world"}
got := os.Args[1:]
if !reflect.DeepEqual(got, want) {
t.Fatalf("args: got %q, want %q (full os.Args=%q)", got, want, os.Args)
}
fmt.Println(expected)
}) {
return
}

cmd := reexectest.CommandContext(t, t.Context(), "args-passthrough", "hello", "world")
out, err := cmd.CombinedOutput()
if err != nil {
t.Errorf("args-passthrough child failed: %v\n%s", err, out)
}
if got := strings.TrimSpace(strings.TrimSuffix(string(out), "PASS\n")); got != expected {
t.Errorf("args-passthrough output: got %q, want %q", got, expected)
}
})
}

func TestRunNonSubtest(t *testing.T) {
const expected = "child-non-sub-test-ok"
if reexectest.Run(t, "non-sub-test", func(t *testing.T) {
fmt.Println(expected)
}) {
return
}

cmd := reexectest.CommandContext(t, t.Context(), "non-sub-test")
out, err := cmd.CombinedOutput()
if err != nil {
t.Errorf("non-sub-test child failed: %v\n%s", err, out)
}
if got := strings.TrimSpace(strings.TrimSuffix(string(out), "PASS\n")); got != expected {
t.Errorf("non-sub-test output: got %q, want %q", got, expected)
}
}
Loading