diff --git a/.github/actions/start-services/action.yml b/.github/actions/start-services/action.yml index e2c01cd3b6..5365085344 100644 --- a/.github/actions/start-services/action.yml +++ b/.github/actions/start-services/action.yml @@ -101,7 +101,6 @@ runs: STORAGE_PROVIDER: "Local" ENVIRONMENT: "local" OTEL_COLLECTOR_GRPC_ENDPOINT: "localhost:4317" - MAX_PARALLEL_MEMFILE_SNAPSHOTTING: "2" SHARED_CHUNK_CACHE_PATH: "./.e2b-chunk-cache" EDGE_TOKEN: "abdcdefghijklmnop" run: | diff --git a/packages/orchestrator/Makefile b/packages/orchestrator/Makefile index 78afc911cf..9d76ae49d8 100644 --- a/packages/orchestrator/Makefile +++ b/packages/orchestrator/Makefile @@ -49,7 +49,6 @@ run-debug: GCP_DOCKER_REPOSITORY_NAME=$(GCP_DOCKER_REPOSITORY_NAME) \ GOOGLE_SERVICE_ACCOUNT_BASE64=$(GOOGLE_SERVICE_ACCOUNT_BASE64) \ OTEL_COLLECTOR_GRPC_ENDPOINT=$(OTEL_COLLECTOR_GRPC_ENDPOINT) \ - MAX_PARALLEL_MEMFILE_SNAPSHOTTING=$(MAX_PARALLEL_MEMFILE_SNAPSHOTTING) \ ./bin/orchestrator define setup_local_env diff --git a/packages/orchestrator/benchmark_test.go b/packages/orchestrator/benchmark_test.go index edda005a76..02ff4bc0ba 100644 --- a/packages/orchestrator/benchmark_test.go +++ b/packages/orchestrator/benchmark_test.go @@ -304,9 +304,8 @@ func BenchmarkBaseImageLaunch(b *testing.B) { type testCycle string const ( - onlyStart testCycle = "only-start" - startAndPause testCycle = "start-and-pause" - startPauseResume testCycle = "start-pause-resume" + onlyStart testCycle = "only-start" + startAndPause testCycle = "start-and-pause" ) type testContainer struct { diff --git a/packages/orchestrator/internal/sandbox/block/cache_test.go b/packages/orchestrator/internal/sandbox/block/cache_test.go index 911a8a2cf4..c187f25a04 100644 --- a/packages/orchestrator/internal/sandbox/block/cache_test.go +++ b/packages/orchestrator/internal/sandbox/block/cache_test.go @@ -1,9 +1,7 @@ package block import ( - "bytes" "crypto/rand" - "fmt" "math" "os" "syscall" @@ -60,7 +58,7 @@ func TestCopyFromProcess_FullRange(t *testing.T) { require.NoError(t, err) require.Equal(t, int(size), n) - require.NoError(t, compareData(data[:n], mem[:n])) + require.NoError(t, testutils.ErrorFromByteSlicesDifference(mem[:n], data[:n])) } func TestCopyFromProcess_LargeRanges(t *testing.T) { @@ -94,19 +92,19 @@ func TestCopyFromProcess_LargeRanges(t *testing.T) { n, err := cache.ReadAt(data1, 0) require.NoError(t, err) require.Equal(t, int(pageSize), n) - require.NoError(t, compareData(data1[:n], mem[0:pageSize])) + require.NoError(t, testutils.ErrorFromByteSlicesDifference(mem[0:pageSize], data1[:n])) data2 := make([]byte, pageSize) n, err = cache.ReadAt(data2, int64(pageSize)) require.NoError(t, err) require.Equal(t, int(pageSize), n) - require.NoError(t, compareData(data2[:n], mem[pageSize*3:pageSize*4])) + require.NoError(t, testutils.ErrorFromByteSlicesDifference(mem[pageSize*3:pageSize*4], data2[:n])) data3 := make([]byte, pageSize) n, err = cache.ReadAt(data3, int64(pageSize*2)) require.NoError(t, err) require.Equal(t, int(pageSize), n) - require.NoError(t, compareData(data3[:n], mem[pageSize:pageSize*2])) + require.NoError(t, testutils.ErrorFromByteSlicesDifference(mem[pageSize:pageSize*2], data3[:n])) } func TestCopyFromProcess_MultipleRanges(t *testing.T) { @@ -152,7 +150,7 @@ func TestCopyFromProcess_MultipleRanges(t *testing.T) { require.NoError(t, err) require.Equal(t, int(pageSize), n) - require.NoError(t, compareData(data[:n], mem[alignedOffset:alignedOffset+int64(pageSize)])) + require.NoError(t, testutils.ErrorFromByteSlicesDifference(mem[alignedOffset:alignedOffset+int64(pageSize)], data[:n])) } } @@ -192,13 +190,13 @@ func TestCopyFromProcess_HugepageToRegularPage(t *testing.T) { n, err = cache.ReadAt(data, 0) require.NoError(t, err) require.Equal(t, int(pageSize*2), n) - require.NoError(t, compareData(data[:n], mem[0:pageSize*2])) + require.NoError(t, testutils.ErrorFromByteSlicesDifference(mem[0:pageSize*2], data[:n])) data = make([]byte, pageSize*4) n, err = cache.ReadAt(data, pageSize*2) require.NoError(t, err) require.Equal(t, int(pageSize*4), n) - require.NoError(t, compareData(data[:n], mem[pageSize*4:pageSize*8])) + require.NoError(t, testutils.ErrorFromByteSlicesDifference(mem[pageSize*4:pageSize*8], data[:n])) } func TestEmptyRanges(t *testing.T) { @@ -218,17 +216,6 @@ func TestEmptyRanges(t *testing.T) { }) } -func compareData(readBytes []byte, expectedBytes []byte) error { - // The bytes.Equal is the first place in this flow that actually touches the uffd managed memory and triggers the pagefault, so any deadlocks will manifest here. - if !bytes.Equal(readBytes, expectedBytes) { - idx, want, got := testutils.FirstDifferentByte(readBytes, expectedBytes) - - return fmt.Errorf("content mismatch: want '%x, got %x at index %d", want, got, idx) - } - - return nil -} - func TestSplitOversizedRanges(t *testing.T) { t.Parallel() @@ -395,7 +382,7 @@ func TestCopyFromProcess_Exceed_MAX_RW_COUNT(t *testing.T) { n, err = cache.ReadAt(data, 0) require.NoError(t, err) require.Equal(t, int(size), n) - require.NoError(t, compareData(data[:n], mem[0:size])) + require.NoError(t, testutils.ErrorFromByteSlicesDifference(mem[0:size], data[:n])) } func BenchmarkCopyFromHugepagesFile(b *testing.B) { diff --git a/packages/orchestrator/internal/sandbox/uffd/fdexit/fdexit.go b/packages/orchestrator/internal/sandbox/uffd/fdexit/fdexit.go index 688cff5331..f4b20ad5fe 100644 --- a/packages/orchestrator/internal/sandbox/uffd/fdexit/fdexit.go +++ b/packages/orchestrator/internal/sandbox/uffd/fdexit/fdexit.go @@ -7,6 +7,8 @@ import ( "sync" ) +var ErrFdExit = errors.New("fd exit signal") + // FdExit is a wrapper around a pipe that allows to signal the exit of the uffd. type FdExit struct { r *os.File diff --git a/packages/orchestrator/internal/sandbox/uffd/testutils/diff_byte.go b/packages/orchestrator/internal/sandbox/uffd/testutils/diff_byte.go index 68298ea6ea..1acc10a8ed 100644 --- a/packages/orchestrator/internal/sandbox/uffd/testutils/diff_byte.go +++ b/packages/orchestrator/internal/sandbox/uffd/testutils/diff_byte.go @@ -1,20 +1,31 @@ package testutils +import ( + "errors" + "fmt" +) + // FirstDifferentByte returns the first byte index where a and b differ. // It also returns the differing byte values (want, got). // If slices are identical, it returns idx -1. -func FirstDifferentByte(a, b []byte) (idx int, want, got byte) { - smallerSize := min(len(a), len(b)) +func ErrorFromByteSlicesDifference(expected, actual []byte) error { + var errs []error - for i := range smallerSize { - if a[i] != b[i] { - return i, b[i], a[i] - } + if len(expected) > len(actual) { + errs = append(errs, fmt.Errorf("expected slice (%d bytes) is longer than actual slice (%d bytes)", len(expected), len(actual))) + } else if len(expected) < len(actual) { + errs = append(errs, fmt.Errorf("actual slice (%d bytes) is longer than expected slice (%d bytes)", len(actual), len(expected))) } - if len(a) != len(b) { - return smallerSize, 0, 0 + smallerSize := min(len(expected), len(actual)) + + for i := range smallerSize { + if expected[i] != actual[i] { + errs = append(errs, fmt.Errorf("first different byte: want '%x', got '%x' at index %d", expected[i], actual[i], i)) + + break + } } - return -1, 0, 0 + return errors.Join(errs...) } diff --git a/packages/orchestrator/internal/sandbox/uffd/uffd.go b/packages/orchestrator/internal/sandbox/uffd/uffd.go index c44c9795e3..9383d1abe0 100644 --- a/packages/orchestrator/internal/sandbox/uffd/uffd.go +++ b/packages/orchestrator/internal/sandbox/uffd/uffd.go @@ -140,7 +140,7 @@ func (u *Uffd) handle(ctx context.Context, sandboxId string) error { m := memory.NewMapping(regions) uffd, err := userfaultfd.NewUserfaultfdFromFd( - uintptr(fds[0]), + userfaultfd.Fd(fds[0]), u.memfile, m, logger.L().With(logger.WithSandboxID(sandboxId)), @@ -164,6 +164,9 @@ func (u *Uffd) handle(ctx context.Context, sandboxId string) error { ctx, u.fdExit, ) + if errors.Is(err, fdexit.ErrFdExit) { + return nil + } if err != nil { return fmt.Errorf("failed handling uffd: %w", err) } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go index 60c2315a69..15f2dee4bd 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go @@ -17,7 +17,6 @@ import ( "os/exec" "os/signal" "strconv" - "strings" "syscall" "testing" @@ -99,10 +98,12 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error err = configureApi(uffdFd, tt.pagesize) require.NoError(t, err) - err = register(uffdFd, memoryStart, uint64(size), UFFDIO_REGISTER_MODE_MISSING) + err = uffdFd.register(memoryStart, uint64(size), UFFDIO_REGISTER_MODE_MISSING) require.NoError(t, err) - cmd := exec.CommandContext(t.Context(), os.Args[0], "-test.run=TestHelperServingProcess") + // We don't use t.Context() here, because we want to be able to kill the process manually and listen to the correct exit code, + // while also handling the cleanup of the uffd. The t.Context seems to trigger before the test cleanup is started. + cmd := exec.CommandContext(context.Background(), os.Args[0], "-test.run=TestHelperServingProcess") cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS=1") cmd.Env = append(cmd.Env, fmt.Sprintf("GO_MMAP_START=%d", memoryStart)) cmd.Env = append(cmd.Env, fmt.Sprintf("GO_MMAP_PAGE_SIZE=%d", tt.pagesize)) @@ -127,11 +128,18 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error assert.NoError(t, closeErr) }() - offsetsReader, offsetsWriter, err := os.Pipe() + accessedOffsetsReader, accessedOffsetsWriter, err := os.Pipe() require.NoError(t, err) t.Cleanup(func() { - offsetsReader.Close() + accessedOffsetsReader.Close() + }) + + dirtyOffsetsReader, dirtyOffsetsWriter, err := os.Pipe() + require.NoError(t, err) + + t.Cleanup(func() { + dirtyOffsetsReader.Close() }) readyReader, readyWriter, err := os.Pipe() @@ -152,8 +160,9 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error cmd.ExtraFiles = []*os.File{ uffdFile, contentReader, - offsetsWriter, + accessedOffsetsWriter, readyWriter, + dirtyOffsetsWriter, } cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -162,36 +171,59 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error require.NoError(t, err) contentReader.Close() - offsetsWriter.Close() + accessedOffsetsWriter.Close() readyWriter.Close() uffdFile.Close() + dirtyOffsetsWriter.Close() + + go func() { + waitErr := cmd.Wait() + assert.NoError(t, waitErr) + + assert.NotEqual(t, -1, cmd.ProcessState.ExitCode(), "process was not terminated gracefully") + assert.NotEqual(t, 2, cmd.ProcessState.ExitCode(), "fd exit prematurely terminated the serve loop") + assert.NotEqual(t, 1, cmd.ProcessState.ExitCode(), "process exited with unexpected exit code") + + assert.Equal(t, 0, cmd.ProcessState.ExitCode()) + }() t.Cleanup(func() { - signalErr := cmd.Process.Signal(syscall.SIGUSR1) + // We are using SIGHUP to actually get exit code, not -1. + signalErr := cmd.Process.Signal(syscall.SIGTERM) assert.NoError(t, signalErr) - - waitErr := cmd.Wait() - // It can be either nil, an ExitError, a context.Canceled error, or "signal: killed" - assert.True(t, - (waitErr != nil && func(err error) bool { - var exitErr *exec.ExitError - - return errors.As(err, &exitErr) - }(waitErr)) || - errors.Is(waitErr, context.Canceled) || - (waitErr != nil && strings.Contains(waitErr.Error(), "signal: killed")) || - waitErr == nil, - "unexpected error: %v", waitErr, - ) }) - offsetsOnce := func() ([]uint, error) { + accessedOffsetsOnce := func() ([]uint, error) { err := cmd.Process.Signal(syscall.SIGUSR2) if err != nil { return nil, err } - offsetsBytes, err := io.ReadAll(offsetsReader) + offsetsBytes, err := io.ReadAll(accessedOffsetsReader) + if err != nil { + return nil, err + } + + var offsetList []uint + + if len(offsetsBytes)%8 != 0 { + return nil, fmt.Errorf("invalid offsets bytes length: %d", len(offsetsBytes)) + } + + for i := 0; i < len(offsetsBytes); i += 8 { + offsetList = append(offsetList, uint(binary.LittleEndian.Uint64(offsetsBytes[i:i+8]))) + } + + return offsetList, nil + } + + dirtyOffsetsOnce := func() ([]uint, error) { + err := cmd.Process.Signal(syscall.SIGUSR1) + if err != nil { + return nil, err + } + + offsetsBytes, err := io.ReadAll(dirtyOffsetsReader) if err != nil { return nil, err } @@ -215,23 +247,38 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error case <-readySignal: } + mapping := memory.NewMapping([]memory.Region{ + { + BaseHostVirtAddr: memoryStart, + Size: uintptr(size), + Offset: 0, + PageSize: uintptr(tt.pagesize), + }, + }) + return &testHandler{ - memoryArea: &memoryArea, - pagesize: tt.pagesize, - data: data, - offsetsOnce: offsetsOnce, + memoryArea: &memoryArea, + pagesize: tt.pagesize, + data: data, + accessedOffsetsOnce: accessedOffsetsOnce, + mapping: mapping, + dirtyOffsetsOnce: dirtyOffsetsOnce, }, nil } -// Secondary process, orchestrator in our case +// Secondary process, orchestrator in our case. func TestHelperServingProcess(t *testing.T) { if os.Getenv("GO_TEST_HELPER_PROCESS") != "1" { t.Skip("this is a helper process, skipping direct execution") } err := crossProcessServe() + if errors.Is(err, fdexit.ErrFdExit) { + os.Exit(2) + } + if err != nil { - fmt.Println("exit serving process", err) + fmt.Fprintf(os.Stderr, "error serving: %v", err) os.Exit(1) } @@ -286,29 +333,61 @@ func crossProcessServe() error { return fmt.Errorf("exit creating logger: %w", err) } - uffd, err := NewUserfaultfdFromFd(uffdFd, data, m, l) + uffd, err := NewUserfaultfdFromFd(Fd(uffdFd), data, m, l) if err != nil { return fmt.Errorf("exit creating uffd: %w", err) } - offsetsFile := os.NewFile(uintptr(5), "offsets") + accessedOffsetsFile := os.NewFile(uintptr(5), "accessed-offsets") - offsetsSignal := make(chan os.Signal, 1) - signal.Notify(offsetsSignal, syscall.SIGUSR2) - defer signal.Stop(offsetsSignal) + accessedOffsestsSignal := make(chan os.Signal, 1) + signal.Notify(accessedOffsestsSignal, syscall.SIGUSR2) + defer signal.Stop(accessedOffsestsSignal) go func() { - defer offsetsFile.Close() + defer accessedOffsetsFile.Close() for { select { case <-ctx.Done(): return - case <-offsetsSignal: + case <-accessedOffsestsSignal: + for offset := range accessed(uffd).Offsets() { + writeErr := binary.Write(accessedOffsetsFile, binary.LittleEndian, uint64(offset)) + if writeErr != nil { + msg := fmt.Errorf("error writing accessed offsets to file: %w", writeErr) + + fmt.Fprint(os.Stderr, msg.Error()) + + cancel(msg) + + return + } + } + + return + } + } + }() + + dirtyOffsetsFile := os.NewFile(uintptr(7), "dirty-offsets") + + dirtyOffsetsSignal := make(chan os.Signal, 1) + signal.Notify(dirtyOffsetsSignal, syscall.SIGUSR1) + defer signal.Stop(dirtyOffsetsSignal) + + go func() { + defer dirtyOffsetsFile.Close() + + for { + select { + case <-ctx.Done(): + return + case <-dirtyOffsetsSignal: for offset := range uffd.Dirty().Offsets() { - writeErr := binary.Write(offsetsFile, binary.LittleEndian, uint64(offset)) + writeErr := binary.Write(dirtyOffsetsFile, binary.LittleEndian, uint64(offset)) if writeErr != nil { - msg := fmt.Errorf("error writing offsets to file: %w", writeErr) + msg := fmt.Errorf("error writing dirty offsets to file: %w", writeErr) fmt.Fprint(os.Stderr, msg.Error()) @@ -335,6 +414,14 @@ func crossProcessServe() error { }() serverErr := uffd.Serve(ctx, fdExit) + if errors.Is(serverErr, fdexit.ErrFdExit) { + err := fmt.Errorf("serving finished via fd exit: %w", serverErr) + + cancel(err) + + return + } + if serverErr != nil { msg := fmt.Errorf("error serving: %w", serverErr) @@ -344,6 +431,8 @@ func crossProcessServe() error { return } + + fmt.Fprint(os.Stderr, "serving finished") }() cleanup := func() { @@ -364,7 +453,7 @@ func crossProcessServe() error { defer cleanup() exitSignal := make(chan os.Signal, 1) - signal.Notify(exitSignal, syscall.SIGUSR1) + signal.Notify(exitSignal, syscall.SIGTERM) defer signal.Stop(exitSignal) readyFile := os.NewFile(uintptr(6), "ready") @@ -376,7 +465,7 @@ func crossProcessServe() error { select { case <-ctx.Done(): - return fmt.Errorf("context done: %w: %w", ctx.Err(), context.Cause(ctx)) + return context.Cause(ctx) case <-exitSignal: return nil } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go index 1a122d01d5..e5145ad060 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go @@ -32,12 +32,19 @@ const ( UFFD_EVENT_PAGEFAULT = C.UFFD_EVENT_PAGEFAULT UFFDIO_REGISTER_MODE_MISSING = C.UFFDIO_REGISTER_MODE_MISSING + UFFDIO_REGISTER_MODE_WP = C.UFFDIO_REGISTER_MODE_WP - UFFDIO_API = C.UFFDIO_API - UFFDIO_REGISTER = C.UFFDIO_REGISTER - UFFDIO_COPY = C.UFFDIO_COPY + UFFDIO_WRITEPROTECT_MODE_WP = C.UFFDIO_WRITEPROTECT_MODE_WP + UFFDIO_COPY_MODE_WP = C.UFFDIO_COPY_MODE_WP + + UFFDIO_API = C.UFFDIO_API + UFFDIO_REGISTER = C.UFFDIO_REGISTER + UFFDIO_UNREGISTER = C.UFFDIO_UNREGISTER + UFFDIO_COPY = C.UFFDIO_COPY + UFFDIO_WRITEPROTECT = C.UFFDIO_WRITEPROTECT UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE + UFFD_PAGEFAULT_FLAG_WP = C.UFFD_PAGEFAULT_FLAG_WP UFFD_FEATURE_MISSING_HUGETLBFS = C.UFFD_FEATURE_MISSING_HUGETLBFS ) @@ -78,6 +85,13 @@ func newUffdioRegister(start, length, mode CULong) UffdioRegister { } } +func newUffdioWriteProtect(start, length, mode CULong) UffdioWriteProtect { + return UffdioWriteProtect{ + _range: newUffdioRange(start, length), + mode: mode, + } +} + func newUffdioCopy(b []byte, address CULong, pagesize CULong, mode CULong, bytesCopied CLong) UffdioCopy { return UffdioCopy{ src: CULong(uintptr(unsafe.Pointer(&b[0]))), @@ -120,6 +134,37 @@ func (f Fd) copy(addr, pagesize uintptr, data []byte, mode CULong) error { return nil } +// mode: UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING +// This is already called by the FC, but only with the UFFDIO_REGISTER_MODE_MISSING +// We need to call it with UFFDIO_REGISTER_MODE_WP when we use both missing and wp +func (f Fd) register(addr uintptr, size uint64, mode CULong) error { + register := newUffdioRegister(CULong(addr), CULong(size), mode) + + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_REGISTER, uintptr(unsafe.Pointer(®ister))) + if errno != 0 { + return fmt.Errorf("UFFDIO_REGISTER ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} + +// mode: UFFDIO_WRITEPROTECT_MODE_WP +// Passing 0 as the mode will remove the write protection. +func (f Fd) writeProtect(addr, size uintptr, mode CULong) error { + register := newUffdioWriteProtect(CULong(addr), CULong(size), mode) + + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_WRITEPROTECT, uintptr(unsafe.Pointer(®ister))) + if errno != 0 { + return fmt.Errorf("UFFDIO_WRITEPROTECT ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} + func (f Fd) close() error { return syscall.Close(int(f)) } + +func (f Fd) fd() int32 { + return int32(f) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go index f9f8d953cb..cfe7619399 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go @@ -8,6 +8,67 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) +// mockFd is a mock implementation of the Fd interface. +// It allows us to test the handling methods separately from the actual uffd serve loop. +type mockFd struct { + // The channels send back the info about the uffd handled operations + // and also allows us to block the methods to test the flow. + copyCh chan *blockedEvent[UffdioCopy] + writeProtectCh chan *blockedEvent[UffdioWriteProtect] +} + +func newMockFd() *mockFd { + return &mockFd{ + copyCh: make(chan *blockedEvent[UffdioCopy]), + writeProtectCh: make(chan *blockedEvent[UffdioWriteProtect]), + } +} + +func (m *mockFd) register(_ uintptr, _ uint64, _ CULong) error { + return nil +} + +func (m *mockFd) copy(addr, pagesize uintptr, _ []byte, mode CULong) error { + // Don't use the uffdioCopy constructor as it unsafely checks slice address and fails for arbitrary pointer. + e := newBlockedEvent(UffdioCopy{ + src: 0, + dst: CULong(addr), + len: CULong(pagesize), + mode: mode, + copy: 0, + }) + + m.copyCh <- e + + <-e.resolved + + return nil +} + +func (m *mockFd) writeProtect(addr, size uintptr, mode CULong) error { + e := newBlockedEvent(UffdioWriteProtect{ + _range: newUffdioRange( + CULong(addr), + CULong(size), + ), + mode: mode, + }) + + m.writeProtectCh <- e + + <-e.resolved + + return nil +} + +func (m *mockFd) close() error { + return nil +} + +func (m *mockFd) fd() int32 { + return 0 +} + // Used for testing. // flags: syscall.O_CLOEXEC|syscall.O_NONBLOCK func newFd(flags uintptr) (Fd, error) { @@ -39,16 +100,19 @@ func configureApi(f Fd, pagesize uint64) error { return nil } -// mode: UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING -// This is already called by the FC, but only with the UFFDIO_REGISTER_MODE_MISSING -// We need to call it with UFFDIO_REGISTER_MODE_WP when we use both missing and wp -func register(f Fd, addr uintptr, size uint64, mode CULong) error { - register := newUffdioRegister(CULong(addr), CULong(size), mode) +// This wrapped event allows us to simulate the finish of processing of events by FC on FC API Pause. +type blockedEvent[T UffdioCopy | UffdioWriteProtect] struct { + event T + resolved chan struct{} +} - ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_REGISTER, uintptr(unsafe.Pointer(®ister))) - if errno != 0 { - return fmt.Errorf("UFFDIO_REGISTER ioctl failed: %w (ret=%d)", errno, ret) +func newBlockedEvent[T UffdioCopy | UffdioWriteProtect](event T) *blockedEvent[T] { + return &blockedEvent[T]{ + event: event, + resolved: make(chan struct{}), } +} - return nil +func (e *blockedEvent[T]) resolve() { + close(e.resolved) } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/helpers_test.go index 467ab423b4..95504bdd51 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/helpers_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/helpers_test.go @@ -9,6 +9,8 @@ import ( "github.com/bits-and-blooms/bitset" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/testutils" ) @@ -41,8 +43,26 @@ type testHandler struct { data *MemorySlicer // Returns offsets of the pages that were faulted. // It can only be called once. - offsetsOnce func() ([]uint, error) - mutex sync.Mutex + // Sorted in ascending order. + accessedOffsetsOnce func() ([]uint, error) + // Returns offsets of the pages that were dirtied. + // It can only be called once. + // Sorted in ascending order. + dirtyOffsetsOnce func() ([]uint, error) + + mutex sync.Mutex + mapping *memory.Mapping +} + +func (h *testHandler) executeOperation(ctx context.Context, op operation) error { + switch op.mode { + case operationModeRead: + return h.executeRead(ctx, op) + case operationModeWrite: + return h.executeWrite(ctx, op) + default: + return fmt.Errorf("invalid operation mode: %d", op.mode) + } } func (h *testHandler) executeRead(ctx context.Context, op operation) error { @@ -55,9 +75,7 @@ func (h *testHandler) executeRead(ctx context.Context, op operation) error { // The bytes.Equal is the first place in this flow that actually touches the uffd managed memory and triggers the pagefault, so any deadlocks will manifest here. if !bytes.Equal(readBytes, expectedBytes) { - idx, want, got := testutils.FirstDifferentByte(readBytes, expectedBytes) - - return fmt.Errorf("content mismatch: want '%x, got %x at index %d", want, got, idx) + return fmt.Errorf("content mismatch: %w", testutils.ErrorFromByteSlicesDifference(expectedBytes, readBytes)) } return nil @@ -82,6 +100,7 @@ func (h *testHandler) executeWrite(ctx context.Context, op operation) error { } // Get a bitset of the offsets of the operations for the given mode. +// Sorted in ascending order. func getOperationsOffsets(ops []operation, m operationMode) []uint { b := bitset.New(0) @@ -93,3 +112,10 @@ func getOperationsOffsets(ops []operation, m operationMode) []uint { return slices.Collect(b.EachSet()) } + +func accessed(u *Userfaultfd) *block.Tracker { + u.settleRequests.Lock() + defer u.settleRequests.Unlock() + + return u.missingRequests.Clone() +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_test.go index 20c8ddeeb3..cf24e55d51 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_test.go @@ -62,6 +62,33 @@ func TestMissing(t *testing.T) { }, }, }, + { + name: "standard 4k page, reads, varying offsets", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 4 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 5 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 2 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + }, + }, { name: "hugepage, operation at start", pagesize: header.HugepageSize, @@ -110,6 +137,33 @@ func TestMissing(t *testing.T) { }, }, }, + { + name: "hugepage, reads, varying offsets", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 4 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 5 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 2 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + }, + }, } for _, tt := range tests { @@ -120,15 +174,13 @@ func TestMissing(t *testing.T) { require.NoError(t, err) for _, operation := range tt.operations { - if operation.mode == operationModeRead { - err := h.executeRead(t.Context(), operation) - require.NoError(t, err, "for operation %+v", operation) - } + err := h.executeOperation(t.Context(), operation) + assert.NoError(t, err, "for operation %+v", operation) //nolint:testifylint } expectedAccessedOffsets := getOperationsOffsets(tt.operations, operationModeRead|operationModeWrite) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") @@ -158,7 +210,7 @@ func TestParallelMissing(t *testing.T) { for range parallelOperations { verr.Go(func() error { - return h.executeRead(t.Context(), readOp) + return h.executeOperation(t.Context(), readOp) }) } @@ -167,7 +219,7 @@ func TestParallelMissing(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{readOp}, operationModeRead) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") @@ -191,14 +243,14 @@ func TestParallelMissingWithPrefault(t *testing.T) { mode: operationModeRead, } - err = h.executeRead(t.Context(), readOp) + err = h.executeOperation(t.Context(), readOp) require.NoError(t, err) var verr errgroup.Group for range parallelOperations { verr.Go(func() error { - return h.executeRead(t.Context(), readOp) + return h.executeOperation(t.Context(), readOp) }) } @@ -207,7 +259,7 @@ func TestParallelMissingWithPrefault(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{readOp}, operationModeRead) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") @@ -232,13 +284,13 @@ func TestSerialMissing(t *testing.T) { } for range serialOperations { - err := h.executeRead(t.Context(), readOp) + err := h.executeOperation(t.Context(), readOp) require.NoError(t, err) } expectedAccessedOffsets := getOperationsOffsets([]operation{readOp}, operationModeRead) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_write_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_write_test.go index 52a73e5fa2..5b987fd95c 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_write_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_write_test.go @@ -10,7 +10,7 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) -func TestMissingWrite(t *testing.T) { +func TestWriteProtection(t *testing.T) { t.Parallel() tests := []testConfig{ @@ -19,6 +19,10 @@ func TestMissingWrite(t *testing.T) { pagesize: header.PageSize, numberOfPages: 32, operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, { offset: 0, mode: operationModeWrite, @@ -30,6 +34,10 @@ func TestMissingWrite(t *testing.T) { pagesize: header.PageSize, numberOfPages: 32, operations: []operation{ + { + offset: 15 * header.PageSize, + mode: operationModeRead, + }, { offset: 15 * header.PageSize, mode: operationModeWrite, @@ -41,6 +49,10 @@ func TestMissingWrite(t *testing.T) { pagesize: header.PageSize, numberOfPages: 32, operations: []operation{ + { + offset: 31 * header.PageSize, + mode: operationModeRead, + }, { offset: 31 * header.PageSize, mode: operationModeWrite, @@ -48,18 +60,50 @@ func TestMissingWrite(t *testing.T) { }, }, { - name: "standard 4k page, read after write", + name: "standard 4k page, writes after reads, varying offsets", pagesize: header.PageSize, numberOfPages: 32, operations: []operation{ { - offset: 0, - mode: operationModeWrite, + offset: 4 * header.PageSize, + mode: operationModeRead, }, { - offset: 0, + offset: 5 * header.PageSize, mode: operationModeRead, }, + { + offset: 2 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 4 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 5 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 2 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.PageSize, + mode: operationModeWrite, + }, }, }, { @@ -67,6 +111,10 @@ func TestMissingWrite(t *testing.T) { pagesize: header.HugepageSize, numberOfPages: 8, operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, { offset: 0, mode: operationModeWrite, @@ -78,6 +126,10 @@ func TestMissingWrite(t *testing.T) { pagesize: header.HugepageSize, numberOfPages: 8, operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeRead, + }, { offset: 3 * header.HugepageSize, mode: operationModeWrite, @@ -89,6 +141,10 @@ func TestMissingWrite(t *testing.T) { pagesize: header.HugepageSize, numberOfPages: 8, operations: []operation{ + { + offset: 7 * header.HugepageSize, + mode: operationModeRead, + }, { offset: 7 * header.HugepageSize, mode: operationModeWrite, @@ -96,18 +152,50 @@ func TestMissingWrite(t *testing.T) { }, }, { - name: "hugepage, read after write", + name: "hugepage, writes after reads, varying offsets", pagesize: header.HugepageSize, numberOfPages: 8, operations: []operation{ { - offset: 0, - mode: operationModeWrite, + offset: 4 * header.HugepageSize, + mode: operationModeRead, }, { - offset: 0, + offset: 5 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 2 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, mode: operationModeRead, }, + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 4 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 5 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 2 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, }, }, } @@ -120,28 +208,28 @@ func TestMissingWrite(t *testing.T) { require.NoError(t, err) for _, operation := range tt.operations { - if operation.mode == operationModeRead { - err := h.executeRead(t.Context(), operation) - require.NoError(t, err, "for operation %+v", operation) - } - - if operation.mode == operationModeWrite { - err := h.executeWrite(t.Context(), operation) - require.NoError(t, err, "for operation %+v", operation) - } + err := h.executeOperation(t.Context(), operation) + assert.NoError(t, err, "for operation %+v", operation) //nolint:testifylint } expectedAccessedOffsets := getOperationsOffsets(tt.operations, operationModeRead|operationModeWrite) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets(tt.operations, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") }) } } -func TestParallelMissingWrite(t *testing.T) { +func TestParallelWriteProtection(t *testing.T) { t.Parallel() parallelOperations := 1_000_000 @@ -154,41 +242,12 @@ func TestParallelMissingWrite(t *testing.T) { h, err := configureCrossProcessTest(t, tt) require.NoError(t, err) - writeOp := operation{ + readOp := operation{ offset: 0, - mode: operationModeWrite, + mode: operationModeRead, } - var verr errgroup.Group - - for range parallelOperations { - verr.Go(func() error { - return h.executeWrite(t.Context(), writeOp) - }) - } - - err = verr.Wait() - require.NoError(t, err) - - expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) - - accessedOffsets, err := h.offsetsOnce() - require.NoError(t, err) - - assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") -} - -func TestParallelMissingWriteWithPrefault(t *testing.T) { - t.Parallel() - - parallelOperations := 10_000 - - tt := testConfig{ - pagesize: header.PageSize, - numberOfPages: 2, - } - - h, err := configureCrossProcessTest(t, tt) + err = h.executeOperation(t.Context(), readOp) require.NoError(t, err) writeOp := operation{ @@ -196,9 +255,6 @@ func TestParallelMissingWriteWithPrefault(t *testing.T) { mode: operationModeWrite, } - err = h.executeWrite(t.Context(), writeOp) - require.NoError(t, err) - var verr errgroup.Group for range parallelOperations { @@ -212,13 +268,20 @@ func TestParallelMissingWriteWithPrefault(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets([]operation{writeOp}, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") } -func TestSerialMissingWrite(t *testing.T) { +func TestSerialWriteProtection(t *testing.T) { t.Parallel() serialOperations := 10_000 @@ -243,8 +306,15 @@ func TestSerialMissingWrite(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets([]operation{writeOp}, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go index e248cfe581..d3b129aa6f 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go @@ -1,5 +1,10 @@ package userfaultfd +// flowchart TD +// A[missing page] -- write (WRITE flag) --> B(COPY) --> C[dirty page] +// A -- read (MISSING flag) --> D(COPY + MODE_WP) --> E[faulted page] +// E -- write (WP+[WRITE] flag) --> F(remove MODE_WP) --> C + import ( "context" "errors" @@ -22,15 +27,24 @@ const maxRequestsInProgress = 4096 var ErrUnexpectedEventType = errors.New("unexpected event type") +type uffdio interface { + copy(addr, pagesize uintptr, data []byte, mode CULong) error + register(addr uintptr, size uint64, mode CULong) error + writeProtect(addr, size uintptr, mode CULong) error + close() error + fd() int32 +} + type Userfaultfd struct { - fd Fd + uffd uffdio src block.Slicer - ma *memory.Mapping + m *memory.Mapping // We don't skip the already mapped pages, because if the memory is swappable the page *might* under some conditions be mapped out. // For hugepages this should not be a problem, but might theoretically happen to normal pages with swap missingRequests *block.Tracker + writeRequests *block.Tracker // We use the settleRequests to guard the missingRequests so we can access a consistent state of the missingRequests after the requests are finished. settleRequests sync.RWMutex @@ -40,20 +54,30 @@ type Userfaultfd struct { } // NewUserfaultfdFromFd creates a new userfaultfd instance with optional configuration. -func NewUserfaultfdFromFd(fd uintptr, src block.Slicer, m *memory.Mapping, logger logger.Logger) (*Userfaultfd, error) { +func NewUserfaultfdFromFd(uffd uffdio, src block.Slicer, m *memory.Mapping, logger logger.Logger) (*Userfaultfd, error) { blockSize := src.BlockSize() for _, region := range m.Regions { if region.PageSize != uintptr(blockSize) { return nil, fmt.Errorf("block size mismatch: %d != %d for region %d", region.PageSize, blockSize, region.BaseHostVirtAddr) } + + // Register the WP for the regions. + // The memory region is already registered (with missing pages in FC), but registering it again with bigger flag subset should merge these registration flags. + // - https://github.com/firecracker-microvm/firecracker/blob/f335a0adf46f0680a141eb1e76fe31ac258918c5/src/vmm/src/persist.rs#L477 + // - https://github.com/bytecodealliance/userfaultfd-rs/blob/main/src/builder.rs + err := uffd.register(region.BaseHostVirtAddr, uint64(region.Size), UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING) + if err != nil { + return nil, fmt.Errorf("failed to reregister memory region with write protection %d-%d: %w", region.Offset, region.Offset+region.Size, err) + } } u := &Userfaultfd{ - fd: Fd(fd), + uffd: uffd, src: src, missingRequests: block.NewTracker(blockSize), - ma: m, + writeRequests: block.NewTracker(blockSize), + m: m, logger: logger, } @@ -66,15 +90,17 @@ func NewUserfaultfdFromFd(fd uintptr, src block.Slicer, m *memory.Mapping, logge } func (u *Userfaultfd) Close() error { - return u.fd.close() + return u.uffd.close() } func (u *Userfaultfd) Serve( ctx context.Context, fdExit *fdexit.FdExit, ) error { + uffd := u.uffd.fd() + pollFds := []unix.PollFd{ - {Fd: int32(u.fd), Events: unix.POLLIN}, + {Fd: uffd, Events: unix.POLLIN}, {Fd: fdExit.Reader(), Events: unix.POLLIN}, } @@ -113,7 +139,7 @@ outerLoop: return fmt.Errorf("failed to handle uffd: %w", errMsg) } - return nil + return fdexit.ErrFdExit } uffdFd := pollFds[0] @@ -135,7 +161,7 @@ outerLoop: buf := make([]byte, unsafe.Sizeof(UffdMsg{})) for { - _, err := syscall.Read(int(u.fd), buf) + _, err := syscall.Read(int(uffd), buf) if err == syscall.EINTR { u.logger.Debug(ctx, "uffd: interrupted read, reading again") @@ -176,21 +202,27 @@ outerLoop: addr := getPagefaultAddress(&pagefault) - offset, pagesize, err := u.ma.GetOffset(addr) + offset, pagesize, err := u.m.GetOffset(addr) if err != nil { u.logger.Error(ctx, "UFFD serve get mapping error", zap.Error(err)) return fmt.Errorf("failed to map: %w", err) } + // Handle write to write protected page (WP flag) + // The documentation does not clearly mention if the WRITE flag must be present with the WP flag, even though we saw it being present in the events. + // - https://docs.kernel.org/admin-guide/mm/userfaultfd.html#write-protect-notifications + if flags&UFFD_PAGEFAULT_FLAG_WP != 0 { + u.handleWriteProtected(ctx, fdExit.SignalExit, addr, pagesize, offset) + + continue + } + // Handle write to missing page (WRITE flag) // If the event has WRITE flag, it was a write to a missing page. // For the write to be executed, we first need to copy the page from the source to the guest memory. if flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 { - err := u.handleMissing(ctx, fdExit.SignalExit, addr, pagesize, offset) - if err != nil { - return fmt.Errorf("failed to handle missing write: %w", err) - } + u.handleMissing(ctx, fdExit.SignalExit, addr, pagesize, offset, true) continue } @@ -198,10 +230,7 @@ outerLoop: // Handle read to missing page ("MISSING" flag) // If the event has no flags, it was a read to a missing page and we need to copy the page from the source to the guest memory. if flags == 0 { - err := u.handleMissing(ctx, fdExit.SignalExit, addr, pagesize, offset) - if err != nil { - return fmt.Errorf("failed to handle missing: %w", err) - } + u.handleMissing(ctx, fdExit.SignalExit, addr, pagesize, offset, false) continue } @@ -217,18 +246,24 @@ func (u *Userfaultfd) handleMissing( addr, pagesize uintptr, offset int64, -) error { + write bool, +) { u.wg.Go(func() error { // The RLock must be called inside the goroutine to ensure RUnlock runs via defer, // even if the errgroup is cancelled or the goroutine returns early. // This check protects us against race condition between marking the request as missing and accessing the missingRequests tracker. - // The Firecracker pause should return only after the requested memory is faulted in, so we don't need to guard the pagefault from the moment it is created. + // The Firecracker pause should return only after the requested memory is copied to the guest memory, so we don't need to guard the pagefault from the moment it is created. u.settleRequests.RLock() defer u.settleRequests.RUnlock() defer func() { if r := recover(); r != nil { u.logger.Error(ctx, "UFFD serve panic", zap.Any("pagesize", pagesize), zap.Any("panic", r)) + + signalErr := onFailure() + if signalErr != nil { + u.logger.Error(ctx, "UFFD handle missing failure error", zap.Error(signalErr)) + } } }() @@ -245,7 +280,12 @@ func (u *Userfaultfd) handleMissing( var copyMode CULong - copyErr := u.fd.copy(addr, pagesize, b, copyMode) + // If the event is not WRITE, we need to add WP to the page, so we can catch the next WRITE+WP and mark the page as dirty. + if !write { + copyMode |= UFFDIO_COPY_MODE_WP + } + + copyErr := u.uffd.copy(addr, pagesize, b, copyMode) if errors.Is(copyErr, unix.EEXIST) { // Page is already mapped @@ -259,18 +299,64 @@ func (u *Userfaultfd) handleMissing( u.logger.Error(ctx, "UFFD serve uffdio copy error", zap.Error(joinedErr)) - return fmt.Errorf("failed uffdio copy %w", joinedErr) + return fmt.Errorf("failed to copy page %d-%d %w", offset, offset+int64(pagesize), joinedErr) } // Add the offset to the missing requests tracker. u.missingRequests.Add(offset) + if write { + // Add the offset to the write requests tracker. + u.writeRequests.Add(offset) + } + return nil }) +} - return nil +// Userfaultfd write-protect mode currently behave differently on none ptes (when e.g. page is missing) over different types of memories (hugepages file backed, etc.). +// - https://docs.kernel.org/admin-guide/mm/userfaultfd.html#write-protect-notifications - "there will be a userfaultfd write fault message generated when writing to a missing page" +// This should not affect the handling we have in place as all events are being handled. +func (u *Userfaultfd) handleWriteProtected(ctx context.Context, onFailure func() error, addr, pagesize uintptr, offset int64) { + u.wg.Go(func() error { + // The RLock must be called inside the goroutine to ensure RUnlock runs via defer, + // even if the errgroup is cancelled or the goroutine returns early. + // This check protects us against race condition between marking the request as dirty and accessing the writeRequests tracker. + // The Firecracker pause should return only after the requested memory is copied to the guest memory, so we don't need to guard the pagefault from the moment it is created. + u.settleRequests.RLock() + defer u.settleRequests.RUnlock() + + defer func() { + if r := recover(); r != nil { + u.logger.Error(ctx, "UFFD remove write protection panic", zap.Any("offset", offset), zap.Any("pagesize", pagesize), zap.Any("panic", r)) + + signalErr := onFailure() + if signalErr != nil { + u.logger.Error(ctx, "UFFD handle write protected failure error", zap.Error(signalErr)) + } + } + }() + + // Passing 0 as the mode removes the write protection. + wpErr := u.uffd.writeProtect(addr, pagesize, 0) + if wpErr != nil { + signalErr := onFailure() + + joinedErr := errors.Join(wpErr, signalErr) + + u.logger.Error(ctx, "UFFD serve write protect error", zap.Error(joinedErr)) + + return fmt.Errorf("failed to remove write protection from page %d-%d %w", offset, offset+int64(pagesize), joinedErr) + } + + // Add the offset to the write requests tracker. + u.writeRequests.Add(offset) + + return nil + }) } +// Dirty returns the dirty pages. func (u *Userfaultfd) Dirty() *block.Tracker { // This will be at worst cancelled when the uffd is closed. u.settleRequests.Lock() @@ -278,5 +364,5 @@ func (u *Userfaultfd) Dirty() *block.Tracker { // so it is consistent even if there is a another uffd call after. defer u.settleRequests.Unlock() - return u.missingRequests.Clone() + return u.writeRequests.Clone() } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go new file mode 100644 index 0000000000..16333261d3 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go @@ -0,0 +1,771 @@ +package userfaultfd + +import ( + "context" + "fmt" + "maps" + "math/rand" + "slices" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" + "github.com/e2b-dev/infra/packages/shared/pkg/logger" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" +) + +func TestNoOperations(t *testing.T) { + t.Parallel() + + pagesize := uint64(header.PageSize) + numberOfPages := uint64(512) + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + logger, err := logger.NewDevelopmentLogger() + require.NoError(t, err) + + // Placeholder uffd that does not serve anything + uffd, err := NewUserfaultfdFromFd(newMockFd(), h.data, &memory.Mapping{}, logger) + require.NoError(t, err) + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.Empty(t, accessedOffsets, "checking which pages were faulted") + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Empty(t, dirtyOffsets, "checking which pages were dirty") + + dirty := uffd.Dirty() + assert.Empty(t, slices.Collect(dirty.Offsets()), "checking dirty pages") +} + +func TestRandomOperations(t *testing.T) { + t.Parallel() + + pagesize := uint64(header.PageSize) + numberOfPages := uint64(4096) + numberOfOperations := 2048 + repetitions := 8 + + for i := range repetitions { + t.Run(fmt.Sprintf("Run_%d_of_%d", i+1, repetitions), func(t *testing.T) { + t.Parallel() + + // Use time-based seed for each run to ensure different random sequences + // This increases the chance of catching bugs that only manifest with specific sequences + seed := time.Now().UnixNano() + int64(i) + rng := rand.New(rand.NewSource(seed)) + + t.Logf("Using random seed: %d", seed) + + // Randomly operations on the data + operations := make([]operation, 0, numberOfOperations) + for range numberOfOperations { + operations = append(operations, operation{ + offset: int64(rng.Intn(int(numberOfPages-1)) * int(pagesize)), + mode: operationMode(rng.Intn(2) + 1), + }) + } + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + operations: operations, + }) + require.NoError(t, err) + + for _, operation := range operations { + err := h.executeOperation(t.Context(), operation) + require.NoError(t, err, "for operation %+v", operation) + } + + expectedAccessedOffsets := getOperationsOffsets(operations, operationModeRead|operationModeWrite) + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted (seed: %d)", seed) + + expectedDirtyOffsets := getOperationsOffsets(operations, operationModeWrite) + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty (seed: %d)", seed) + }) + } +} + +// Badly configured uffd panic recovery caused silent close of uffd—first operation (or parallel first operations) were always handled ok, but subsequent operations would freeze. +// This behavior was flaky with the rest of the tests, because it was racy. +func TestUffdNotClosingAfterOperation(t *testing.T) { + t.Parallel() + + pagesize := uint64(header.PageSize) + numberOfPages := uint64(4096) + + t.Run("missing write", func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + write1 := operation{ + offset: 0, + mode: operationModeWrite, + } + + // We need different offset, because kernel would cache the same page for read. + write2 := operation{ + offset: int64(2 * header.PageSize), + mode: operationModeWrite, + } + + err = h.executeOperation(t.Context(), write1) + require.NoError(t, err) + + // Use the offset helpers to settle the requests. + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{0}, accessedOffsets, "checking which pages were faulted") + + err = h.executeOperation(t.Context(), write2) + require.NoError(t, err) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{uint(2 * header.PageSize), 0}, dirtyOffsets, "checking which pages were dirty") + }) + + t.Run("missing read", func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + read1 := operation{ + offset: 0, + mode: operationModeRead, + } + + // We need different offset, because kernel would cache the same page for read. + read2 := operation{ + offset: 2 * header.PageSize, + mode: operationModeRead, + } + + err = h.executeOperation(t.Context(), read1) + require.NoError(t, err) + + // Use the offset helpers to settle the requests. + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Empty(t, dirtyOffsets, "checking which pages were dirty") + + err = h.executeOperation(t.Context(), read2) + require.NoError(t, err) + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{uint(2 * header.PageSize), 0}, accessedOffsets, "checking which pages were faulted") + }) + + t.Run("write protected", func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + read1 := operation{ + offset: 0, + mode: operationModeRead, + } + + read2 := operation{ + offset: 1 * header.PageSize, + mode: operationModeRead, + } + + // We need at least 2 wp events to check the wp handler, so we need to write to 2 different pages. + + write1 := operation{ + offset: 0, + mode: operationModeWrite, + } + + write2 := operation{ + offset: 1 * header.PageSize, + mode: operationModeWrite, + } + + err = h.executeOperation(t.Context(), read1) + require.NoError(t, err) + + err = h.executeOperation(t.Context(), read2) + require.NoError(t, err) + + // Use the offset helpers to settle the requests. + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{0, 1 * header.PageSize}, accessedOffsets, "checking which pages were faulted") + + err = h.executeOperation(t.Context(), write1) + require.NoError(t, err) + + err = h.executeOperation(t.Context(), write2) + require.NoError(t, err) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{0, 1 * header.PageSize}, dirtyOffsets, "checking which pages were dirty") + }) +} + +func TestUffdEvents(t *testing.T) { + pagesize := uint64(header.PageSize) + numberOfPages := uint64(32) + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + logger, err := logger.NewDevelopmentLogger() + require.NoError(t, err) + + mockFd := newMockFd() + + // Placeholder uffd that does not serve anything + uffd, err := NewUserfaultfdFromFd(mockFd, h.data, &memory.Mapping{}, logger) + require.NoError(t, err) + + events := []event{ + // Same operation and offset, repeated (copies at 0), with both mode 0 and UFFDIO_COPY_MODE_WP + { + UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, + offset: 0, + }, + { + UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: 0, + }, + { + UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, + offset: 0, + }, + { + UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: 0, + }, + + // WriteProtect at same offset, repeated + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, + offset: 0, + }, + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, + offset: 0, + }, + + // Copy at next offset, include both mode 0 and UFFDIO_COPY_MODE_WP + { + UffdioCopy: &UffdioCopy{ + dst: header.PageSize, + len: header.PageSize, + mode: 0, + }, + offset: int64(header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: header.PageSize, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: int64(header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: header.PageSize, + len: header.PageSize, + mode: 0, + }, + offset: int64(header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: header.PageSize, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: int64(header.PageSize), + }, + + // WriteProtect at next offset, repeated + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: header.PageSize, + len: header.PageSize, + }, + }, + offset: int64(header.PageSize), + }, + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: header.PageSize, + len: header.PageSize, + }, + }, + offset: int64(header.PageSize), + }, + + // Copy at another offset, include both mode 0 and UFFDIO_COPY_MODE_WP + { + UffdioCopy: &UffdioCopy{ + dst: 2 * header.PageSize, + len: header.PageSize, + mode: 0, + }, + offset: int64(2 * header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: 2 * header.PageSize, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: int64(2 * header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: 2 * header.PageSize, + len: header.PageSize, + mode: 0, + }, + offset: int64(2 * header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: 2 * header.PageSize, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: int64(2 * header.PageSize), + }, + + // WriteProtect at another offset, repeated + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 2 * header.PageSize, + len: header.PageSize, + }, + }, + offset: int64(2 * header.PageSize), + }, + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 2 * header.PageSize, + len: header.PageSize, + }, + }, + offset: int64(2 * header.PageSize), + }, + } + + for _, event := range events { + err := event.trigger(t.Context(), uffd) + require.NoError(t, err, "for event %+v", event) + } + + receivedEvents := make([]event, 0, len(events)) + + for range events { + select { + case copyEvent, ok := <-mockFd.copyCh: + if !ok { + t.FailNow() + } + + copyEvent.resolve() + + // We don't add the offset here, because it is propagated only through the "accessed" and "dirty" sets. + // When later comparing the events, we will compare the events without the offset. + receivedEvents = append(receivedEvents, event{UffdioCopy: ©Event.event}) + case writeProtectEvent, ok := <-mockFd.writeProtectCh: + if !ok { + t.FailNow() + } + + writeProtectEvent.resolve() + + // We don't add the offset here, because it is propagated only through the "accessed" and "dirty" sets. + // When later comparing the events, we will compare the events without the offset. + receivedEvents = append(receivedEvents, event{UffdioWriteProtect: &writeProtectEvent.event}) + case <-t.Context().Done(): + t.FailNow() + } + } + + assert.Len(t, receivedEvents, len(events), "checking received events") + assert.ElementsMatch(t, zeroOffsets(events), receivedEvents, "checking received events") + + select { + case <-mockFd.copyCh: + t.Fatalf("copy channel should not have any events") + case <-mockFd.writeProtectCh: + t.Fatalf("write protect channel should not have any events") + case <-t.Context().Done(): + t.FailNow() + default: + } + + dirty := uffd.Dirty() + + expectedDirtyOffsets := make(map[int64]struct{}) + expectedAccessedOffsets := make(map[int64]struct{}) + + for _, event := range events { + if event.UffdioWriteProtect != nil { + expectedDirtyOffsets[event.offset] = struct{}{} + } + if event.UffdioCopy != nil { + if event.UffdioCopy.mode != UFFDIO_COPY_MODE_WP { + expectedDirtyOffsets[event.offset] = struct{}{} + } + + expectedAccessedOffsets[event.offset] = struct{}{} + } + } + + assert.ElementsMatch(t, slices.Collect(maps.Keys(expectedDirtyOffsets)), slices.Collect(dirty.Offsets()), "checking dirty pages") + + accessed := accessed(uffd) + assert.ElementsMatch(t, slices.Collect(maps.Keys(expectedAccessedOffsets)), slices.Collect(accessed.Offsets()), "checking accessed pages") +} + +func TestUffdSettleRequests(t *testing.T) { + t.Parallel() + + pagesize := uint64(header.PageSize) + numberOfPages := uint64(32) + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + logger, err := logger.NewDevelopmentLogger() + require.NoError(t, err) + + testEventsSettle := func(t *testing.T, events []event) { + t.Helper() + + mockFd := newMockFd() + + // Placeholder uffd that does not serve anything + uffd, err := NewUserfaultfdFromFd(mockFd, h.data, &memory.Mapping{}, logger) + require.NoError(t, err) + + for _, e := range events { + err = e.trigger(t.Context(), uffd) + require.NoError(t, err, "for event %+v", e) + } + + var blockedCopyEvents []*blockedEvent[UffdioCopy] + var blockedWriteProtectEvents []*blockedEvent[UffdioWriteProtect] + + for range events { + // Wait until the event is blocked + select { + case copyEvent, ok := <-mockFd.copyCh: + if !ok { + t.FailNow() + } + + require.NotNil(t, copyEvent.event, "copy event should not be nil") + assert.Contains(t, zeroOffsets(events), event{UffdioCopy: ©Event.event}, "checking copy event") + + blockedCopyEvents = append(blockedCopyEvents, copyEvent) + case writeProtectEvent, ok := <-mockFd.writeProtectCh: + if !ok { + t.FailNow() + } + + require.NotNil(t, writeProtectEvent.event, "write protect event should not be nil") + assert.Contains(t, zeroOffsets(events), event{UffdioWriteProtect: &writeProtectEvent.event}, "checking write protect event") + + blockedWriteProtectEvents = append(blockedWriteProtectEvents, writeProtectEvent) + case <-t.Context().Done(): + t.FailNow() + } + } + + require.Len(t, events, len(blockedCopyEvents)+len(blockedWriteProtectEvents), "checking blocked events") + + simulatedFCPause := make(chan struct{}) + + d := make(chan *block.Tracker) + + go func() { + acquired := uffd.settleRequests.TryLock() + assert.False(t, acquired, "settleRequests write lock should not be acquired") + + simulatedFCPause <- struct{}{} + + // This should block, until the events are resolved. + dirty := uffd.Dirty() + + select { + case d <- dirty: + case <-t.Context().Done(): + return + } + }() + + // This would be the place where the FC API Pause would return. + <-simulatedFCPause + + // Resolve the events to unblock getting the dirty pages in the goroutine. + for _, e := range blockedCopyEvents { + e.resolve() + } + + for _, e := range blockedWriteProtectEvents { + e.resolve() + } + + select { + case <-mockFd.copyCh: + t.Fatalf("copy channel should not have any events") + case <-mockFd.writeProtectCh: + t.Fatalf("write protect channel should not have any events") + case <-t.Context().Done(): + t.FailNow() + case dirty, ok := <-d: + if !ok { + t.FailNow() + } + + assert.ElementsMatch(t, dirtyOffsets(events), slices.Collect(dirty.Offsets()), "checking dirty pages") + } + } + + t.Run("missing", func(t *testing.T) { + t.Parallel() + + events := []event{ + {UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, offset: 2 * int64(header.PageSize)}, + } + + testEventsSettle(t, events) + }) + + t.Run("write protect", func(t *testing.T) { + t.Parallel() + + events := []event{ + {UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, offset: 2 * int64(header.PageSize)}, + } + + testEventsSettle(t, events) + }) + + t.Run("missing write", func(t *testing.T) { + t.Parallel() + + events := []event{ + {UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, offset: 2 * int64(header.PageSize)}, + } + + testEventsSettle(t, events) + }) + + t.Run("event mix", func(t *testing.T) { + t.Parallel() + + events := []event{ + {UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, offset: 2 * int64(header.PageSize)}, + {UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, offset: 2 * int64(header.PageSize)}, + {UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, offset: 2 * int64(header.PageSize)}, + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, offset: 0, + }, + } + + testEventsSettle(t, events) + }) +} + +type event struct { + *UffdioCopy + *UffdioWriteProtect + + offset int64 +} + +func (e event) trigger(ctx context.Context, uffd *Userfaultfd) error { + switch { + case e.UffdioCopy != nil: + triggerMissing(ctx, uffd, *e.UffdioCopy, e.offset) + case e.UffdioWriteProtect != nil: + triggerWriteProtected(ctx, uffd, *e.UffdioWriteProtect, e.offset) + default: + return fmt.Errorf("invalid event: %+v", e) + } + + return nil +} + +// Return the event copy without the offset, because the offset is propagated only through the "accessed" and "dirty" sets, so direct comparisons would fail. +func (e event) withoutOffset() event { + return event{ + UffdioCopy: e.UffdioCopy, + UffdioWriteProtect: e.UffdioWriteProtect, + } +} + +// Creates a new slice of events with the offset set to 0, so we can compare the events without the offset. +func zeroOffsets(events []event) []event { + return utils.Map(events, func(e event) event { + return e.withoutOffset() + }) +} + +func triggerMissing(ctx context.Context, uffd *Userfaultfd, c UffdioCopy, offset int64) { + var write bool + + if c.mode != UFFDIO_COPY_MODE_WP { + write = true + } + + uffd.handleMissing( + ctx, + func() error { return nil }, + uintptr(c.dst), + uintptr(uffd.src.BlockSize()), + offset, + write, + ) +} + +func triggerWriteProtected(ctx context.Context, uffd *Userfaultfd, c UffdioWriteProtect, offset int64) { + uffd.handleWriteProtected( + ctx, + func() error { return nil }, + uintptr(c._range.start), + uintptr(uffd.src.BlockSize()), + offset, + ) +} + +func dirtyOffsets(events []event) []int64 { + offsets := make(map[int64]struct{}) + + for _, e := range events { + if e.UffdioWriteProtect != nil { + offsets[e.offset] = struct{}{} + } + + if e.UffdioCopy != nil { + if e.UffdioCopy.mode != UFFDIO_COPY_MODE_WP { + offsets[e.offset] = struct{}{} + } + } + } + + return slices.Collect(maps.Keys(offsets)) +}