From 24f2e71071e33893c5e6a3735443d7bcb8b343c5 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Tue, 9 Dec 2025 17:31:49 -0800 Subject: [PATCH 01/40] WIP --- packages/api/internal/cfg/model.go | 2 +- .../internal/sandbox/block/cache.go | 12 + .../internal/sandbox/block/range.go | 79 ++++++ .../internal/sandbox/block/tracker.go | 18 +- .../internal/sandbox/build/diff.go | 11 + .../internal/sandbox/build/local_diff.go | 42 ++- .../internal/sandbox/build/storage_diff.go | 8 +- .../internal/sandbox/diffcreator.go | 36 --- .../internal/sandbox/fc/client.go | 35 ++- .../internal/sandbox/fc/memory.go | 147 +++++++++++ .../internal/sandbox/fc/process.go | 4 +- .../orchestrator/internal/sandbox/sandbox.go | 152 ++++++----- .../internal/sandbox/uffd/memory/mapping.go | 83 +++++- .../uffd/memory/mapping_host_virt_test.go | 230 ++++++++++++++++ ...mapping_test.go => mapping_offset_test.go} | 186 +++++++------ .../internal/sandbox/uffd/memory/region.go | 23 ++ .../internal/sandbox/uffd/memory_backend.go | 8 +- .../internal/sandbox/uffd/noop.go | 18 +- .../internal/sandbox/uffd/uffd.go | 29 +-- .../internal/sandbox/uffd/userfaultfd/fd.go | 18 +- .../sandbox/uffd/userfaultfd/userfaultfd.go | 10 - .../get_memory_mappings_parameters.go | 128 +++++++++ .../get_memory_mappings_responses.go | 187 +++++++++++++ .../operations/get_memory_parameters.go | 128 +++++++++ .../client/operations/get_memory_responses.go | 187 +++++++++++++ .../fc/client/operations/operations_client.go | 80 ++++++ packages/shared/pkg/fc/firecracker.yml | 245 ++++++++++++++---- packages/shared/pkg/fc/models/cpu_config.go | 47 +++- .../pkg/fc/models/full_vm_configuration.go | 102 ++++++++ .../fc/models/guest_memory_region_mapping.go | 122 +++++++++ .../pkg/fc/models/memory_mappings_response.go | 124 +++++++++ .../shared/pkg/fc/models/memory_response.go | 88 +++++++ .../shared/pkg/fc/models/network_override.go | 88 +++++++ .../pkg/fc/models/snapshot_create_params.go | 16 +- .../pkg/fc/models/snapshot_load_params.go | 63 +++++ .../shared/pkg/storage/temporary_memfile.go | 61 ----- 36 files changed, 2375 insertions(+), 442 deletions(-) create mode 100644 packages/orchestrator/internal/sandbox/block/range.go create mode 100644 packages/orchestrator/internal/sandbox/fc/memory.go create mode 100644 packages/orchestrator/internal/sandbox/uffd/memory/mapping_host_virt_test.go rename packages/orchestrator/internal/sandbox/uffd/memory/{mapping_test.go => mapping_offset_test.go} (50%) create mode 100644 packages/shared/pkg/fc/client/operations/get_memory_mappings_parameters.go create mode 100644 packages/shared/pkg/fc/client/operations/get_memory_mappings_responses.go create mode 100644 packages/shared/pkg/fc/client/operations/get_memory_parameters.go create mode 100644 packages/shared/pkg/fc/client/operations/get_memory_responses.go create mode 100644 packages/shared/pkg/fc/models/guest_memory_region_mapping.go create mode 100644 packages/shared/pkg/fc/models/memory_mappings_response.go create mode 100644 packages/shared/pkg/fc/models/memory_response.go create mode 100644 packages/shared/pkg/fc/models/network_override.go delete mode 100644 packages/shared/pkg/storage/temporary_memfile.go diff --git a/packages/api/internal/cfg/model.go b/packages/api/internal/cfg/model.go index e5f1b5dfb3..59745406fd 100644 --- a/packages/api/internal/cfg/model.go +++ b/packages/api/internal/cfg/model.go @@ -6,7 +6,7 @@ const ( DefaultKernelVersion = "vmlinux-6.1.158" // The Firecracker version the last tag + the short SHA (so we can build our dev previews) // TODO: The short tag here has only 7 characters — the one from our build pipeline will likely have exactly 8 so this will break. - DefaultFirecrackerVersion = "v1.12.1_d990331" + DefaultFirecrackerVersion = "v1.12.2_g1133bd6cd" ) type Config struct { diff --git a/packages/orchestrator/internal/sandbox/block/cache.go b/packages/orchestrator/internal/sandbox/block/cache.go index 666e9de883..1c8ae893a0 100644 --- a/packages/orchestrator/internal/sandbox/block/cache.go +++ b/packages/orchestrator/internal/sandbox/block/cache.go @@ -271,3 +271,15 @@ func (m *Cache) FileSize() (int64, error) { return stat.Blocks * fsStat.Bsize, nil } + +func (m *Cache) Address(off uint64) *byte { + return &(*m.mmap)[off] +} + +func (m *Cache) BlockSize() int64 { + return m.blockSize +} + +func (m *Cache) Path() string { + return m.filePath +} diff --git a/packages/orchestrator/internal/sandbox/block/range.go b/packages/orchestrator/internal/sandbox/block/range.go new file mode 100644 index 0000000000..b871341ad0 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/block/range.go @@ -0,0 +1,79 @@ +package block + +import ( + "iter" + + "github.com/bits-and-blooms/bitset" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +type Range struct { + // Start is the start address of the range in bytes. + // Start is inclusive. + Start int64 + // Size is the size of the range in bytes. + Size uint64 +} + +func (r *Range) End() int64 { + return r.Start + int64(r.Size) +} + +// Offsets returns the block offsets contained in the range. +// This assumes the Range.Start is a multiple of the blockSize. +func (r *Range) Offsets(blockSize int64) iter.Seq[int64] { + return func(yield func(offset int64) bool) { + for i := r.Start; i < r.End(); i += blockSize { + if !yield(i) { + return + } + } + } +} + +// NewRange creates a new range from a start address and size in bytes. +func NewRange(start int64, size uint64) Range { + return Range{ + Start: start, + Size: size, + } +} + +// NewRangeFromBlocks creates a new range from a start index and number of blocks. +func NewRangeFromBlocks(startIdx, numberOfBlocks, blockSize int64) Range { + return Range{ + Start: header.BlockOffset(startIdx, blockSize), + Size: uint64(header.BlockOffset(numberOfBlocks, blockSize)), + } +} + +// bitsetRanges returns a sequence of the ranges of the set bits of the bitset. +func BitsetRanges(b *bitset.BitSet) iter.Seq[Range] { + return func(yield func(Range) bool) { + start, ok := b.NextSet(0) + + for ok { + end, endOk := b.NextClear(start) + if !endOk { + yield(NewRange(int64(start), uint64(b.Len()-start))) + + return + } + + if !yield(NewRange(int64(start), uint64(end-start))) { + return + } + + start, ok = b.NextSet(end + 1) + } + } +} + +func GetSize(rs []Range) (size uint64) { + for _, r := range rs { + size += r.Size + } + + return size +} diff --git a/packages/orchestrator/internal/sandbox/block/tracker.go b/packages/orchestrator/internal/sandbox/block/tracker.go index b9daaed02a..dc0c74e853 100644 --- a/packages/orchestrator/internal/sandbox/block/tracker.go +++ b/packages/orchestrator/internal/sandbox/block/tracker.go @@ -53,12 +53,10 @@ func (t *Tracker) Reset() { t.b.ClearAll() } -// BitSet returns a clone of the bitset and the block size. +// BitSet returns the bitset. +// This is not safe to use concurrently. func (t *Tracker) BitSet() *bitset.BitSet { - t.mu.RLock() - defer t.mu.RUnlock() - - return t.b.Clone() + return t.b } func (t *Tracker) BlockSize() int64 { @@ -66,14 +64,20 @@ func (t *Tracker) BlockSize() int64 { } func (t *Tracker) Clone() *Tracker { + t.mu.RLock() + defer t.mu.RUnlock() + return &Tracker{ - b: t.BitSet(), + b: t.b.Clone(), blockSize: t.BlockSize(), } } func (t *Tracker) Offsets() iter.Seq[int64] { - return bitsetOffsets(t.BitSet(), t.BlockSize()) + t.mu.RLock() + defer t.mu.RUnlock() + + return bitsetOffsets(t.b.Clone(), t.BlockSize()) } func bitsetOffsets(b *bitset.BitSet, blockSize int64) iter.Seq[int64] { diff --git a/packages/orchestrator/internal/sandbox/build/diff.go b/packages/orchestrator/internal/sandbox/build/diff.go index 87f63a2194..84ca020b99 100644 --- a/packages/orchestrator/internal/sandbox/build/diff.go +++ b/packages/orchestrator/internal/sandbox/build/diff.go @@ -2,9 +2,12 @@ package build import ( "context" + "fmt" "io" + "path/filepath" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" + "github.com/e2b-dev/infra/packages/shared/pkg/id" "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) @@ -66,3 +69,11 @@ func (n *NoDiff) Init(context.Context) error { func (n *NoDiff) BlockSize() int64 { return 0 } + +func GenerateDiffCachePath(basePath string, buildId string, diffType DiffType) string { + cachePathSuffix := id.Generate() + + cacheFile := fmt.Sprintf("%s-%s-%s", buildId, diffType, cachePathSuffix) + + return filepath.Join(basePath, cacheFile) +} diff --git a/packages/orchestrator/internal/sandbox/build/local_diff.go b/packages/orchestrator/internal/sandbox/build/local_diff.go index c192e52778..91757f024c 100644 --- a/packages/orchestrator/internal/sandbox/build/local_diff.go +++ b/packages/orchestrator/internal/sandbox/build/local_diff.go @@ -4,10 +4,8 @@ import ( "context" "fmt" "os" - "path/filepath" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" - "github.com/e2b-dev/infra/packages/shared/pkg/id" ) type LocalDiffFile struct { @@ -22,10 +20,7 @@ func NewLocalDiffFile( buildId string, diffType DiffType, ) (*LocalDiffFile, error) { - cachePathSuffix := id.Generate() - - cacheFile := fmt.Sprintf("%s-%s-%s", buildId, diffType, cachePathSuffix) - cachePath := filepath.Join(basePath, cacheFile) + cachePath := GenerateDiffCachePath(basePath, buildId, diffType) f, err := os.OpenFile(cachePath, os.O_RDWR|os.O_CREATE, 0o644) if err != nil { @@ -77,41 +72,44 @@ func (f *LocalDiffFile) CloseToDiff( f.cachePath, size.Size(), blockSize, + true, ) } type localDiff struct { - size int64 - blockSize int64 - cacheKey DiffStoreKey - cachePath string - cache *block.Cache + cacheKey DiffStoreKey + cache *block.Cache } var _ Diff = (*localDiff)(nil) +func NewLocalDiffFromCache( + cacheKey DiffStoreKey, + cache *block.Cache, +) (*localDiff, error) { + return &localDiff{ + cache: cache, + cacheKey: cacheKey, + }, nil +} + func newLocalDiff( cacheKey DiffStoreKey, cachePath string, - size int64, + size, blockSize int64, + dirty bool, ) (*localDiff, error) { - cache, err := block.NewCache(size, blockSize, cachePath, true) + cache, err := block.NewCache(size, blockSize, cachePath, dirty) if err != nil { return nil, fmt.Errorf("failed to create cache: %w", err) } - return &localDiff{ - size: size, - blockSize: blockSize, - cacheKey: cacheKey, - cachePath: cachePath, - cache: cache, - }, nil + return NewLocalDiffFromCache(cacheKey, cache) } func (b *localDiff) CachePath() (string, error) { - return b.cachePath, nil + return b.cache.Path(), nil } func (b *localDiff) Close() error { @@ -139,5 +137,5 @@ func (b *localDiff) Init(context.Context) error { } func (b *localDiff) BlockSize() int64 { - return b.blockSize + return b.cache.BlockSize() } diff --git a/packages/orchestrator/internal/sandbox/build/storage_diff.go b/packages/orchestrator/internal/sandbox/build/storage_diff.go index 02723401d5..ee6e77fdc6 100644 --- a/packages/orchestrator/internal/sandbox/build/storage_diff.go +++ b/packages/orchestrator/internal/sandbox/build/storage_diff.go @@ -4,11 +4,9 @@ import ( "context" "fmt" "io" - "path/filepath" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" blockmetrics "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block/metrics" - "github.com/e2b-dev/infra/packages/shared/pkg/id" "github.com/e2b-dev/infra/packages/shared/pkg/storage" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) @@ -47,15 +45,13 @@ func newStorageDiff( metrics blockmetrics.Metrics, persistence storage.StorageProvider, ) (*StorageDiff, error) { - cachePathSuffix := id.Generate() - storagePath := storagePath(buildId, diffType) storageObjectType, ok := storageObjectType(diffType) if !ok { return nil, UnknownDiffTypeError{diffType} } - cacheFile := fmt.Sprintf("%s-%s-%s", buildId, diffType, cachePathSuffix) - cachePath := filepath.Join(basePath, cacheFile) + + cachePath := GenerateDiffCachePath(basePath, buildId, diffType) return &StorageDiff{ storagePath: storagePath, diff --git a/packages/orchestrator/internal/sandbox/diffcreator.go b/packages/orchestrator/internal/sandbox/diffcreator.go index 161ac1ac31..fa65b1fd60 100644 --- a/packages/orchestrator/internal/sandbox/diffcreator.go +++ b/packages/orchestrator/internal/sandbox/diffcreator.go @@ -2,15 +2,9 @@ package sandbox import ( "context" - "errors" - "fmt" "io" - "os" - - "github.com/bits-and-blooms/bitset" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/rootfs" - "github.com/e2b-dev/infra/packages/shared/pkg/storage" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) @@ -26,33 +20,3 @@ type RootfsDiffCreator struct { func (r *RootfsDiffCreator) process(ctx context.Context, out io.Writer) (*header.DiffMetadata, error) { return r.rootfs.ExportDiff(ctx, out, r.closeHook) } - -type MemoryDiffCreator struct { - memfile *storage.TemporaryMemfile - dirtyPages *bitset.BitSet - blockSize int64 - doneHook func(context.Context) error -} - -func (r *MemoryDiffCreator) process(ctx context.Context, out io.Writer) (h *header.DiffMetadata, e error) { - defer func() { - err := r.doneHook(ctx) - if err != nil { - e = errors.Join(e, err) - } - }() - - memfileSource, err := os.Open(r.memfile.Path()) - if err != nil { - return nil, fmt.Errorf("failed to open memfile: %w", err) - } - defer memfileSource.Close() - - return header.WriteDiffWithTrace( - ctx, - memfileSource, - r.blockSize, - r.dirtyPages, - out, - ) -} diff --git a/packages/orchestrator/internal/sandbox/fc/client.go b/packages/orchestrator/internal/sandbox/fc/client.go index 33449b5159..1396980d77 100644 --- a/packages/orchestrator/internal/sandbox/fc/client.go +++ b/packages/orchestrator/internal/sandbox/fc/client.go @@ -4,14 +4,17 @@ import ( "context" "fmt" + "github.com/bits-and-blooms/bitset" "github.com/firecracker-microvm/firecracker-go-sdk" "github.com/go-openapi/strfmt" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/socket" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/template" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" "github.com/e2b-dev/infra/packages/shared/pkg/fc/client" "github.com/e2b-dev/infra/packages/shared/pkg/fc/client/operations" "github.com/e2b-dev/infra/packages/shared/pkg/fc/models" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) @@ -126,13 +129,11 @@ func (c *apiClient) pauseVM(ctx context.Context) error { func (c *apiClient) createSnapshot( ctx context.Context, snapfilePath string, - memfilePath string, ) error { snapshotConfig := operations.CreateSnapshotParams{ Context: ctx, Body: &models.SnapshotCreateParams{ SnapshotType: models.SnapshotCreateParamsSnapshotTypeFull, - MemFilePath: &memfilePath, SnapshotPath: &snapfilePath, }, } @@ -301,3 +302,33 @@ func (c *apiClient) startVM(ctx context.Context) error { return nil } + +func (c *apiClient) memoryMappings(ctx context.Context) (*memory.Mapping, error) { + memoryMappingsParams := operations.GetMemoryMappingsParams{ + Context: ctx, + } + + memoryMappings, err := c.client.Operations.GetMemoryMappings(&memoryMappingsParams) + if err != nil { + return nil, fmt.Errorf("error getting memory mappings: %w", err) + } + + return memory.NewMappingFromFc(memoryMappings.Payload.Mappings) +} + +func (c *apiClient) memoryInfo(ctx context.Context, blockSize int64) (*header.DiffMetadata, error) { + memoryParams := operations.GetMemoryParams{ + Context: ctx, + } + + memoryInfo, err := c.client.Operations.GetMemory(&memoryParams) + if err != nil { + return nil, fmt.Errorf("error getting memory: %w", err) + } + + return &header.DiffMetadata{ + Dirty: bitset.From(memoryInfo.Payload.Resident), + Empty: bitset.From(memoryInfo.Payload.Empty), + BlockSize: blockSize, + }, nil +} diff --git a/packages/orchestrator/internal/sandbox/fc/memory.go b/packages/orchestrator/internal/sandbox/fc/memory.go new file mode 100644 index 0000000000..ef168c4719 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/fc/memory.go @@ -0,0 +1,147 @@ +package fc + +import ( + "context" + "errors" + "fmt" + "math/rand" + "time" + + "github.com/bits-and-blooms/bitset" + "github.com/tklauser/go-sysconf" + "golang.org/x/sys/unix" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" +) + +var IOV_MAX = utils.Must(getIOVMax()) + +const ( + oomMinBackoff = 100 * time.Millisecond + oomMaxJitter = 100 * time.Millisecond +) + +func (p *Process) MemoryInfo(ctx context.Context, blockSize int64) (*header.DiffMetadata, error) { + return p.client.memoryInfo(ctx, blockSize) +} + +func (p *Process) ExportMemory( + ctx context.Context, + include *bitset.BitSet, + cachePath string, + blockSize int64, +) (*block.Cache, error) { + m, err := p.client.memoryMappings(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get memory mappings: %w", err) + } + + var remoteRanges []block.Range + + for r := range block.BitsetRanges(include) { + hostVirtRanges, err := m.GetHostVirtRanges(r.Start, int64(r.Size)) + if err != nil { + return nil, fmt.Errorf("failed to get host virt ranges: %w", err) + } + + remoteRanges = append(remoteRanges, hostVirtRanges...) + } + + size := block.GetSize(remoteRanges) + + cache, err := block.NewCache(int64(size), blockSize, cachePath, false) + if err != nil { + return nil, fmt.Errorf("failed to create cache: %w", err) + } + + pid, err := p.Pid() + if err != nil { + return nil, fmt.Errorf("failed to get pid: %w", err) + } + + err = copyProcessMemory(ctx, pid, remoteRanges, cache) + if err != nil { + return nil, fmt.Errorf("failed to copy process memory: %w", err) + } + + return cache, nil +} + +func copyProcessMemory( + ctx context.Context, + pid int, + ranges []block.Range, + local *block.Cache, +) error { + var start uint64 + + for i := 0; i < len(ranges); i += int(IOV_MAX) { + segmentRanges := ranges[i:min(i+int(IOV_MAX), len(ranges))] + + remote := make([]unix.RemoteIovec, len(segmentRanges)) + + var segmentSize uint64 + + for j, r := range segmentRanges { + remote[j] = unix.RemoteIovec{ + Base: uintptr(r.Start), + Len: int(r.Size), + } + + segmentSize += r.Size + } + + local := []unix.Iovec{ + { + Base: local.Address(start), + // We could keep this as full cache length, but we might as well be exact here. + Len: segmentSize, + }, + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + // We could retry only on the remaining segment size, but for simplicity we retry the whole segment. + _, err := unix.ProcessVMReadv(pid, + local, + remote, + 0, + ) + if errors.Is(err, unix.EAGAIN) { + continue + } + if errors.Is(err, unix.EINTR) { + continue + } + if errors.Is(err, unix.ENOMEM) { + time.Sleep(oomMinBackoff + time.Duration(rand.Intn(int(oomMaxJitter.Milliseconds())))*time.Millisecond) + + continue + } + + if err != nil { + return fmt.Errorf("failed to read memory: %w", err) + } + + start += segmentSize + } + } + + return nil +} + +func getIOVMax() (int64, error) { + iovMax, err := sysconf.Sysconf(sysconf.SC_IOV_MAX) + if err != nil { + return 0, fmt.Errorf("failed to get IOV_MAX: %w", err) + } + + return iovMax, nil +} diff --git a/packages/orchestrator/internal/sandbox/fc/process.go b/packages/orchestrator/internal/sandbox/fc/process.go index 88e919fb83..5f9d66b511 100644 --- a/packages/orchestrator/internal/sandbox/fc/process.go +++ b/packages/orchestrator/internal/sandbox/fc/process.go @@ -487,9 +487,9 @@ func (p *Process) Pause(ctx context.Context) error { } // CreateSnapshot VM needs to be paused before creating a snapshot. -func (p *Process) CreateSnapshot(ctx context.Context, snapfilePath string, memfilePath string) error { +func (p *Process) CreateSnapshot(ctx context.Context, snapfilePath string) error { ctx, childSpan := tracer.Start(ctx, "create-snapshot-fc") defer childSpan.End() - return p.client.createSnapshot(ctx, snapfilePath, memfilePath) + return p.client.createSnapshot(ctx, snapfilePath) } diff --git a/packages/orchestrator/internal/sandbox/sandbox.go b/packages/orchestrator/internal/sandbox/sandbox.go index ea30e3161a..ae72b4a2e7 100644 --- a/packages/orchestrator/internal/sandbox/sandbox.go +++ b/packages/orchestrator/internal/sandbox/sandbox.go @@ -7,6 +7,7 @@ import ( "net/http" "time" + "github.com/bits-and-blooms/bitset" "github.com/google/uuid" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel" @@ -93,6 +94,16 @@ type Resources struct { Slot *network.Slot rootfs rootfs.Provider memory uffd.MemoryBackend + // Filter to apply to the dirty bitset before creating the diff metadata. + memoryDiffFilter func(ctx context.Context) (*header.DiffMetadata, error) +} + +func (r *Resources) Dirty(ctx context.Context) (*header.DiffMetadata, error) { + if r.memoryDiffFilter == nil { + return nil, fmt.Errorf("memory diff filter is not set") + } + + return r.memoryDiffFilter(ctx) } type internalConfig struct { @@ -291,6 +302,27 @@ func (f *Factory) CreateSandbox( Slot: ips.slot, rootfs: rootfsProvider, memory: uffd.NewNoopMemory(memfileSize, memfile.BlockSize()), + memoryDiffFilter: func(ctx context.Context) (*header.DiffMetadata, error) { + diffInfo, err := fcHandle.MemoryInfo(ctx, memfile.BlockSize()) + if err != nil { + return nil, err + } + + dirty := diffInfo.Dirty.Difference(diffInfo.Empty) + + numberOfPages := header.BlockOffset(memfileSize, memfile.BlockSize()) + + empty := bitset.New(uint(numberOfPages)) + empty.FlipRange(0, uint(numberOfPages)) + + empty = empty.Difference(dirty) + + return &header.DiffMetadata{ + Dirty: dirty, + Empty: empty, + BlockSize: memfile.BlockSize(), + }, nil + }, } metadata := &Metadata{ @@ -525,6 +557,19 @@ func (f *Factory) ResumeSandbox( Slot: ips.slot, rootfs: rootfsOverlay, memory: fcUffd, + memoryDiffFilter: func(ctx context.Context) (*header.DiffMetadata, error) { + dirty, err := fcUffd.Dirty(ctx) + if err != nil { + return nil, err + } + + return &header.DiffMetadata{ + Dirty: dirty, + // We don't track and filter empty pages for subsequent sandbox pauses as pages should usually not be empty. + Empty: bitset.New(0), + BlockSize: memfile.BlockSize(), + }, nil + }, } metadata := &Metadata{ @@ -673,10 +718,6 @@ func (s *Sandbox) Shutdown(ctx context.Context) error { return fmt.Errorf("failed to pause VM: %w", err) } - if _, err := s.memory.Disable(ctx); err != nil { - return fmt.Errorf("failed to disable uffd: %w", err) - } - // This is required because the FC API doesn't support passing /dev/null tf, err := storage.TemplateFiles{ BuildID: uuid.New().String(), @@ -692,28 +733,11 @@ func (s *Sandbox) Shutdown(ctx context.Context) error { snapfile := template.NewLocalFileLink(tf.CacheSnapfilePath()) defer snapfile.Close() - // The memfile is required only because the FC API doesn't support passing /dev/null - memfile, err := storage.AcquireTmpMemfile(ctx, s.config, tf.BuildID) - if err != nil { - return fmt.Errorf("failed to acquire memfile snapshot: %w", err) - } - defer memfile.Close() - - err = s.process.CreateSnapshot( - ctx, - snapfile.Path(), - memfile.Path(), - ) + err = s.process.CreateSnapshot(ctx, snapfile.Path()) if err != nil { return fmt.Errorf("error creating snapshot: %w", err) } - // Close the memfile right after the snapshot to release the lock. - err = memfile.Close() - if err != nil { - return fmt.Errorf("error closing memfile: %w", err) - } - // This should properly flush rootfs to the underlying device. err = s.Close(ctx) if err != nil { @@ -757,37 +781,11 @@ func (s *Sandbox) Pause( return nil, fmt.Errorf("failed to pause VM: %w", err) } - // This disables the uffd and returns the dirty pages. - // With FC async io engine, there can be some further writes to the memory during the actual create snapshot process, - // but as we are still including even read pages as dirty so this should not introduce more bugs right now. - dirty, err := s.memory.Disable(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get dirty pages: %w", err) - } - // Snapfile is not closed as it's returned and cached for later use (like resume) snapfile := template.NewLocalFileLink(snapshotTemplateFiles.CacheSnapfilePath()) cleanup.AddNoContext(ctx, snapfile.Close) - // Memfile is also closed on diff creation processing - /* The process of snapshotting memory is as follows: - 1. Pause FC via API - 2. Snapshot FC via API—memory dump to “file on disk” that is actually tmpfs, because it is too slow - 3. Create the diff - copy the diff pages from tmpfs to normal disk file - 4. Delete tmpfs file - 5. Unlock so another snapshot can use tmpfs space - */ - memfile, err := storage.AcquireTmpMemfile(ctx, s.config, buildID.String()) - if err != nil { - return nil, fmt.Errorf("failed to acquire memfile snapshot: %w", err) - } - // Close the file even if an error occurs - defer memfile.Close() - err = s.process.CreateSnapshot( - ctx, - snapfile.Path(), - memfile.Path(), - ) + err = s.process.CreateSnapshot(ctx, snapfile.Path()) if err != nil { return nil, fmt.Errorf("error creating snapshot: %w", err) } @@ -797,25 +795,25 @@ func (s *Sandbox) Pause( if err != nil { return nil, fmt.Errorf("failed to get original memfile: %w", err) } + originalRootfs, err := s.Template.Rootfs() if err != nil { return nil, fmt.Errorf("failed to get original rootfs: %w", err) } + diffMetadata, err := s.Resources.Dirty(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get dirty memory: %w", err) + } + // Start POSTPROCESSING memfileDiff, memfileDiffHeader, err := pauseProcessMemory( ctx, buildID, originalMemfile.Header(), - &MemoryDiffCreator{ - memfile: memfile, - dirtyPages: dirty.BitSet(), - blockSize: originalMemfile.BlockSize(), - doneHook: func(context.Context) error { - return memfile.Close() - }, - }, + diffMetadata, s.config.DefaultCacheDir, + s.process, ) if err != nil { return nil, fmt.Errorf("error while post processing: %w", err) @@ -859,45 +857,41 @@ func (s *Sandbox) Pause( func pauseProcessMemory( ctx context.Context, - buildId uuid.UUID, + buildID uuid.UUID, originalHeader *header.Header, - diffCreator DiffCreator, + diffMetadata *header.DiffMetadata, cacheDir string, + fc *fc.Process, ) (d build.Diff, h *header.Header, e error) { ctx, span := tracer.Start(ctx, "process-memory") defer span.End() - memfileDiffFile, err := build.NewLocalDiffFile( - cacheDir, - buildId.String(), - build.Memfile, - ) - if err != nil { - return nil, nil, fmt.Errorf("failed to create memfile diff file: %w", err) - } + memfileDiffPath := build.GenerateDiffCachePath(cacheDir, buildID.String(), build.Memfile) - m, err := diffCreator.process(ctx, memfileDiffFile) + cache, err := fc.ExportMemory( + ctx, + diffMetadata.Dirty, + memfileDiffPath, + diffMetadata.BlockSize, + ) if err != nil { - err = errors.Join(err, memfileDiffFile.Close()) - - return nil, nil, fmt.Errorf("error creating diff: %w", err) + return nil, nil, fmt.Errorf("failed to export memory: %w", err) } - telemetry.ReportEvent(ctx, "created diff") - memfileDiff, err := memfileDiffFile.CloseToDiff(int64(originalHeader.Metadata.BlockSize)) + diff, err := build.NewLocalDiffFromCache( + build.GetDiffStoreKey(buildID.String(), build.Memfile), + cache, + ) if err != nil { - return nil, nil, fmt.Errorf("failed to convert memfile diff file to local diff: %w", err) + return nil, nil, fmt.Errorf("failed to create local diff from cache: %w", err) } - telemetry.ReportEvent(ctx, "converted memfile diff file to local diff") - memfileHeader, err := m.ToDiffHeader(ctx, originalHeader, buildId) + header, err := diffMetadata.ToDiffHeader(ctx, originalHeader, buildID) if err != nil { - err = errors.Join(err, memfileDiff.Close()) - return nil, nil, fmt.Errorf("failed to create memfile header: %w", err) } - return memfileDiff, memfileHeader, nil + return diff, header, nil } func pauseProcessRootfs( diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go b/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go index 824a1e4adb..d82ad56069 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go @@ -2,6 +2,9 @@ package memory import ( "fmt" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" + "github.com/e2b-dev/infra/packages/shared/pkg/fc/models" ) type AddressNotFoundError struct { @@ -9,7 +12,15 @@ type AddressNotFoundError struct { } func (e AddressNotFoundError) Error() string { - return fmt.Sprintf("address %d not found in any mapping", e.hostVirtAddr) + return fmt.Sprintf("host virtual address %d not found in any mapping", e.hostVirtAddr) +} + +type OffsetNotFoundError struct { + offset int64 +} + +func (e OffsetNotFoundError) Error() string { + return fmt.Sprintf("offset %d not found in any mapping", e.offset) } type Mapping struct { @@ -20,13 +31,77 @@ func NewMapping(regions []Region) *Mapping { return &Mapping{Regions: regions} } -// GetOffset returns the relative offset and the page size of the mapped range for a given address. -func (m *Mapping) GetOffset(hostVirtAddr uintptr) (int64, uint64, error) { +func NewMappingFromFc(regions []*models.GuestMemoryRegionMapping) (*Mapping, error) { + r := make([]Region, len(regions)) + + for i, infoRegion := range regions { + if infoRegion.BaseHostVirtAddr == nil || infoRegion.Size == nil || infoRegion.Offset == nil || infoRegion.PageSize == nil { + return nil, fmt.Errorf("missing required fields for memory region %d", i) + } + + r[i] = Region{ + BaseHostVirtAddr: uintptr(*infoRegion.BaseHostVirtAddr), + Size: uintptr(*infoRegion.Size), + Offset: uintptr(*infoRegion.Offset), + PageSize: uintptr(*infoRegion.PageSize), + } + } + + return NewMapping(r), nil +} + +// GetOffset returns the relative offset and the pagesize of the mapped range for a given address. +func (m *Mapping) GetOffset(hostVirtAddr uintptr) (int64, uintptr, error) { for _, r := range m.Regions { if hostVirtAddr >= r.BaseHostVirtAddr && hostVirtAddr < r.endHostVirtAddr() { - return r.shiftedOffset(hostVirtAddr), uint64(r.PageSize), nil + return r.shiftedOffset(hostVirtAddr), r.PageSize, nil } } return 0, 0, AddressNotFoundError{hostVirtAddr: hostVirtAddr} } + +// GetHostVirtAddr returns the host virtual address and size of the remaining contiguous mapped host range for the given offset. +func (m *Mapping) GetHostVirtAddr(off int64) (uintptr, int64, error) { + for _, r := range m.Regions { + if off >= int64(r.Offset) && off < r.endOffset() { + return r.shiftedHostVirtAddr(off), r.endOffset() - off, nil + } + } + + return 0, 0, OffsetNotFoundError{offset: off} +} + +// GetHostVirtRanges returns the host virtual addresses and sizes (ranges) that cover exactly the given [offset, offset+length) range in the host virtual address space. +func (m *Mapping) GetHostVirtRanges(off int64, size int64) (hostVirtRanges []block.Range, err error) { + for n := int64(0); n < size; { + currentOff := off + n + + region, err := m.getHostVirtRegion(currentOff) + if err != nil { + return nil, err + } + + start := region.shiftedHostVirtAddr(currentOff) + remainingSize := min(int64(region.endHostVirtAddr()-start), size-n) + + r := block.NewRange(int64(start), uint64(remainingSize)) + + hostVirtRanges = append(hostVirtRanges, r) + + n += int64(r.Size) + } + + return hostVirtRanges, nil +} + +// getHostVirtRegion returns the region that contains the given offset. +func (m *Mapping) getHostVirtRegion(off int64) (*Region, error) { + for _, r := range m.Regions { + if off >= int64(r.Offset) && off < r.endOffset() { + return &r, nil + } + } + + return nil, OffsetNotFoundError{offset: off} +} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/mapping_host_virt_test.go b/packages/orchestrator/internal/sandbox/uffd/memory/mapping_host_virt_test.go new file mode 100644 index 0000000000..05e8cc75f5 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/memory/mapping_host_virt_test.go @@ -0,0 +1,230 @@ +package memory + +import ( + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +func TestMapping_GetHostVirtAddr(t *testing.T) { + t.Parallel() + + regions := []Region{ + { + BaseHostVirtAddr: 0x1000, + Size: 0x2000, + Offset: 0x5000, + PageSize: header.PageSize, + }, + { + BaseHostVirtAddr: 0x5000, + Size: 0x1000, + Offset: 0x8000, + PageSize: header.PageSize, + }, + } + mapping := NewMapping(regions) + + tests := []struct { + name string + offset int64 + expectedHostVirt uintptr + remainingRegionSize int64 + expectError error + }{ + { + name: "valid offset in first region", + offset: 0x5500, // 0x5000 + (0x1500 - 0x1000) + expectedHostVirt: 0x1500, // 0x1000 + (0x5500 - 0x5000) + // region ends at 0x7000; remaining = 0x7000 - 0x5500 = 0x1b00 + remainingRegionSize: 0x1b00, + }, + { + name: "valid offset at start of first region", + offset: 0x5000, + expectedHostVirt: 0x1000, // 0x1000 + (0x5000 - 0x5000) + remainingRegionSize: 0x2000, // 0x7000 - 0x5000 + }, + { + name: "valid offset near end of first region", + offset: 0x6FFF, // 0x7000 - 1 + expectedHostVirt: 0x2FFF, // 0x1000 + (0x6FFF - 0x5000) + remainingRegionSize: 0x1, // 0x7000 - 0x6FFF + }, + { + name: "valid offset at start of second region", + offset: 0x8000, + expectedHostVirt: 0x5000, // 0x5000 + (0x8000 - 0x8000) + remainingRegionSize: 0x1000, // 0x9000 - 0x8000 + }, + { + name: "offset before first region", + offset: 0x4000, + expectError: OffsetNotFoundError{offset: 0x4000}, + }, + { + name: "offset after last region", + offset: 0xA000, + expectError: OffsetNotFoundError{offset: 0xA000}, + }, + { + name: "offset in gap between regions", + offset: 0x7000, + expectError: OffsetNotFoundError{offset: 0x7000}, + }, + { + name: "offset at exact end of first region (exclusive)", + offset: 0x7000, // 0x5000 + 0x2000 + expectError: OffsetNotFoundError{offset: 0x7000}, + }, + { + name: "offset at exact end of second region (exclusive)", + offset: 0x9000, // 0x8000 + 0x1000 + expectError: OffsetNotFoundError{offset: 0x9000}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + hostVirt, size, err := mapping.GetHostVirtAddr(tt.offset) + if tt.expectError != nil { + require.ErrorIs(t, err, tt.expectError) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedHostVirt, hostVirt, "hostVirt: %d, expectedHostVirt: %d", hostVirt, tt.expectedHostVirt) + assert.Equal(t, tt.remainingRegionSize, size, "size: %d, expectedSize: %d", size, tt.remainingRegionSize) + } + }) + } +} + +func TestMapping_GetHostVirtAddr_EmptyRegions(t *testing.T) { + t.Parallel() + + mapping := NewMapping([]Region{}) + + // Test GetHostVirtAddr with empty regions + _, _, err := mapping.GetHostVirtAddr(0x1000) + require.ErrorIs(t, err, OffsetNotFoundError{offset: 0x1000}) +} + +func TestMapping_GetHostVirtAddr_BoundaryConditions(t *testing.T) { + t.Parallel() + + regions := []Region{ + { + BaseHostVirtAddr: 0x1000, + Size: 0x2000, + Offset: 0x5000, + PageSize: header.PageSize, + }, + } + + mapping := NewMapping(regions) + + // Test exact start boundary + hostVirt, size, err := mapping.GetHostVirtAddr(0x5000) + require.NoError(t, err) + assert.Equal(t, uintptr(0x1000), hostVirt) // 0x1000 + (0x5000 - 0x5000) + assert.Equal(t, int64(0x7000-0x5000), size) // 0x2000 + + // Test offset before end boundary + hostVirt, size, err = mapping.GetHostVirtAddr(0x6FFF) // just before end + require.NoError(t, err) + assert.Equal(t, uintptr(0x1000+(0x6FFF-0x5000)), hostVirt) + assert.Equal(t, int64(0x7000-0x6FFF), size) + + // Test exact end boundary (should fail - exclusive) + _, _, err = mapping.GetHostVirtAddr(0x7000) + require.ErrorIs(t, err, OffsetNotFoundError{offset: 0x7000}) + + // Test below start boundary (should fail) + _, _, err = mapping.GetHostVirtAddr(0x4000) + require.ErrorIs(t, err, OffsetNotFoundError{offset: 0x4000}) +} + +func TestMapping_GetHostVirtAddr_SingleLargeRegion(t *testing.T) { + t.Parallel() + + // Entire 64-bit address space region + regions := []Region{ + { + BaseHostVirtAddr: 0x0, + Size: math.MaxInt64 - 0x100, + Offset: 0x100, + PageSize: header.PageSize, + }, + } + mapping := NewMapping(regions) + + hostVirt, size, err := mapping.GetHostVirtAddr(0x100 + 0x1000) // Offset 0x1100 + require.NoError(t, err) + assert.Equal(t, uintptr(0x1000), hostVirt) // 0x1000 + assert.Equal(t, int64(math.MaxInt64-0x100-0x1000), size) +} + +func TestMapping_GetHostVirtAddr_ZeroSizeRegion(t *testing.T) { + t.Parallel() + + regions := []Region{ + { + BaseHostVirtAddr: 0x2000, + Size: 0, + Offset: 0x1000, + PageSize: header.PageSize, + }, + } + + mapping := NewMapping(regions) + + _, _, err := mapping.GetHostVirtAddr(0x1000) + require.ErrorIs(t, err, OffsetNotFoundError{offset: 0x1000}) +} + +func TestMapping_GetHostVirtAddr_MultipleRegionsSparse(t *testing.T) { + t.Parallel() + + regions := []Region{ + { + BaseHostVirtAddr: 0x100, + Size: 0x100, + Offset: 0x1000, + PageSize: header.PageSize, + }, + { + BaseHostVirtAddr: 0x10000, + Size: 0x100, + Offset: 0x2000, + PageSize: header.PageSize, + }, + } + mapping := NewMapping(regions) + + // Should succeed for start of first region + hostVirt, size, err := mapping.GetHostVirtAddr(0x1000) + require.NoError(t, err) + assert.Equal(t, uintptr(0x100), hostVirt) // 0x100 + (0x1000 - 0x1000) + assert.Equal(t, int64(0x1100-0x1000), size) // 0x100 + + // Should succeed for just before end of first region + hostVirt, size, err = mapping.GetHostVirtAddr(0x10FF) // 0x1100 - 1 + require.NoError(t, err) + assert.Equal(t, uintptr(0x100+(0x10FF-0x1000)), hostVirt) + assert.Equal(t, int64(0x1100-0x10FF), size) // 1 + + // Should succeed for start of second region + hostVirt, size, err = mapping.GetHostVirtAddr(0x2000) + require.NoError(t, err) + assert.Equal(t, uintptr(0x10000), hostVirt) // 0x10000 + (0x2000 - 0x2000) + assert.Equal(t, int64(0x2100-0x2000), size) // 0x100 + + // In gap + _, _, err = mapping.GetHostVirtAddr(0x1500) + require.ErrorIs(t, err, OffsetNotFoundError{offset: 0x1500}) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go b/packages/orchestrator/internal/sandbox/uffd/memory/mapping_offset_test.go similarity index 50% rename from packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go rename to packages/orchestrator/internal/sandbox/uffd/memory/mapping_offset_test.go index 7b7d87e06f..4be1c1e6e0 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory/mapping_offset_test.go @@ -10,6 +10,8 @@ import ( ) func TestMapping_GetOffset(t *testing.T) { + t.Parallel() + regions := []Region{ { BaseHostVirtAddr: 0x1000, @@ -24,50 +26,51 @@ func TestMapping_GetOffset(t *testing.T) { PageSize: header.PageSize, }, } + mapping := NewMapping(regions) tests := []struct { - name string - hostVirtAddr uintptr - expectedOffset int64 - expectedSize uint64 - expectError error + name string + hostVirtAddr uintptr + expectedOffset int64 + expectedPagesize uintptr + expectError error }{ { - name: "valid address in first region", - hostVirtAddr: 0x1500, - expectedOffset: 0x5500, // 0x5000 + (0x1500 - 0x1000) - expectedSize: 0x1000, + name: "valid address in first region", + hostVirtAddr: 0x1500, + expectedOffset: 0x5500, // 0x5000 + (0x1500 - 0x1000) + expectedPagesize: 0x1000, }, { - name: "valid address at start of first region", - hostVirtAddr: 0x1000, - expectedOffset: 0x5000, - expectedSize: 0x1000, + name: "valid address at start of first region", + hostVirtAddr: 0x1000, + expectedOffset: 0x5000, + expectedPagesize: 0x1000, }, { - name: "valid address at end-1 of first region", - hostVirtAddr: 0x2FFF, // 0x1000 + 0x2000 - 1 - expectedOffset: 0x6FFF, // 0x5000 + (0x2FFF - 0x1000) - expectedSize: 0x1000, + name: "valid address at end-1 of first region", + hostVirtAddr: 0x2FFF, // 0x1000 + 0x2000 - 1 + expectedOffset: 0x6FFF, // 0x5000 + (0x2FFF - 0x1000) + expectedPagesize: 0x1000, }, { - name: "valid address in second region", - hostVirtAddr: 0x5500, - expectedOffset: 0x8500, // 0x8000 + (0x5500 - 0x5000) - expectedSize: 0x1000, + name: "valid address in second region", + hostVirtAddr: 0x5500, + expectedOffset: 0x8500, // 0x8000 + (0x5500 - 0x5000) + expectedPagesize: 0x1000, }, { - name: "valid address at start of second region", - hostVirtAddr: 0x5000, - expectedOffset: 0x8000, - expectedSize: 0x1000, + name: "valid address at start of second region", + hostVirtAddr: 0x5000, + expectedOffset: 0x8000, + expectedPagesize: 0x1000, }, { - name: "valid address at end-1 of second region", - hostVirtAddr: 0x5FFF, - expectedOffset: 0x8FFF, // 0x8000 + (0x5FFF - 0x5000) - expectedSize: 0x1000, + name: "valid address at end-1 of second region", + hostVirtAddr: 0x5FFF, + expectedOffset: 0x8FFF, // 0x8000 + (0x5FFF - 0x5000) + expectedPagesize: 0x1000, }, { name: "address before first region", @@ -98,60 +101,33 @@ func TestMapping_GetOffset(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - offset, size, err := mapping.GetOffset(tt.hostVirtAddr) + t.Parallel() + + offset, pagesize, err := mapping.GetOffset(tt.hostVirtAddr) if tt.expectError != nil { require.ErrorIs(t, err, tt.expectError) } else { require.NoError(t, err) assert.Equal(t, tt.expectedOffset, offset) - assert.Equal(t, tt.expectedSize, size) + assert.Equal(t, tt.expectedPagesize, pagesize) } }) } } func TestMapping_EmptyRegions(t *testing.T) { + t.Parallel() + mapping := NewMapping([]Region{}) // Test GetOffset with empty regions _, _, err := mapping.GetOffset(0x1000) - require.Error(t, err) -} - -func TestMapping_OverlappingRegions(t *testing.T) { - // Test with overlapping regions (edge case) - regions := []Region{ - { - BaseHostVirtAddr: 0x1000, - Size: 0x2000, - Offset: 0x5000, - PageSize: header.PageSize, - }, - { - BaseHostVirtAddr: 0x2000, // Overlaps with first region - Size: 0x1000, - Offset: 0x8000, - PageSize: header.PageSize, - }, - } - mapping := NewMapping(regions) - - // The first matching region should be returned - offset, size, err := mapping.GetOffset(0x2500) // In overlap area - require.NoError(t, err) - - // Should get result from first region - require.Equal(t, int64(0x5000+(0x2500-0x1000)), offset) // 0x6500 - require.Equal(t, uint64(header.PageSize), size) - - // Also test that the underlying implementation prefers the first region if both regions contain the address - offset2, size2, err2 := mapping.GetOffset(0x2000) - require.NoError(t, err2) - require.Equal(t, int64(0x5000+(0x2000-0x1000)), offset2) // 0x6000 from first region - require.Equal(t, uint64(header.PageSize), size2) + require.ErrorIs(t, err, AddressNotFoundError{hostVirtAddr: 0x1000}) } func TestMapping_BoundaryConditions(t *testing.T) { + t.Parallel() + regions := []Region{ { BaseHostVirtAddr: 0x1000, @@ -160,28 +136,33 @@ func TestMapping_BoundaryConditions(t *testing.T) { PageSize: header.PageSize, }, } + mapping := NewMapping(regions) // Test exact start boundary - offset, _, err := mapping.GetOffset(0x1000) + offset, pagesize, err := mapping.GetOffset(0x1000) require.NoError(t, err) - require.Equal(t, int64(0x5000), offset) // 0x5000 + (0x1000 - 0x1000) + assert.Equal(t, int64(0x5000), offset) // 0x5000 + (0x1000 - 0x1000) + assert.Equal(t, uintptr(header.PageSize), pagesize) // Test just before end boundary (exclusive) - offset, _, err = mapping.GetOffset(0x2FFF) // 0x1000 + 0x2000 - 1 + offset, pagesize, err = mapping.GetOffset(0x2FFF) // 0x1000 + 0x2000 - 1 require.NoError(t, err) - require.Equal(t, int64(0x5000+(0x2FFF-0x1000)), offset) // 0x6FFF + assert.Equal(t, int64(0x5000+(0x2FFF-0x1000)), offset) // 0x6FFF + assert.Equal(t, uintptr(header.PageSize), pagesize) // Test exact end boundary (should fail - exclusive) _, _, err = mapping.GetOffset(0x3000) // 0x1000 + 0x2000 - require.Error(t, err) + require.ErrorIs(t, err, AddressNotFoundError{hostVirtAddr: 0x3000}) // Test below start boundary (should fail) - _, _, err = mapping.GetOffset(0x0FFF) - require.Error(t, err) + _, _, err = mapping.GetOffset(0x0FFF) // 0x1000 - 0x1000 + require.ErrorIs(t, err, AddressNotFoundError{hostVirtAddr: 0x0FFF}) } func TestMapping_SingleLargeRegion(t *testing.T) { + t.Parallel() + // Entire 64-bit address space region regions := []Region{ { @@ -193,13 +174,15 @@ func TestMapping_SingleLargeRegion(t *testing.T) { } mapping := NewMapping(regions) - offset, size, err := mapping.GetOffset(0xABCDEF) + offset, pagesize, err := mapping.GetOffset(0xABCDEF) require.NoError(t, err) - require.Equal(t, int64(0x100+0xABCDEF), offset) - require.Equal(t, uint64(header.PageSize), size) + assert.Equal(t, int64(0x100+0xABCDEF), offset) + assert.Equal(t, uintptr(header.PageSize), pagesize) } func TestMapping_ZeroSizeRegion(t *testing.T) { + t.Parallel() + regions := []Region{ { BaseHostVirtAddr: 0x2000, @@ -208,12 +191,16 @@ func TestMapping_ZeroSizeRegion(t *testing.T) { PageSize: header.PageSize, }, } + mapping := NewMapping(regions) + _, _, err := mapping.GetOffset(0x2000) - require.Error(t, err) + require.ErrorIs(t, err, AddressNotFoundError{hostVirtAddr: 0x2000}) } func TestMapping_MultipleRegionsSparse(t *testing.T) { + t.Parallel() + regions := []Region{ { BaseHostVirtAddr: 0x100, @@ -229,19 +216,52 @@ func TestMapping_MultipleRegionsSparse(t *testing.T) { }, } mapping := NewMapping(regions) + // Should succeed for start of first region - offset, size, err := mapping.GetOffset(0x100) + offset, pagesize, err := mapping.GetOffset(0x100) require.NoError(t, err) - require.Equal(t, int64(0x1000), offset) - require.Equal(t, uint64(header.PageSize), size) + assert.Equal(t, int64(0x1000), offset) + assert.Equal(t, uintptr(header.PageSize), pagesize) // Should succeed for start of second region - offset, size, err = mapping.GetOffset(0x10000) + offset, pagesize, err = mapping.GetOffset(0x10000) require.NoError(t, err) - require.Equal(t, int64(0x2000), offset) - require.Equal(t, uint64(header.PageSize), size) + assert.Equal(t, int64(0x2000), offset) + assert.Equal(t, uintptr(header.PageSize), pagesize) // In gap _, _, err = mapping.GetOffset(0x5000) - require.Error(t, err) + require.ErrorIs(t, err, AddressNotFoundError{hostVirtAddr: 0x5000}) +} + +// Additional test for hugepage page size +func TestMapping_HugepagePagesize(t *testing.T) { + t.Parallel() + + const hugepageSize = 2 * 1024 * 1024 // 2MB + regions := []Region{ + { + BaseHostVirtAddr: 0x400000, + Size: hugepageSize, + Offset: 0x800000, + PageSize: hugepageSize, + }, + } + mapping := NewMapping(regions) + + // Test valid address in region using hugepages + offset, pagesize, err := mapping.GetOffset(0x401000) + require.NoError(t, err) + assert.Equal(t, int64(0x800000+(0x401000-0x400000)), offset) + assert.Equal(t, uintptr(hugepageSize), pagesize) + + // Test start of region + offset, pagesize, err = mapping.GetOffset(0x400000) + require.NoError(t, err) + assert.Equal(t, int64(0x800000), offset) + assert.Equal(t, uintptr(hugepageSize), pagesize) + + // Test end of region (exclusive, should fail) + _, _, err = mapping.GetOffset(0x400000 + uintptr(hugepageSize)) + require.ErrorIs(t, err, AddressNotFoundError{hostVirtAddr: 0x400000 + uintptr(hugepageSize)}) } diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/region.go b/packages/orchestrator/internal/sandbox/uffd/memory/region.go index db1d4f8a3c..3de0402502 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory/region.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory/region.go @@ -1,5 +1,7 @@ package memory +import "iter" + // Region is a mapping of a region of memory of the guest to a region of memory on the host. // The serialization is based on the Firecracker UFFD protocol communication. // https://github.com/firecracker-microvm/firecracker/blob/ceeca6a14284537ae0b2a192cd2ffef10d3a81e2/src/vmm/src/persist.rs#L96 @@ -11,6 +13,12 @@ type Region struct { PageSize uintptr `json:"page_size_kib"` // This is actually in bytes in the deprecated version. } +// endOffset returns the end offset of the region in bytes. +// The end offset is exclusive. +func (r *Region) endOffset() int64 { + return int64(r.Offset + r.Size) +} + // endHostVirtAddr returns the end address of the region in host virtual address. // The end address is exclusive. func (r *Region) endHostVirtAddr() uintptr { @@ -21,3 +29,18 @@ func (r *Region) endHostVirtAddr() uintptr { func (r *Region) shiftedOffset(addr uintptr) int64 { return int64(addr - r.BaseHostVirtAddr + r.Offset) } + +// shiftedHostVirtAddr returns the host virtual address of the given offset in the region. +func (r *Region) shiftedHostVirtAddr(off int64) uintptr { + return uintptr(off) + r.BaseHostVirtAddr - r.Offset +} + +func (r *Region) Offsets() iter.Seq[int64] { + return func(yield func(offset int64) bool) { + for i := int64(r.Offset); i < r.endOffset(); i += int64(r.PageSize) { + if !yield(i) { + return + } + } + } +} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory_backend.go b/packages/orchestrator/internal/sandbox/uffd/memory_backend.go index cf64c83620..c01585ec9c 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory_backend.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory_backend.go @@ -3,15 +3,13 @@ package uffd import ( "context" - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" + "github.com/bits-and-blooms/bitset" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) type MemoryBackend interface { - // Disable unregisters the uffd from the memory mapping and returns the dirty pages. - // It must be called after FC pause finished and before FC snapshot is created. - Disable(ctx context.Context) (*block.Tracker, error) - + Dirty(ctx context.Context) (*bitset.BitSet, error) Start(ctx context.Context, sandboxId string) error Stop() error Ready() chan struct{} diff --git a/packages/orchestrator/internal/sandbox/uffd/noop.go b/packages/orchestrator/internal/sandbox/uffd/noop.go index c859c333e4..f74708dc95 100644 --- a/packages/orchestrator/internal/sandbox/uffd/noop.go +++ b/packages/orchestrator/internal/sandbox/uffd/noop.go @@ -5,7 +5,6 @@ import ( "github.com/bits-and-blooms/bitset" - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) @@ -14,29 +13,26 @@ type NoopMemory struct { size int64 blockSize int64 - dirty *block.Tracker - exit *utils.ErrorOnce } var _ MemoryBackend = (*NoopMemory)(nil) func NewNoopMemory(size, blockSize int64) *NoopMemory { - blocks := header.TotalBlocks(size, blockSize) - - b := bitset.New(uint(blocks)) - b.FlipRange(0, b.Len()) - return &NoopMemory{ size: size, blockSize: blockSize, - dirty: block.NewTrackerFromBitset(b, blockSize), exit: utils.NewErrorOnce(), } } -func (m *NoopMemory) Disable(context.Context) (*block.Tracker, error) { - return m.dirty.Clone(), nil +func (m *NoopMemory) Dirty(context.Context) (*bitset.BitSet, error) { + blocks := uint(header.TotalBlocks(m.size, m.blockSize)) + + b := bitset.New(blocks) + b.FlipRange(0, blocks) + + return b, nil } func (m *NoopMemory) Start(context.Context, string) error { diff --git a/packages/orchestrator/internal/sandbox/uffd/uffd.go b/packages/orchestrator/internal/sandbox/uffd/uffd.go index ce9728f034..3eaad96492 100644 --- a/packages/orchestrator/internal/sandbox/uffd/uffd.go +++ b/packages/orchestrator/internal/sandbox/uffd/uffd.go @@ -10,6 +10,7 @@ import ( "syscall" "time" + "github.com/bits-and-blooms/bitset" "go.opentelemetry.io/otel" "go.uber.org/zap" @@ -181,36 +182,14 @@ func (u *Uffd) Exit() *utils.ErrorOnce { return u.exit } -// Disable unregisters the uffd from the memory mapping, -// allowing us to create a "diff" snapshot via FC API without dirty tracking enabled, -// and without pagefaulting all remaining missing pages. -// -// It should be called *after* Dirty(). -// -// After calling Disable(), this uffd is no longer usable—we won't be able to resume the sandbox via API. -// The uffd itself is not closed though, as that should be done by the sandbox cleanup. -func (u *Uffd) Disable(ctx context.Context) (*block.Tracker, error) { - uffd, err := u.handler.WaitWithContext(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get uffd: %w", err) - } - - err = uffd.Unregister() - if err != nil { - return nil, fmt.Errorf("failed to unregister uffd: %w", err) - } - - return u.dirty(ctx) -} - // Dirty waits for the current requests to finish and returns the dirty pages. // -// It *MUST* be only called after the sandbox was successfully paused via API. -func (u *Uffd) dirty(ctx context.Context) (*block.Tracker, error) { +// It *MUST* be only called after the sandbox was successfully paused via API and after the snapshot endpoint was called. +func (u *Uffd) Dirty(ctx context.Context) (*bitset.BitSet, error) { uffd, err := u.handler.WaitWithContext(ctx) if err != nil { return nil, fmt.Errorf("failed to get uffd: %w", err) } - return uffd.Dirty(), nil + return uffd.Dirty().BitSet(), nil } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go index 3ccee96802..2d37277e0e 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go @@ -35,10 +35,9 @@ const ( UFFDIO_REGISTER_MODE_MISSING = C.UFFDIO_REGISTER_MODE_MISSING - UFFDIO_API = C.UFFDIO_API - UFFDIO_REGISTER = C.UFFDIO_REGISTER - UFFDIO_UNREGISTER = C.UFFDIO_UNREGISTER - UFFDIO_COPY = C.UFFDIO_COPY + UFFDIO_API = C.UFFDIO_API + UFFDIO_REGISTER = C.UFFDIO_REGISTER + UFFDIO_COPY = C.UFFDIO_COPY UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE @@ -149,17 +148,6 @@ func (u uffdFd) register(addr uintptr, size uint64, mode CULong) error { return nil } -func (u uffdFd) unregister(addr uintptr, size uint64) error { - r := newUffdioRange(CULong(addr), CULong(size)) - - ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(u), UFFDIO_UNREGISTER, uintptr(unsafe.Pointer(&r))) - if errno != 0 { - return fmt.Errorf("UFFDIO_UNREGISTER ioctl failed: %w (ret=%d)", errno, ret) - } - - return nil -} - // mode: UFFDIO_COPY_MODE_WP // When we use both missing and wp, we need to use UFFDIO_COPY_MODE_WP, otherwise copying would unprotect the page func (u uffdFd) copy(addr uintptr, data []byte, pagesize uint64, mode CULong) error { diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go index 2a1397e534..a7cb017cc8 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go @@ -271,16 +271,6 @@ func (u *Userfaultfd) handleMissing( return nil } -func (u *Userfaultfd) Unregister() error { - for _, r := range u.ma.Regions { - if err := u.fd.unregister(r.BaseHostVirtAddr, uint64(r.Size)); err != nil { - return fmt.Errorf("failed to unregister: %w", err) - } - } - - return nil -} - func (u *Userfaultfd) Dirty() *block.Tracker { // This will be at worst cancelled when the uffd is closed. u.settleRequests.Lock() diff --git a/packages/shared/pkg/fc/client/operations/get_memory_mappings_parameters.go b/packages/shared/pkg/fc/client/operations/get_memory_mappings_parameters.go new file mode 100644 index 0000000000..db59c6aeca --- /dev/null +++ b/packages/shared/pkg/fc/client/operations/get_memory_mappings_parameters.go @@ -0,0 +1,128 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package operations + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + "net/http" + "time" + + "github.com/go-openapi/errors" + "github.com/go-openapi/runtime" + cr "github.com/go-openapi/runtime/client" + "github.com/go-openapi/strfmt" +) + +// NewGetMemoryMappingsParams creates a new GetMemoryMappingsParams object, +// with the default timeout for this client. +// +// Default values are not hydrated, since defaults are normally applied by the API server side. +// +// To enforce default values in parameter, use SetDefaults or WithDefaults. +func NewGetMemoryMappingsParams() *GetMemoryMappingsParams { + return &GetMemoryMappingsParams{ + timeout: cr.DefaultTimeout, + } +} + +// NewGetMemoryMappingsParamsWithTimeout creates a new GetMemoryMappingsParams object +// with the ability to set a timeout on a request. +func NewGetMemoryMappingsParamsWithTimeout(timeout time.Duration) *GetMemoryMappingsParams { + return &GetMemoryMappingsParams{ + timeout: timeout, + } +} + +// NewGetMemoryMappingsParamsWithContext creates a new GetMemoryMappingsParams object +// with the ability to set a context for a request. +func NewGetMemoryMappingsParamsWithContext(ctx context.Context) *GetMemoryMappingsParams { + return &GetMemoryMappingsParams{ + Context: ctx, + } +} + +// NewGetMemoryMappingsParamsWithHTTPClient creates a new GetMemoryMappingsParams object +// with the ability to set a custom HTTPClient for a request. +func NewGetMemoryMappingsParamsWithHTTPClient(client *http.Client) *GetMemoryMappingsParams { + return &GetMemoryMappingsParams{ + HTTPClient: client, + } +} + +/* +GetMemoryMappingsParams contains all the parameters to send to the API endpoint + + for the get memory mappings operation. + + Typically these are written to a http.Request. +*/ +type GetMemoryMappingsParams struct { + timeout time.Duration + Context context.Context + HTTPClient *http.Client +} + +// WithDefaults hydrates default values in the get memory mappings params (not the query body). +// +// All values with no default are reset to their zero value. +func (o *GetMemoryMappingsParams) WithDefaults() *GetMemoryMappingsParams { + o.SetDefaults() + return o +} + +// SetDefaults hydrates default values in the get memory mappings params (not the query body). +// +// All values with no default are reset to their zero value. +func (o *GetMemoryMappingsParams) SetDefaults() { + // no default values defined for this parameter +} + +// WithTimeout adds the timeout to the get memory mappings params +func (o *GetMemoryMappingsParams) WithTimeout(timeout time.Duration) *GetMemoryMappingsParams { + o.SetTimeout(timeout) + return o +} + +// SetTimeout adds the timeout to the get memory mappings params +func (o *GetMemoryMappingsParams) SetTimeout(timeout time.Duration) { + o.timeout = timeout +} + +// WithContext adds the context to the get memory mappings params +func (o *GetMemoryMappingsParams) WithContext(ctx context.Context) *GetMemoryMappingsParams { + o.SetContext(ctx) + return o +} + +// SetContext adds the context to the get memory mappings params +func (o *GetMemoryMappingsParams) SetContext(ctx context.Context) { + o.Context = ctx +} + +// WithHTTPClient adds the HTTPClient to the get memory mappings params +func (o *GetMemoryMappingsParams) WithHTTPClient(client *http.Client) *GetMemoryMappingsParams { + o.SetHTTPClient(client) + return o +} + +// SetHTTPClient adds the HTTPClient to the get memory mappings params +func (o *GetMemoryMappingsParams) SetHTTPClient(client *http.Client) { + o.HTTPClient = client +} + +// WriteToRequest writes these params to a swagger request +func (o *GetMemoryMappingsParams) WriteToRequest(r runtime.ClientRequest, reg strfmt.Registry) error { + + if err := r.SetTimeout(o.timeout); err != nil { + return err + } + var res []error + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} diff --git a/packages/shared/pkg/fc/client/operations/get_memory_mappings_responses.go b/packages/shared/pkg/fc/client/operations/get_memory_mappings_responses.go new file mode 100644 index 0000000000..97238ab640 --- /dev/null +++ b/packages/shared/pkg/fc/client/operations/get_memory_mappings_responses.go @@ -0,0 +1,187 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package operations + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "encoding/json" + "fmt" + "io" + + "github.com/go-openapi/runtime" + "github.com/go-openapi/strfmt" + + "github.com/e2b-dev/infra/packages/shared/pkg/fc/models" +) + +// GetMemoryMappingsReader is a Reader for the GetMemoryMappings structure. +type GetMemoryMappingsReader struct { + formats strfmt.Registry +} + +// ReadResponse reads a server response into the received o. +func (o *GetMemoryMappingsReader) ReadResponse(response runtime.ClientResponse, consumer runtime.Consumer) (interface{}, error) { + switch response.Code() { + case 200: + result := NewGetMemoryMappingsOK() + if err := result.readResponse(response, consumer, o.formats); err != nil { + return nil, err + } + return result, nil + default: + result := NewGetMemoryMappingsDefault(response.Code()) + if err := result.readResponse(response, consumer, o.formats); err != nil { + return nil, err + } + if response.Code()/100 == 2 { + return result, nil + } + return nil, result + } +} + +// NewGetMemoryMappingsOK creates a GetMemoryMappingsOK with default headers values +func NewGetMemoryMappingsOK() *GetMemoryMappingsOK { + return &GetMemoryMappingsOK{} +} + +/* +GetMemoryMappingsOK describes a response with status code 200, with default header values. + +OK +*/ +type GetMemoryMappingsOK struct { + Payload *models.MemoryMappingsResponse +} + +// IsSuccess returns true when this get memory mappings o k response has a 2xx status code +func (o *GetMemoryMappingsOK) IsSuccess() bool { + return true +} + +// IsRedirect returns true when this get memory mappings o k response has a 3xx status code +func (o *GetMemoryMappingsOK) IsRedirect() bool { + return false +} + +// IsClientError returns true when this get memory mappings o k response has a 4xx status code +func (o *GetMemoryMappingsOK) IsClientError() bool { + return false +} + +// IsServerError returns true when this get memory mappings o k response has a 5xx status code +func (o *GetMemoryMappingsOK) IsServerError() bool { + return false +} + +// IsCode returns true when this get memory mappings o k response a status code equal to that given +func (o *GetMemoryMappingsOK) IsCode(code int) bool { + return code == 200 +} + +// Code gets the status code for the get memory mappings o k response +func (o *GetMemoryMappingsOK) Code() int { + return 200 +} + +func (o *GetMemoryMappingsOK) Error() string { + payload, _ := json.Marshal(o.Payload) + return fmt.Sprintf("[GET /memory/mappings][%d] getMemoryMappingsOK %s", 200, payload) +} + +func (o *GetMemoryMappingsOK) String() string { + payload, _ := json.Marshal(o.Payload) + return fmt.Sprintf("[GET /memory/mappings][%d] getMemoryMappingsOK %s", 200, payload) +} + +func (o *GetMemoryMappingsOK) GetPayload() *models.MemoryMappingsResponse { + return o.Payload +} + +func (o *GetMemoryMappingsOK) readResponse(response runtime.ClientResponse, consumer runtime.Consumer, formats strfmt.Registry) error { + + o.Payload = new(models.MemoryMappingsResponse) + + // response payload + if err := consumer.Consume(response.Body(), o.Payload); err != nil && err != io.EOF { + return err + } + + return nil +} + +// NewGetMemoryMappingsDefault creates a GetMemoryMappingsDefault with default headers values +func NewGetMemoryMappingsDefault(code int) *GetMemoryMappingsDefault { + return &GetMemoryMappingsDefault{ + _statusCode: code, + } +} + +/* +GetMemoryMappingsDefault describes a response with status code -1, with default header values. + +Internal server error +*/ +type GetMemoryMappingsDefault struct { + _statusCode int + + Payload *models.Error +} + +// IsSuccess returns true when this get memory mappings default response has a 2xx status code +func (o *GetMemoryMappingsDefault) IsSuccess() bool { + return o._statusCode/100 == 2 +} + +// IsRedirect returns true when this get memory mappings default response has a 3xx status code +func (o *GetMemoryMappingsDefault) IsRedirect() bool { + return o._statusCode/100 == 3 +} + +// IsClientError returns true when this get memory mappings default response has a 4xx status code +func (o *GetMemoryMappingsDefault) IsClientError() bool { + return o._statusCode/100 == 4 +} + +// IsServerError returns true when this get memory mappings default response has a 5xx status code +func (o *GetMemoryMappingsDefault) IsServerError() bool { + return o._statusCode/100 == 5 +} + +// IsCode returns true when this get memory mappings default response a status code equal to that given +func (o *GetMemoryMappingsDefault) IsCode(code int) bool { + return o._statusCode == code +} + +// Code gets the status code for the get memory mappings default response +func (o *GetMemoryMappingsDefault) Code() int { + return o._statusCode +} + +func (o *GetMemoryMappingsDefault) Error() string { + payload, _ := json.Marshal(o.Payload) + return fmt.Sprintf("[GET /memory/mappings][%d] getMemoryMappings default %s", o._statusCode, payload) +} + +func (o *GetMemoryMappingsDefault) String() string { + payload, _ := json.Marshal(o.Payload) + return fmt.Sprintf("[GET /memory/mappings][%d] getMemoryMappings default %s", o._statusCode, payload) +} + +func (o *GetMemoryMappingsDefault) GetPayload() *models.Error { + return o.Payload +} + +func (o *GetMemoryMappingsDefault) readResponse(response runtime.ClientResponse, consumer runtime.Consumer, formats strfmt.Registry) error { + + o.Payload = new(models.Error) + + // response payload + if err := consumer.Consume(response.Body(), o.Payload); err != nil && err != io.EOF { + return err + } + + return nil +} diff --git a/packages/shared/pkg/fc/client/operations/get_memory_parameters.go b/packages/shared/pkg/fc/client/operations/get_memory_parameters.go new file mode 100644 index 0000000000..fefeefdc21 --- /dev/null +++ b/packages/shared/pkg/fc/client/operations/get_memory_parameters.go @@ -0,0 +1,128 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package operations + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + "net/http" + "time" + + "github.com/go-openapi/errors" + "github.com/go-openapi/runtime" + cr "github.com/go-openapi/runtime/client" + "github.com/go-openapi/strfmt" +) + +// NewGetMemoryParams creates a new GetMemoryParams object, +// with the default timeout for this client. +// +// Default values are not hydrated, since defaults are normally applied by the API server side. +// +// To enforce default values in parameter, use SetDefaults or WithDefaults. +func NewGetMemoryParams() *GetMemoryParams { + return &GetMemoryParams{ + timeout: cr.DefaultTimeout, + } +} + +// NewGetMemoryParamsWithTimeout creates a new GetMemoryParams object +// with the ability to set a timeout on a request. +func NewGetMemoryParamsWithTimeout(timeout time.Duration) *GetMemoryParams { + return &GetMemoryParams{ + timeout: timeout, + } +} + +// NewGetMemoryParamsWithContext creates a new GetMemoryParams object +// with the ability to set a context for a request. +func NewGetMemoryParamsWithContext(ctx context.Context) *GetMemoryParams { + return &GetMemoryParams{ + Context: ctx, + } +} + +// NewGetMemoryParamsWithHTTPClient creates a new GetMemoryParams object +// with the ability to set a custom HTTPClient for a request. +func NewGetMemoryParamsWithHTTPClient(client *http.Client) *GetMemoryParams { + return &GetMemoryParams{ + HTTPClient: client, + } +} + +/* +GetMemoryParams contains all the parameters to send to the API endpoint + + for the get memory operation. + + Typically these are written to a http.Request. +*/ +type GetMemoryParams struct { + timeout time.Duration + Context context.Context + HTTPClient *http.Client +} + +// WithDefaults hydrates default values in the get memory params (not the query body). +// +// All values with no default are reset to their zero value. +func (o *GetMemoryParams) WithDefaults() *GetMemoryParams { + o.SetDefaults() + return o +} + +// SetDefaults hydrates default values in the get memory params (not the query body). +// +// All values with no default are reset to their zero value. +func (o *GetMemoryParams) SetDefaults() { + // no default values defined for this parameter +} + +// WithTimeout adds the timeout to the get memory params +func (o *GetMemoryParams) WithTimeout(timeout time.Duration) *GetMemoryParams { + o.SetTimeout(timeout) + return o +} + +// SetTimeout adds the timeout to the get memory params +func (o *GetMemoryParams) SetTimeout(timeout time.Duration) { + o.timeout = timeout +} + +// WithContext adds the context to the get memory params +func (o *GetMemoryParams) WithContext(ctx context.Context) *GetMemoryParams { + o.SetContext(ctx) + return o +} + +// SetContext adds the context to the get memory params +func (o *GetMemoryParams) SetContext(ctx context.Context) { + o.Context = ctx +} + +// WithHTTPClient adds the HTTPClient to the get memory params +func (o *GetMemoryParams) WithHTTPClient(client *http.Client) *GetMemoryParams { + o.SetHTTPClient(client) + return o +} + +// SetHTTPClient adds the HTTPClient to the get memory params +func (o *GetMemoryParams) SetHTTPClient(client *http.Client) { + o.HTTPClient = client +} + +// WriteToRequest writes these params to a swagger request +func (o *GetMemoryParams) WriteToRequest(r runtime.ClientRequest, reg strfmt.Registry) error { + + if err := r.SetTimeout(o.timeout); err != nil { + return err + } + var res []error + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} diff --git a/packages/shared/pkg/fc/client/operations/get_memory_responses.go b/packages/shared/pkg/fc/client/operations/get_memory_responses.go new file mode 100644 index 0000000000..43f4632d7f --- /dev/null +++ b/packages/shared/pkg/fc/client/operations/get_memory_responses.go @@ -0,0 +1,187 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package operations + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "encoding/json" + "fmt" + "io" + + "github.com/go-openapi/runtime" + "github.com/go-openapi/strfmt" + + "github.com/e2b-dev/infra/packages/shared/pkg/fc/models" +) + +// GetMemoryReader is a Reader for the GetMemory structure. +type GetMemoryReader struct { + formats strfmt.Registry +} + +// ReadResponse reads a server response into the received o. +func (o *GetMemoryReader) ReadResponse(response runtime.ClientResponse, consumer runtime.Consumer) (interface{}, error) { + switch response.Code() { + case 200: + result := NewGetMemoryOK() + if err := result.readResponse(response, consumer, o.formats); err != nil { + return nil, err + } + return result, nil + default: + result := NewGetMemoryDefault(response.Code()) + if err := result.readResponse(response, consumer, o.formats); err != nil { + return nil, err + } + if response.Code()/100 == 2 { + return result, nil + } + return nil, result + } +} + +// NewGetMemoryOK creates a GetMemoryOK with default headers values +func NewGetMemoryOK() *GetMemoryOK { + return &GetMemoryOK{} +} + +/* +GetMemoryOK describes a response with status code 200, with default header values. + +OK +*/ +type GetMemoryOK struct { + Payload *models.MemoryResponse +} + +// IsSuccess returns true when this get memory o k response has a 2xx status code +func (o *GetMemoryOK) IsSuccess() bool { + return true +} + +// IsRedirect returns true when this get memory o k response has a 3xx status code +func (o *GetMemoryOK) IsRedirect() bool { + return false +} + +// IsClientError returns true when this get memory o k response has a 4xx status code +func (o *GetMemoryOK) IsClientError() bool { + return false +} + +// IsServerError returns true when this get memory o k response has a 5xx status code +func (o *GetMemoryOK) IsServerError() bool { + return false +} + +// IsCode returns true when this get memory o k response a status code equal to that given +func (o *GetMemoryOK) IsCode(code int) bool { + return code == 200 +} + +// Code gets the status code for the get memory o k response +func (o *GetMemoryOK) Code() int { + return 200 +} + +func (o *GetMemoryOK) Error() string { + payload, _ := json.Marshal(o.Payload) + return fmt.Sprintf("[GET /memory][%d] getMemoryOK %s", 200, payload) +} + +func (o *GetMemoryOK) String() string { + payload, _ := json.Marshal(o.Payload) + return fmt.Sprintf("[GET /memory][%d] getMemoryOK %s", 200, payload) +} + +func (o *GetMemoryOK) GetPayload() *models.MemoryResponse { + return o.Payload +} + +func (o *GetMemoryOK) readResponse(response runtime.ClientResponse, consumer runtime.Consumer, formats strfmt.Registry) error { + + o.Payload = new(models.MemoryResponse) + + // response payload + if err := consumer.Consume(response.Body(), o.Payload); err != nil && err != io.EOF { + return err + } + + return nil +} + +// NewGetMemoryDefault creates a GetMemoryDefault with default headers values +func NewGetMemoryDefault(code int) *GetMemoryDefault { + return &GetMemoryDefault{ + _statusCode: code, + } +} + +/* +GetMemoryDefault describes a response with status code -1, with default header values. + +Internal server error +*/ +type GetMemoryDefault struct { + _statusCode int + + Payload *models.Error +} + +// IsSuccess returns true when this get memory default response has a 2xx status code +func (o *GetMemoryDefault) IsSuccess() bool { + return o._statusCode/100 == 2 +} + +// IsRedirect returns true when this get memory default response has a 3xx status code +func (o *GetMemoryDefault) IsRedirect() bool { + return o._statusCode/100 == 3 +} + +// IsClientError returns true when this get memory default response has a 4xx status code +func (o *GetMemoryDefault) IsClientError() bool { + return o._statusCode/100 == 4 +} + +// IsServerError returns true when this get memory default response has a 5xx status code +func (o *GetMemoryDefault) IsServerError() bool { + return o._statusCode/100 == 5 +} + +// IsCode returns true when this get memory default response a status code equal to that given +func (o *GetMemoryDefault) IsCode(code int) bool { + return o._statusCode == code +} + +// Code gets the status code for the get memory default response +func (o *GetMemoryDefault) Code() int { + return o._statusCode +} + +func (o *GetMemoryDefault) Error() string { + payload, _ := json.Marshal(o.Payload) + return fmt.Sprintf("[GET /memory][%d] getMemory default %s", o._statusCode, payload) +} + +func (o *GetMemoryDefault) String() string { + payload, _ := json.Marshal(o.Payload) + return fmt.Sprintf("[GET /memory][%d] getMemory default %s", o._statusCode, payload) +} + +func (o *GetMemoryDefault) GetPayload() *models.Error { + return o.Payload +} + +func (o *GetMemoryDefault) readResponse(response runtime.ClientResponse, consumer runtime.Consumer, formats strfmt.Registry) error { + + o.Payload = new(models.Error) + + // response payload + if err := consumer.Consume(response.Body(), o.Payload); err != nil && err != io.EOF { + return err + } + + return nil +} diff --git a/packages/shared/pkg/fc/client/operations/operations_client.go b/packages/shared/pkg/fc/client/operations/operations_client.go index 3fb32aae0b..6cd55d39e6 100644 --- a/packages/shared/pkg/fc/client/operations/operations_client.go +++ b/packages/shared/pkg/fc/client/operations/operations_client.go @@ -70,6 +70,10 @@ type ClientService interface { GetMachineConfiguration(params *GetMachineConfigurationParams, opts ...ClientOption) (*GetMachineConfigurationOK, error) + GetMemory(params *GetMemoryParams, opts ...ClientOption) (*GetMemoryOK, error) + + GetMemoryMappings(params *GetMemoryMappingsParams, opts ...ClientOption) (*GetMemoryMappingsOK, error) + GetMmds(params *GetMmdsParams, opts ...ClientOption) (*GetMmdsOK, error) LoadSnapshot(params *LoadSnapshotParams, opts ...ClientOption) (*LoadSnapshotNoContent, error) @@ -417,6 +421,82 @@ func (a *Client) GetMachineConfiguration(params *GetMachineConfigurationParams, return nil, runtime.NewAPIError("unexpected success response: content available as default response in error", unexpectedSuccess, unexpectedSuccess.Code()) } +/* +GetMemory gets the memory info resident and empty pages + +Returns an object with resident and empty bitmaps. The resident bitmap marks all pages that are resident. The empty bitmap marks zero pages (subset of resident pages). This is checked at the pageSize of each region. All regions must have the same page size. +*/ +func (a *Client) GetMemory(params *GetMemoryParams, opts ...ClientOption) (*GetMemoryOK, error) { + // TODO: Validate the params before sending + if params == nil { + params = NewGetMemoryParams() + } + op := &runtime.ClientOperation{ + ID: "getMemory", + Method: "GET", + PathPattern: "/memory", + ProducesMediaTypes: []string{"application/json"}, + ConsumesMediaTypes: []string{"application/json"}, + Schemes: []string{"http"}, + Params: params, + Reader: &GetMemoryReader{formats: a.formats}, + Context: params.Context, + Client: params.HTTPClient, + } + for _, opt := range opts { + opt(op) + } + + result, err := a.transport.Submit(op) + if err != nil { + return nil, err + } + success, ok := result.(*GetMemoryOK) + if ok { + return success, nil + } + // unexpected success response + unexpectedSuccess := result.(*GetMemoryDefault) + return nil, runtime.NewAPIError("unexpected success response: content available as default response in error", unexpectedSuccess, unexpectedSuccess.Code()) +} + +/* +GetMemoryMappings gets the memory mappings with skippable pages bitmap +*/ +func (a *Client) GetMemoryMappings(params *GetMemoryMappingsParams, opts ...ClientOption) (*GetMemoryMappingsOK, error) { + // TODO: Validate the params before sending + if params == nil { + params = NewGetMemoryMappingsParams() + } + op := &runtime.ClientOperation{ + ID: "getMemoryMappings", + Method: "GET", + PathPattern: "/memory/mappings", + ProducesMediaTypes: []string{"application/json"}, + ConsumesMediaTypes: []string{"application/json"}, + Schemes: []string{"http"}, + Params: params, + Reader: &GetMemoryMappingsReader{formats: a.formats}, + Context: params.Context, + Client: params.HTTPClient, + } + for _, opt := range opts { + opt(op) + } + + result, err := a.transport.Submit(op) + if err != nil { + return nil, err + } + success, ok := result.(*GetMemoryMappingsOK) + if ok { + return success, nil + } + // unexpected success response + unexpectedSuccess := result.(*GetMemoryMappingsDefault) + return nil, runtime.NewAPIError("unexpected success response: content available as default response in error", unexpectedSuccess, unexpectedSuccess.Code()) +} + /* GetMmds gets the m m d s data store */ diff --git a/packages/shared/pkg/fc/firecracker.yml b/packages/shared/pkg/fc/firecracker.yml index 525fb1677b..2cd2d92f20 100644 --- a/packages/shared/pkg/fc/firecracker.yml +++ b/packages/shared/pkg/fc/firecracker.yml @@ -5,10 +5,10 @@ info: The API is accessible through HTTP calls on specific URLs carrying JSON modeled data. The transport medium is a Unix Domain Socket. - version: 1.7.0-dev + version: 1.12.2 termsOfService: "" contact: - email: "compute-capsule@amazon.com" + email: "firecracker-maintainers@amazon.com" license: name: "Apache 2.0" url: "http://www.apache.org/licenses/LICENSE-2.0.html" @@ -85,12 +85,12 @@ paths: Will fail if update is not possible. operationId: putBalloon parameters: - - name: body - in: body - description: Balloon properties - required: true - schema: - $ref: "#/definitions/Balloon" + - name: body + in: body + description: Balloon properties + required: true + schema: + $ref: "#/definitions/Balloon" responses: 204: description: Balloon device created/updated @@ -109,12 +109,12 @@ paths: Will fail if update is not possible. operationId: patchBalloon parameters: - - name: body - in: body - description: Balloon properties - required: true - schema: - $ref: "#/definitions/BalloonUpdate" + - name: body + in: body + description: Balloon properties + required: true + schema: + $ref: "#/definitions/BalloonUpdate" responses: 204: description: Balloon device updated @@ -151,12 +151,12 @@ paths: Will fail if update is not possible. operationId: patchBalloonStatsInterval parameters: - - name: body - in: body - description: Balloon properties - required: true - schema: - $ref: "#/definitions/BalloonStatsUpdate" + - name: body + in: body + description: Balloon properties + required: true + schema: + $ref: "#/definitions/BalloonStatsUpdate" responses: 204: description: Balloon statistics interval updated @@ -220,6 +220,7 @@ paths: schema: $ref: "#/definitions/Error" + /drives/{drive_id}: put: summary: Creates or updates a drive. Pre-boot only. @@ -487,7 +488,8 @@ paths: /entropy: put: summary: Creates an entropy device. Pre-boot only. - description: Enables an entropy device that provides high-quality random data to the guest. + description: + Enables an entropy device that provides high-quality random data to the guest. operationId: putEntropyDevice parameters: - name: body @@ -504,10 +506,12 @@ paths: schema: $ref: "#/definitions/Error" + /network-interfaces/{iface_id}: put: summary: Creates a network interface. Pre-boot only. - description: Creates new network interface with ID specified by iface_id path parameter. + description: + Creates new network interface with ID specified by iface_id path parameter. operationId: putGuestNetworkInterfaceByID parameters: - name: iface_id @@ -534,7 +538,8 @@ paths: $ref: "#/definitions/Error" patch: summary: Updates the rate limiters applied to a network interface. Post-boot only. - description: Updates the rate limiters applied to a network interface. + description: + Updates the rate limiters applied to a network interface. operationId: patchGuestNetworkInterfaceByID parameters: - name: iface_id @@ -589,7 +594,8 @@ paths: /snapshot/load: put: summary: Loads a snapshot. Pre-boot only. - description: Loads the microVM state from a snapshot. + description: + Loads the microVM state from a snapshot. Only accepted on a fresh Firecracker process (before configuring any resource other than the Logger and Metrics). operationId: loadSnapshot @@ -612,6 +618,35 @@ paths: schema: $ref: "#/definitions/Error" + /memory/mappings: + get: + summary: Gets the memory mappings with skippable pages bitmap. + operationId: getMemoryMappings + responses: + 200: + description: OK + schema: + $ref: "#/definitions/MemoryMappingsResponse" + default: + description: Internal server error + schema: + $ref: "#/definitions/Error" + + /memory: + get: + summary: Gets the memory info (resident and empty pages). + description: Returns an object with resident and empty bitmaps. The resident bitmap marks all pages that are resident. The empty bitmap marks zero pages (subset of resident pages). This is checked at the pageSize of each region. All regions must have the same page size. + operationId: getMemory + responses: + 200: + description: OK + schema: + $ref: "#/definitions/MemoryResponse" + default: + description: Internal server error + schema: + $ref: "#/definitions/Error" + /version: get: summary: Gets the Firecracker version. @@ -629,7 +664,8 @@ paths: /vm: patch: summary: Updates the microVM state. - description: Sets the desired state (Paused or Resumed) for the microVM. + description: + Sets the desired state (Paused or Resumed) for the microVM. operationId: patchVm parameters: - name: body @@ -700,7 +736,8 @@ definitions: required: - amount_mib - deflate_on_oom - description: Balloon device descriptor. + description: + Balloon device descriptor. properties: amount_mib: type: integer @@ -716,7 +753,8 @@ definitions: type: object required: - amount_mib - description: Balloon device descriptor. + description: + Balloon device descriptor. properties: amount_mib: type: integer @@ -724,7 +762,8 @@ definitions: BalloonStats: type: object - description: Describes the balloon device statistics. + description: + Describes the balloon device statistics. required: - target_pages - actual_pages @@ -788,7 +827,8 @@ definitions: type: object required: - stats_polling_interval_s - description: Update the statistics polling interval, with the first statistics update scheduled immediately. Statistics cannot be turned on/off after boot. + description: + Update the statistics polling interval, with the first statistics update scheduled immediately. Statistics cannot be turned on/off after boot. properties: stats_polling_interval_s: type: integer @@ -798,7 +838,8 @@ definitions: type: object required: - kernel_image_path - description: Boot source descriptor. + description: + Boot source descriptor. properties: boot_args: type: string @@ -828,7 +869,7 @@ definitions: default: "None" CpuConfig: - type: string + type: object description: The CPU configuration template defines a set of bit maps as modifiers of flags accessed by register to be disabled/enabled for the microvm. @@ -842,6 +883,12 @@ definitions: reg_modifiers: type: object description: A collection of registers to be modified. (aarch64) + vcpu_features: + type: object + description: A collection of vcpu features to be modified. (aarch64) + kvm_capabilities: + type: object + description: A collection of kvm capabilities to be modified. (aarch64) Drive: type: object @@ -861,18 +908,21 @@ definitions: type: boolean cache_type: type: string - description: Represents the caching strategy for the block device. + description: + Represents the caching strategy for the block device. enum: ["Unsafe", "Writeback"] default: "Unsafe" # VirtioBlock specific parameters is_read_only: type: boolean - description: Is block read only. + description: + Is block read only. This field is required for virtio-block config and should be omitted for vhost-user-block configuration. path_on_host: type: string - description: Host level path for the guest drive. + description: + Host level path for the guest drive. This field is required for virtio-block config and should be omitted for vhost-user-block configuration. rate_limiter: $ref: "#/definitions/RateLimiter" @@ -888,7 +938,8 @@ definitions: # VhostUserBlock specific parameters socket: type: string - description: Path to the socket of vhost-user-block backend. + description: + Path to the socket of vhost-user-block backend. This field is required for vhost-user-block config should be omitted for virtio-block configuration. Error: @@ -911,6 +962,8 @@ definitions: $ref: "#/definitions/Drive" boot-source: $ref: "#/definitions/BootSource" + cpu-config: + $ref: "#/definitions/CpuConfig" logger: $ref: "#/definitions/Logger" machine-config: @@ -926,10 +979,13 @@ definitions: $ref: "#/definitions/NetworkInterface" vsock: $ref: "#/definitions/Vsock" + entropy: + $ref: "#/definitions/EntropyDevice" InstanceActionInfo: type: object - description: Variant wrapper containing the real action. + description: + Variant wrapper containing the real action. required: - action_type properties: @@ -943,7 +999,8 @@ definitions: InstanceInfo: type: object - description: Describes MicroVM instance information. + description: + Describes MicroVM instance information. required: - app_name - id @@ -969,9 +1026,63 @@ definitions: description: MicroVM hypervisor build version. type: string + GuestMemoryRegionMapping: + type: object + description: Describes the region of guest memory that can be used for creating the memfile. + required: + - base_host_virt_addr + - size + - offset + - page_size + properties: + base_host_virt_addr: + type: integer + size: + description: The size of the region in bytes. + type: integer + offset: + description: The offset of the region in bytes. + type: integer + page_size: + description: The page size in bytes. + type: integer + + MemoryMappingsResponse: + type: object + description: Response containing memory region mappings. + required: + - mappings + properties: + mappings: + type: array + description: The memory region mappings. + items: + $ref: "#/definitions/GuestMemoryRegionMapping" + + MemoryResponse: + type: object + description: Response containing the memory info (resident and empty pages). + required: + - resident + - empty + properties: + resident: + type: array + description: The resident bitmap as a vector of u64 values. Each bit represents if the page is resident. + items: + type: integer + format: uint64 + empty: + type: array + description: The empty bitmap as a vector of u64 values. Each bit represents if the page is zero (empty). This is a subset of the resident pages. + items: + type: integer + format: uint64 + Logger: type: object - description: Describes the configuration option for the logging capability. + description: + Describes the configuration option for the logging capability. properties: level: type: string @@ -1005,6 +1116,9 @@ definitions: properties: cpu_template: $ref: "#/definitions/CpuTemplate" + # gdb_socket_path: + # type: string + # description: Path to the GDB socket. Requires the gdb feature to be enabled. smt: type: boolean description: Flag for enabling/disabling simultaneous multithreading. Can be enabled only on x86. @@ -1053,7 +1167,8 @@ definitions: Metrics: type: object - description: Describes the configuration option for the metrics capability. + description: + Describes the configuration option for the metrics capability. required: - metrics_path properties: @@ -1063,7 +1178,8 @@ definitions: MmdsConfig: type: object - description: Defines the MMDS configuration. + description: + Defines the MMDS configuration. required: - network_interfaces properties: @@ -1094,11 +1210,13 @@ definitions: MmdsContentsObject: type: object - description: Describes the contents of MMDS in JSON format. + description: + Describes the contents of MMDS in JSON format. NetworkInterface: type: object - description: Defines a network interface. + description: + Defines a network interface. required: - host_dev_name - iface_id @@ -1124,7 +1242,8 @@ definitions: type: string path_on_host: type: string - description: Host level path for the guest drive. + description: + Host level path for the guest drive. This field is optional for virtio-block config and should be omitted for vhost-user-block configuration. rate_limiter: $ref: "#/definitions/RateLimiter" @@ -1161,7 +1280,6 @@ definitions: SnapshotCreateParams: type: object required: - - mem_file_path - snapshot_path properties: mem_file_path: @@ -1179,6 +1297,24 @@ definitions: Type of snapshot to create. It is optional and by default, a full snapshot is created. + NetworkOverride: + type: object + description: + Allows for changing the backing TAP device of a network interface + during snapshot restore. + required: + - iface_id + - host_dev_name + properties: + iface_id: + type: string + description: + The name of the interface to modify + host_dev_name: + type: string + description: + The new host device of the interface + SnapshotLoadParams: type: object description: @@ -1189,7 +1325,8 @@ definitions: properties: enable_diff_snapshots: type: boolean - description: Enable support for incremental (diff) snapshots by tracking dirty guest pages. + description: + Enable support for incremental (diff) snapshots by tracking dirty guest pages. mem_file_path: type: string description: @@ -1207,7 +1344,14 @@ definitions: description: Path to the file that contains the microVM state to be loaded. resume_vm: type: boolean - description: When set to true, the vm is also resumed if the snapshot load is successful. + description: + When set to true, the vm is also resumed if the snapshot load is successful. + network_overrides: + type: array + description: Network host device names to override + items: + $ref: "#/definitions/NetworkOverride" + TokenBucket: type: object @@ -1242,7 +1386,8 @@ definitions: Vm: type: object - description: Defines the microVM running state. It is especially useful in the snapshotting context. + description: + Defines the microVM running state. It is especially useful in the snapshotting context. required: - state properties: @@ -1254,14 +1399,16 @@ definitions: EntropyDevice: type: object - description: Defines an entropy device. + description: + Defines an entropy device. properties: rate_limiter: $ref: "#/definitions/RateLimiter" FirecrackerVersion: type: object - description: Describes the Firecracker version. + description: + Describes the Firecracker version. required: - firecracker_version properties: diff --git a/packages/shared/pkg/fc/models/cpu_config.go b/packages/shared/pkg/fc/models/cpu_config.go index d68bea978c..11007527a3 100644 --- a/packages/shared/pkg/fc/models/cpu_config.go +++ b/packages/shared/pkg/fc/models/cpu_config.go @@ -8,26 +8,55 @@ package models import ( "context" - "github.com/go-openapi/errors" "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" ) // CPUConfig The CPU configuration template defines a set of bit maps as modifiers of flags accessed by register to be disabled/enabled for the microvm. // // swagger:model CpuConfig -type CPUConfig string +type CPUConfig struct { -// Validate validates this Cpu config -func (m CPUConfig) Validate(formats strfmt.Registry) error { - var res []error + // A collection of CPUIDs to be modified. (x86_64) + CpuidModifiers interface{} `json:"cpuid_modifiers,omitempty"` - if len(res) > 0 { - return errors.CompositeValidationError(res...) - } + // A collection of kvm capabilities to be modified. (aarch64) + KvmCapabilities interface{} `json:"kvm_capabilities,omitempty"` + + // A collection of model specific registers to be modified. (x86_64) + MsrModifiers interface{} `json:"msr_modifiers,omitempty"` + + // A collection of registers to be modified. (aarch64) + RegModifiers interface{} `json:"reg_modifiers,omitempty"` + + // A collection of vcpu features to be modified. (aarch64) + VcpuFeatures interface{} `json:"vcpu_features,omitempty"` +} + +// Validate validates this Cpu config +func (m *CPUConfig) Validate(formats strfmt.Registry) error { return nil } // ContextValidate validates this Cpu config based on context it is used -func (m CPUConfig) ContextValidate(ctx context.Context, formats strfmt.Registry) error { +func (m *CPUConfig) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} + +// MarshalBinary interface implementation +func (m *CPUConfig) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *CPUConfig) UnmarshalBinary(b []byte) error { + var res CPUConfig + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res return nil } diff --git a/packages/shared/pkg/fc/models/full_vm_configuration.go b/packages/shared/pkg/fc/models/full_vm_configuration.go index 4ae633777e..3d99a9dfd5 100644 --- a/packages/shared/pkg/fc/models/full_vm_configuration.go +++ b/packages/shared/pkg/fc/models/full_vm_configuration.go @@ -25,9 +25,15 @@ type FullVMConfiguration struct { // boot source BootSource *BootSource `json:"boot-source,omitempty"` + // cpu config + CPUConfig *CPUConfig `json:"cpu-config,omitempty"` + // Configurations for all block devices. Drives []*Drive `json:"drives"` + // entropy + Entropy *EntropyDevice `json:"entropy,omitempty"` + // logger Logger *Logger `json:"logger,omitempty"` @@ -59,10 +65,18 @@ func (m *FullVMConfiguration) Validate(formats strfmt.Registry) error { res = append(res, err) } + if err := m.validateCPUConfig(formats); err != nil { + res = append(res, err) + } + if err := m.validateDrives(formats); err != nil { res = append(res, err) } + if err := m.validateEntropy(formats); err != nil { + res = append(res, err) + } + if err := m.validateLogger(formats); err != nil { res = append(res, err) } @@ -131,6 +145,25 @@ func (m *FullVMConfiguration) validateBootSource(formats strfmt.Registry) error return nil } +func (m *FullVMConfiguration) validateCPUConfig(formats strfmt.Registry) error { + if swag.IsZero(m.CPUConfig) { // not required + return nil + } + + if m.CPUConfig != nil { + if err := m.CPUConfig.Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("cpu-config") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("cpu-config") + } + return err + } + } + + return nil +} + func (m *FullVMConfiguration) validateDrives(formats strfmt.Registry) error { if swag.IsZero(m.Drives) { // not required return nil @@ -157,6 +190,25 @@ func (m *FullVMConfiguration) validateDrives(formats strfmt.Registry) error { return nil } +func (m *FullVMConfiguration) validateEntropy(formats strfmt.Registry) error { + if swag.IsZero(m.Entropy) { // not required + return nil + } + + if m.Entropy != nil { + if err := m.Entropy.Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("entropy") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("entropy") + } + return err + } + } + + return nil +} + func (m *FullVMConfiguration) validateLogger(formats strfmt.Registry) error { if swag.IsZero(m.Logger) { // not required return nil @@ -290,10 +342,18 @@ func (m *FullVMConfiguration) ContextValidate(ctx context.Context, formats strfm res = append(res, err) } + if err := m.contextValidateCPUConfig(ctx, formats); err != nil { + res = append(res, err) + } + if err := m.contextValidateDrives(ctx, formats); err != nil { res = append(res, err) } + if err := m.contextValidateEntropy(ctx, formats); err != nil { + res = append(res, err) + } + if err := m.contextValidateLogger(ctx, formats); err != nil { res = append(res, err) } @@ -366,6 +426,27 @@ func (m *FullVMConfiguration) contextValidateBootSource(ctx context.Context, for return nil } +func (m *FullVMConfiguration) contextValidateCPUConfig(ctx context.Context, formats strfmt.Registry) error { + + if m.CPUConfig != nil { + + if swag.IsZero(m.CPUConfig) { // not required + return nil + } + + if err := m.CPUConfig.ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("cpu-config") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("cpu-config") + } + return err + } + } + + return nil +} + func (m *FullVMConfiguration) contextValidateDrives(ctx context.Context, formats strfmt.Registry) error { for i := 0; i < len(m.Drives); i++ { @@ -391,6 +472,27 @@ func (m *FullVMConfiguration) contextValidateDrives(ctx context.Context, formats return nil } +func (m *FullVMConfiguration) contextValidateEntropy(ctx context.Context, formats strfmt.Registry) error { + + if m.Entropy != nil { + + if swag.IsZero(m.Entropy) { // not required + return nil + } + + if err := m.Entropy.ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("entropy") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("entropy") + } + return err + } + } + + return nil +} + func (m *FullVMConfiguration) contextValidateLogger(ctx context.Context, formats strfmt.Registry) error { if m.Logger != nil { diff --git a/packages/shared/pkg/fc/models/guest_memory_region_mapping.go b/packages/shared/pkg/fc/models/guest_memory_region_mapping.go new file mode 100644 index 0000000000..308497a7cf --- /dev/null +++ b/packages/shared/pkg/fc/models/guest_memory_region_mapping.go @@ -0,0 +1,122 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// GuestMemoryRegionMapping Describes the region of guest memory that can be used for creating the memfile. +// +// swagger:model GuestMemoryRegionMapping +type GuestMemoryRegionMapping struct { + + // base host virt addr + // Required: true + BaseHostVirtAddr *int64 `json:"base_host_virt_addr"` + + // The offset of the region in bytes. + // Required: true + Offset *int64 `json:"offset"` + + // The page size in bytes. + // Required: true + PageSize *int64 `json:"page_size"` + + // The size of the region in bytes. + // Required: true + Size *int64 `json:"size"` +} + +// Validate validates this guest memory region mapping +func (m *GuestMemoryRegionMapping) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateBaseHostVirtAddr(formats); err != nil { + res = append(res, err) + } + + if err := m.validateOffset(formats); err != nil { + res = append(res, err) + } + + if err := m.validatePageSize(formats); err != nil { + res = append(res, err) + } + + if err := m.validateSize(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *GuestMemoryRegionMapping) validateBaseHostVirtAddr(formats strfmt.Registry) error { + + if err := validate.Required("base_host_virt_addr", "body", m.BaseHostVirtAddr); err != nil { + return err + } + + return nil +} + +func (m *GuestMemoryRegionMapping) validateOffset(formats strfmt.Registry) error { + + if err := validate.Required("offset", "body", m.Offset); err != nil { + return err + } + + return nil +} + +func (m *GuestMemoryRegionMapping) validatePageSize(formats strfmt.Registry) error { + + if err := validate.Required("page_size", "body", m.PageSize); err != nil { + return err + } + + return nil +} + +func (m *GuestMemoryRegionMapping) validateSize(formats strfmt.Registry) error { + + if err := validate.Required("size", "body", m.Size); err != nil { + return err + } + + return nil +} + +// ContextValidate validates this guest memory region mapping based on context it is used +func (m *GuestMemoryRegionMapping) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} + +// MarshalBinary interface implementation +func (m *GuestMemoryRegionMapping) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *GuestMemoryRegionMapping) UnmarshalBinary(b []byte) error { + var res GuestMemoryRegionMapping + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/packages/shared/pkg/fc/models/memory_mappings_response.go b/packages/shared/pkg/fc/models/memory_mappings_response.go new file mode 100644 index 0000000000..886324b023 --- /dev/null +++ b/packages/shared/pkg/fc/models/memory_mappings_response.go @@ -0,0 +1,124 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + "strconv" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// MemoryMappingsResponse Response containing memory region mappings. +// +// swagger:model MemoryMappingsResponse +type MemoryMappingsResponse struct { + + // The memory region mappings. + // Required: true + Mappings []*GuestMemoryRegionMapping `json:"mappings"` +} + +// Validate validates this memory mappings response +func (m *MemoryMappingsResponse) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateMappings(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *MemoryMappingsResponse) validateMappings(formats strfmt.Registry) error { + + if err := validate.Required("mappings", "body", m.Mappings); err != nil { + return err + } + + for i := 0; i < len(m.Mappings); i++ { + if swag.IsZero(m.Mappings[i]) { // not required + continue + } + + if m.Mappings[i] != nil { + if err := m.Mappings[i].Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("mappings" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("mappings" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +// ContextValidate validate this memory mappings response based on the context it is used +func (m *MemoryMappingsResponse) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + if err := m.contextValidateMappings(ctx, formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *MemoryMappingsResponse) contextValidateMappings(ctx context.Context, formats strfmt.Registry) error { + + for i := 0; i < len(m.Mappings); i++ { + + if m.Mappings[i] != nil { + + if swag.IsZero(m.Mappings[i]) { // not required + return nil + } + + if err := m.Mappings[i].ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("mappings" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("mappings" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + +// MarshalBinary interface implementation +func (m *MemoryMappingsResponse) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *MemoryMappingsResponse) UnmarshalBinary(b []byte) error { + var res MemoryMappingsResponse + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/packages/shared/pkg/fc/models/memory_response.go b/packages/shared/pkg/fc/models/memory_response.go new file mode 100644 index 0000000000..9176ceb09e --- /dev/null +++ b/packages/shared/pkg/fc/models/memory_response.go @@ -0,0 +1,88 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// MemoryResponse Response containing the memory info (resident and empty pages). +// +// swagger:model MemoryResponse +type MemoryResponse struct { + + // The empty bitmap as a vector of u64 values. Each bit represents if the page is zero (empty). This is a subset of the resident pages. + // Required: true + Empty []uint64 `json:"empty"` + + // The resident bitmap as a vector of u64 values. Each bit represents if the page is resident. + // Required: true + Resident []uint64 `json:"resident"` +} + +// Validate validates this memory response +func (m *MemoryResponse) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateEmpty(formats); err != nil { + res = append(res, err) + } + + if err := m.validateResident(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *MemoryResponse) validateEmpty(formats strfmt.Registry) error { + + if err := validate.Required("empty", "body", m.Empty); err != nil { + return err + } + + return nil +} + +func (m *MemoryResponse) validateResident(formats strfmt.Registry) error { + + if err := validate.Required("resident", "body", m.Resident); err != nil { + return err + } + + return nil +} + +// ContextValidate validates this memory response based on context it is used +func (m *MemoryResponse) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} + +// MarshalBinary interface implementation +func (m *MemoryResponse) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *MemoryResponse) UnmarshalBinary(b []byte) error { + var res MemoryResponse + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/packages/shared/pkg/fc/models/network_override.go b/packages/shared/pkg/fc/models/network_override.go new file mode 100644 index 0000000000..bcdd3c63b0 --- /dev/null +++ b/packages/shared/pkg/fc/models/network_override.go @@ -0,0 +1,88 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// NetworkOverride Allows for changing the backing TAP device of a network interface during snapshot restore. +// +// swagger:model NetworkOverride +type NetworkOverride struct { + + // The new host device of the interface + // Required: true + HostDevName *string `json:"host_dev_name"` + + // The name of the interface to modify + // Required: true + IfaceID *string `json:"iface_id"` +} + +// Validate validates this network override +func (m *NetworkOverride) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateHostDevName(formats); err != nil { + res = append(res, err) + } + + if err := m.validateIfaceID(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *NetworkOverride) validateHostDevName(formats strfmt.Registry) error { + + if err := validate.Required("host_dev_name", "body", m.HostDevName); err != nil { + return err + } + + return nil +} + +func (m *NetworkOverride) validateIfaceID(formats strfmt.Registry) error { + + if err := validate.Required("iface_id", "body", m.IfaceID); err != nil { + return err + } + + return nil +} + +// ContextValidate validates this network override based on context it is used +func (m *NetworkOverride) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} + +// MarshalBinary interface implementation +func (m *NetworkOverride) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *NetworkOverride) UnmarshalBinary(b []byte) error { + var res NetworkOverride + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/packages/shared/pkg/fc/models/snapshot_create_params.go b/packages/shared/pkg/fc/models/snapshot_create_params.go index 0f06bee0e9..d5aaeb286c 100644 --- a/packages/shared/pkg/fc/models/snapshot_create_params.go +++ b/packages/shared/pkg/fc/models/snapshot_create_params.go @@ -21,8 +21,7 @@ import ( type SnapshotCreateParams struct { // Path to the file that will contain the guest memory. - // Required: true - MemFilePath *string `json:"mem_file_path"` + MemFilePath string `json:"mem_file_path,omitempty"` // Path to the file that will contain the microVM state. // Required: true @@ -37,10 +36,6 @@ type SnapshotCreateParams struct { func (m *SnapshotCreateParams) Validate(formats strfmt.Registry) error { var res []error - if err := m.validateMemFilePath(formats); err != nil { - res = append(res, err) - } - if err := m.validateSnapshotPath(formats); err != nil { res = append(res, err) } @@ -55,15 +50,6 @@ func (m *SnapshotCreateParams) Validate(formats strfmt.Registry) error { return nil } -func (m *SnapshotCreateParams) validateMemFilePath(formats strfmt.Registry) error { - - if err := validate.Required("mem_file_path", "body", m.MemFilePath); err != nil { - return err - } - - return nil -} - func (m *SnapshotCreateParams) validateSnapshotPath(formats strfmt.Registry) error { if err := validate.Required("snapshot_path", "body", m.SnapshotPath); err != nil { diff --git a/packages/shared/pkg/fc/models/snapshot_load_params.go b/packages/shared/pkg/fc/models/snapshot_load_params.go index 4fa8f87470..48f6f44979 100644 --- a/packages/shared/pkg/fc/models/snapshot_load_params.go +++ b/packages/shared/pkg/fc/models/snapshot_load_params.go @@ -7,6 +7,7 @@ package models import ( "context" + "strconv" "github.com/go-openapi/errors" "github.com/go-openapi/strfmt" @@ -28,6 +29,9 @@ type SnapshotLoadParams struct { // Path to the file that contains the guest memory to be loaded. It is only allowed if `mem_backend` is not present. This parameter has been deprecated and it will be removed in future Firecracker release. MemFilePath string `json:"mem_file_path,omitempty"` + // Network host device names to override + NetworkOverrides []*NetworkOverride `json:"network_overrides"` + // When set to true, the vm is also resumed if the snapshot load is successful. ResumeVM bool `json:"resume_vm,omitempty"` @@ -44,6 +48,10 @@ func (m *SnapshotLoadParams) Validate(formats strfmt.Registry) error { res = append(res, err) } + if err := m.validateNetworkOverrides(formats); err != nil { + res = append(res, err) + } + if err := m.validateSnapshotPath(formats); err != nil { res = append(res, err) } @@ -73,6 +81,32 @@ func (m *SnapshotLoadParams) validateMemBackend(formats strfmt.Registry) error { return nil } +func (m *SnapshotLoadParams) validateNetworkOverrides(formats strfmt.Registry) error { + if swag.IsZero(m.NetworkOverrides) { // not required + return nil + } + + for i := 0; i < len(m.NetworkOverrides); i++ { + if swag.IsZero(m.NetworkOverrides[i]) { // not required + continue + } + + if m.NetworkOverrides[i] != nil { + if err := m.NetworkOverrides[i].Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("network_overrides" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("network_overrides" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + func (m *SnapshotLoadParams) validateSnapshotPath(formats strfmt.Registry) error { if err := validate.Required("snapshot_path", "body", m.SnapshotPath); err != nil { @@ -90,6 +124,10 @@ func (m *SnapshotLoadParams) ContextValidate(ctx context.Context, formats strfmt res = append(res, err) } + if err := m.contextValidateNetworkOverrides(ctx, formats); err != nil { + res = append(res, err) + } + if len(res) > 0 { return errors.CompositeValidationError(res...) } @@ -117,6 +155,31 @@ func (m *SnapshotLoadParams) contextValidateMemBackend(ctx context.Context, form return nil } +func (m *SnapshotLoadParams) contextValidateNetworkOverrides(ctx context.Context, formats strfmt.Registry) error { + + for i := 0; i < len(m.NetworkOverrides); i++ { + + if m.NetworkOverrides[i] != nil { + + if swag.IsZero(m.NetworkOverrides[i]) { // not required + return nil + } + + if err := m.NetworkOverrides[i].ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("network_overrides" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("network_overrides" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + // MarshalBinary interface implementation func (m *SnapshotLoadParams) MarshalBinary() ([]byte, error) { if m == nil { diff --git a/packages/shared/pkg/storage/temporary_memfile.go b/packages/shared/pkg/storage/temporary_memfile.go deleted file mode 100644 index a574f227e0..0000000000 --- a/packages/shared/pkg/storage/temporary_memfile.go +++ /dev/null @@ -1,61 +0,0 @@ -package storage - -import ( - "context" - "fmt" - "os" - "path/filepath" - "sync" - - "github.com/google/uuid" - "golang.org/x/sync/semaphore" - - "github.com/e2b-dev/infra/packages/shared/pkg/env" - "github.com/e2b-dev/infra/packages/shared/pkg/utils" -) - -var maxParallelMemfileSnapshotting = utils.Must(env.GetEnvAsInt("MAX_PARALLEL_MEMFILE_SNAPSHOTTING", 8)) - -var snapshotCacheQueue = semaphore.NewWeighted(int64(maxParallelMemfileSnapshotting)) - -type TemporaryMemfile struct { - path string - closeFn func() -} - -func AcquireTmpMemfile( - ctx context.Context, - config BuilderConfig, - buildID string, -) (*TemporaryMemfile, error) { - randomID := uuid.NewString() - - err := snapshotCacheQueue.Acquire(ctx, 1) - if err != nil { - return nil, fmt.Errorf("failed to acquire cache: %w", err) - } - releaseOnce := sync.OnceFunc(func() { - snapshotCacheQueue.Release(1) - }) - - return &TemporaryMemfile{ - path: cacheMemfileFullSnapshotPath(config, buildID, randomID), - closeFn: releaseOnce, - }, nil -} - -func (f *TemporaryMemfile) Path() string { - return f.path -} - -func (f *TemporaryMemfile) Close() error { - defer f.closeFn() - - return os.Remove(f.path) -} - -func cacheMemfileFullSnapshotPath(config BuilderConfig, buildID string, randomID string) string { - name := fmt.Sprintf("%s-%s-%s.full", buildID, MemfileName, randomID) - - return filepath.Join(config.GetSnapshotCacheDir(), name) -} From 179cd6a9583d740caac585ba2da5b981158bed47 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Tue, 9 Dec 2025 17:51:36 -0800 Subject: [PATCH 02/40] Cleanup --- packages/orchestrator/go.mod | 2 +- .../internal/sandbox/build/local_diff.go | 4 +- .../internal/sandbox/fc/memory.go | 4 ++ .../userfaultfd/cross_process_helpers_test.go | 2 +- .../internal/sandbox/uffd/userfaultfd/fd.go | 49 ++++--------------- .../uffd/userfaultfd/fd_helpers_test.go | 40 +++++++++++++++ .../sandbox/uffd/userfaultfd/userfaultfd.go | 14 +++--- 7 files changed, 64 insertions(+), 51 deletions(-) create mode 100644 packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go diff --git a/packages/orchestrator/go.mod b/packages/orchestrator/go.mod index 34ef990ce8..ae0edcc46d 100644 --- a/packages/orchestrator/go.mod +++ b/packages/orchestrator/go.mod @@ -44,6 +44,7 @@ require ( github.com/shirou/gopsutil/v4 v4.25.6 github.com/soheilhy/cmux v0.1.5 github.com/stretchr/testify v1.11.1 + github.com/tklauser/go-sysconf v0.3.14 github.com/vishvananda/netlink v1.3.1-0.20240922070040-084abd93d350 github.com/vishvananda/netns v0.0.5 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 @@ -221,7 +222,6 @@ require ( github.com/shopspring/decimal v1.4.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spiffe/go-spiffe/v2 v2.5.0 // indirect - github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.9.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect diff --git a/packages/orchestrator/internal/sandbox/build/local_diff.go b/packages/orchestrator/internal/sandbox/build/local_diff.go index 91757f024c..e98d9fccdc 100644 --- a/packages/orchestrator/internal/sandbox/build/local_diff.go +++ b/packages/orchestrator/internal/sandbox/build/local_diff.go @@ -86,7 +86,7 @@ var _ Diff = (*localDiff)(nil) func NewLocalDiffFromCache( cacheKey DiffStoreKey, cache *block.Cache, -) (*localDiff, error) { +) (Diff, error) { return &localDiff{ cache: cache, cacheKey: cacheKey, @@ -99,7 +99,7 @@ func newLocalDiff( size, blockSize int64, dirty bool, -) (*localDiff, error) { +) (Diff, error) { cache, err := block.NewCache(size, blockSize, cachePath, dirty) if err != nil { return nil, fmt.Errorf("failed to create cache: %w", err) diff --git a/packages/orchestrator/internal/sandbox/fc/memory.go b/packages/orchestrator/internal/sandbox/fc/memory.go index ef168c4719..c77f56f372 100644 --- a/packages/orchestrator/internal/sandbox/fc/memory.go +++ b/packages/orchestrator/internal/sandbox/fc/memory.go @@ -16,6 +16,7 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) +// IOV_MAX is the limit of the vectors that can be passed in a single ioctl call. var IOV_MAX = utils.Must(getIOVMax()) const ( @@ -23,6 +24,9 @@ const ( oomMaxJitter = 100 * time.Millisecond ) +// MemoryInfo returns the memory info for the sandbox. +// The dirty field represents mincore resident pages—essentially pages that were faulted in. +// The empty field represents pages that are *resident*, but also completely empty. func (p *Process) MemoryInfo(ctx context.Context, blockSize int64) (*header.DiffMetadata, error) { return p.client.memoryInfo(ctx, blockSize) } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go index be12f71a1f..503a51bb66 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go @@ -50,7 +50,7 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error uffdFd.close() }) - err = uffdFd.configureApi(tt.pagesize) + err = configureApi(uffdFd, tt.pagesize) require.NoError(t, err) err = uffdFd.register(memoryStart, uint64(size), UFFDIO_REGISTER_MODE_MISSING) diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go index 2d37277e0e..7fe3efe338 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go @@ -23,8 +23,6 @@ import ( "fmt" "syscall" "unsafe" - - "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) const ( @@ -102,45 +100,16 @@ func getPagefaultAddress(pagefault *UffdPagefault) uintptr { return uintptr(pagefault.address) } -// uffdFd is a helper type that wraps uffd fd. -type uffdFd uintptr - -// flags: syscall.O_CLOEXEC|syscall.O_NONBLOCK -func newFd(flags uintptr) (uffdFd, error) { - uffd, _, errno := syscall.Syscall(NR_userfaultfd, flags, 0, 0) - if errno != 0 { - return 0, fmt.Errorf("userfaultfd syscall failed: %w", errno) - } - - return uffdFd(uffd), nil -} - -// features: UFFD_FEATURE_MISSING_HUGETLBFS -// This is already called by the FC -func (u uffdFd) configureApi(pagesize uint64) error { - var features CULong - - // Only set the hugepage feature if we're using hugepages - if pagesize == header.HugepageSize { - features |= UFFD_FEATURE_MISSING_HUGETLBFS - } - - api := newUffdioAPI(UFFD_API, features) - ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(u), UFFDIO_API, uintptr(unsafe.Pointer(&api))) - if errno != 0 { - return fmt.Errorf("UFFDIO_API ioctl failed: %w (ret=%d)", errno, ret) - } - - return nil -} +// Fd is a helper type that wraps uffd fd. +type Fd uintptr // mode: UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING // This is already called by the FC, but only with the UFFDIO_REGISTER_MODE_MISSING // We need to call it with UFFDIO_REGISTER_MODE_WP when we use both missing and wp -func (u uffdFd) register(addr uintptr, size uint64, mode CULong) error { +func (f Fd) register(addr uintptr, size uint64, mode CULong) error { register := newUffdioRegister(CULong(addr), CULong(size), mode) - ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(u), UFFDIO_REGISTER, uintptr(unsafe.Pointer(®ister))) + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_REGISTER, uintptr(unsafe.Pointer(®ister))) if errno != 0 { return fmt.Errorf("UFFDIO_REGISTER ioctl failed: %w (ret=%d)", errno, ret) } @@ -150,21 +119,21 @@ func (u uffdFd) register(addr uintptr, size uint64, mode CULong) error { // mode: UFFDIO_COPY_MODE_WP // When we use both missing and wp, we need to use UFFDIO_COPY_MODE_WP, otherwise copying would unprotect the page -func (u uffdFd) copy(addr uintptr, data []byte, pagesize uint64, mode CULong) error { +func (f Fd) copy(addr, pagesize uintptr, data []byte, mode CULong) error { cpy := newUffdioCopy(data, CULong(addr)&^CULong(pagesize-1), CULong(pagesize), mode, 0) - if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(u), UFFDIO_COPY, uintptr(unsafe.Pointer(&cpy))); errno != 0 { + if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_COPY, uintptr(unsafe.Pointer(&cpy))); errno != 0 { return errno } // Check if the copied size matches the requested pagesize - if uint64(cpy.copy) != pagesize { + if cpy.copy != CLong(pagesize) { return fmt.Errorf("UFFDIO_COPY copied %d bytes, expected %d", cpy.copy, pagesize) } return nil } -func (u uffdFd) close() error { - return syscall.Close(int(u)) +func (f Fd) close() error { + return syscall.Close(int(f)) } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go new file mode 100644 index 0000000000..17f1bf02be --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go @@ -0,0 +1,40 @@ +package userfaultfd + +import ( + "fmt" + "syscall" + "unsafe" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +// Used for testing. +// flags: syscall.O_CLOEXEC|syscall.O_NONBLOCK +func newFd(flags uintptr) (Fd, error) { + uffd, _, errno := syscall.Syscall(NR_userfaultfd, flags, 0, 0) + if errno != 0 { + return 0, fmt.Errorf("userfaultfd syscall failed: %w", errno) + } + + return Fd(uffd), nil +} + +// Used for testing +// features: UFFD_FEATURE_MISSING_HUGETLBFS +// This is already called by the FC +func configureApi(f Fd, pagesize uint64) error { + var features CULong + + // Only set the hugepage feature if we're using hugepages + if pagesize == header.HugepageSize { + features |= UFFD_FEATURE_MISSING_HUGETLBFS + } + + api := newUffdioAPI(UFFD_API, features) + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_API, uintptr(unsafe.Pointer(&api))) + if errno != 0 { + return fmt.Errorf("UFFDIO_API ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go index a7cb017cc8..e248cfe581 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go @@ -23,7 +23,7 @@ const maxRequestsInProgress = 4096 var ErrUnexpectedEventType = errors.New("unexpected event type") type Userfaultfd struct { - fd uffdFd + fd Fd src block.Slicer ma *memory.Mapping @@ -50,7 +50,7 @@ func NewUserfaultfdFromFd(fd uintptr, src block.Slicer, m *memory.Mapping, logge } u := &Userfaultfd{ - fd: uffdFd(fd), + fd: Fd(fd), src: src, missingRequests: block.NewTracker(blockSize), ma: m, @@ -187,7 +187,7 @@ outerLoop: // If the event has WRITE flag, it was a write to a missing page. // For the write to be executed, we first need to copy the page from the source to the guest memory. if flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 { - err := u.handleMissing(ctx, fdExit.SignalExit, addr, offset, pagesize) + err := u.handleMissing(ctx, fdExit.SignalExit, addr, pagesize, offset) if err != nil { return fmt.Errorf("failed to handle missing write: %w", err) } @@ -198,7 +198,7 @@ outerLoop: // Handle read to missing page ("MISSING" flag) // If the event has no flags, it was a read to a missing page and we need to copy the page from the source to the guest memory. if flags == 0 { - err := u.handleMissing(ctx, fdExit.SignalExit, addr, offset, pagesize) + err := u.handleMissing(ctx, fdExit.SignalExit, addr, pagesize, offset) if err != nil { return fmt.Errorf("failed to handle missing: %w", err) } @@ -214,9 +214,9 @@ outerLoop: func (u *Userfaultfd) handleMissing( ctx context.Context, onFailure func() error, - addr uintptr, + addr, + pagesize uintptr, offset int64, - pagesize uint64, ) error { u.wg.Go(func() error { // The RLock must be called inside the goroutine to ensure RUnlock runs via defer, @@ -245,7 +245,7 @@ func (u *Userfaultfd) handleMissing( var copyMode CULong - copyErr := u.fd.copy(addr, b, pagesize, copyMode) + copyErr := u.fd.copy(addr, pagesize, b, copyMode) if errors.Is(copyErr, unix.EEXIST) { // Page is already mapped From 934c15148426da67b6fc601c28c8253d7b095fc2 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Tue, 9 Dec 2025 17:53:54 -0800 Subject: [PATCH 03/40] Reduce changes --- packages/orchestrator/internal/sandbox/build/local_diff.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/build/local_diff.go b/packages/orchestrator/internal/sandbox/build/local_diff.go index e98d9fccdc..b1ecf9df7c 100644 --- a/packages/orchestrator/internal/sandbox/build/local_diff.go +++ b/packages/orchestrator/internal/sandbox/build/local_diff.go @@ -72,7 +72,6 @@ func (f *LocalDiffFile) CloseToDiff( f.cachePath, size.Size(), blockSize, - true, ) } @@ -98,9 +97,8 @@ func newLocalDiff( cachePath string, size, blockSize int64, - dirty bool, ) (Diff, error) { - cache, err := block.NewCache(size, blockSize, cachePath, dirty) + cache, err := block.NewCache(size, blockSize, cachePath, true) if err != nil { return nil, fmt.Errorf("failed to create cache: %w", err) } From 2b944be4195eb2f8da6c343fb46b625c50e1af71 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Tue, 9 Dec 2025 17:55:26 -0800 Subject: [PATCH 04/40] Clarify call --- packages/orchestrator/internal/sandbox/sandbox.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/orchestrator/internal/sandbox/sandbox.go b/packages/orchestrator/internal/sandbox/sandbox.go index ae72b4a2e7..ed72e1ffab 100644 --- a/packages/orchestrator/internal/sandbox/sandbox.go +++ b/packages/orchestrator/internal/sandbox/sandbox.go @@ -310,7 +310,7 @@ func (f *Factory) CreateSandbox( dirty := diffInfo.Dirty.Difference(diffInfo.Empty) - numberOfPages := header.BlockOffset(memfileSize, memfile.BlockSize()) + numberOfPages := header.TotalBlocks(memfileSize, memfile.BlockSize()) empty := bitset.New(uint(numberOfPages)) empty.FlipRange(0, uint(numberOfPages)) From 02ceff0e048873ff488c0db152192dc7b60d7e24 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Tue, 9 Dec 2025 17:57:05 -0800 Subject: [PATCH 05/40] Clarify --- packages/orchestrator/internal/sandbox/sandbox.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/sandbox.go b/packages/orchestrator/internal/sandbox/sandbox.go index ed72e1ffab..3dfb709261 100644 --- a/packages/orchestrator/internal/sandbox/sandbox.go +++ b/packages/orchestrator/internal/sandbox/sandbox.go @@ -98,7 +98,7 @@ type Resources struct { memoryDiffFilter func(ctx context.Context) (*header.DiffMetadata, error) } -func (r *Resources) Dirty(ctx context.Context) (*header.DiffMetadata, error) { +func (r *Resources) MemfileDiffMetadata(ctx context.Context) (*header.DiffMetadata, error) { if r.memoryDiffFilter == nil { return nil, fmt.Errorf("memory diff filter is not set") } @@ -801,9 +801,9 @@ func (s *Sandbox) Pause( return nil, fmt.Errorf("failed to get original rootfs: %w", err) } - diffMetadata, err := s.Resources.Dirty(ctx) + memfileDiffMetadata, err := s.Resources.MemfileDiffMetadata(ctx) if err != nil { - return nil, fmt.Errorf("failed to get dirty memory: %w", err) + return nil, fmt.Errorf("failed to get memfile metadata: %w", err) } // Start POSTPROCESSING @@ -811,7 +811,7 @@ func (s *Sandbox) Pause( ctx, buildID, originalMemfile.Header(), - diffMetadata, + memfileDiffMetadata, s.config.DefaultCacheDir, s.process, ) From 4ef608dc7ea1931aa74bb0dbb7c7bc14d71d023f Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Tue, 9 Dec 2025 18:00:51 -0800 Subject: [PATCH 06/40] Add break --- packages/orchestrator/internal/sandbox/fc/memory.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/packages/orchestrator/internal/sandbox/fc/memory.go b/packages/orchestrator/internal/sandbox/fc/memory.go index c77f56f372..146c311dba 100644 --- a/packages/orchestrator/internal/sandbox/fc/memory.go +++ b/packages/orchestrator/internal/sandbox/fc/memory.go @@ -135,6 +135,8 @@ func copyProcessMemory( } start += segmentSize + + break } } From 7e8ad280a0dfca5b5948c3f0aff51f5ffde485e7 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Tue, 9 Dec 2025 18:03:20 -0800 Subject: [PATCH 07/40] Add explicit error --- packages/orchestrator/internal/sandbox/fc/memory.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/packages/orchestrator/internal/sandbox/fc/memory.go b/packages/orchestrator/internal/sandbox/fc/memory.go index 146c311dba..3bc02ff061 100644 --- a/packages/orchestrator/internal/sandbox/fc/memory.go +++ b/packages/orchestrator/internal/sandbox/fc/memory.go @@ -113,7 +113,7 @@ func copyProcessMemory( } // We could retry only on the remaining segment size, but for simplicity we retry the whole segment. - _, err := unix.ProcessVMReadv(pid, + n, err := unix.ProcessVMReadv(pid, local, remote, 0, @@ -134,6 +134,10 @@ func copyProcessMemory( return fmt.Errorf("failed to read memory: %w", err) } + if uint64(n) != segmentSize { + return fmt.Errorf("failed to read memory: expected %d bytes, got %d", segmentSize, n) + } + start += segmentSize break From 03df796c88075eca1ea3e4df81cc06467a0a056e Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Tue, 9 Dec 2025 18:04:11 -0800 Subject: [PATCH 08/40] Fix shadowing --- packages/orchestrator/internal/sandbox/fc/memory.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/fc/memory.go b/packages/orchestrator/internal/sandbox/fc/memory.go index 3bc02ff061..5121724708 100644 --- a/packages/orchestrator/internal/sandbox/fc/memory.go +++ b/packages/orchestrator/internal/sandbox/fc/memory.go @@ -77,7 +77,7 @@ func copyProcessMemory( ctx context.Context, pid int, ranges []block.Range, - local *block.Cache, + cache *block.Cache, ) error { var start uint64 @@ -99,7 +99,7 @@ func copyProcessMemory( local := []unix.Iovec{ { - Base: local.Address(start), + Base: cache.Address(start), // We could keep this as full cache length, but we might as well be exact here. Len: segmentSize, }, From d551f5731260204aab0ec4b084bf6cdfdb7fde52 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Tue, 9 Dec 2025 18:06:42 -0800 Subject: [PATCH 09/40] Fix range --- packages/orchestrator/internal/sandbox/block/range.go | 6 +++--- packages/orchestrator/internal/sandbox/fc/memory.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/block/range.go b/packages/orchestrator/internal/sandbox/block/range.go index b871341ad0..da66c6769c 100644 --- a/packages/orchestrator/internal/sandbox/block/range.go +++ b/packages/orchestrator/internal/sandbox/block/range.go @@ -49,19 +49,19 @@ func NewRangeFromBlocks(startIdx, numberOfBlocks, blockSize int64) Range { } // bitsetRanges returns a sequence of the ranges of the set bits of the bitset. -func BitsetRanges(b *bitset.BitSet) iter.Seq[Range] { +func BitsetRanges(b *bitset.BitSet, blockSize int64) iter.Seq[Range] { return func(yield func(Range) bool) { start, ok := b.NextSet(0) for ok { end, endOk := b.NextClear(start) if !endOk { - yield(NewRange(int64(start), uint64(b.Len()-start))) + yield(NewRangeFromBlocks(int64(start), int64(b.Len()-start), blockSize)) return } - if !yield(NewRange(int64(start), uint64(end-start))) { + if !yield(NewRangeFromBlocks(int64(start), int64(end-start), blockSize)) { return } diff --git a/packages/orchestrator/internal/sandbox/fc/memory.go b/packages/orchestrator/internal/sandbox/fc/memory.go index 5121724708..b178d7a247 100644 --- a/packages/orchestrator/internal/sandbox/fc/memory.go +++ b/packages/orchestrator/internal/sandbox/fc/memory.go @@ -44,7 +44,7 @@ func (p *Process) ExportMemory( var remoteRanges []block.Range - for r := range block.BitsetRanges(include) { + for r := range block.BitsetRanges(include, blockSize) { hostVirtRanges, err := m.GetHostVirtRanges(r.Start, int64(r.Size)) if err != nil { return nil, fmt.Errorf("failed to get host virt ranges: %w", err) From 01dfa38f24281a8f3cabea91098970c7974e5afb Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Tue, 9 Dec 2025 18:10:21 -0800 Subject: [PATCH 10/40] Improve export cleanup --- packages/orchestrator/internal/sandbox/fc/memory.go | 11 ++++++----- packages/orchestrator/internal/sandbox/sandbox.go | 13 +++++++------ 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/fc/memory.go b/packages/orchestrator/internal/sandbox/fc/memory.go index b178d7a247..5b1b5f462a 100644 --- a/packages/orchestrator/internal/sandbox/fc/memory.go +++ b/packages/orchestrator/internal/sandbox/fc/memory.go @@ -55,19 +55,20 @@ func (p *Process) ExportMemory( size := block.GetSize(remoteRanges) - cache, err := block.NewCache(int64(size), blockSize, cachePath, false) + pid, err := p.Pid() if err != nil { - return nil, fmt.Errorf("failed to create cache: %w", err) + return nil, fmt.Errorf("failed to get pid: %w", err) } - pid, err := p.Pid() + cache, err := block.NewCache(int64(size), blockSize, cachePath, false) if err != nil { - return nil, fmt.Errorf("failed to get pid: %w", err) + return nil, fmt.Errorf("failed to create cache: %w", err) } err = copyProcessMemory(ctx, pid, remoteRanges, cache) if err != nil { - return nil, fmt.Errorf("failed to copy process memory: %w", err) + // Close the cache even if the copy fails. + return nil, fmt.Errorf("failed to copy process memory: %w", errors.Join(err, cache.Close())) } return cache, nil diff --git a/packages/orchestrator/internal/sandbox/sandbox.go b/packages/orchestrator/internal/sandbox/sandbox.go index 3dfb709261..db6cdc12c2 100644 --- a/packages/orchestrator/internal/sandbox/sandbox.go +++ b/packages/orchestrator/internal/sandbox/sandbox.go @@ -868,6 +868,11 @@ func pauseProcessMemory( memfileDiffPath := build.GenerateDiffCachePath(cacheDir, buildID.String(), build.Memfile) + header, err := diffMetadata.ToDiffHeader(ctx, originalHeader, buildID) + if err != nil { + return nil, nil, fmt.Errorf("failed to create memfile header: %w", err) + } + cache, err := fc.ExportMemory( ctx, diffMetadata.Dirty, @@ -883,12 +888,8 @@ func pauseProcessMemory( cache, ) if err != nil { - return nil, nil, fmt.Errorf("failed to create local diff from cache: %w", err) - } - - header, err := diffMetadata.ToDiffHeader(ctx, originalHeader, buildID) - if err != nil { - return nil, nil, fmt.Errorf("failed to create memfile header: %w", err) + // Close the cache even if the diff creation fails. + return nil, nil, fmt.Errorf("failed to create local diff from cache: %w", errors.Join(err, cache.Close())) } return diff, header, nil From fe86cbea419ba3606f44479a503e25638bb83099 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Tue, 9 Dec 2025 18:33:51 -0800 Subject: [PATCH 11/40] Refactor --- .../internal/sandbox/block/cache.go | 238 ++++++++++++------ .../internal/sandbox/block/range.go | 10 +- .../internal/sandbox/fc/client.go | 14 +- .../internal/sandbox/fc/memory.go | 98 +------- 4 files changed, 179 insertions(+), 181 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/block/cache.go b/packages/orchestrator/internal/sandbox/block/cache.go index 1c8ae893a0..290bc7bdfc 100644 --- a/packages/orchestrator/internal/sandbox/block/cache.go +++ b/packages/orchestrator/internal/sandbox/block/cache.go @@ -6,21 +6,33 @@ import ( "fmt" "io" "math" + "math/rand" "os" "sort" "sync" "sync/atomic" "syscall" + "time" "github.com/edsrzf/mmap-go" + "github.com/tklauser/go-sysconf" "go.opentelemetry.io/otel" "golang.org/x/sys/unix" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" +) + +const ( + oomMinBackoff = 100 * time.Millisecond + oomMaxJitter = 100 * time.Millisecond ) var tracer = otel.Tracer("github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block") +// IOV_MAX is the limit of the vectors that can be passed in a single ioctl call. +var IOV_MAX = utils.Must(getIOVMax()) + type CacheClosedError struct { filePath string } @@ -79,19 +91,19 @@ func NewCache(size, blockSize int64, filePath string, dirtyFile bool) (*Cache, e }, nil } -func (m *Cache) isClosed() bool { - return m.closed.Load() +func (c *Cache) isClosed() bool { + return c.closed.Load() } -func (m *Cache) Sync() error { - m.mu.Lock() - defer m.mu.Unlock() +func (c *Cache) Sync() error { + c.mu.Lock() + defer c.mu.Unlock() - if m.isClosed() { - return NewErrCacheClosed(m.filePath) + if c.isClosed() { + return NewErrCacheClosed(c.filePath) } - err := m.mmap.Flush() + err := c.mmap.Flush() if err != nil { return fmt.Errorf("error syncing cache: %w", err) } @@ -99,26 +111,26 @@ func (m *Cache) Sync() error { return nil } -func (m *Cache) ExportToDiff(ctx context.Context, out io.Writer) (*header.DiffMetadata, error) { +func (c *Cache) ExportToDiff(ctx context.Context, out io.Writer) (*header.DiffMetadata, error) { ctx, childSpan := tracer.Start(ctx, "export-to-diff") defer childSpan.End() - m.mu.Lock() - defer m.mu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() - if m.isClosed() { - return nil, NewErrCacheClosed(m.filePath) + if c.isClosed() { + return nil, NewErrCacheClosed(c.filePath) } - err := m.mmap.Flush() + err := c.mmap.Flush() if err != nil { return nil, fmt.Errorf("error flushing mmap: %w", err) } - builder := header.NewDiffMetadataBuilder(m.size, m.blockSize) + builder := header.NewDiffMetadataBuilder(c.size, c.blockSize) - for _, offset := range m.dirtySortedKeys() { - block := (*m.mmap)[offset : offset+m.blockSize] + for _, offset := range c.dirtySortedKeys() { + block := (*c.mmap)[offset : offset+c.blockSize] err := builder.Process(ctx, block, out, offset) if err != nil { @@ -129,15 +141,15 @@ func (m *Cache) ExportToDiff(ctx context.Context, out io.Writer) (*header.DiffMe return builder.Build(), nil } -func (m *Cache) ReadAt(b []byte, off int64) (int, error) { - m.mu.RLock() - defer m.mu.RUnlock() +func (c *Cache) ReadAt(b []byte, off int64) (int, error) { + c.mu.RLock() + defer c.mu.RUnlock() - if m.isClosed() { - return 0, NewErrCacheClosed(m.filePath) + if c.isClosed() { + return 0, NewErrCacheClosed(c.filePath) } - slice, err := m.Slice(off, int64(len(b))) + slice, err := c.Slice(off, int64(len(b))) if err != nil { return 0, fmt.Errorf("error slicing mmap: %w", err) } @@ -145,67 +157,67 @@ func (m *Cache) ReadAt(b []byte, off int64) (int, error) { return copy(b, slice), nil } -func (m *Cache) WriteAt(b []byte, off int64) (int, error) { - m.mu.Lock() - defer m.mu.Unlock() +func (c *Cache) WriteAt(b []byte, off int64) (int, error) { + c.mu.Lock() + defer c.mu.Unlock() - if m.isClosed() { - return 0, NewErrCacheClosed(m.filePath) + if c.isClosed() { + return 0, NewErrCacheClosed(c.filePath) } - return m.WriteAtWithoutLock(b, off) + return c.WriteAtWithoutLock(b, off) } -func (m *Cache) Close() (e error) { - m.mu.Lock() - defer m.mu.Unlock() +func (c *Cache) Close() (e error) { + c.mu.Lock() + defer c.mu.Unlock() - succ := m.closed.CompareAndSwap(false, true) + succ := c.closed.CompareAndSwap(false, true) if !succ { - return NewErrCacheClosed(m.filePath) + return NewErrCacheClosed(c.filePath) } - err := m.mmap.Unmap() + err := c.mmap.Unmap() if err != nil { e = errors.Join(e, fmt.Errorf("error unmapping mmap: %w", err)) } // TODO: Move to to the scope of the caller - e = errors.Join(e, os.RemoveAll(m.filePath)) + e = errors.Join(e, os.RemoveAll(c.filePath)) return e } -func (m *Cache) Size() (int64, error) { - if m.isClosed() { - return 0, NewErrCacheClosed(m.filePath) +func (c *Cache) Size() (int64, error) { + if c.isClosed() { + return 0, NewErrCacheClosed(c.filePath) } - return m.size, nil + return c.size, nil } // Slice returns a slice of the mmap. // When using Slice you must ensure thread safety, ideally by only writing to the same block once and the exposing the slice. -func (m *Cache) Slice(off, length int64) ([]byte, error) { - if m.isClosed() { - return nil, NewErrCacheClosed(m.filePath) +func (c *Cache) Slice(off, length int64) ([]byte, error) { + if c.isClosed() { + return nil, NewErrCacheClosed(c.filePath) } - if m.dirtyFile || m.isCached(off, length) { + if c.dirtyFile || c.isCached(off, length) { end := off + length - if end > m.size { - end = m.size + if end > c.size { + end = c.size } - return (*m.mmap)[off:end], nil + return (*c.mmap)[off:end], nil } return nil, BytesNotAvailableError{} } -func (m *Cache) isCached(off, length int64) bool { - for _, blockOff := range header.BlocksOffsets(length, m.blockSize) { - _, dirty := m.dirty.Load(off + blockOff) +func (c *Cache) isCached(off, length int64) bool { + for _, blockOff := range header.BlocksOffsets(length, c.blockSize) { + _, dirty := c.dirty.Load(off + blockOff) if !dirty { return false } @@ -214,35 +226,35 @@ func (m *Cache) isCached(off, length int64) bool { return true } -func (m *Cache) setIsCached(off, length int64) { - for _, blockOff := range header.BlocksOffsets(length, m.blockSize) { - m.dirty.Store(off+blockOff, struct{}{}) +func (c *Cache) setIsCached(off, length int64) { + for _, blockOff := range header.BlocksOffsets(length, c.blockSize) { + c.dirty.Store(off+blockOff, struct{}{}) } } // When using WriteAtWithoutLock you must ensure thread safety, ideally by only writing to the same block once and the exposing the slice. -func (m *Cache) WriteAtWithoutLock(b []byte, off int64) (int, error) { - if m.isClosed() { - return 0, NewErrCacheClosed(m.filePath) +func (c *Cache) WriteAtWithoutLock(b []byte, off int64) (int, error) { + if c.isClosed() { + return 0, NewErrCacheClosed(c.filePath) } end := off + int64(len(b)) - if end > m.size { - end = m.size + if end > c.size { + end = c.size } - n := copy((*m.mmap)[off:end], b) + n := copy((*c.mmap)[off:end], b) - m.setIsCached(off, end-off) + c.setIsCached(off, end-off) return n, nil } // dirtySortedKeys returns a sorted list of dirty keys. // Key represents a block offset. -func (m *Cache) dirtySortedKeys() []int64 { +func (c *Cache) dirtySortedKeys() []int64 { var keys []int64 - m.dirty.Range(func(key, _ any) bool { + c.dirty.Range(func(key, _ any) bool { keys = append(keys, key.(int64)) return true @@ -256,30 +268,112 @@ func (m *Cache) dirtySortedKeys() []int64 { // FileSize returns the size of the cache on disk. // The size might differ from the dirty size, as it may not be fully on disk. -func (m *Cache) FileSize() (int64, error) { +func (c *Cache) FileSize() (int64, error) { var stat syscall.Stat_t - err := syscall.Stat(m.filePath, &stat) + err := syscall.Stat(c.filePath, &stat) if err != nil { return 0, fmt.Errorf("failed to get file stats: %w", err) } var fsStat syscall.Statfs_t - err = syscall.Statfs(m.filePath, &fsStat) + err = syscall.Statfs(c.filePath, &fsStat) if err != nil { - return 0, fmt.Errorf("failed to get disk stats for path %s: %w", m.filePath, err) + return 0, fmt.Errorf("failed to get disk stats for path %s: %w", c.filePath, err) } return stat.Blocks * fsStat.Bsize, nil } -func (m *Cache) Address(off uint64) *byte { - return &(*m.mmap)[off] +func (c *Cache) Address(off uint64) *byte { + return &(*c.mmap)[off] } -func (m *Cache) BlockSize() int64 { - return m.blockSize +func (c *Cache) BlockSize() int64 { + return c.blockSize } -func (m *Cache) Path() string { - return m.filePath +func (c *Cache) Path() string { + return c.filePath +} + +func (c *Cache) CopyFromProcess( + ctx context.Context, + pid int, + ranges []Range, +) error { + var start uint64 + + for i := 0; i < len(ranges); i += int(IOV_MAX) { + segmentRanges := ranges[i:min(i+int(IOV_MAX), len(ranges))] + + remote := make([]unix.RemoteIovec, len(segmentRanges)) + + var segmentSize uint64 + + for j, r := range segmentRanges { + remote[j] = unix.RemoteIovec{ + Base: uintptr(r.Start), + Len: int(r.Size), + } + + segmentSize += r.Size + } + + local := []unix.Iovec{ + { + Base: c.Address(start), + // We could keep this as full cache length, but we might as well be exact here. + Len: segmentSize, + }, + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + // We could retry only on the remaining segment size, but for simplicity we retry the whole segment. + n, err := unix.ProcessVMReadv(pid, + local, + remote, + 0, + ) + if errors.Is(err, unix.EAGAIN) { + continue + } + if errors.Is(err, unix.EINTR) { + continue + } + if errors.Is(err, unix.ENOMEM) { + time.Sleep(oomMinBackoff + time.Duration(rand.Intn(int(oomMaxJitter.Milliseconds())))*time.Millisecond) + + continue + } + + if err != nil { + return fmt.Errorf("failed to read memory: %w", err) + } + + if uint64(n) != segmentSize { + return fmt.Errorf("failed to read memory: expected %d bytes, got %d", segmentSize, n) + } + + start += segmentSize + + break + } + } + + return nil +} + +func getIOVMax() (int64, error) { + iovMax, err := sysconf.Sysconf(sysconf.SC_IOV_MAX) + if err != nil { + return 0, fmt.Errorf("failed to get IOV_MAX: %w", err) + } + + return iovMax, nil } diff --git a/packages/orchestrator/internal/sandbox/block/range.go b/packages/orchestrator/internal/sandbox/block/range.go index da66c6769c..52e21f55f3 100644 --- a/packages/orchestrator/internal/sandbox/block/range.go +++ b/packages/orchestrator/internal/sandbox/block/range.go @@ -24,8 +24,8 @@ func (r *Range) End() int64 { // This assumes the Range.Start is a multiple of the blockSize. func (r *Range) Offsets(blockSize int64) iter.Seq[int64] { return func(yield func(offset int64) bool) { - for i := r.Start; i < r.End(); i += blockSize { - if !yield(i) { + for off := r.Start; off < r.End(); off += blockSize { + if !yield(off) { return } } @@ -51,9 +51,9 @@ func NewRangeFromBlocks(startIdx, numberOfBlocks, blockSize int64) Range { // bitsetRanges returns a sequence of the ranges of the set bits of the bitset. func BitsetRanges(b *bitset.BitSet, blockSize int64) iter.Seq[Range] { return func(yield func(Range) bool) { - start, ok := b.NextSet(0) + start, found := b.NextSet(0) - for ok { + for found { end, endOk := b.NextClear(start) if !endOk { yield(NewRangeFromBlocks(int64(start), int64(b.Len()-start), blockSize)) @@ -65,7 +65,7 @@ func BitsetRanges(b *bitset.BitSet, blockSize int64) iter.Seq[Range] { return } - start, ok = b.NextSet(end + 1) + start, found = b.NextSet(end + 1) } } } diff --git a/packages/orchestrator/internal/sandbox/fc/client.go b/packages/orchestrator/internal/sandbox/fc/client.go index 1396980d77..fdec37f928 100644 --- a/packages/orchestrator/internal/sandbox/fc/client.go +++ b/packages/orchestrator/internal/sandbox/fc/client.go @@ -304,31 +304,31 @@ func (c *apiClient) startVM(ctx context.Context) error { } func (c *apiClient) memoryMappings(ctx context.Context) (*memory.Mapping, error) { - memoryMappingsParams := operations.GetMemoryMappingsParams{ + params := operations.GetMemoryMappingsParams{ Context: ctx, } - memoryMappings, err := c.client.Operations.GetMemoryMappings(&memoryMappingsParams) + res, err := c.client.Operations.GetMemoryMappings(¶ms) if err != nil { return nil, fmt.Errorf("error getting memory mappings: %w", err) } - return memory.NewMappingFromFc(memoryMappings.Payload.Mappings) + return memory.NewMappingFromFc(res.Payload.Mappings) } func (c *apiClient) memoryInfo(ctx context.Context, blockSize int64) (*header.DiffMetadata, error) { - memoryParams := operations.GetMemoryParams{ + params := operations.GetMemoryParams{ Context: ctx, } - memoryInfo, err := c.client.Operations.GetMemory(&memoryParams) + res, err := c.client.Operations.GetMemory(¶ms) if err != nil { return nil, fmt.Errorf("error getting memory: %w", err) } return &header.DiffMetadata{ - Dirty: bitset.From(memoryInfo.Payload.Resident), - Empty: bitset.From(memoryInfo.Payload.Empty), + Dirty: bitset.From(res.Payload.Resident), + Empty: bitset.From(res.Payload.Empty), BlockSize: blockSize, }, nil } diff --git a/packages/orchestrator/internal/sandbox/fc/memory.go b/packages/orchestrator/internal/sandbox/fc/memory.go index 5b1b5f462a..e4f7dbead1 100644 --- a/packages/orchestrator/internal/sandbox/fc/memory.go +++ b/packages/orchestrator/internal/sandbox/fc/memory.go @@ -4,24 +4,11 @@ import ( "context" "errors" "fmt" - "math/rand" - "time" "github.com/bits-and-blooms/bitset" - "github.com/tklauser/go-sysconf" - "golang.org/x/sys/unix" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" - "github.com/e2b-dev/infra/packages/shared/pkg/utils" -) - -// IOV_MAX is the limit of the vectors that can be passed in a single ioctl call. -var IOV_MAX = utils.Must(getIOVMax()) - -const ( - oomMinBackoff = 100 * time.Millisecond - oomMaxJitter = 100 * time.Millisecond ) // MemoryInfo returns the memory info for the sandbox. @@ -65,7 +52,7 @@ func (p *Process) ExportMemory( return nil, fmt.Errorf("failed to create cache: %w", err) } - err = copyProcessMemory(ctx, pid, remoteRanges, cache) + err = cache.CopyFromProcess(ctx, pid, remoteRanges) if err != nil { // Close the cache even if the copy fails. return nil, fmt.Errorf("failed to copy process memory: %w", errors.Join(err, cache.Close())) @@ -73,86 +60,3 @@ func (p *Process) ExportMemory( return cache, nil } - -func copyProcessMemory( - ctx context.Context, - pid int, - ranges []block.Range, - cache *block.Cache, -) error { - var start uint64 - - for i := 0; i < len(ranges); i += int(IOV_MAX) { - segmentRanges := ranges[i:min(i+int(IOV_MAX), len(ranges))] - - remote := make([]unix.RemoteIovec, len(segmentRanges)) - - var segmentSize uint64 - - for j, r := range segmentRanges { - remote[j] = unix.RemoteIovec{ - Base: uintptr(r.Start), - Len: int(r.Size), - } - - segmentSize += r.Size - } - - local := []unix.Iovec{ - { - Base: cache.Address(start), - // We could keep this as full cache length, but we might as well be exact here. - Len: segmentSize, - }, - } - - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - // We could retry only on the remaining segment size, but for simplicity we retry the whole segment. - n, err := unix.ProcessVMReadv(pid, - local, - remote, - 0, - ) - if errors.Is(err, unix.EAGAIN) { - continue - } - if errors.Is(err, unix.EINTR) { - continue - } - if errors.Is(err, unix.ENOMEM) { - time.Sleep(oomMinBackoff + time.Duration(rand.Intn(int(oomMaxJitter.Milliseconds())))*time.Millisecond) - - continue - } - - if err != nil { - return fmt.Errorf("failed to read memory: %w", err) - } - - if uint64(n) != segmentSize { - return fmt.Errorf("failed to read memory: expected %d bytes, got %d", segmentSize, n) - } - - start += segmentSize - - break - } - } - - return nil -} - -func getIOVMax() (int64, error) { - iovMax, err := sysconf.Sysconf(sysconf.SC_IOV_MAX) - if err != nil { - return 0, fmt.Errorf("failed to get IOV_MAX: %w", err) - } - - return iovMax, nil -} From 9d7bdd60809080b60f806eb98eaf1daecd95463f Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Wed, 10 Dec 2025 12:45:54 -0800 Subject: [PATCH 12/40] Cleanup --- packages/orchestrator/internal/sandbox/fc/client.go | 2 +- packages/orchestrator/internal/sandbox/fc/memory.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/fc/client.go b/packages/orchestrator/internal/sandbox/fc/client.go index fdec37f928..fa800b7e5e 100644 --- a/packages/orchestrator/internal/sandbox/fc/client.go +++ b/packages/orchestrator/internal/sandbox/fc/client.go @@ -303,7 +303,7 @@ func (c *apiClient) startVM(ctx context.Context) error { return nil } -func (c *apiClient) memoryMappings(ctx context.Context) (*memory.Mapping, error) { +func (c *apiClient) memoryMapping(ctx context.Context) (*memory.Mapping, error) { params := operations.GetMemoryMappingsParams{ Context: ctx, } diff --git a/packages/orchestrator/internal/sandbox/fc/memory.go b/packages/orchestrator/internal/sandbox/fc/memory.go index e4f7dbead1..e1536b2974 100644 --- a/packages/orchestrator/internal/sandbox/fc/memory.go +++ b/packages/orchestrator/internal/sandbox/fc/memory.go @@ -24,7 +24,7 @@ func (p *Process) ExportMemory( cachePath string, blockSize int64, ) (*block.Cache, error) { - m, err := p.client.memoryMappings(ctx) + m, err := p.client.memoryMapping(ctx) if err != nil { return nil, fmt.Errorf("failed to get memory mappings: %w", err) } @@ -40,13 +40,13 @@ func (p *Process) ExportMemory( remoteRanges = append(remoteRanges, hostVirtRanges...) } - size := block.GetSize(remoteRanges) - pid, err := p.Pid() if err != nil { return nil, fmt.Errorf("failed to get pid: %w", err) } + size := block.GetSize(remoteRanges) + cache, err := block.NewCache(int64(size), blockSize, cachePath, false) if err != nil { return nil, fmt.Errorf("failed to create cache: %w", err) From 2081e9b659ea184c03f27575309ddb7ab862b8fd Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Tue, 16 Dec 2025 15:41:03 -0800 Subject: [PATCH 13/40] WIP --- packages/api/internal/cfg/model.go | 2 +- .../internal/sandbox/block/cache.go | 82 ++- .../internal/sandbox/block/iov.go | 21 + .../internal/sandbox/block/range.go | 16 +- .../internal/sandbox/block/range_test.go | 475 ++++++++++++++++++ .../internal/sandbox/block/tracker.go | 22 + .../internal/sandbox/fc/client.go | 1 + .../internal/sandbox/fc/memory.go | 8 +- .../uffd/memory/cache_copyfromprocess_test.go | 459 +++++++++++++++++ .../internal/sandbox/uffd/memory/mapping.go | 13 +- .../uffd/memory/mapping_host_virt_test.go | 187 ++++--- 11 files changed, 1185 insertions(+), 101 deletions(-) create mode 100644 packages/orchestrator/internal/sandbox/block/iov.go create mode 100644 packages/orchestrator/internal/sandbox/block/range_test.go create mode 100644 packages/orchestrator/internal/sandbox/uffd/memory/cache_copyfromprocess_test.go diff --git a/packages/api/internal/cfg/model.go b/packages/api/internal/cfg/model.go index 59745406fd..61a7eee988 100644 --- a/packages/api/internal/cfg/model.go +++ b/packages/api/internal/cfg/model.go @@ -6,7 +6,7 @@ const ( DefaultKernelVersion = "vmlinux-6.1.158" // The Firecracker version the last tag + the short SHA (so we can build our dev previews) // TODO: The short tag here has only 7 characters — the one from our build pipeline will likely have exactly 8 so this will break. - DefaultFirecrackerVersion = "v1.12.2_g1133bd6cd" + DefaultFirecrackerVersion = "v1.12.2_ga3608adc9" ) type Config struct { diff --git a/packages/orchestrator/internal/sandbox/block/cache.go b/packages/orchestrator/internal/sandbox/block/cache.go index 290bc7bdfc..d51ca2fb81 100644 --- a/packages/orchestrator/internal/sandbox/block/cache.go +++ b/packages/orchestrator/internal/sandbox/block/cache.go @@ -15,12 +15,10 @@ import ( "time" "github.com/edsrzf/mmap-go" - "github.com/tklauser/go-sysconf" "go.opentelemetry.io/otel" "golang.org/x/sys/unix" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" - "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) const ( @@ -30,9 +28,6 @@ const ( var tracer = otel.Tracer("github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block") -// IOV_MAX is the limit of the vectors that can be passed in a single ioctl call. -var IOV_MAX = utils.Must(getIOVMax()) - type CacheClosedError struct { filePath string } @@ -369,11 +364,78 @@ func (c *Cache) CopyFromProcess( return nil } -func getIOVMax() (int64, error) { - iovMax, err := sysconf.Sysconf(sysconf.SC_IOV_MAX) - if err != nil { - return 0, fmt.Errorf("failed to get IOV_MAX: %w", err) +func (c *Cache) CopyProcessMemory( + ctx context.Context, + pid int, + ranges []Range, +) error { + var start int64 + + for i := 0; i < len(ranges); i += int(IOV_MAX) { + segmentRanges := ranges[i:min(i+int(IOV_MAX), len(ranges))] + + remote := make([]unix.RemoteIovec, len(segmentRanges)) + + var segmentSize int64 + + for j, r := range segmentRanges { + remote[j] = unix.RemoteIovec{ + Base: uintptr(r.Start), + Len: int(r.Size), + } + + segmentSize += r.Size + } + + local := []unix.Iovec{ + { + Base: c.Address(start), + // We could keep this as full cache length, but we might as well be exact here. + Len: uint64(segmentSize), + }, + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + // We could retry only on the remaining segment size, but for simplicity we retry the whole segment. + n, err := unix.ProcessVMReadv(pid, + local, + remote, + 0, + ) + if errors.Is(err, unix.EAGAIN) { + continue + } + if errors.Is(err, unix.EINTR) { + continue + } + if errors.Is(err, unix.ENOMEM) { + time.Sleep(oomMinBackoff + time.Duration(rand.Intn(int(oomMaxJitter.Milliseconds())))*time.Millisecond) + + continue + } + + if err != nil { + return fmt.Errorf("failed to read memory: %w", err) + } + + if int64(n) != segmentSize { + return fmt.Errorf("failed to read memory: expected %d bytes, got %d", segmentSize, n) + } + + // Mark the copied data as cached so it can be read via Slice/ReadAt + c.dirty.AddOffsets(start, segmentSize) + + start += segmentSize + + break + } } - return iovMax, nil + return nil } diff --git a/packages/orchestrator/internal/sandbox/block/iov.go b/packages/orchestrator/internal/sandbox/block/iov.go new file mode 100644 index 0000000000..49e95837f3 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/block/iov.go @@ -0,0 +1,21 @@ +package block + +import ( + "fmt" + + "github.com/tklauser/go-sysconf" + + "github.com/e2b-dev/infra/packages/shared/pkg/utils" +) + +// IOV_MAX is the limit of the vectors that can be passed in a single ioctl call. +var IOV_MAX = utils.Must(getIOVMax()) + +func getIOVMax() (int64, error) { + iovMax, err := sysconf.Sysconf(sysconf.SC_IOV_MAX) + if err != nil { + return 0, fmt.Errorf("failed to get IOV_MAX: %w", err) + } + + return iovMax, nil +} diff --git a/packages/orchestrator/internal/sandbox/block/range.go b/packages/orchestrator/internal/sandbox/block/range.go index 52e21f55f3..53eb8c41f1 100644 --- a/packages/orchestrator/internal/sandbox/block/range.go +++ b/packages/orchestrator/internal/sandbox/block/range.go @@ -13,7 +13,7 @@ type Range struct { // Start is inclusive. Start int64 // Size is the size of the range in bytes. - Size uint64 + Size int64 } func (r *Range) End() int64 { @@ -24,7 +24,13 @@ func (r *Range) End() int64 { // This assumes the Range.Start is a multiple of the blockSize. func (r *Range) Offsets(blockSize int64) iter.Seq[int64] { return func(yield func(offset int64) bool) { - for off := r.Start; off < r.End(); off += blockSize { + getOffsets(r.Start, r.End(), blockSize)(yield) + } +} + +func getOffsets(start, end int64, blockSize int64) iter.Seq[int64] { + return func(yield func(offset int64) bool) { + for off := start; off < end; off += blockSize { if !yield(off) { return } @@ -33,7 +39,7 @@ func (r *Range) Offsets(blockSize int64) iter.Seq[int64] { } // NewRange creates a new range from a start address and size in bytes. -func NewRange(start int64, size uint64) Range { +func NewRange(start int64, size int64) Range { return Range{ Start: start, Size: size, @@ -44,7 +50,7 @@ func NewRange(start int64, size uint64) Range { func NewRangeFromBlocks(startIdx, numberOfBlocks, blockSize int64) Range { return Range{ Start: header.BlockOffset(startIdx, blockSize), - Size: uint64(header.BlockOffset(numberOfBlocks, blockSize)), + Size: header.BlockOffset(numberOfBlocks, blockSize), } } @@ -70,7 +76,7 @@ func BitsetRanges(b *bitset.BitSet, blockSize int64) iter.Seq[Range] { } } -func GetSize(rs []Range) (size uint64) { +func GetSize(rs []Range) (size int64) { for _, r := range rs { size += r.Size } diff --git a/packages/orchestrator/internal/sandbox/block/range_test.go b/packages/orchestrator/internal/sandbox/block/range_test.go new file mode 100644 index 0000000000..3de26f6df0 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/block/range_test.go @@ -0,0 +1,475 @@ +package block + +import ( + "slices" + "testing" + + "github.com/bits-and-blooms/bitset" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRange_End(t *testing.T) { + tests := []struct { + name string + start int64 + size uint64 + expected int64 + }{ + { + name: "zero size", + start: 100, + size: 0, + expected: 100, + }, + { + name: "single byte", + start: 0, + size: 1, + expected: 1, + }, + { + name: "multiple bytes", + start: 10, + size: 20, + expected: 30, + }, + { + name: "large size", + start: 0, + size: 1024 * 1024, + expected: 1024 * 1024, + }, + { + name: "negative start", + start: -100, + size: 50, + expected: -50, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := Range{ + Start: tt.start, + Size: tt.size, + } + assert.Equal(t, tt.expected, r.End()) + }) + } +} + +func TestNewRange(t *testing.T) { + tests := []struct { + name string + start int64 + size uint64 + expected Range + }{ + { + name: "basic range", + start: 0, + size: 4096, + expected: Range{ + Start: 0, + Size: 4096, + }, + }, + { + name: "non-zero start", + start: 8192, + size: 2048, + expected: Range{ + Start: 8192, + Size: 2048, + }, + }, + { + name: "zero size", + start: 100, + size: 0, + expected: Range{ + Start: 100, + Size: 0, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := NewRange(tt.start, tt.size) + assert.Equal(t, tt.expected, r) + }) + } +} + +func TestNewRangeFromBlocks(t *testing.T) { + tests := []struct { + name string + startIdx int64 + numberOfBlocks int64 + blockSize int64 + expected Range + }{ + { + name: "single block at start", + startIdx: 0, + numberOfBlocks: 1, + blockSize: 4096, + expected: Range{ + Start: 0, + Size: 4096, + }, + }, + { + name: "multiple blocks", + startIdx: 0, + numberOfBlocks: 3, + blockSize: 4096, + expected: Range{ + Start: 0, + Size: 12288, // 3 * 4096 + }, + }, + { + name: "blocks starting at non-zero index", + startIdx: 5, + numberOfBlocks: 2, + blockSize: 4096, + expected: Range{ + Start: 20480, // 5 * 4096 + Size: 8192, // 2 * 4096 + }, + }, + { + name: "zero blocks", + startIdx: 10, + numberOfBlocks: 0, + blockSize: 4096, + expected: Range{ + Start: 40960, // 10 * 4096 + Size: 0, + }, + }, + { + name: "different block size", + startIdx: 0, + numberOfBlocks: 4, + blockSize: 8192, + expected: Range{ + Start: 0, + Size: 32768, // 4 * 8192 + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := NewRangeFromBlocks(tt.startIdx, tt.numberOfBlocks, tt.blockSize) + assert.Equal(t, tt.expected, r) + }) + } +} + +func TestRange_Offsets(t *testing.T) { + tests := []struct { + name string + range_ Range + blockSize int64 + expected []int64 + }{ + { + name: "single block", + range_: Range{ + Start: 0, + Size: 4096, + }, + blockSize: 4096, + expected: []int64{0}, + }, + { + name: "multiple blocks", + range_: Range{ + Start: 0, + Size: 12288, // 3 * 4096 + }, + blockSize: 4096, + expected: []int64{0, 4096, 8192}, + }, + { + name: "non-zero start", + range_: Range{ + Start: 8192, + Size: 8192, // 2 * 4096 + }, + blockSize: 4096, + expected: []int64{8192, 12288}, + }, + { + name: "zero size", + range_: Range{ + Start: 4096, + Size: 0, + }, + blockSize: 4096, + expected: []int64{}, + }, + { + name: "smaller than block size", + range_: Range{ + Start: 0, + Size: 1024, + }, + blockSize: 4096, + expected: []int64{0}, + }, + { + name: "different block size", + range_: Range{ + Start: 0, + Size: 16384, // 4 * 4096 + }, + blockSize: 8192, + expected: []int64{0, 8192}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + offsets := slices.Collect(tt.range_.Offsets(tt.blockSize)) + if len(tt.expected) == 0 { + assert.Empty(t, offsets) + } else { + assert.Equal(t, tt.expected, offsets) + } + }) + } +} + +func TestRange_Offsets_Iteration(t *testing.T) { + // Test that iteration can be stopped early + r := Range{ + Start: 0, + Size: 40960, // 10 * 4096 + } + blockSize := int64(4096) + + var collected []int64 + for offset := range r.Offsets(blockSize) { + collected = append(collected, offset) + if len(collected) >= 3 { + break + } + } + + assert.Len(t, collected, 3) + assert.Equal(t, []int64{0, 4096, 8192}, collected) +} + +func TestBitsetRanges_Empty(t *testing.T) { + b := bitset.New(100) + blockSize := int64(4096) + + ranges := slices.Collect(BitsetRanges(b, blockSize)) + assert.Empty(t, ranges) +} + +func TestBitsetRanges_SingleBit(t *testing.T) { + b := bitset.New(100) + b.Set(5) + blockSize := int64(4096) + + ranges := slices.Collect(BitsetRanges(b, blockSize)) + require.Len(t, ranges, 1) + assert.Equal(t, Range{ + Start: 20480, // 5 * 4096 + Size: 4096, + }, ranges[0]) +} + +func TestBitsetRanges_Contiguous(t *testing.T) { + b := bitset.New(100) + // Set bits 2, 3, 4, 5 + b.Set(2) + b.Set(3) + b.Set(4) + b.Set(5) + blockSize := int64(4096) + + ranges := slices.Collect(BitsetRanges(b, blockSize)) + require.Len(t, ranges, 1) + assert.Equal(t, Range{ + Start: 8192, // 2 * 4096 + Size: 16384, // 4 * 4096 + }, ranges[0]) +} + +func TestBitsetRanges_MultipleRanges(t *testing.T) { + b := bitset.New(100) + // Set bits 1, 2, 3 (contiguous) + b.Set(1) + b.Set(2) + b.Set(3) + // Gap + // Set bits 7, 8 (contiguous) + b.Set(7) + b.Set(8) + blockSize := int64(4096) + + ranges := slices.Collect(BitsetRanges(b, blockSize)) + require.Len(t, ranges, 2) + assert.Equal(t, Range{ + Start: 4096, // 1 * 4096 + Size: 12288, // 3 * 4096 + }, ranges[0]) + assert.Equal(t, Range{ + Start: 28672, // 7 * 4096 + Size: 8192, // 2 * 4096 + }, ranges[1]) +} + +func TestBitsetRanges_AllSet(t *testing.T) { + b := bitset.New(10) + for i := uint(0); i < 10; i++ { + b.Set(i) + } + blockSize := int64(4096) + + ranges := slices.Collect(BitsetRanges(b, blockSize)) + require.Len(t, ranges, 1) + assert.Equal(t, Range{ + Start: 0, + Size: 40960, // 10 * 4096 + }, ranges[0]) +} + +func TestBitsetRanges_EndOfBitset(t *testing.T) { + b := bitset.New(20) + // Set bits 15, 16, 17, 18, 19 (at the end) + for i := uint(15); i < 20; i++ { + b.Set(i) + } + blockSize := int64(4096) + + ranges := slices.Collect(BitsetRanges(b, blockSize)) + require.Len(t, ranges, 1) + assert.Equal(t, Range{ + Start: 61440, // 15 * 4096 + Size: 20480, // 5 * 4096 + }, ranges[0]) +} + +func TestBitsetRanges_Sparse(t *testing.T) { + b := bitset.New(100) + // Set individual bits with gaps + b.Set(0) + b.Set(10) + b.Set(20) + b.Set(30) + blockSize := int64(4096) + + ranges := slices.Collect(BitsetRanges(b, blockSize)) + require.Len(t, ranges, 4) + assert.Equal(t, Range{Start: 0, Size: 4096}, ranges[0]) + assert.Equal(t, Range{Start: 40960, Size: 4096}, ranges[1]) + assert.Equal(t, Range{Start: 81920, Size: 4096}, ranges[2]) + assert.Equal(t, Range{Start: 122880, Size: 4096}, ranges[3]) +} + +func TestGetSize(t *testing.T) { + tests := []struct { + name string + ranges []Range + expected uint64 + }{ + { + name: "empty", + ranges: []Range{}, + expected: 0, + }, + { + name: "single range", + ranges: []Range{ + {Start: 0, Size: 4096}, + }, + expected: 4096, + }, + { + name: "multiple ranges", + ranges: []Range{ + {Start: 0, Size: 4096}, + {Start: 8192, Size: 8192}, + {Start: 16384, Size: 4096}, + }, + expected: 16384, // 4096 + 8192 + 4096 + }, + { + name: "zero size ranges", + ranges: []Range{ + {Start: 0, Size: 0}, + {Start: 4096, Size: 4096}, + {Start: 8192, Size: 0}, + }, + expected: 4096, + }, + { + name: "large sizes", + ranges: []Range{ + {Start: 0, Size: 1024 * 1024}, + {Start: 1024 * 1024, Size: 2 * 1024 * 1024}, + }, + expected: 3 * 1024 * 1024, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + size := GetSize(tt.ranges) + assert.Equal(t, tt.expected, size) + }) + } +} + +func TestRange_Offsets_EdgeCases(t *testing.T) { + tests := []struct { + name string + range_ Range + blockSize int64 + expected []int64 + }{ + { + name: "exact block boundary end", + range_: Range{ + Start: 0, + Size: 12288, // exactly 3 blocks + }, + blockSize: 4096, + expected: []int64{0, 4096, 8192}, + }, + { + name: "one byte over block boundary", + range_: Range{ + Start: 0, + Size: 12289, // 3 blocks + 1 byte + }, + blockSize: 4096, + expected: []int64{0, 4096, 8192, 12288}, + }, + { + name: "one byte less than block boundary", + range_: Range{ + Start: 0, + Size: 12287, // 3 blocks - 1 byte + }, + blockSize: 4096, + expected: []int64{0, 4096, 8192}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + offsets := slices.Collect(tt.range_.Offsets(tt.blockSize)) + assert.Equal(t, tt.expected, offsets) + }) + } +} diff --git a/packages/orchestrator/internal/sandbox/block/tracker.go b/packages/orchestrator/internal/sandbox/block/tracker.go index dc0c74e853..e0408acd0f 100644 --- a/packages/orchestrator/internal/sandbox/block/tracker.go +++ b/packages/orchestrator/internal/sandbox/block/tracker.go @@ -39,6 +39,19 @@ func (t *Tracker) Has(off int64) bool { return t.b.Test(uint(header.BlockIdx(off, t.blockSize))) } +func (t *Tracker) HasOffsets(off, length int64) bool { + t.mu.RLock() + defer t.mu.RUnlock() + + for off := range getOffsets(off, off+length, t.blockSize) { + if !t.b.Test(uint(header.BlockIdx(off, t.blockSize))) { + return false + } + } + + return true +} + func (t *Tracker) Add(off int64) { t.mu.Lock() defer t.mu.Unlock() @@ -46,6 +59,15 @@ func (t *Tracker) Add(off int64) { t.b.Set(uint(header.BlockIdx(off, t.blockSize))) } +func (t *Tracker) AddOffsets(off, length int64) { + t.mu.Lock() + defer t.mu.Unlock() + + for off := range getOffsets(off, off+length, t.blockSize) { + t.b.Set(uint(header.BlockIdx(off, t.blockSize))) + } +} + func (t *Tracker) Reset() { t.mu.Lock() defer t.mu.Unlock() diff --git a/packages/orchestrator/internal/sandbox/fc/client.go b/packages/orchestrator/internal/sandbox/fc/client.go index fa800b7e5e..43b62ee6ae 100644 --- a/packages/orchestrator/internal/sandbox/fc/client.go +++ b/packages/orchestrator/internal/sandbox/fc/client.go @@ -67,6 +67,7 @@ func (c *apiClient) loadSnapshot( EnableDiffSnapshots: false, MemBackend: backend, SnapshotPath: &snapfilePath, + NetworkOverrides: []*models.NetworkOverride{}, }, } diff --git a/packages/orchestrator/internal/sandbox/fc/memory.go b/packages/orchestrator/internal/sandbox/fc/memory.go index e1536b2974..7158dea527 100644 --- a/packages/orchestrator/internal/sandbox/fc/memory.go +++ b/packages/orchestrator/internal/sandbox/fc/memory.go @@ -47,12 +47,16 @@ func (p *Process) ExportMemory( size := block.GetSize(remoteRanges) - cache, err := block.NewCache(int64(size), blockSize, cachePath, false) + cache, err := block.NewCache( + cachePath, + size, + blockSize, + ) if err != nil { return nil, fmt.Errorf("failed to create cache: %w", err) } - err = cache.CopyFromProcess(ctx, pid, remoteRanges) + err = cache.CopyProcessMemory(ctx, pid, remoteRanges) if err != nil { // Close the cache even if the copy fails. return nil, fmt.Errorf("failed to copy process memory: %w", errors.Join(err, cache.Close())) diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/cache_copyfromprocess_test.go b/packages/orchestrator/internal/sandbox/uffd/memory/cache_copyfromprocess_test.go new file mode 100644 index 0000000000..06d9d2bf8a --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/memory/cache_copyfromprocess_test.go @@ -0,0 +1,459 @@ +package memory + +import ( + "context" + "encoding/binary" + "fmt" + "os" + "os/exec" + "os/signal" + "strconv" + "syscall" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +// TestCopyFromProcess_HelperProcess is a helper process that allocates memory +// with known content and waits for the test to read it. +func TestCopyFromProcess_HelperProcess(t *testing.T) { + if os.Getenv("GO_TEST_COPY_HELPER_PROCESS") != "1" { + t.Skip("this is a helper process, skipping direct execution") + } + + // Allocate memory with known content + sizeStr := os.Getenv("GO_TEST_MEMORY_SIZE") + size, err := strconv.ParseUint(sizeStr, 10, 64) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to parse memory size: %v\n", err) + os.Exit(1) + } + + // Allocate memory using mmap (similar to testutils.NewPageMmap but simpler) + mem, err := unix.Mmap(-1, 0, int(size), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_ANON|unix.MAP_PRIVATE) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to mmap memory: %v\n", err) + os.Exit(1) + } + defer unix.Munmap(mem) + + // Fill memory with a pattern: each byte is its offset modulo 256 + for i := range mem { + mem[i] = byte(i % 256) + } + + // Write the memory address to stdout (as 8 bytes, little endian) + addr := uint64(0) + if len(mem) > 0 { + addr = uint64(uintptr(unsafe.Pointer(&mem[0]))) + } + err = binary.Write(os.Stdout, binary.LittleEndian, addr) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to write address: %v\n", err) + os.Exit(1) + } + + // Signal ready by closing stdout + os.Stdout.Close() + + // Wait for SIGTERM to exit + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGTERM) + defer signal.Stop(sigChan) + + select { + case <-sigChan: + // Exit cleanly + case <-time.After(30 * time.Second): + // Timeout after 30 seconds + fmt.Fprintf(os.Stderr, "helper process timeout\n") + os.Exit(1) + } +} + +func TestCopyFromProcess_Success(t *testing.T) { + t.Parallel() + + ctx := context.Background() + size := int64(4096) + + // Create cache + tmpFile := t.TempDir() + "/cache" + cache, err := block.NewCache( + tmpFile, + size, + header.PageSize, + ) + require.NoError(t, err) + defer cache.Close() + + // Start helper process + cmd := exec.CommandContext(ctx, os.Args[0], "-test.run=TestCopyFromProcess_HelperProcess") + cmd.Env = append(os.Environ(), "GO_TEST_COPY_HELPER_PROCESS=1") + cmd.Env = append(cmd.Env, fmt.Sprintf("GO_TEST_MEMORY_SIZE=%d", size)) + cmd.Stderr = os.Stderr + + stdout, err := cmd.StdoutPipe() + require.NoError(t, err) + + err = cmd.Start() + require.NoError(t, err) + + t.Cleanup(func() { + cmd.Process.Signal(syscall.SIGTERM) + cmd.Wait() + }) + + // Read the memory address from the helper process + var addr uint64 + err = binary.Read(stdout, binary.LittleEndian, &addr) + require.NoError(t, err) + stdout.Close() + + // Wait a bit for the process to be ready + time.Sleep(100 * time.Millisecond) + + // Test copying a single range + ranges := []block.Range{ + {Start: int64(addr), Size: size}, + } + + err = cache.CopyFromProcess(ctx, cmd.Process.Pid, ranges) + require.NoError(t, err) + + // Verify the copied data + data := make([]byte, size) + n, err := cache.ReadAt(data, 0) + require.NoError(t, err) + require.Equal(t, int(size), n) + + // Verify pattern: each byte should be its offset modulo 256 + for i := range data { + expected := byte(i % 256) + assert.Equal(t, expected, data[i], "byte at offset %d should be %d, got %d", i, expected, data[i]) + } +} + +func TestCopyFromProcess_MultipleRanges(t *testing.T) { + t.Parallel() + + ctx := context.Background() + segmentSize := uint64(1024) + totalSize := segmentSize * 3 + + // Create cache + tmpFile := t.TempDir() + "/cache" + cache, err := block.NewCache(tmpFile, int64(totalSize), header.PageSize) + require.NoError(t, err) + defer cache.Close() + + // Start helper process + cmd := exec.CommandContext(ctx, os.Args[0], "-test.run=TestCopyFromProcess_HelperProcess") + cmd.Env = append(os.Environ(), "GO_TEST_COPY_HELPER_PROCESS=1") + cmd.Env = append(cmd.Env, fmt.Sprintf("GO_TEST_MEMORY_SIZE=%d", totalSize)) + cmd.Stderr = os.Stderr + + stdout, err := cmd.StdoutPipe() + require.NoError(t, err) + + err = cmd.Start() + require.NoError(t, err) + + t.Cleanup(func() { + cmd.Process.Signal(syscall.SIGTERM) + cmd.Wait() + }) + + // Read the memory address from the helper process + var baseAddr uint64 + err = binary.Read(stdout, binary.LittleEndian, &baseAddr) + require.NoError(t, err) + stdout.Close() + + // Wait a bit for the process to be ready + time.Sleep(100 * time.Millisecond) + + // Test copying multiple non-contiguous ranges + ranges := []block.Range{ + {Start: int64(baseAddr), Size: segmentSize}, + {Start: int64(baseAddr + segmentSize*2), Size: segmentSize}, + {Start: int64(baseAddr + segmentSize), Size: segmentSize}, + } + + err = cache.CopyFromProcess(ctx, cmd.Process.Pid, ranges) + require.NoError(t, err) + + // Verify the first segment + data1 := make([]byte, segmentSize) + n, err := cache.ReadAt(data1, 0) + require.NoError(t, err) + require.Equal(t, int(segmentSize), n) + for i := range data1 { + expected := byte(i % 256) + assert.Equal(t, expected, data1[i], "first segment, byte at offset %d", i) + } + + // Verify the second segment (copied to offset segmentSize*2 in cache) + data2 := make([]byte, segmentSize) + n, err = cache.ReadAt(data2, int64(segmentSize*2)) + require.NoError(t, err) + require.Equal(t, int(segmentSize), n) + for i := range data2 { + expected := byte((int(segmentSize*2) + i) % 256) + assert.Equal(t, expected, data2[i], "second segment, byte at offset %d", i) + } + + // Verify the third segment (copied to offset segmentSize in cache) + data3 := make([]byte, segmentSize) + n, err = cache.ReadAt(data3, int64(segmentSize)) + require.NoError(t, err) + require.Equal(t, int(segmentSize), n) + for i := range data3 { + expected := byte((int(segmentSize) + i) % 256) + assert.Equal(t, expected, data3[i], "third segment, byte at offset %d", i) + } +} + +func TestCopyFromProcess_EmptyRanges(t *testing.T) { + t.Parallel() + + ctx := context.Background() + size := uint64(4096) + + // Create cache + tmpFile := t.TempDir() + "/cache" + cache, err := block.NewCache(int64(size), int64(header.PageSize), tmpFile, false) + require.NoError(t, err) + defer cache.Close() + + // Test with empty ranges + ranges := []block.Range{} + err = cache.CopyFromProcess(ctx, os.Getpid(), ranges) + require.NoError(t, err) +} + +func TestCopyFromProcess_ContextCancellation(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + size := uint64(4096) + + // Create cache + tmpFile := t.TempDir() + "/cache" + cache, err := block.NewCache(int64(size), int64(header.PageSize), tmpFile, false) + require.NoError(t, err) + defer cache.Close() + + // Start helper process + cmd := exec.CommandContext(ctx, os.Args[0], "-test.run=TestCopyFromProcess_HelperProcess") + cmd.Env = append(os.Environ(), "GO_TEST_COPY_HELPER_PROCESS=1") + cmd.Env = append(cmd.Env, fmt.Sprintf("GO_TEST_MEMORY_SIZE=%d", size)) + cmd.Stderr = os.Stderr + + stdout, err := cmd.StdoutPipe() + require.NoError(t, err) + + err = cmd.Start() + require.NoError(t, err) + + t.Cleanup(func() { + cmd.Process.Signal(syscall.SIGTERM) + cmd.Wait() + }) + + // Read the memory address from the helper process + var addr uint64 + err = binary.Read(stdout, binary.LittleEndian, &addr) + require.NoError(t, err) + stdout.Close() + + // Wait a bit for the process to be ready + time.Sleep(100 * time.Millisecond) + + // Cancel context immediately + cancel() + + // Test copying with cancelled context + ranges := []block.Range{ + {Start: int64(addr), Size: size}, + } + + err = cache.CopyFromProcess(ctx, cmd.Process.Pid, ranges) + require.Error(t, err) + assert.Equal(t, context.Canceled, err) +} + +func TestCopyFromProcess_InvalidPID(t *testing.T) { + t.Parallel() + + ctx := context.Background() + size := uint64(4096) + + // Create cache + tmpFile := t.TempDir() + "/cache" + cache, err := block.NewCache(tmpFile, int64(size), header.PageSize) + require.NoError(t, err) + defer cache.Close() + + // Test with invalid PID (very high PID that doesn't exist) + invalidPID := 999999999 + ranges := []block.Range{ + {Start: 0x1000, Size: 1024}, + } + + err = cache.CopyFromProcess(ctx, invalidPID, ranges) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to read memory") +} + +func TestCopyFromProcess_InvalidAddress(t *testing.T) { + t.Parallel() + + ctx := context.Background() + size := uint64(4096) + + // Create cache + tmpFile := t.TempDir() + "/cache" + cache, err := block.NewCache(int64(size), int64(header.PageSize), tmpFile, false) + require.NoError(t, err) + defer cache.Close() + + // Test with invalid memory address (very high address that fits in int64) + invalidAddr := int64(0x7FFFFFFF00000000) // Large but valid int64 value + ranges := []block.Range{ + {Start: invalidAddr, Size: 1024}, + } + + err = cache.CopyFromProcess(ctx, os.Getpid(), ranges) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to read memory") +} + +func TestCopyFromProcess_ZeroSizeRange(t *testing.T) { + t.Parallel() + + ctx := context.Background() + size := uint64(4096) + + // Create cache + tmpFile := t.TempDir() + "/cache" + cache, err := block.NewCache(int64(size), int64(header.PageSize), tmpFile, false) + require.NoError(t, err) + defer cache.Close() + + // Start helper process + cmd := exec.CommandContext(ctx, os.Args[0], "-test.run=TestCopyFromProcess_HelperProcess") + cmd.Env = append(os.Environ(), "GO_TEST_COPY_HELPER_PROCESS=1") + cmd.Env = append(cmd.Env, fmt.Sprintf("GO_TEST_MEMORY_SIZE=%d", size)) + cmd.Stderr = os.Stderr + + stdout, err := cmd.StdoutPipe() + require.NoError(t, err) + + err = cmd.Start() + require.NoError(t, err) + + t.Cleanup(func() { + cmd.Process.Signal(syscall.SIGTERM) + cmd.Wait() + }) + + // Read the memory address from the helper process + var addr uint64 + err = binary.Read(stdout, binary.LittleEndian, &addr) + require.NoError(t, err) + stdout.Close() + + // Wait a bit for the process to be ready + time.Sleep(100 * time.Millisecond) + + // Test with zero-size range + ranges := []block.Range{ + {Start: int64(addr), Size: 0}, + } + + err = cache.CopyFromProcess(ctx, cmd.Process.Pid, ranges) + require.NoError(t, err) +} + +func TestCopyFromProcess_LargeRanges(t *testing.T) { + t.Parallel() + + ctx := context.Background() + // Use a size that exceeds IOV_MAX (typically 1024 on Linux) if we have many small ranges + // We'll use 1500 ranges to ensure we exceed IOV_MAX + numRanges := 1500 + rangeSize := uint64(64) // Small ranges + + // Create cache large enough for all ranges + totalSize := rangeSize * uint64(numRanges) + tmpFile := t.TempDir() + "/cache" + cache, err := block.NewCache(int64(totalSize), int64(header.PageSize), tmpFile, false) + require.NoError(t, err) + defer cache.Close() + + // Start helper process + cmd := exec.CommandContext(ctx, os.Args[0], "-test.run=TestCopyFromProcess_HelperProcess") + cmd.Env = append(os.Environ(), "GO_TEST_COPY_HELPER_PROCESS=1") + cmd.Env = append(cmd.Env, fmt.Sprintf("GO_TEST_MEMORY_SIZE=%d", totalSize)) + cmd.Stderr = os.Stderr + + stdout, err := cmd.StdoutPipe() + require.NoError(t, err) + + err = cmd.Start() + require.NoError(t, err) + + t.Cleanup(func() { + cmd.Process.Signal(syscall.SIGTERM) + cmd.Wait() + }) + + // Read the memory address from the helper process + var baseAddr uint64 + err = binary.Read(stdout, binary.LittleEndian, &baseAddr) + require.NoError(t, err) + stdout.Close() + + // Wait a bit for the process to be ready + time.Sleep(100 * time.Millisecond) + + // Create many small ranges that exceed IOV_MAX + ranges := make([]block.Range, numRanges) + for i := 0; i < numRanges; i++ { + ranges[i] = block.Range{ + Start: int64(baseAddr) + int64(i)*int64(rangeSize), + Size: rangeSize, + } + } + + err = cache.CopyFromProcess(ctx, cmd.Process.Pid, ranges) + require.NoError(t, err) + + // Verify the data was copied correctly + // Check a few ranges to ensure they were copied + checkCount := 10 + if numRanges < checkCount { + checkCount = numRanges + } + for i := 0; i < checkCount; i++ { + offset := int64(i) * int64(rangeSize) + data := make([]byte, rangeSize) + n, err := cache.ReadAt(data, offset) + require.NoError(t, err) + require.Equal(t, int(rangeSize), n) + + // Verify pattern + for j := range data { + expected := byte((int(offset) + j) % 256) + assert.Equal(t, expected, data[j], "range %d, byte at offset %d", i, j) + } + } +} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go b/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go index d82ad56069..9a5103b070 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go @@ -61,17 +61,6 @@ func (m *Mapping) GetOffset(hostVirtAddr uintptr) (int64, uintptr, error) { return 0, 0, AddressNotFoundError{hostVirtAddr: hostVirtAddr} } -// GetHostVirtAddr returns the host virtual address and size of the remaining contiguous mapped host range for the given offset. -func (m *Mapping) GetHostVirtAddr(off int64) (uintptr, int64, error) { - for _, r := range m.Regions { - if off >= int64(r.Offset) && off < r.endOffset() { - return r.shiftedHostVirtAddr(off), r.endOffset() - off, nil - } - } - - return 0, 0, OffsetNotFoundError{offset: off} -} - // GetHostVirtRanges returns the host virtual addresses and sizes (ranges) that cover exactly the given [offset, offset+length) range in the host virtual address space. func (m *Mapping) GetHostVirtRanges(off int64, size int64) (hostVirtRanges []block.Range, err error) { for n := int64(0); n < size; { @@ -85,7 +74,7 @@ func (m *Mapping) GetHostVirtRanges(off int64, size int64) (hostVirtRanges []blo start := region.shiftedHostVirtAddr(currentOff) remainingSize := min(int64(region.endHostVirtAddr()-start), size-n) - r := block.NewRange(int64(start), uint64(remainingSize)) + r := block.NewRange(int64(start), remainingSize) hostVirtRanges = append(hostVirtRanges, r) diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/mapping_host_virt_test.go b/packages/orchestrator/internal/sandbox/uffd/memory/mapping_host_virt_test.go index 05e8cc75f5..f004708039 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory/mapping_host_virt_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory/mapping_host_virt_test.go @@ -7,10 +7,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) -func TestMapping_GetHostVirtAddr(t *testing.T) { +func TestMapping_GetHostVirtRanges(t *testing.T) { t.Parallel() regions := []Region{ @@ -30,61 +31,108 @@ func TestMapping_GetHostVirtAddr(t *testing.T) { mapping := NewMapping(regions) tests := []struct { - name string - offset int64 - expectedHostVirt uintptr - remainingRegionSize int64 - expectError error + name string + offset int64 + size int64 + expectedRanges []block.Range + expectError error + expectErrorAt int64 // offset where error should occur }{ { - name: "valid offset in first region", - offset: 0x5500, // 0x5000 + (0x1500 - 0x1000) - expectedHostVirt: 0x1500, // 0x1000 + (0x5500 - 0x5000) - // region ends at 0x7000; remaining = 0x7000 - 0x5500 = 0x1b00 - remainingRegionSize: 0x1b00, + name: "valid offset in first region, single byte", + offset: 0x5500, // 0x5000 + (0x1500 - 0x1000) + size: 0x1, + expectedRanges: []block.Range{ + {Start: 0x1500, Size: 0x1}, // 0x1000 + (0x5500 - 0x5000) + }, }, { - name: "valid offset at start of first region", - offset: 0x5000, - expectedHostVirt: 0x1000, // 0x1000 + (0x5000 - 0x5000) - remainingRegionSize: 0x2000, // 0x7000 - 0x5000 + name: "valid offset at start of first region, full region size", + offset: 0x5000, + size: 0x2000, + expectedRanges: []block.Range{ + {Start: 0x1000, Size: 0x2000}, // 0x1000 + (0x5000 - 0x5000) + }, }, { - name: "valid offset near end of first region", - offset: 0x6FFF, // 0x7000 - 1 - expectedHostVirt: 0x2FFF, // 0x1000 + (0x6FFF - 0x5000) - remainingRegionSize: 0x1, // 0x7000 - 0x6FFF + name: "valid offset near end of first region, single byte", + offset: 0x6FFF, // 0x7000 - 1 + size: 0x1, + expectedRanges: []block.Range{ + {Start: 0x2FFF, Size: 0x1}, // 0x1000 + (0x6FFF - 0x5000) + }, }, { - name: "valid offset at start of second region", - offset: 0x8000, - expectedHostVirt: 0x5000, // 0x5000 + (0x8000 - 0x8000) - remainingRegionSize: 0x1000, // 0x9000 - 0x8000 + name: "valid offset at start of second region, full region size", + offset: 0x8000, + size: 0x1000, + expectedRanges: []block.Range{ + {Start: 0x5000, Size: 0x1000}, // 0x5000 + (0x8000 - 0x8000) + }, }, { - name: "offset before first region", - offset: 0x4000, - expectError: OffsetNotFoundError{offset: 0x4000}, + name: "offset before first region", + offset: 0x4000, + size: 0x100, + expectError: OffsetNotFoundError{offset: 0x4000}, + expectErrorAt: 0x4000, }, { - name: "offset after last region", - offset: 0xA000, - expectError: OffsetNotFoundError{offset: 0xA000}, + name: "offset after last region", + offset: 0xA000, + size: 0x100, + expectError: OffsetNotFoundError{offset: 0xA000}, + expectErrorAt: 0xA000, }, { - name: "offset in gap between regions", - offset: 0x7000, - expectError: OffsetNotFoundError{offset: 0x7000}, + name: "offset in gap between regions", + offset: 0x7000, + size: 0x100, + expectError: OffsetNotFoundError{offset: 0x7000}, + expectErrorAt: 0x7000, }, { - name: "offset at exact end of first region (exclusive)", - offset: 0x7000, // 0x5000 + 0x2000 - expectError: OffsetNotFoundError{offset: 0x7000}, + name: "offset at exact end of first region (exclusive)", + offset: 0x7000, // 0x5000 + 0x2000 + size: 0x100, + expectError: OffsetNotFoundError{offset: 0x7000}, + expectErrorAt: 0x7000, }, { - name: "offset at exact end of second region (exclusive)", - offset: 0x9000, // 0x8000 + 0x1000 - expectError: OffsetNotFoundError{offset: 0x9000}, + name: "offset at exact end of second region (exclusive)", + offset: 0x9000, // 0x8000 + 0x1000 + size: 0x100, + expectError: OffsetNotFoundError{offset: 0x9000}, + expectErrorAt: 0x9000, + }, + { + name: "range spanning from first region into gap (should fail at gap)", + offset: 0x6F00, + size: 0x200, // extends to 0x7100, crossing gap at 0x7000 + expectError: OffsetNotFoundError{offset: 0x7000}, + expectErrorAt: 0x7000, + }, + { + name: "range spanning both regions (fails due to gap)", + offset: 0x6F00, + size: 0x1100, // from 0x6F00 to 0x8000, but gap at 0x7000 + expectError: OffsetNotFoundError{offset: 0x7000}, + expectErrorAt: 0x7000, + }, + { + name: "range within first region, partial", + offset: 0x5500, + size: 0x500, // 0x5500 to 0x5A00 + expectedRanges: []block.Range{ + {Start: 0x1500, Size: 0x500}, // 0x1000 + (0x5500 - 0x5000) + }, + }, + { + name: "range from end of first region to start of second (fails at gap)", + offset: 0x6FFF, + size: 0x1001, // from 0x6FFF to 0x8000, crossing gap + expectError: OffsetNotFoundError{offset: 0x7000}, + expectErrorAt: 0x7000, }, } @@ -92,29 +140,32 @@ func TestMapping_GetHostVirtAddr(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - hostVirt, size, err := mapping.GetHostVirtAddr(tt.offset) + ranges, err := mapping.GetHostVirtRanges(tt.offset, tt.size) if tt.expectError != nil { - require.ErrorIs(t, err, tt.expectError) + require.Error(t, err) + var offsetErr OffsetNotFoundError + require.ErrorAs(t, err, &offsetErr) + assert.Equal(t, tt.expectErrorAt, offsetErr.offset) + assert.Nil(t, ranges) } else { require.NoError(t, err) - assert.Equal(t, tt.expectedHostVirt, hostVirt, "hostVirt: %d, expectedHostVirt: %d", hostVirt, tt.expectedHostVirt) - assert.Equal(t, tt.remainingRegionSize, size, "size: %d, expectedSize: %d", size, tt.remainingRegionSize) + assert.Equal(t, tt.expectedRanges, ranges) } }) } } -func TestMapping_GetHostVirtAddr_EmptyRegions(t *testing.T) { +func TestMapping_GetHostVirtRanges_EmptyRegions(t *testing.T) { t.Parallel() mapping := NewMapping([]Region{}) - // Test GetHostVirtAddr with empty regions - _, _, err := mapping.GetHostVirtAddr(0x1000) + // Test GetHostVirtRanges with empty regions + _, err := mapping.GetHostVirtRanges(0x1000, 0x100) require.ErrorIs(t, err, OffsetNotFoundError{offset: 0x1000}) } -func TestMapping_GetHostVirtAddr_BoundaryConditions(t *testing.T) { +func TestMapping_GetHostVirtRanges_BoundaryConditions(t *testing.T) { t.Parallel() regions := []Region{ @@ -129,27 +180,25 @@ func TestMapping_GetHostVirtAddr_BoundaryConditions(t *testing.T) { mapping := NewMapping(regions) // Test exact start boundary - hostVirt, size, err := mapping.GetHostVirtAddr(0x5000) + ranges, err := mapping.GetHostVirtRanges(0x5000, 0x2000) require.NoError(t, err) - assert.Equal(t, uintptr(0x1000), hostVirt) // 0x1000 + (0x5000 - 0x5000) - assert.Equal(t, int64(0x7000-0x5000), size) // 0x2000 + assert.Equal(t, []block.Range{{Start: 0x1000, Size: 0x2000}}, ranges) // Test offset before end boundary - hostVirt, size, err = mapping.GetHostVirtAddr(0x6FFF) // just before end + ranges, err = mapping.GetHostVirtRanges(0x6FFF, 0x1) // just before end require.NoError(t, err) - assert.Equal(t, uintptr(0x1000+(0x6FFF-0x5000)), hostVirt) - assert.Equal(t, int64(0x7000-0x6FFF), size) + assert.Equal(t, []block.Range{{Start: 0x2FFF, Size: 0x1}}, ranges) // Test exact end boundary (should fail - exclusive) - _, _, err = mapping.GetHostVirtAddr(0x7000) + _, err = mapping.GetHostVirtRanges(0x7000, 0x100) require.ErrorIs(t, err, OffsetNotFoundError{offset: 0x7000}) // Test below start boundary (should fail) - _, _, err = mapping.GetHostVirtAddr(0x4000) + _, err = mapping.GetHostVirtRanges(0x4000, 0x100) require.ErrorIs(t, err, OffsetNotFoundError{offset: 0x4000}) } -func TestMapping_GetHostVirtAddr_SingleLargeRegion(t *testing.T) { +func TestMapping_GetHostVirtRanges_SingleLargeRegion(t *testing.T) { t.Parallel() // Entire 64-bit address space region @@ -163,13 +212,12 @@ func TestMapping_GetHostVirtAddr_SingleLargeRegion(t *testing.T) { } mapping := NewMapping(regions) - hostVirt, size, err := mapping.GetHostVirtAddr(0x100 + 0x1000) // Offset 0x1100 + ranges, err := mapping.GetHostVirtRanges(0x100+0x1000, 0x1000) // Offset 0x1100, size 0x1000 require.NoError(t, err) - assert.Equal(t, uintptr(0x1000), hostVirt) // 0x1000 - assert.Equal(t, int64(math.MaxInt64-0x100-0x1000), size) + assert.Equal(t, []block.Range{{Start: 0x1000, Size: 0x1000}}, ranges) } -func TestMapping_GetHostVirtAddr_ZeroSizeRegion(t *testing.T) { +func TestMapping_GetHostVirtRanges_ZeroSizeRegion(t *testing.T) { t.Parallel() regions := []Region{ @@ -183,11 +231,11 @@ func TestMapping_GetHostVirtAddr_ZeroSizeRegion(t *testing.T) { mapping := NewMapping(regions) - _, _, err := mapping.GetHostVirtAddr(0x1000) + _, err := mapping.GetHostVirtRanges(0x1000, 0x100) require.ErrorIs(t, err, OffsetNotFoundError{offset: 0x1000}) } -func TestMapping_GetHostVirtAddr_MultipleRegionsSparse(t *testing.T) { +func TestMapping_GetHostVirtRanges_MultipleRegionsSparse(t *testing.T) { t.Parallel() regions := []Region{ @@ -207,24 +255,21 @@ func TestMapping_GetHostVirtAddr_MultipleRegionsSparse(t *testing.T) { mapping := NewMapping(regions) // Should succeed for start of first region - hostVirt, size, err := mapping.GetHostVirtAddr(0x1000) + ranges, err := mapping.GetHostVirtRanges(0x1000, 0x100) require.NoError(t, err) - assert.Equal(t, uintptr(0x100), hostVirt) // 0x100 + (0x1000 - 0x1000) - assert.Equal(t, int64(0x1100-0x1000), size) // 0x100 + assert.Equal(t, []block.Range{{Start: 0x100, Size: 0x100}}, ranges) // Should succeed for just before end of first region - hostVirt, size, err = mapping.GetHostVirtAddr(0x10FF) // 0x1100 - 1 + ranges, err = mapping.GetHostVirtRanges(0x10FF, 0x1) // 0x1100 - 1 require.NoError(t, err) - assert.Equal(t, uintptr(0x100+(0x10FF-0x1000)), hostVirt) - assert.Equal(t, int64(0x1100-0x10FF), size) // 1 + assert.Equal(t, []block.Range{{Start: 0x1FF, Size: 0x1}}, ranges) // Should succeed for start of second region - hostVirt, size, err = mapping.GetHostVirtAddr(0x2000) + ranges, err = mapping.GetHostVirtRanges(0x2000, 0x100) require.NoError(t, err) - assert.Equal(t, uintptr(0x10000), hostVirt) // 0x10000 + (0x2000 - 0x2000) - assert.Equal(t, int64(0x2100-0x2000), size) // 0x100 + assert.Equal(t, []block.Range{{Start: 0x10000, Size: 0x100}}, ranges) // In gap - _, _, err = mapping.GetHostVirtAddr(0x1500) + _, err = mapping.GetHostVirtRanges(0x1500, 0x100) require.ErrorIs(t, err, OffsetNotFoundError{offset: 0x1500}) } From 565caf70abaed3c5fb8ed37bebfa5b2011d5451b Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Thu, 18 Dec 2025 01:15:07 -0800 Subject: [PATCH 14/40] Fix error --- .../internal/sandbox/block/cache.go | 20 ++++++++++--------- .../internal/sandbox/block/iov.go | 4 ++-- .../internal/sandbox/fc/memory.go | 3 ++- .../uffd/memory/cache_copyfromprocess_test.go | 17 ++++++++-------- 4 files changed, 24 insertions(+), 20 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/block/cache.go b/packages/orchestrator/internal/sandbox/block/cache.go index d51ca2fb81..d3f1094885 100644 --- a/packages/orchestrator/internal/sandbox/block/cache.go +++ b/packages/orchestrator/internal/sandbox/block/cache.go @@ -279,7 +279,7 @@ func (c *Cache) FileSize() (int64, error) { return stat.Blocks * fsStat.Bsize, nil } -func (c *Cache) Address(off uint64) *byte { +func (c *Cache) Address(off int64) *byte { return &(*c.mmap)[off] } @@ -296,14 +296,15 @@ func (c *Cache) CopyFromProcess( pid int, ranges []Range, ) error { - var start uint64 + var start int64 - for i := 0; i < len(ranges); i += int(IOV_MAX) { - segmentRanges := ranges[i:min(i+int(IOV_MAX), len(ranges))] + for i := 0; i < len(ranges); i += IOV_MAX { + // TODO: Is this accumulation correct? + segmentRanges := ranges[i:min(i+IOV_MAX, len(ranges))] remote := make([]unix.RemoteIovec, len(segmentRanges)) - var segmentSize uint64 + var segmentSize int64 for j, r := range segmentRanges { remote[j] = unix.RemoteIovec{ @@ -318,7 +319,7 @@ func (c *Cache) CopyFromProcess( { Base: c.Address(start), // We could keep this as full cache length, but we might as well be exact here. - Len: segmentSize, + Len: uint64(segmentSize), }, } @@ -351,7 +352,7 @@ func (c *Cache) CopyFromProcess( return fmt.Errorf("failed to read memory: %w", err) } - if uint64(n) != segmentSize { + if int64(n) != segmentSize { return fmt.Errorf("failed to read memory: expected %d bytes, got %d", segmentSize, n) } @@ -428,8 +429,9 @@ func (c *Cache) CopyProcessMemory( return fmt.Errorf("failed to read memory: expected %d bytes, got %d", segmentSize, n) } - // Mark the copied data as cached so it can be read via Slice/ReadAt - c.dirty.AddOffsets(start, segmentSize) + for _, blockOff := range header.BlocksOffsets(segmentSize, c.blockSize) { + c.dirty.Store(start+blockOff, struct{}{}) + } start += segmentSize diff --git a/packages/orchestrator/internal/sandbox/block/iov.go b/packages/orchestrator/internal/sandbox/block/iov.go index 49e95837f3..99b922b3c5 100644 --- a/packages/orchestrator/internal/sandbox/block/iov.go +++ b/packages/orchestrator/internal/sandbox/block/iov.go @@ -11,11 +11,11 @@ import ( // IOV_MAX is the limit of the vectors that can be passed in a single ioctl call. var IOV_MAX = utils.Must(getIOVMax()) -func getIOVMax() (int64, error) { +func getIOVMax() (int, error) { iovMax, err := sysconf.Sysconf(sysconf.SC_IOV_MAX) if err != nil { return 0, fmt.Errorf("failed to get IOV_MAX: %w", err) } - return iovMax, nil + return int(iovMax), nil } diff --git a/packages/orchestrator/internal/sandbox/fc/memory.go b/packages/orchestrator/internal/sandbox/fc/memory.go index 7158dea527..f0f6a70a9f 100644 --- a/packages/orchestrator/internal/sandbox/fc/memory.go +++ b/packages/orchestrator/internal/sandbox/fc/memory.go @@ -48,9 +48,10 @@ func (p *Process) ExportMemory( size := block.GetSize(remoteRanges) cache, err := block.NewCache( - cachePath, size, blockSize, + cachePath, + false, ) if err != nil { return nil, fmt.Errorf("failed to create cache: %w", err) diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/cache_copyfromprocess_test.go b/packages/orchestrator/internal/sandbox/uffd/memory/cache_copyfromprocess_test.go index 06d9d2bf8a..d2bc88c38b 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory/cache_copyfromprocess_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory/cache_copyfromprocess_test.go @@ -87,9 +87,10 @@ func TestCopyFromProcess_Success(t *testing.T) { // Create cache tmpFile := t.TempDir() + "/cache" cache, err := block.NewCache( - tmpFile, size, header.PageSize, + tmpFile, + false, ) require.NoError(t, err) defer cache.Close() @@ -150,7 +151,7 @@ func TestCopyFromProcess_MultipleRanges(t *testing.T) { // Create cache tmpFile := t.TempDir() + "/cache" - cache, err := block.NewCache(tmpFile, int64(totalSize), header.PageSize) + cache, err := block.NewCache(int64(totalSize), header.PageSize, tmpFile, false) require.NoError(t, err) defer cache.Close() @@ -182,9 +183,9 @@ func TestCopyFromProcess_MultipleRanges(t *testing.T) { // Test copying multiple non-contiguous ranges ranges := []block.Range{ - {Start: int64(baseAddr), Size: segmentSize}, - {Start: int64(baseAddr + segmentSize*2), Size: segmentSize}, - {Start: int64(baseAddr + segmentSize), Size: segmentSize}, + {Start: int64(baseAddr), Size: int64(segmentSize)}, + {Start: int64(baseAddr + segmentSize*2), Size: int64(segmentSize)}, + {Start: int64(baseAddr + segmentSize), Size: int64(segmentSize)}, } err = cache.CopyFromProcess(ctx, cmd.Process.Pid, ranges) @@ -282,7 +283,7 @@ func TestCopyFromProcess_ContextCancellation(t *testing.T) { // Test copying with cancelled context ranges := []block.Range{ - {Start: int64(addr), Size: size}, + {Start: int64(addr), Size: int64(size)}, } err = cache.CopyFromProcess(ctx, cmd.Process.Pid, ranges) @@ -298,7 +299,7 @@ func TestCopyFromProcess_InvalidPID(t *testing.T) { // Create cache tmpFile := t.TempDir() + "/cache" - cache, err := block.NewCache(tmpFile, int64(size), header.PageSize) + cache, err := block.NewCache(int64(size), header.PageSize, tmpFile, false) require.NoError(t, err) defer cache.Close() @@ -430,7 +431,7 @@ func TestCopyFromProcess_LargeRanges(t *testing.T) { for i := 0; i < numRanges; i++ { ranges[i] = block.Range{ Start: int64(baseAddr) + int64(i)*int64(rangeSize), - Size: rangeSize, + Size: int64(rangeSize), } } From 73857b68b79495bef4116028169a94e7b70cd2f5 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Thu, 18 Dec 2025 02:16:50 -0800 Subject: [PATCH 15/40] Fix test offsets --- .../internal/sandbox/block/cache.go | 90 +++--------- .../cache_copyfromprocess_test.go | 137 +++++++----------- .../internal/sandbox/block/range_test.go | 6 +- .../internal/sandbox/fc/memory.go | 16 +- packages/shared/pkg/fc/firecracker.yml | 36 +---- packages/shared/pkg/fc/models/cpu_config.go | 47 ++---- .../pkg/fc/models/snapshot_load_params.go | 63 -------- 7 files changed, 91 insertions(+), 304 deletions(-) rename packages/orchestrator/internal/sandbox/{uffd/memory => block}/cache_copyfromprocess_test.go (80%) diff --git a/packages/orchestrator/internal/sandbox/block/cache.go b/packages/orchestrator/internal/sandbox/block/cache.go index d3f1094885..6d97a352df 100644 --- a/packages/orchestrator/internal/sandbox/block/cache.go +++ b/packages/orchestrator/internal/sandbox/block/cache.go @@ -26,6 +26,8 @@ const ( oomMaxJitter = 100 * time.Millisecond ) +var ErrNoRanges = errors.New("no ranges (or ranges with total size 0) provided") + var tracer = otel.Tracer("github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block") type CacheClosedError struct { @@ -291,89 +293,41 @@ func (c *Cache) Path() string { return c.filePath } -func (c *Cache) CopyFromProcess( +func NewCacheFromProcessMemory( ctx context.Context, + blockSize int64, + filePath string, pid int, ranges []Range, -) error { - var start int64 - - for i := 0; i < len(ranges); i += IOV_MAX { - // TODO: Is this accumulation correct? - segmentRanges := ranges[i:min(i+IOV_MAX, len(ranges))] - - remote := make([]unix.RemoteIovec, len(segmentRanges)) - - var segmentSize int64 - - for j, r := range segmentRanges { - remote[j] = unix.RemoteIovec{ - Base: uintptr(r.Start), - Len: int(r.Size), - } - - segmentSize += r.Size - } - - local := []unix.Iovec{ - { - Base: c.Address(start), - // We could keep this as full cache length, but we might as well be exact here. - Len: uint64(segmentSize), - }, - } - - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - // We could retry only on the remaining segment size, but for simplicity we retry the whole segment. - n, err := unix.ProcessVMReadv(pid, - local, - remote, - 0, - ) - if errors.Is(err, unix.EAGAIN) { - continue - } - if errors.Is(err, unix.EINTR) { - continue - } - if errors.Is(err, unix.ENOMEM) { - time.Sleep(oomMinBackoff + time.Duration(rand.Intn(int(oomMaxJitter.Milliseconds())))*time.Millisecond) +) (*Cache, error) { + size := GetSize(ranges) - continue - } - - if err != nil { - return fmt.Errorf("failed to read memory: %w", err) - } - - if int64(n) != segmentSize { - return fmt.Errorf("failed to read memory: expected %d bytes, got %d", segmentSize, n) - } + if size == 0 { + return nil, ErrNoRanges + } - start += segmentSize + cache, err := NewCache(size, blockSize, filePath, false) + if err != nil { + return nil, err + } - break - } + err = cache.copyProcessMemory(ctx, pid, ranges) + if err != nil { + return nil, err } - return nil + return cache, nil } -func (c *Cache) CopyProcessMemory( +func (c *Cache) copyProcessMemory( ctx context.Context, pid int, ranges []Range, ) error { var start int64 - for i := 0; i < len(ranges); i += int(IOV_MAX) { - segmentRanges := ranges[i:min(i+int(IOV_MAX), len(ranges))] + for i := 0; i < len(ranges); i += IOV_MAX { + segmentRanges := ranges[i:min(i+IOV_MAX, len(ranges))] remote := make([]unix.RemoteIovec, len(segmentRanges)) @@ -430,6 +384,8 @@ func (c *Cache) CopyProcessMemory( } for _, blockOff := range header.BlocksOffsets(segmentSize, c.blockSize) { + fmt.Println("setting dirty", start+blockOff) + c.dirty.Store(start+blockOff, struct{}{}) } diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/cache_copyfromprocess_test.go b/packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go similarity index 80% rename from packages/orchestrator/internal/sandbox/uffd/memory/cache_copyfromprocess_test.go rename to packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go index d2bc88c38b..7b4b3c2176 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory/cache_copyfromprocess_test.go +++ b/packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go @@ -1,4 +1,4 @@ -package memory +package block import ( "context" @@ -17,7 +17,6 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/sys/unix" - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) @@ -84,17 +83,6 @@ func TestCopyFromProcess_Success(t *testing.T) { ctx := context.Background() size := int64(4096) - // Create cache - tmpFile := t.TempDir() + "/cache" - cache, err := block.NewCache( - size, - header.PageSize, - tmpFile, - false, - ) - require.NoError(t, err) - defer cache.Close() - // Start helper process cmd := exec.CommandContext(ctx, os.Args[0], "-test.run=TestCopyFromProcess_HelperProcess") cmd.Env = append(os.Environ(), "GO_TEST_COPY_HELPER_PROCESS=1") @@ -122,13 +110,17 @@ func TestCopyFromProcess_Success(t *testing.T) { time.Sleep(100 * time.Millisecond) // Test copying a single range - ranges := []block.Range{ + ranges := []Range{ {Start: int64(addr), Size: size}, } - err = cache.CopyFromProcess(ctx, cmd.Process.Pid, ranges) + tmpFile := t.TempDir() + "/cache" + + cache, err := NewCacheFromProcessMemory(ctx, header.PageSize, tmpFile, cmd.Process.Pid, ranges) require.NoError(t, err) + defer cache.Close() + // Verify the copied data data := make([]byte, size) n, err := cache.ReadAt(data, 0) @@ -146,15 +138,9 @@ func TestCopyFromProcess_MultipleRanges(t *testing.T) { t.Parallel() ctx := context.Background() - segmentSize := uint64(1024) + segmentSize := uint64(header.PageSize) // Use PageSize to ensure alignment totalSize := segmentSize * 3 - // Create cache - tmpFile := t.TempDir() + "/cache" - cache, err := block.NewCache(int64(totalSize), header.PageSize, tmpFile, false) - require.NoError(t, err) - defer cache.Close() - // Start helper process cmd := exec.CommandContext(ctx, os.Args[0], "-test.run=TestCopyFromProcess_HelperProcess") cmd.Env = append(os.Environ(), "GO_TEST_COPY_HELPER_PROCESS=1") @@ -182,16 +168,18 @@ func TestCopyFromProcess_MultipleRanges(t *testing.T) { time.Sleep(100 * time.Millisecond) // Test copying multiple non-contiguous ranges - ranges := []block.Range{ + ranges := []Range{ {Start: int64(baseAddr), Size: int64(segmentSize)}, {Start: int64(baseAddr + segmentSize*2), Size: int64(segmentSize)}, {Start: int64(baseAddr + segmentSize), Size: int64(segmentSize)}, } - err = cache.CopyFromProcess(ctx, cmd.Process.Pid, ranges) + tmpFile := t.TempDir() + "/cache" + cache, err := NewCacheFromProcessMemory(ctx, header.PageSize, tmpFile, cmd.Process.Pid, ranges) require.NoError(t, err) + defer cache.Close() - // Verify the first segment + // Verify the first segment (at cache offset 0) data1 := make([]byte, segmentSize) n, err := cache.ReadAt(data1, 0) require.NoError(t, err) @@ -226,18 +214,12 @@ func TestCopyFromProcess_EmptyRanges(t *testing.T) { t.Parallel() ctx := context.Background() - size := uint64(4096) - - // Create cache - tmpFile := t.TempDir() + "/cache" - cache, err := block.NewCache(int64(size), int64(header.PageSize), tmpFile, false) - require.NoError(t, err) - defer cache.Close() // Test with empty ranges - ranges := []block.Range{} - err = cache.CopyFromProcess(ctx, os.Getpid(), ranges) - require.NoError(t, err) + ranges := []Range{} + tmpFile := t.TempDir() + "/cache" + _, err := NewCacheFromProcessMemory(ctx, header.PageSize, tmpFile, os.Getpid(), ranges) + require.ErrorIs(t, err, ErrNoRanges) } func TestCopyFromProcess_ContextCancellation(t *testing.T) { @@ -246,12 +228,6 @@ func TestCopyFromProcess_ContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) size := uint64(4096) - // Create cache - tmpFile := t.TempDir() + "/cache" - cache, err := block.NewCache(int64(size), int64(header.PageSize), tmpFile, false) - require.NoError(t, err) - defer cache.Close() - // Start helper process cmd := exec.CommandContext(ctx, os.Args[0], "-test.run=TestCopyFromProcess_HelperProcess") cmd.Env = append(os.Environ(), "GO_TEST_COPY_HELPER_PROCESS=1") @@ -282,11 +258,12 @@ func TestCopyFromProcess_ContextCancellation(t *testing.T) { cancel() // Test copying with cancelled context - ranges := []block.Range{ + ranges := []Range{ {Start: int64(addr), Size: int64(size)}, } - err = cache.CopyFromProcess(ctx, cmd.Process.Pid, ranges) + tmpFile := t.TempDir() + "/cache" + _, err = NewCacheFromProcessMemory(ctx, header.PageSize, tmpFile, cmd.Process.Pid, ranges) require.Error(t, err) assert.Equal(t, context.Canceled, err) } @@ -295,21 +272,15 @@ func TestCopyFromProcess_InvalidPID(t *testing.T) { t.Parallel() ctx := context.Background() - size := uint64(4096) - - // Create cache - tmpFile := t.TempDir() + "/cache" - cache, err := block.NewCache(int64(size), header.PageSize, tmpFile, false) - require.NoError(t, err) - defer cache.Close() // Test with invalid PID (very high PID that doesn't exist) invalidPID := 999999999 - ranges := []block.Range{ + ranges := []Range{ {Start: 0x1000, Size: 1024}, } - err = cache.CopyFromProcess(ctx, invalidPID, ranges) + tmpFile := t.TempDir() + "/cache" + _, err := NewCacheFromProcessMemory(ctx, header.PageSize, tmpFile, invalidPID, ranges) require.Error(t, err) assert.Contains(t, err.Error(), "failed to read memory") } @@ -318,21 +289,15 @@ func TestCopyFromProcess_InvalidAddress(t *testing.T) { t.Parallel() ctx := context.Background() - size := uint64(4096) - - // Create cache - tmpFile := t.TempDir() + "/cache" - cache, err := block.NewCache(int64(size), int64(header.PageSize), tmpFile, false) - require.NoError(t, err) - defer cache.Close() // Test with invalid memory address (very high address that fits in int64) invalidAddr := int64(0x7FFFFFFF00000000) // Large but valid int64 value - ranges := []block.Range{ + ranges := []Range{ {Start: invalidAddr, Size: 1024}, } - err = cache.CopyFromProcess(ctx, os.Getpid(), ranges) + tmpFile := t.TempDir() + "/cache" + _, err := NewCacheFromProcessMemory(ctx, header.PageSize, tmpFile, os.Getpid(), ranges) require.Error(t, err) assert.Contains(t, err.Error(), "failed to read memory") } @@ -343,12 +308,6 @@ func TestCopyFromProcess_ZeroSizeRange(t *testing.T) { ctx := context.Background() size := uint64(4096) - // Create cache - tmpFile := t.TempDir() + "/cache" - cache, err := block.NewCache(int64(size), int64(header.PageSize), tmpFile, false) - require.NoError(t, err) - defer cache.Close() - // Start helper process cmd := exec.CommandContext(ctx, os.Args[0], "-test.run=TestCopyFromProcess_HelperProcess") cmd.Env = append(os.Environ(), "GO_TEST_COPY_HELPER_PROCESS=1") @@ -376,12 +335,13 @@ func TestCopyFromProcess_ZeroSizeRange(t *testing.T) { time.Sleep(100 * time.Millisecond) // Test with zero-size range - ranges := []block.Range{ + ranges := []Range{ {Start: int64(addr), Size: 0}, } - err = cache.CopyFromProcess(ctx, cmd.Process.Pid, ranges) - require.NoError(t, err) + tmpFile := t.TempDir() + "/cache" + _, err = NewCacheFromProcessMemory(ctx, header.PageSize, tmpFile, cmd.Process.Pid, ranges) + require.ErrorIs(t, err, ErrNoRanges) } func TestCopyFromProcess_LargeRanges(t *testing.T) { @@ -395,10 +355,6 @@ func TestCopyFromProcess_LargeRanges(t *testing.T) { // Create cache large enough for all ranges totalSize := rangeSize * uint64(numRanges) - tmpFile := t.TempDir() + "/cache" - cache, err := block.NewCache(int64(totalSize), int64(header.PageSize), tmpFile, false) - require.NoError(t, err) - defer cache.Close() // Start helper process cmd := exec.CommandContext(ctx, os.Args[0], "-test.run=TestCopyFromProcess_HelperProcess") @@ -427,34 +383,45 @@ func TestCopyFromProcess_LargeRanges(t *testing.T) { time.Sleep(100 * time.Millisecond) // Create many small ranges that exceed IOV_MAX - ranges := make([]block.Range, numRanges) + ranges := make([]Range, numRanges) for i := 0; i < numRanges; i++ { - ranges[i] = block.Range{ + ranges[i] = Range{ Start: int64(baseAddr) + int64(i)*int64(rangeSize), Size: int64(rangeSize), } } - err = cache.CopyFromProcess(ctx, cmd.Process.Pid, ranges) + tmpFile := t.TempDir() + "/cache" + cache, err := NewCacheFromProcessMemory(ctx, header.PageSize, tmpFile, cmd.Process.Pid, ranges) require.NoError(t, err) + defer cache.Close() // Verify the data was copied correctly // Check a few ranges to ensure they were copied + // ReadAt offsets must be multiples of header.PageSize checkCount := 10 if numRanges < checkCount { checkCount = numRanges } for i := 0; i < checkCount; i++ { - offset := int64(i) * int64(rangeSize) - data := make([]byte, rangeSize) - n, err := cache.ReadAt(data, offset) + // Calculate the actual offset in cache (ranges are stored sequentially) + actualOffset := int64(i) * int64(rangeSize) + // Align offset to header.PageSize boundary + alignedOffset := (actualOffset / header.PageSize) * header.PageSize + // Calculate offset within the aligned block + offsetInBlock := actualOffset - alignedOffset + + // Read a full page to ensure we get the data + data := make([]byte, header.PageSize) + fmt.Println("reading at aligned offset", alignedOffset, "with offset in block", offsetInBlock) + n, err := cache.ReadAt(data, alignedOffset) require.NoError(t, err) - require.Equal(t, int(rangeSize), n) + require.Equal(t, int(header.PageSize), n) - // Verify pattern - for j := range data { - expected := byte((int(offset) + j) % 256) - assert.Equal(t, expected, data[j], "range %d, byte at offset %d", i, j) + // Verify pattern for the range we're checking + for j := 0; j < int(rangeSize); j++ { + expected := byte((int(actualOffset) + j) % 256) + assert.Equal(t, expected, data[offsetInBlock+int64(j)], "range %d, byte at offset %d", i, j) } } } diff --git a/packages/orchestrator/internal/sandbox/block/range_test.go b/packages/orchestrator/internal/sandbox/block/range_test.go index 3de26f6df0..fb4c0109f9 100644 --- a/packages/orchestrator/internal/sandbox/block/range_test.go +++ b/packages/orchestrator/internal/sandbox/block/range_test.go @@ -13,7 +13,7 @@ func TestRange_End(t *testing.T) { tests := []struct { name string start int64 - size uint64 + size int64 expected int64 }{ { @@ -63,7 +63,7 @@ func TestNewRange(t *testing.T) { tests := []struct { name string start int64 - size uint64 + size int64 expected Range }{ { @@ -380,7 +380,7 @@ func TestGetSize(t *testing.T) { tests := []struct { name string ranges []Range - expected uint64 + expected int64 }{ { name: "empty", diff --git a/packages/orchestrator/internal/sandbox/fc/memory.go b/packages/orchestrator/internal/sandbox/fc/memory.go index f0f6a70a9f..3452c8c7a4 100644 --- a/packages/orchestrator/internal/sandbox/fc/memory.go +++ b/packages/orchestrator/internal/sandbox/fc/memory.go @@ -2,7 +2,6 @@ package fc import ( "context" - "errors" "fmt" "github.com/bits-and-blooms/bitset" @@ -45,23 +44,10 @@ func (p *Process) ExportMemory( return nil, fmt.Errorf("failed to get pid: %w", err) } - size := block.GetSize(remoteRanges) - - cache, err := block.NewCache( - size, - blockSize, - cachePath, - false, - ) + cache, err := block.NewCacheFromProcessMemory(ctx, blockSize, cachePath, pid, remoteRanges) if err != nil { return nil, fmt.Errorf("failed to create cache: %w", err) } - err = cache.CopyProcessMemory(ctx, pid, remoteRanges) - if err != nil { - // Close the cache even if the copy fails. - return nil, fmt.Errorf("failed to copy process memory: %w", errors.Join(err, cache.Close())) - } - return cache, nil } diff --git a/packages/shared/pkg/fc/firecracker.yml b/packages/shared/pkg/fc/firecracker.yml index 2cd2d92f20..4200d4d0fa 100644 --- a/packages/shared/pkg/fc/firecracker.yml +++ b/packages/shared/pkg/fc/firecracker.yml @@ -5,10 +5,10 @@ info: The API is accessible through HTTP calls on specific URLs carrying JSON modeled data. The transport medium is a Unix Domain Socket. - version: 1.12.2 + version: 1.10.1 termsOfService: "" contact: - email: "firecracker-maintainers@amazon.com" + email: "compute-capsule@amazon.com" license: name: "Apache 2.0" url: "http://www.apache.org/licenses/LICENSE-2.0.html" @@ -869,7 +869,7 @@ definitions: default: "None" CpuConfig: - type: object + type: string description: The CPU configuration template defines a set of bit maps as modifiers of flags accessed by register to be disabled/enabled for the microvm. @@ -883,12 +883,6 @@ definitions: reg_modifiers: type: object description: A collection of registers to be modified. (aarch64) - vcpu_features: - type: object - description: A collection of vcpu features to be modified. (aarch64) - kvm_capabilities: - type: object - description: A collection of kvm capabilities to be modified. (aarch64) Drive: type: object @@ -1297,24 +1291,6 @@ definitions: Type of snapshot to create. It is optional and by default, a full snapshot is created. - NetworkOverride: - type: object - description: - Allows for changing the backing TAP device of a network interface - during snapshot restore. - required: - - iface_id - - host_dev_name - properties: - iface_id: - type: string - description: - The name of the interface to modify - host_dev_name: - type: string - description: - The new host device of the interface - SnapshotLoadParams: type: object description: @@ -1346,12 +1322,6 @@ definitions: type: boolean description: When set to true, the vm is also resumed if the snapshot load is successful. - network_overrides: - type: array - description: Network host device names to override - items: - $ref: "#/definitions/NetworkOverride" - TokenBucket: type: object diff --git a/packages/shared/pkg/fc/models/cpu_config.go b/packages/shared/pkg/fc/models/cpu_config.go index 11007527a3..d68bea978c 100644 --- a/packages/shared/pkg/fc/models/cpu_config.go +++ b/packages/shared/pkg/fc/models/cpu_config.go @@ -8,55 +8,26 @@ package models import ( "context" + "github.com/go-openapi/errors" "github.com/go-openapi/strfmt" - "github.com/go-openapi/swag" ) // CPUConfig The CPU configuration template defines a set of bit maps as modifiers of flags accessed by register to be disabled/enabled for the microvm. // // swagger:model CpuConfig -type CPUConfig struct { - - // A collection of CPUIDs to be modified. (x86_64) - CpuidModifiers interface{} `json:"cpuid_modifiers,omitempty"` - - // A collection of kvm capabilities to be modified. (aarch64) - KvmCapabilities interface{} `json:"kvm_capabilities,omitempty"` - - // A collection of model specific registers to be modified. (x86_64) - MsrModifiers interface{} `json:"msr_modifiers,omitempty"` - - // A collection of registers to be modified. (aarch64) - RegModifiers interface{} `json:"reg_modifiers,omitempty"` - - // A collection of vcpu features to be modified. (aarch64) - VcpuFeatures interface{} `json:"vcpu_features,omitempty"` -} +type CPUConfig string // Validate validates this Cpu config -func (m *CPUConfig) Validate(formats strfmt.Registry) error { - return nil -} +func (m CPUConfig) Validate(formats strfmt.Registry) error { + var res []error -// ContextValidate validates this Cpu config based on context it is used -func (m *CPUConfig) ContextValidate(ctx context.Context, formats strfmt.Registry) error { - return nil -} - -// MarshalBinary interface implementation -func (m *CPUConfig) MarshalBinary() ([]byte, error) { - if m == nil { - return nil, nil + if len(res) > 0 { + return errors.CompositeValidationError(res...) } - return swag.WriteJSON(m) + return nil } -// UnmarshalBinary interface implementation -func (m *CPUConfig) UnmarshalBinary(b []byte) error { - var res CPUConfig - if err := swag.ReadJSON(b, &res); err != nil { - return err - } - *m = res +// ContextValidate validates this Cpu config based on context it is used +func (m CPUConfig) ContextValidate(ctx context.Context, formats strfmt.Registry) error { return nil } diff --git a/packages/shared/pkg/fc/models/snapshot_load_params.go b/packages/shared/pkg/fc/models/snapshot_load_params.go index 48f6f44979..4fa8f87470 100644 --- a/packages/shared/pkg/fc/models/snapshot_load_params.go +++ b/packages/shared/pkg/fc/models/snapshot_load_params.go @@ -7,7 +7,6 @@ package models import ( "context" - "strconv" "github.com/go-openapi/errors" "github.com/go-openapi/strfmt" @@ -29,9 +28,6 @@ type SnapshotLoadParams struct { // Path to the file that contains the guest memory to be loaded. It is only allowed if `mem_backend` is not present. This parameter has been deprecated and it will be removed in future Firecracker release. MemFilePath string `json:"mem_file_path,omitempty"` - // Network host device names to override - NetworkOverrides []*NetworkOverride `json:"network_overrides"` - // When set to true, the vm is also resumed if the snapshot load is successful. ResumeVM bool `json:"resume_vm,omitempty"` @@ -48,10 +44,6 @@ func (m *SnapshotLoadParams) Validate(formats strfmt.Registry) error { res = append(res, err) } - if err := m.validateNetworkOverrides(formats); err != nil { - res = append(res, err) - } - if err := m.validateSnapshotPath(formats); err != nil { res = append(res, err) } @@ -81,32 +73,6 @@ func (m *SnapshotLoadParams) validateMemBackend(formats strfmt.Registry) error { return nil } -func (m *SnapshotLoadParams) validateNetworkOverrides(formats strfmt.Registry) error { - if swag.IsZero(m.NetworkOverrides) { // not required - return nil - } - - for i := 0; i < len(m.NetworkOverrides); i++ { - if swag.IsZero(m.NetworkOverrides[i]) { // not required - continue - } - - if m.NetworkOverrides[i] != nil { - if err := m.NetworkOverrides[i].Validate(formats); err != nil { - if ve, ok := err.(*errors.Validation); ok { - return ve.ValidateName("network_overrides" + "." + strconv.Itoa(i)) - } else if ce, ok := err.(*errors.CompositeError); ok { - return ce.ValidateName("network_overrides" + "." + strconv.Itoa(i)) - } - return err - } - } - - } - - return nil -} - func (m *SnapshotLoadParams) validateSnapshotPath(formats strfmt.Registry) error { if err := validate.Required("snapshot_path", "body", m.SnapshotPath); err != nil { @@ -124,10 +90,6 @@ func (m *SnapshotLoadParams) ContextValidate(ctx context.Context, formats strfmt res = append(res, err) } - if err := m.contextValidateNetworkOverrides(ctx, formats); err != nil { - res = append(res, err) - } - if len(res) > 0 { return errors.CompositeValidationError(res...) } @@ -155,31 +117,6 @@ func (m *SnapshotLoadParams) contextValidateMemBackend(ctx context.Context, form return nil } -func (m *SnapshotLoadParams) contextValidateNetworkOverrides(ctx context.Context, formats strfmt.Registry) error { - - for i := 0; i < len(m.NetworkOverrides); i++ { - - if m.NetworkOverrides[i] != nil { - - if swag.IsZero(m.NetworkOverrides[i]) { // not required - return nil - } - - if err := m.NetworkOverrides[i].ContextValidate(ctx, formats); err != nil { - if ve, ok := err.(*errors.Validation); ok { - return ve.ValidateName("network_overrides" + "." + strconv.Itoa(i)) - } else if ce, ok := err.(*errors.CompositeError); ok { - return ce.ValidateName("network_overrides" + "." + strconv.Itoa(i)) - } - return err - } - } - - } - - return nil -} - // MarshalBinary interface implementation func (m *SnapshotLoadParams) MarshalBinary() ([]byte, error) { if m == nil { From 56c6b309ba9317a12461e6c06342b487cc6f58e3 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Thu, 18 Dec 2025 02:18:51 -0800 Subject: [PATCH 16/40] Update deps --- packages/orchestrator/go.sum | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/orchestrator/go.sum b/packages/orchestrator/go.sum index a90cb09cf2..99351936b4 100644 --- a/packages/orchestrator/go.sum +++ b/packages/orchestrator/go.sum @@ -1118,7 +1118,6 @@ github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG github.com/tchap/go-patricia v2.2.6+incompatible/go.mod h1:bmLyhP68RS6kStMGxByiQ23RP/odRBOTVjwp2cDyi6I= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZb78yU= -github.com/tklauser/go-sysconf v0.3.14/go.mod h1:1ym4lWMLUOhuBOPGtRcJm7tEGX4SCYNEEEtghGG/8uY= github.com/tklauser/numcpus v0.9.0 h1:lmyCHtANi8aRUgkckBgoDk1nHCux3n2cgkJLXdQGPDo= github.com/tklauser/numcpus v0.9.0/go.mod h1:SN6Nq1O3VychhC1npsWostA+oW+VOQTxZrS604NSRyI= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= From 18f1b3372fd5ee86f3f33ab2932544562e14455e Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Fri, 19 Dec 2025 05:43:27 -0800 Subject: [PATCH 17/40] Fix 1.10 client call --- packages/orchestrator/internal/sandbox/fc/client.go | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/orchestrator/internal/sandbox/fc/client.go b/packages/orchestrator/internal/sandbox/fc/client.go index 43b62ee6ae..fa800b7e5e 100644 --- a/packages/orchestrator/internal/sandbox/fc/client.go +++ b/packages/orchestrator/internal/sandbox/fc/client.go @@ -67,7 +67,6 @@ func (c *apiClient) loadSnapshot( EnableDiffSnapshots: false, MemBackend: backend, SnapshotPath: &snapfilePath, - NetworkOverrides: []*models.NetworkOverride{}, }, } From 4da1210e2b6b851d4fcc98e04006d92a6c3dbe07 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Sun, 21 Dec 2025 08:21:18 -0800 Subject: [PATCH 18/40] Update FC versions --- packages/shared/pkg/feature-flags/flags.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/shared/pkg/feature-flags/flags.go b/packages/shared/pkg/feature-flags/flags.go index 37c0f69e6c..1b00964e20 100644 --- a/packages/shared/pkg/feature-flags/flags.go +++ b/packages/shared/pkg/feature-flags/flags.go @@ -143,8 +143,8 @@ func newStringFlag(name string, fallback string) StringFlag { // The Firecracker version the last tag + the short SHA (so we can build our dev previews) // TODO: The short tag here has only 7 characters — the one from our build pipeline will likely have exactly 8 so this will break. const ( - DefaultFirecackerV1_10Version = "v1.10.1_1fcdaec" - DefaultFirecackerV1_12Version = "v1.12.1_d990331" + DefaultFirecackerV1_10Version = "v1.10.1_85abaa22" + DefaultFirecackerV1_12Version = "v1.12.1_67263381" DefaultFirecrackerVersion = DefaultFirecackerV1_12Version ) From 1ad4777b47bd3048c54d011edbdf1a65c7d5d089 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 01:56:55 -0800 Subject: [PATCH 19/40] Cleanup --- .vscode/settings.json | 2 +- .../block/cache_copyfromprocess_test.go | 25 ++++++++++--------- .../internal/sandbox/block/range.go | 2 +- .../internal/sandbox/block/range_test.go | 2 +- .../internal/sandbox/fc/memory.go | 2 +- .../internal/sandbox/uffd/memory/mapping.go | 2 +- .../userfaultfd/cross_process_helpers_test.go | 2 +- .../internal/sandbox/uffd/userfaultfd/fd.go | 14 ----------- .../uffd/userfaultfd/fd_helpers_test.go | 14 +++++++++++ packages/shared/pkg/feature-flags/flags.go | 4 +-- 10 files changed, 35 insertions(+), 34 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 65b70264f8..ccba45f0ba 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -92,5 +92,5 @@ "mise.configureExtensionsAutomatically": true, "mise.configureExtensionsAutomaticallyIgnoreList": [ "ms-vscode.js-debug" - ], + ] } \ No newline at end of file diff --git a/packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go b/packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go index 7b4b3c2176..f48eef811a 100644 --- a/packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go +++ b/packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go @@ -32,14 +32,16 @@ func TestCopyFromProcess_HelperProcess(t *testing.T) { size, err := strconv.ParseUint(sizeStr, 10, 64) if err != nil { fmt.Fprintf(os.Stderr, "failed to parse memory size: %v\n", err) - os.Exit(1) + + panic(err) } // Allocate memory using mmap (similar to testutils.NewPageMmap but simpler) mem, err := unix.Mmap(-1, 0, int(size), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_ANON|unix.MAP_PRIVATE) if err != nil { fmt.Fprintf(os.Stderr, "failed to mmap memory: %v\n", err) - os.Exit(1) + + panic(err) } defer unix.Munmap(mem) @@ -56,7 +58,8 @@ func TestCopyFromProcess_HelperProcess(t *testing.T) { err = binary.Write(os.Stdout, binary.LittleEndian, addr) if err != nil { fmt.Fprintf(os.Stderr, "failed to write address: %v\n", err) - os.Exit(1) + + panic(err) } // Signal ready by closing stdout @@ -73,7 +76,8 @@ func TestCopyFromProcess_HelperProcess(t *testing.T) { case <-time.After(30 * time.Second): // Timeout after 30 seconds fmt.Fprintf(os.Stderr, "helper process timeout\n") - os.Exit(1) + + panic("helper process timeout") } } @@ -384,7 +388,7 @@ func TestCopyFromProcess_LargeRanges(t *testing.T) { // Create many small ranges that exceed IOV_MAX ranges := make([]Range, numRanges) - for i := 0; i < numRanges; i++ { + for i := range numRanges { ranges[i] = Range{ Start: int64(baseAddr) + int64(i)*int64(rangeSize), Size: int64(rangeSize), @@ -399,11 +403,8 @@ func TestCopyFromProcess_LargeRanges(t *testing.T) { // Verify the data was copied correctly // Check a few ranges to ensure they were copied // ReadAt offsets must be multiples of header.PageSize - checkCount := 10 - if numRanges < checkCount { - checkCount = numRanges - } - for i := 0; i < checkCount; i++ { + checkCount := min(numRanges, 10) + for i := range checkCount { // Calculate the actual offset in cache (ranges are stored sequentially) actualOffset := int64(i) * int64(rangeSize) // Align offset to header.PageSize boundary @@ -419,8 +420,8 @@ func TestCopyFromProcess_LargeRanges(t *testing.T) { require.Equal(t, int(header.PageSize), n) // Verify pattern for the range we're checking - for j := 0; j < int(rangeSize); j++ { - expected := byte((int(actualOffset) + j) % 256) + for j := range rangeSize { + expected := byte((actualOffset + int64(j)) % 256) assert.Equal(t, expected, data[offsetInBlock+int64(j)], "range %d, byte at offset %d", i, j) } } diff --git a/packages/orchestrator/internal/sandbox/block/range.go b/packages/orchestrator/internal/sandbox/block/range.go index 53eb8c41f1..8a243c71f6 100644 --- a/packages/orchestrator/internal/sandbox/block/range.go +++ b/packages/orchestrator/internal/sandbox/block/range.go @@ -17,7 +17,7 @@ type Range struct { } func (r *Range) End() int64 { - return r.Start + int64(r.Size) + return r.Start + r.Size } // Offsets returns the block offsets contained in the range. diff --git a/packages/orchestrator/internal/sandbox/block/range_test.go b/packages/orchestrator/internal/sandbox/block/range_test.go index fb4c0109f9..7185fd8b13 100644 --- a/packages/orchestrator/internal/sandbox/block/range_test.go +++ b/packages/orchestrator/internal/sandbox/block/range_test.go @@ -330,7 +330,7 @@ func TestBitsetRanges_MultipleRanges(t *testing.T) { func TestBitsetRanges_AllSet(t *testing.T) { b := bitset.New(10) - for i := uint(0); i < 10; i++ { + for i := range uint(10) { b.Set(i) } blockSize := int64(4096) diff --git a/packages/orchestrator/internal/sandbox/fc/memory.go b/packages/orchestrator/internal/sandbox/fc/memory.go index 3452c8c7a4..0d29ae408c 100644 --- a/packages/orchestrator/internal/sandbox/fc/memory.go +++ b/packages/orchestrator/internal/sandbox/fc/memory.go @@ -31,7 +31,7 @@ func (p *Process) ExportMemory( var remoteRanges []block.Range for r := range block.BitsetRanges(include, blockSize) { - hostVirtRanges, err := m.GetHostVirtRanges(r.Start, int64(r.Size)) + hostVirtRanges, err := m.GetHostVirtRanges(r.Start, r.Size) if err != nil { return nil, fmt.Errorf("failed to get host virt ranges: %w", err) } diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go b/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go index 9a5103b070..bd6c2b33c9 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go @@ -78,7 +78,7 @@ func (m *Mapping) GetHostVirtRanges(off int64, size int64) (hostVirtRanges []blo hostVirtRanges = append(hostVirtRanges, r) - n += int64(r.Size) + n += r.Size } return hostVirtRanges, nil diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go index 503a51bb66..3921ee429c 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go @@ -53,7 +53,7 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error err = configureApi(uffdFd, tt.pagesize) require.NoError(t, err) - err = uffdFd.register(memoryStart, uint64(size), UFFDIO_REGISTER_MODE_MISSING) + err = register(uffdFd, memoryStart, uint64(size), UFFDIO_REGISTER_MODE_MISSING) require.NoError(t, err) cmd := exec.CommandContext(t.Context(), os.Args[0], "-test.run=TestHelperServingProcess") diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go index 7fe3efe338..1a122d01d5 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go @@ -103,20 +103,6 @@ func getPagefaultAddress(pagefault *UffdPagefault) uintptr { // Fd is a helper type that wraps uffd fd. type Fd uintptr -// mode: UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING -// This is already called by the FC, but only with the UFFDIO_REGISTER_MODE_MISSING -// We need to call it with UFFDIO_REGISTER_MODE_WP when we use both missing and wp -func (f Fd) register(addr uintptr, size uint64, mode CULong) error { - register := newUffdioRegister(CULong(addr), CULong(size), mode) - - ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_REGISTER, uintptr(unsafe.Pointer(®ister))) - if errno != 0 { - return fmt.Errorf("UFFDIO_REGISTER ioctl failed: %w (ret=%d)", errno, ret) - } - - return nil -} - // mode: UFFDIO_COPY_MODE_WP // When we use both missing and wp, we need to use UFFDIO_COPY_MODE_WP, otherwise copying would unprotect the page func (f Fd) copy(addr, pagesize uintptr, data []byte, mode CULong) error { diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go index 17f1bf02be..f9f8d953cb 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go @@ -38,3 +38,17 @@ func configureApi(f Fd, pagesize uint64) error { return nil } + +// mode: UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING +// This is already called by the FC, but only with the UFFDIO_REGISTER_MODE_MISSING +// We need to call it with UFFDIO_REGISTER_MODE_WP when we use both missing and wp +func register(f Fd, addr uintptr, size uint64, mode CULong) error { + register := newUffdioRegister(CULong(addr), CULong(size), mode) + + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_REGISTER, uintptr(unsafe.Pointer(®ister))) + if errno != 0 { + return fmt.Errorf("UFFDIO_REGISTER ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} diff --git a/packages/shared/pkg/feature-flags/flags.go b/packages/shared/pkg/feature-flags/flags.go index 1b00964e20..0f710a575c 100644 --- a/packages/shared/pkg/feature-flags/flags.go +++ b/packages/shared/pkg/feature-flags/flags.go @@ -143,8 +143,8 @@ func newStringFlag(name string, fallback string) StringFlag { // The Firecracker version the last tag + the short SHA (so we can build our dev previews) // TODO: The short tag here has only 7 characters — the one from our build pipeline will likely have exactly 8 so this will break. const ( - DefaultFirecackerV1_10Version = "v1.10.1_85abaa22" - DefaultFirecackerV1_12Version = "v1.12.1_67263381" + DefaultFirecackerV1_10Version = "v1.10.1_fb257a1" + DefaultFirecackerV1_12Version = "v1.12.1_717921c" DefaultFirecrackerVersion = DefaultFirecackerV1_12Version ) From c65e4b88939839f5ad60379d1418d915f8ad03c9 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 02:09:53 -0800 Subject: [PATCH 20/40] Hide address --- packages/orchestrator/internal/sandbox/block/cache.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/block/cache.go b/packages/orchestrator/internal/sandbox/block/cache.go index 15862fd610..79b8cb5a10 100644 --- a/packages/orchestrator/internal/sandbox/block/cache.go +++ b/packages/orchestrator/internal/sandbox/block/cache.go @@ -273,7 +273,7 @@ func (c *Cache) FileSize() (int64, error) { return stat.Blocks * fsStat.Bsize, nil } -func (c *Cache) Address(off int64) *byte { +func (c *Cache) address(off int64) *byte { return &(*c.mmap)[off] } @@ -336,7 +336,7 @@ func (c *Cache) copyProcessMemory( local := []unix.Iovec{ { - Base: c.Address(start), + Base: c.address(start), // We could keep this as full cache length, but we might as well be exact here. Len: uint64(segmentSize), }, From 7a3ccf212279d7e95adb409e0083f20dc3a268b0 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 02:52:11 -0800 Subject: [PATCH 21/40] Refactor --- .../orchestrator/internal/sandbox/sandbox.go | 49 +------------------ .../internal/sandbox/uffd/memory_backend.go | 5 +- .../internal/sandbox/uffd/noop.go | 39 +++++++++++---- .../internal/sandbox/uffd/uffd.go | 17 +++++-- 4 files changed, 47 insertions(+), 63 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/sandbox.go b/packages/orchestrator/internal/sandbox/sandbox.go index 8e180603f4..7c7cc4efbf 100644 --- a/packages/orchestrator/internal/sandbox/sandbox.go +++ b/packages/orchestrator/internal/sandbox/sandbox.go @@ -7,7 +7,6 @@ import ( "net/http" "time" - "github.com/bits-and-blooms/bitset" "github.com/google/uuid" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel" @@ -94,16 +93,6 @@ type Resources struct { Slot *network.Slot rootfs rootfs.Provider memory uffd.MemoryBackend - // Filter to apply to the dirty bitset before creating the diff metadata. - memoryDiffFilter func(ctx context.Context) (*header.DiffMetadata, error) -} - -func (r *Resources) MemfileDiffMetadata(ctx context.Context) (*header.DiffMetadata, error) { - if r.memoryDiffFilter == nil { - return nil, fmt.Errorf("memory diff filter is not set") - } - - return r.memoryDiffFilter(ctx) } type internalConfig struct { @@ -303,28 +292,7 @@ func (f *Factory) CreateSandbox( resources := &Resources{ Slot: ips.slot, rootfs: rootfsProvider, - memory: uffd.NewNoopMemory(memfileSize, memfile.BlockSize()), - memoryDiffFilter: func(ctx context.Context) (*header.DiffMetadata, error) { - diffInfo, err := fcHandle.MemoryInfo(ctx, memfile.BlockSize()) - if err != nil { - return nil, err - } - - dirty := diffInfo.Dirty.Difference(diffInfo.Empty) - - numberOfPages := header.TotalBlocks(memfileSize, memfile.BlockSize()) - - empty := bitset.New(uint(numberOfPages)) - empty.FlipRange(0, uint(numberOfPages)) - - empty = empty.Difference(dirty) - - return &header.DiffMetadata{ - Dirty: dirty, - Empty: empty, - BlockSize: memfile.BlockSize(), - }, nil - }, + memory: uffd.NewNoopMemory(memfileSize, memfile.BlockSize(), fcHandle.MemoryInfo), } metadata := &Metadata{ @@ -559,19 +527,6 @@ func (f *Factory) ResumeSandbox( Slot: ips.slot, rootfs: rootfsOverlay, memory: fcUffd, - memoryDiffFilter: func(ctx context.Context) (*header.DiffMetadata, error) { - dirty, err := fcUffd.Dirty(ctx) - if err != nil { - return nil, err - } - - return &header.DiffMetadata{ - Dirty: dirty, - // We don't track and filter empty pages for subsequent sandbox pauses as pages should usually not be empty. - Empty: bitset.New(0), - BlockSize: memfile.BlockSize(), - }, nil - }, } metadata := &Metadata{ @@ -808,7 +763,7 @@ func (s *Sandbox) Pause( return nil, fmt.Errorf("failed to get original rootfs: %w", err) } - memfileDiffMetadata, err := s.Resources.MemfileDiffMetadata(ctx) + memfileDiffMetadata, err := s.Resources.memory.DiffMetadata(ctx) if err != nil { return nil, fmt.Errorf("failed to get memfile metadata: %w", err) } diff --git a/packages/orchestrator/internal/sandbox/uffd/memory_backend.go b/packages/orchestrator/internal/sandbox/uffd/memory_backend.go index c01585ec9c..65bafe7477 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory_backend.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory_backend.go @@ -3,13 +3,12 @@ package uffd import ( "context" - "github.com/bits-and-blooms/bitset" - + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) type MemoryBackend interface { - Dirty(ctx context.Context) (*bitset.BitSet, error) + DiffMetadata(ctx context.Context) (*header.DiffMetadata, error) Start(ctx context.Context, sandboxId string) error Stop() error Ready() chan struct{} diff --git a/packages/orchestrator/internal/sandbox/uffd/noop.go b/packages/orchestrator/internal/sandbox/uffd/noop.go index f4d40a23d5..62da48b691 100644 --- a/packages/orchestrator/internal/sandbox/uffd/noop.go +++ b/packages/orchestrator/internal/sandbox/uffd/noop.go @@ -13,26 +13,45 @@ type NoopMemory struct { size int64 blockSize int64 - exit *utils.ErrorOnce + exit *utils.ErrorOnce + getDiffMetadata func(ctx context.Context, blockSize int64) (*header.DiffMetadata, error) } var _ MemoryBackend = (*NoopMemory)(nil) -func NewNoopMemory(size, blockSize int64) *NoopMemory { +func NewNoopMemory( + size, + blockSize int64, + getDiffMetadata func(ctx context.Context, blockSize int64) (*header.DiffMetadata, error), +) *NoopMemory { return &NoopMemory{ - size: size, - blockSize: blockSize, - exit: utils.NewErrorOnce(), + size: size, + blockSize: blockSize, + exit: utils.NewErrorOnce(), + getDiffMetadata: getDiffMetadata, } } -func (m *NoopMemory) Dirty(context.Context) (*bitset.BitSet, error) { - blocks := uint(header.TotalBlocks(m.size, m.blockSize)) +func (m *NoopMemory) DiffMetadata(ctx context.Context) (*header.DiffMetadata, error) { + diffInfo, err := m.getDiffMetadata(ctx, m.blockSize) + if err != nil { + return nil, err + } + + dirty := diffInfo.Dirty.Difference(diffInfo.Empty) + + numberOfPages := header.TotalBlocks(m.size, m.blockSize) + + empty := bitset.New(uint(numberOfPages)) + empty.FlipRange(0, uint(numberOfPages)) - b := bitset.New(blocks) - b.FlipRange(0, blocks) + empty = empty.Difference(dirty) - return b, nil + return &header.DiffMetadata{ + Dirty: dirty, + Empty: empty, + BlockSize: m.blockSize, + }, nil } func (m *NoopMemory) Start(context.Context, string) error { diff --git a/packages/orchestrator/internal/sandbox/uffd/uffd.go b/packages/orchestrator/internal/sandbox/uffd/uffd.go index 3eaad96492..d12bd81f78 100644 --- a/packages/orchestrator/internal/sandbox/uffd/uffd.go +++ b/packages/orchestrator/internal/sandbox/uffd/uffd.go @@ -19,6 +19,7 @@ import ( "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/userfaultfd" "github.com/e2b-dev/infra/packages/shared/pkg/logger" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) @@ -182,14 +183,24 @@ func (u *Uffd) Exit() *utils.ErrorOnce { return u.exit } -// Dirty waits for the current requests to finish and returns the dirty pages. +// DiffMetadata waits for the current requests to finish and returns the dirty pages. // // It *MUST* be only called after the sandbox was successfully paused via API and after the snapshot endpoint was called. -func (u *Uffd) Dirty(ctx context.Context) (*bitset.BitSet, error) { +func (u *Uffd) DiffMetadata(ctx context.Context) (*header.DiffMetadata, error) { uffd, err := u.handler.WaitWithContext(ctx) if err != nil { return nil, fmt.Errorf("failed to get uffd: %w", err) } - return uffd.Dirty().BitSet(), nil + dirty := uffd.Dirty() + if err != nil { + return nil, err + } + + return &header.DiffMetadata{ + Dirty: dirty.BitSet(), + // We don't track and filter empty pages for subsequent sandbox pauses as pages should usually not be empty. + Empty: bitset.New(0), + BlockSize: u.memfile.BlockSize(), + }, nil } From a77af2c31bc9120c12bc8976bd4fad20a3f8f006 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 02:53:12 -0800 Subject: [PATCH 22/40] Cleanup --- packages/orchestrator/internal/sandbox/uffd/uffd.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/uffd/uffd.go b/packages/orchestrator/internal/sandbox/uffd/uffd.go index d12bd81f78..c44c9795e3 100644 --- a/packages/orchestrator/internal/sandbox/uffd/uffd.go +++ b/packages/orchestrator/internal/sandbox/uffd/uffd.go @@ -192,13 +192,8 @@ func (u *Uffd) DiffMetadata(ctx context.Context) (*header.DiffMetadata, error) { return nil, fmt.Errorf("failed to get uffd: %w", err) } - dirty := uffd.Dirty() - if err != nil { - return nil, err - } - return &header.DiffMetadata{ - Dirty: dirty.BitSet(), + Dirty: uffd.Dirty().BitSet(), // We don't track and filter empty pages for subsequent sandbox pauses as pages should usually not be empty. Empty: bitset.New(0), BlockSize: u.memfile.BlockSize(), From f5ae9d40203aaccfb8f3f9ac918a2dfe09834314 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 02:55:52 -0800 Subject: [PATCH 23/40] Remove log --- packages/orchestrator/internal/sandbox/block/cache.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/block/cache.go b/packages/orchestrator/internal/sandbox/block/cache.go index 79b8cb5a10..67548d0d58 100644 --- a/packages/orchestrator/internal/sandbox/block/cache.go +++ b/packages/orchestrator/internal/sandbox/block/cache.go @@ -376,8 +376,6 @@ func (c *Cache) copyProcessMemory( } for _, blockOff := range header.BlocksOffsets(segmentSize, c.blockSize) { - fmt.Println("setting dirty", start+blockOff) - c.dirty.Store(start+blockOff, struct{}{}) } From 8751e4185b0ddf66196223c998bc17b4324e6247 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 03:29:19 -0800 Subject: [PATCH 24/40] Fix deps --- packages/orchestrator/go.sum | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/orchestrator/go.sum b/packages/orchestrator/go.sum index 99351936b4..a90cb09cf2 100644 --- a/packages/orchestrator/go.sum +++ b/packages/orchestrator/go.sum @@ -1118,6 +1118,7 @@ github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG github.com/tchap/go-patricia v2.2.6+incompatible/go.mod h1:bmLyhP68RS6kStMGxByiQ23RP/odRBOTVjwp2cDyi6I= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZb78yU= +github.com/tklauser/go-sysconf v0.3.14/go.mod h1:1ym4lWMLUOhuBOPGtRcJm7tEGX4SCYNEEEtghGG/8uY= github.com/tklauser/numcpus v0.9.0 h1:lmyCHtANi8aRUgkckBgoDk1nHCux3n2cgkJLXdQGPDo= github.com/tklauser/numcpus v0.9.0/go.mod h1:SN6Nq1O3VychhC1npsWostA+oW+VOQTxZrS604NSRyI= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= From 9c23c888262425f94b48846cefd1efb8734179cd Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 03:51:29 -0800 Subject: [PATCH 25/40] Cleanup --- .../internal/sandbox/block/cache.go | 10 ++- .../block/cache_copyfromprocess_test.go | 67 ++++--------------- .../internal/sandbox/block/range_test.go | 6 -- 3 files changed, 17 insertions(+), 66 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/block/cache.go b/packages/orchestrator/internal/sandbox/block/cache.go index 67548d0d58..1d57013352 100644 --- a/packages/orchestrator/internal/sandbox/block/cache.go +++ b/packages/orchestrator/internal/sandbox/block/cache.go @@ -26,8 +26,6 @@ const ( oomMaxJitter = 100 * time.Millisecond ) -var ErrNoRanges = errors.New("no ranges (or ranges with total size 0) provided") - var tracer = otel.Tracer("github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block") type CacheClosedError struct { @@ -294,15 +292,15 @@ func NewCacheFromProcessMemory( ) (*Cache, error) { size := GetSize(ranges) - if size == 0 { - return nil, ErrNoRanges - } - cache, err := NewCache(size, blockSize, filePath, false) if err != nil { return nil, err } + if size == 0 { + return cache, nil + } + err = cache.copyProcessMemory(ctx, pid, ranges) if err != nil { return nil, err diff --git a/packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go b/packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go index f48eef811a..b1c15f6502 100644 --- a/packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go +++ b/packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go @@ -214,18 +214,6 @@ func TestCopyFromProcess_MultipleRanges(t *testing.T) { } } -func TestCopyFromProcess_EmptyRanges(t *testing.T) { - t.Parallel() - - ctx := context.Background() - - // Test with empty ranges - ranges := []Range{} - tmpFile := t.TempDir() + "/cache" - _, err := NewCacheFromProcessMemory(ctx, header.PageSize, tmpFile, os.Getpid(), ranges) - require.ErrorIs(t, err, ErrNoRanges) -} - func TestCopyFromProcess_ContextCancellation(t *testing.T) { t.Parallel() @@ -306,48 +294,6 @@ func TestCopyFromProcess_InvalidAddress(t *testing.T) { assert.Contains(t, err.Error(), "failed to read memory") } -func TestCopyFromProcess_ZeroSizeRange(t *testing.T) { - t.Parallel() - - ctx := context.Background() - size := uint64(4096) - - // Start helper process - cmd := exec.CommandContext(ctx, os.Args[0], "-test.run=TestCopyFromProcess_HelperProcess") - cmd.Env = append(os.Environ(), "GO_TEST_COPY_HELPER_PROCESS=1") - cmd.Env = append(cmd.Env, fmt.Sprintf("GO_TEST_MEMORY_SIZE=%d", size)) - cmd.Stderr = os.Stderr - - stdout, err := cmd.StdoutPipe() - require.NoError(t, err) - - err = cmd.Start() - require.NoError(t, err) - - t.Cleanup(func() { - cmd.Process.Signal(syscall.SIGTERM) - cmd.Wait() - }) - - // Read the memory address from the helper process - var addr uint64 - err = binary.Read(stdout, binary.LittleEndian, &addr) - require.NoError(t, err) - stdout.Close() - - // Wait a bit for the process to be ready - time.Sleep(100 * time.Millisecond) - - // Test with zero-size range - ranges := []Range{ - {Start: int64(addr), Size: 0}, - } - - tmpFile := t.TempDir() + "/cache" - _, err = NewCacheFromProcessMemory(ctx, header.PageSize, tmpFile, cmd.Process.Pid, ranges) - require.ErrorIs(t, err, ErrNoRanges) -} - func TestCopyFromProcess_LargeRanges(t *testing.T) { t.Parallel() @@ -426,3 +372,16 @@ func TestCopyFromProcess_LargeRanges(t *testing.T) { } } } + +func TestEmptyRanges(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + ranges := []Range{} + tmpFile := t.TempDir() + "/cache" + c, err := NewCacheFromProcessMemory(ctx, header.PageSize, tmpFile, os.Getpid(), ranges) + require.NoError(t, err) + + defer c.Close() +} diff --git a/packages/orchestrator/internal/sandbox/block/range_test.go b/packages/orchestrator/internal/sandbox/block/range_test.go index 7185fd8b13..ca0650fde1 100644 --- a/packages/orchestrator/internal/sandbox/block/range_test.go +++ b/packages/orchestrator/internal/sandbox/block/range_test.go @@ -40,12 +40,6 @@ func TestRange_End(t *testing.T) { size: 1024 * 1024, expected: 1024 * 1024, }, - { - name: "negative start", - start: -100, - size: 50, - expected: -50, - }, } for _, tt := range tests { From 4f1804db25209c8cafce785205518f96693d07eb Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 03:57:11 -0800 Subject: [PATCH 26/40] Close cache on failure --- packages/orchestrator/internal/sandbox/block/cache.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/orchestrator/internal/sandbox/block/cache.go b/packages/orchestrator/internal/sandbox/block/cache.go index 1d57013352..8548971ec8 100644 --- a/packages/orchestrator/internal/sandbox/block/cache.go +++ b/packages/orchestrator/internal/sandbox/block/cache.go @@ -303,7 +303,7 @@ func NewCacheFromProcessMemory( err = cache.copyProcessMemory(ctx, pid, ranges) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to copy process memory: %w", errors.Join(err, cache.Close())) } return cache, nil From 44433cd54b9e6d5cc0b200e72b92d4444e9714bb Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 04:00:14 -0800 Subject: [PATCH 27/40] Cleanup --- .../internal/sandbox/block/range.go | 18 ------------- .../internal/sandbox/block/range_test.go | 25 ++++++++++++++++--- .../internal/sandbox/block/tracker.go | 22 ---------------- 3 files changed, 22 insertions(+), 43 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/block/range.go b/packages/orchestrator/internal/sandbox/block/range.go index 8a243c71f6..6a18ca5192 100644 --- a/packages/orchestrator/internal/sandbox/block/range.go +++ b/packages/orchestrator/internal/sandbox/block/range.go @@ -20,24 +20,6 @@ func (r *Range) End() int64 { return r.Start + r.Size } -// Offsets returns the block offsets contained in the range. -// This assumes the Range.Start is a multiple of the blockSize. -func (r *Range) Offsets(blockSize int64) iter.Seq[int64] { - return func(yield func(offset int64) bool) { - getOffsets(r.Start, r.End(), blockSize)(yield) - } -} - -func getOffsets(start, end int64, blockSize int64) iter.Seq[int64] { - return func(yield func(offset int64) bool) { - for off := start; off < end; off += blockSize { - if !yield(off) { - return - } - } - } -} - // NewRange creates a new range from a start address and size in bytes. func NewRange(start int64, size int64) Range { return Range{ diff --git a/packages/orchestrator/internal/sandbox/block/range_test.go b/packages/orchestrator/internal/sandbox/block/range_test.go index ca0650fde1..5c766d6021 100644 --- a/packages/orchestrator/internal/sandbox/block/range_test.go +++ b/packages/orchestrator/internal/sandbox/block/range_test.go @@ -1,6 +1,7 @@ package block import ( + "iter" "slices" "testing" @@ -9,6 +10,24 @@ import ( "github.com/stretchr/testify/require" ) +// rangeOffsets returns the block offsets contained in the range. +// This assumes the Range.Start is a multiple of the blockSize. +func rangeOffsets(r *Range, blockSize int64) iter.Seq[int64] { + return func(yield func(offset int64) bool) { + getOffsets(r.Start, r.End(), blockSize)(yield) + } +} + +func getOffsets(start, end int64, blockSize int64) iter.Seq[int64] { + return func(yield func(offset int64) bool) { + for off := start; off < end; off += blockSize { + if !yield(off) { + return + } + } + } +} + func TestRange_End(t *testing.T) { tests := []struct { name string @@ -230,7 +249,7 @@ func TestRange_Offsets(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - offsets := slices.Collect(tt.range_.Offsets(tt.blockSize)) + offsets := slices.Collect(rangeOffsets(&tt.range_, tt.blockSize)) if len(tt.expected) == 0 { assert.Empty(t, offsets) } else { @@ -249,7 +268,7 @@ func TestRange_Offsets_Iteration(t *testing.T) { blockSize := int64(4096) var collected []int64 - for offset := range r.Offsets(blockSize) { + for offset := range rangeOffsets(&r, blockSize) { collected = append(collected, offset) if len(collected) >= 3 { break @@ -462,7 +481,7 @@ func TestRange_Offsets_EdgeCases(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - offsets := slices.Collect(tt.range_.Offsets(tt.blockSize)) + offsets := slices.Collect(rangeOffsets(&tt.range_, tt.blockSize)) assert.Equal(t, tt.expected, offsets) }) } diff --git a/packages/orchestrator/internal/sandbox/block/tracker.go b/packages/orchestrator/internal/sandbox/block/tracker.go index e0408acd0f..dc0c74e853 100644 --- a/packages/orchestrator/internal/sandbox/block/tracker.go +++ b/packages/orchestrator/internal/sandbox/block/tracker.go @@ -39,19 +39,6 @@ func (t *Tracker) Has(off int64) bool { return t.b.Test(uint(header.BlockIdx(off, t.blockSize))) } -func (t *Tracker) HasOffsets(off, length int64) bool { - t.mu.RLock() - defer t.mu.RUnlock() - - for off := range getOffsets(off, off+length, t.blockSize) { - if !t.b.Test(uint(header.BlockIdx(off, t.blockSize))) { - return false - } - } - - return true -} - func (t *Tracker) Add(off int64) { t.mu.Lock() defer t.mu.Unlock() @@ -59,15 +46,6 @@ func (t *Tracker) Add(off int64) { t.b.Set(uint(header.BlockIdx(off, t.blockSize))) } -func (t *Tracker) AddOffsets(off, length int64) { - t.mu.Lock() - defer t.mu.Unlock() - - for off := range getOffsets(off, off+length, t.blockSize) { - t.b.Set(uint(header.BlockIdx(off, t.blockSize))) - } -} - func (t *Tracker) Reset() { t.mu.Lock() defer t.mu.Unlock() From 1e12ae76511ac7d4688eddc4e84f5d2e795a2c71 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 04:01:47 -0800 Subject: [PATCH 28/40] Remove unused --- .../internal/sandbox/uffd/memory/region.go | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/region.go b/packages/orchestrator/internal/sandbox/uffd/memory/region.go index 3de0402502..821f642471 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory/region.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory/region.go @@ -1,7 +1,5 @@ package memory -import "iter" - // Region is a mapping of a region of memory of the guest to a region of memory on the host. // The serialization is based on the Firecracker UFFD protocol communication. // https://github.com/firecracker-microvm/firecracker/blob/ceeca6a14284537ae0b2a192cd2ffef10d3a81e2/src/vmm/src/persist.rs#L96 @@ -34,13 +32,3 @@ func (r *Region) shiftedOffset(addr uintptr) int64 { func (r *Region) shiftedHostVirtAddr(off int64) uintptr { return uintptr(off) + r.BaseHostVirtAddr - r.Offset } - -func (r *Region) Offsets() iter.Seq[int64] { - return func(yield func(offset int64) bool) { - for i := int64(r.Offset); i < r.endOffset(); i += int64(r.PageSize) { - if !yield(i) { - return - } - } - } -} From 4fe3b23b8b7856e28d14a514bd52d7f4e687c0f7 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 06:38:41 -0800 Subject: [PATCH 29/40] Refactor --- .../actions/build-sandbox-template/action.yml | 4 +- .../internal/sandbox/block/cache.go | 46 +++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/.github/actions/build-sandbox-template/action.yml b/.github/actions/build-sandbox-template/action.yml index 7f5dcd4c53..92bdbdd8a8 100644 --- a/.github/actions/build-sandbox-template/action.yml +++ b/.github/actions/build-sandbox-template/action.yml @@ -7,8 +7,8 @@ runs: - name: Build Sandbox Template env: TEMPLATE_ID: "2j6ly824owf4awgai1xo" - KERNEL_VERSION: "vmlinux-6.1.102" - FIRECRACKER_VERSION: "v1.12.1_d990331" + KERNEL_VERSION: "vmlinux-6.1.158" + FIRECRACKER_VERSION: "v1.12.1_717921c" run: | # Generate an unique build ID for the template for this run export BUILD_ID=$(uuidgen) diff --git a/packages/orchestrator/internal/sandbox/block/cache.go b/packages/orchestrator/internal/sandbox/block/cache.go index 8548971ec8..debce091d4 100644 --- a/packages/orchestrator/internal/sandbox/block/cache.go +++ b/packages/orchestrator/internal/sandbox/block/cache.go @@ -14,6 +14,7 @@ import ( "syscall" "time" + "github.com/bits-and-blooms/bitset" "github.com/edsrzf/mmap-go" "go.opentelemetry.io/otel" "golang.org/x/sys/unix" @@ -62,6 +63,15 @@ func NewCache(size, blockSize int64, filePath string, dirtyFile bool) (*Cache, e defer f.Close() + if size == 0 { + return &Cache{ + filePath: filePath, + size: size, + blockSize: blockSize, + dirtyFile: dirtyFile, + }, nil + } + // This should create a sparse file on Linux. err = f.Truncate(size) if err != nil { @@ -98,6 +108,10 @@ func (c *Cache) Sync() error { return NewErrCacheClosed(c.filePath) } + if c.mmap == nil { + return nil + } + err := c.mmap.Flush() if err != nil { return fmt.Errorf("error syncing cache: %w", err) @@ -117,6 +131,14 @@ func (c *Cache) ExportToDiff(ctx context.Context, out io.Writer) (*header.DiffMe return nil, NewErrCacheClosed(c.filePath) } + if c.mmap == nil { + return &header.DiffMetadata{ + Dirty: bitset.New(0), + Empty: bitset.New(0), + BlockSize: c.blockSize, + }, nil + } + err := c.mmap.Flush() if err != nil { return nil, fmt.Errorf("error flushing mmap: %w", err) @@ -140,6 +162,10 @@ func (c *Cache) ReadAt(b []byte, off int64) (int, error) { c.mu.RLock() defer c.mu.RUnlock() + if c.mmap == nil { + return 0, nil + } + if c.isClosed() { return 0, NewErrCacheClosed(c.filePath) } @@ -156,6 +182,10 @@ func (c *Cache) WriteAt(b []byte, off int64) (int, error) { c.mu.Lock() defer c.mu.Unlock() + if c.mmap == nil { + return 0, nil + } + if c.isClosed() { return 0, NewErrCacheClosed(c.filePath) } @@ -167,6 +197,10 @@ func (c *Cache) Close() (e error) { c.mu.Lock() defer c.mu.Unlock() + if c.mmap == nil { + return nil + } + succ := c.closed.CompareAndSwap(false, true) if !succ { return NewErrCacheClosed(c.filePath) @@ -198,6 +232,10 @@ func (c *Cache) Slice(off, length int64) ([]byte, error) { return nil, NewErrCacheClosed(c.filePath) } + if c.mmap == nil { + return nil, nil + } + if c.dirtyFile || c.isCached(off, length) { end := min(off+length, c.size) @@ -230,6 +268,10 @@ func (c *Cache) WriteAtWithoutLock(b []byte, off int64) (int, error) { return 0, NewErrCacheClosed(c.filePath) } + if c.mmap == nil { + return 0, nil + } + end := min(off+int64(len(b)), c.size) n := copy((*c.mmap)[off:end], b) @@ -272,6 +314,10 @@ func (c *Cache) FileSize() (int64, error) { } func (c *Cache) address(off int64) *byte { + if c.mmap == nil { + return nil + } + return &(*c.mmap)[off] } From 04f6e8c5eb3b0c6bc11d379938cc3907b13d6793 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 06:50:43 -0800 Subject: [PATCH 30/40] Cleanup --- .../block/cache_copyfromprocess_test.go | 91 ++++++++++++------- .../orchestrator/internal/sandbox/sandbox.go | 4 +- 2 files changed, 59 insertions(+), 36 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go b/packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go index b1c15f6502..f4ddcec6a9 100644 --- a/packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go +++ b/packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go @@ -2,6 +2,7 @@ package block import ( "context" + "crypto/rand" "encoding/binary" "fmt" "os" @@ -45,9 +46,12 @@ func TestCopyFromProcess_HelperProcess(t *testing.T) { } defer unix.Munmap(mem) - // Fill memory with a pattern: each byte is its offset modulo 256 - for i := range mem { - mem[i] = byte(i % 256) + // Fill memory with random data + _, err = rand.Read(mem) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to generate random data: %v\n", err) + + panic(err) } // Write the memory address to stdout (as 8 bytes, little endian) @@ -62,6 +66,14 @@ func TestCopyFromProcess_HelperProcess(t *testing.T) { panic(err) } + // Write the random data to stdout so the test can verify it + _, err = os.Stdout.Write(mem) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to write random data: %v\n", err) + + panic(err) + } + // Signal ready by closing stdout os.Stdout.Close() @@ -108,6 +120,11 @@ func TestCopyFromProcess_Success(t *testing.T) { var addr uint64 err = binary.Read(stdout, binary.LittleEndian, &addr) require.NoError(t, err) + + // Read the random data that was written to memory + expectedData := make([]byte, size) + _, err = stdout.Read(expectedData) + require.NoError(t, err) stdout.Close() // Wait a bit for the process to be ready @@ -131,11 +148,8 @@ func TestCopyFromProcess_Success(t *testing.T) { require.NoError(t, err) require.Equal(t, int(size), n) - // Verify pattern: each byte should be its offset modulo 256 - for i := range data { - expected := byte(i % 256) - assert.Equal(t, expected, data[i], "byte at offset %d should be %d, got %d", i, expected, data[i]) - } + // Verify the data matches the random data exactly + assert.Equal(t, expectedData, data, "copied data should match the original random data") } func TestCopyFromProcess_MultipleRanges(t *testing.T) { @@ -166,16 +180,22 @@ func TestCopyFromProcess_MultipleRanges(t *testing.T) { var baseAddr uint64 err = binary.Read(stdout, binary.LittleEndian, &baseAddr) require.NoError(t, err) + + // Read the random data that was written to memory + expectedData := make([]byte, totalSize) + _, err = stdout.Read(expectedData) + require.NoError(t, err) stdout.Close() // Wait a bit for the process to be ready time.Sleep(100 * time.Millisecond) // Test copying multiple non-contiguous ranges + // Order: 0th, 2nd, then 1st segment in memory ranges := []Range{ - {Start: int64(baseAddr), Size: int64(segmentSize)}, - {Start: int64(baseAddr + segmentSize*2), Size: int64(segmentSize)}, - {Start: int64(baseAddr + segmentSize), Size: int64(segmentSize)}, + {Start: int64(baseAddr), Size: int64(segmentSize)}, // 0th segment + {Start: int64(baseAddr + segmentSize*2), Size: int64(segmentSize)}, // 2nd segment + {Start: int64(baseAddr + segmentSize), Size: int64(segmentSize)}, // 1st segment } tmpFile := t.TempDir() + "/cache" @@ -183,35 +203,29 @@ func TestCopyFromProcess_MultipleRanges(t *testing.T) { require.NoError(t, err) defer cache.Close() - // Verify the first segment (at cache offset 0) + // Verify the first segment (at cache offset 0): should be from source baseAddr (0th segment of process) data1 := make([]byte, segmentSize) n, err := cache.ReadAt(data1, 0) require.NoError(t, err) require.Equal(t, int(segmentSize), n) - for i := range data1 { - expected := byte(i % 256) - assert.Equal(t, expected, data1[i], "first segment, byte at offset %d", i) - } + expected1 := expectedData[0:segmentSize] + assert.Equal(t, expected1, data1, "first segment should match original random data") - // Verify the second segment (copied to offset segmentSize*2 in cache) + // Verify the second segment (at cache offset segmentSize): should be from source baseAddr+segmentSize*2 (2nd segment of process) data2 := make([]byte, segmentSize) - n, err = cache.ReadAt(data2, int64(segmentSize*2)) + n, err = cache.ReadAt(data2, int64(segmentSize)) require.NoError(t, err) require.Equal(t, int(segmentSize), n) - for i := range data2 { - expected := byte((int(segmentSize*2) + i) % 256) - assert.Equal(t, expected, data2[i], "second segment, byte at offset %d", i) - } + expected2 := expectedData[segmentSize*2 : segmentSize*3] + assert.Equal(t, expected2, data2, "second segment should match original random data") - // Verify the third segment (copied to offset segmentSize in cache) + // Verify the third segment (at cache offset segmentSize*2): should be from source baseAddr+segmentSize (1st segment of process) data3 := make([]byte, segmentSize) - n, err = cache.ReadAt(data3, int64(segmentSize)) + n, err = cache.ReadAt(data3, int64(segmentSize*2)) require.NoError(t, err) require.Equal(t, int(segmentSize), n) - for i := range data3 { - expected := byte((int(segmentSize) + i) % 256) - assert.Equal(t, expected, data3[i], "third segment, byte at offset %d", i) - } + expected3 := expectedData[segmentSize : segmentSize*2] + assert.Equal(t, expected3, data3, "third segment should match original random data") } func TestCopyFromProcess_ContextCancellation(t *testing.T) { @@ -241,6 +255,11 @@ func TestCopyFromProcess_ContextCancellation(t *testing.T) { var addr uint64 err = binary.Read(stdout, binary.LittleEndian, &addr) require.NoError(t, err) + + // Read the random data (even though we won't use it, we need to consume it from stdout) + expectedData := make([]byte, size) + _, err = stdout.Read(expectedData) + require.NoError(t, err) stdout.Close() // Wait a bit for the process to be ready @@ -256,8 +275,7 @@ func TestCopyFromProcess_ContextCancellation(t *testing.T) { tmpFile := t.TempDir() + "/cache" _, err = NewCacheFromProcessMemory(ctx, header.PageSize, tmpFile, cmd.Process.Pid, ranges) - require.Error(t, err) - assert.Equal(t, context.Canceled, err) + require.ErrorIs(t, err, context.Canceled) } func TestCopyFromProcess_InvalidPID(t *testing.T) { @@ -327,6 +345,11 @@ func TestCopyFromProcess_LargeRanges(t *testing.T) { var baseAddr uint64 err = binary.Read(stdout, binary.LittleEndian, &baseAddr) require.NoError(t, err) + + // Read the random data that was written to memory + expectedData := make([]byte, totalSize) + _, err = stdout.Read(expectedData) + require.NoError(t, err) stdout.Close() // Wait a bit for the process to be ready @@ -360,15 +383,15 @@ func TestCopyFromProcess_LargeRanges(t *testing.T) { // Read a full page to ensure we get the data data := make([]byte, header.PageSize) - fmt.Println("reading at aligned offset", alignedOffset, "with offset in block", offsetInBlock) + n, err := cache.ReadAt(data, alignedOffset) require.NoError(t, err) require.Equal(t, int(header.PageSize), n) - // Verify pattern for the range we're checking + // Verify the range we're checking matches the expected random data for j := range rangeSize { - expected := byte((actualOffset + int64(j)) % 256) - assert.Equal(t, expected, data[offsetInBlock+int64(j)], "range %d, byte at offset %d", i, j) + expectedByte := expectedData[actualOffset+int64(j)] + assert.Equal(t, expectedByte, data[offsetInBlock+int64(j)], "range %d, byte at offset %d", i, j) } } } diff --git a/packages/orchestrator/internal/sandbox/sandbox.go b/packages/orchestrator/internal/sandbox/sandbox.go index 7c7cc4efbf..eba221b61d 100644 --- a/packages/orchestrator/internal/sandbox/sandbox.go +++ b/packages/orchestrator/internal/sandbox/sandbox.go @@ -828,13 +828,13 @@ func pauseProcessMemory( ctx, span := tracer.Start(ctx, "process-memory") defer span.End() - memfileDiffPath := build.GenerateDiffCachePath(cacheDir, buildID.String(), build.Memfile) - header, err := diffMetadata.ToDiffHeader(ctx, originalHeader, buildID) if err != nil { return nil, nil, fmt.Errorf("failed to create memfile header: %w", err) } + memfileDiffPath := build.GenerateDiffCachePath(cacheDir, buildID.String(), build.Memfile) + cache, err := fc.ExportMemory( ctx, diffMetadata.Dirty, From 7797ff90e3a32a56ce0af38cbe0a62093c351ac7 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 06:57:57 -0800 Subject: [PATCH 31/40] Describe sandbox pause process --- packages/orchestrator/internal/sandbox/sandbox.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/packages/orchestrator/internal/sandbox/sandbox.go b/packages/orchestrator/internal/sandbox/sandbox.go index eba221b61d..3b927125ff 100644 --- a/packages/orchestrator/internal/sandbox/sandbox.go +++ b/packages/orchestrator/internal/sandbox/sandbox.go @@ -709,6 +709,19 @@ func (s *Sandbox) Shutdown(ctx context.Context) error { return nil } + +// Pause creates a snapshot of the sandbox. +// +// +// Currently the memory snapshotting works like this: +// 1. We pause FC VM +// 2. We call FC snapshot endpoint without specifying memfile path. With our custom FC, +// with out custom Fc, this only creates the snapfile and drains and flushes the disk. +// 3. We call custom FC endpoint that returns memory addresses of the sandbox memory, that we will process after. +// 4. In case of NoopMemory (the sandbox was not a resume) we also call the custom FC endpoint, +// that returns info about resident memory pages and about empty memory pages. +// 5. Base on the info from the custom FC endpoint or from Uffd we copy the pages directly from the FC process to a local cache. +// 6. We then can either close the sandbox or resume it. func (s *Sandbox) Pause( ctx context.Context, m metadata.Template, From 75c965f167f4e9adcc28990549f8a5bb5893fa9e Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 06:59:17 -0800 Subject: [PATCH 32/40] Add remote to nil cache --- packages/orchestrator/internal/sandbox/block/cache.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/orchestrator/internal/sandbox/block/cache.go b/packages/orchestrator/internal/sandbox/block/cache.go index debce091d4..3bbf143a50 100644 --- a/packages/orchestrator/internal/sandbox/block/cache.go +++ b/packages/orchestrator/internal/sandbox/block/cache.go @@ -198,7 +198,7 @@ func (c *Cache) Close() (e error) { defer c.mu.Unlock() if c.mmap == nil { - return nil + return os.RemoveAll(c.filePath) } succ := c.closed.CompareAndSwap(false, true) From 987603922b7a0139a41d3eae84274e03be286bb0 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 07:03:23 -0800 Subject: [PATCH 33/40] Fix format --- .../orchestrator/internal/sandbox/sandbox.go | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/sandbox.go b/packages/orchestrator/internal/sandbox/sandbox.go index 3b927125ff..c743e39daa 100644 --- a/packages/orchestrator/internal/sandbox/sandbox.go +++ b/packages/orchestrator/internal/sandbox/sandbox.go @@ -709,19 +709,17 @@ func (s *Sandbox) Shutdown(ctx context.Context) error { return nil } - // Pause creates a snapshot of the sandbox. // -// // Currently the memory snapshotting works like this: -// 1. We pause FC VM -// 2. We call FC snapshot endpoint without specifying memfile path. With our custom FC, -// with out custom Fc, this only creates the snapfile and drains and flushes the disk. -// 3. We call custom FC endpoint that returns memory addresses of the sandbox memory, that we will process after. -// 4. In case of NoopMemory (the sandbox was not a resume) we also call the custom FC endpoint, -// that returns info about resident memory pages and about empty memory pages. -// 5. Base on the info from the custom FC endpoint or from Uffd we copy the pages directly from the FC process to a local cache. -// 6. We then can either close the sandbox or resume it. +// 1. We pause FC VM +// 2. We call FC snapshot endpoint without specifying memfile path. With our custom FC, +// with out custom Fc, this only creates the snapfile and drains and flushes the disk. +// 3. We call custom FC endpoint that returns memory addresses of the sandbox memory, that we will process after. +// 4. In case of NoopMemory (the sandbox was not a resume) we also call the custom FC endpoint, +// that returns info about resident memory pages and about empty memory pages. +// 5. Base on the info from the custom FC endpoint or from Uffd we copy the pages directly from the FC process to a local cache. +// 6. We then can either close the sandbox or resume it. func (s *Sandbox) Pause( ctx context.Context, m metadata.Template, From 02190a6bfdf874152739be309d50a4b82f64e98c Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 07:06:25 -0800 Subject: [PATCH 34/40] Fix comment --- packages/orchestrator/internal/sandbox/sandbox.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/orchestrator/internal/sandbox/sandbox.go b/packages/orchestrator/internal/sandbox/sandbox.go index c743e39daa..802a97367e 100644 --- a/packages/orchestrator/internal/sandbox/sandbox.go +++ b/packages/orchestrator/internal/sandbox/sandbox.go @@ -714,7 +714,7 @@ func (s *Sandbox) Shutdown(ctx context.Context) error { // Currently the memory snapshotting works like this: // 1. We pause FC VM // 2. We call FC snapshot endpoint without specifying memfile path. With our custom FC, -// with out custom Fc, this only creates the snapfile and drains and flushes the disk. +// this only creates the snapfile and drains and flushes the disk. // 3. We call custom FC endpoint that returns memory addresses of the sandbox memory, that we will process after. // 4. In case of NoopMemory (the sandbox was not a resume) we also call the custom FC endpoint, // that returns info about resident memory pages and about empty memory pages. From 518e4ee9aee4880bcf79af4843410c1431e0e35e Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 07:51:59 -0800 Subject: [PATCH 35/40] Make test less flaky --- .../block/cache_copyfromprocess_test.go | 38 +++++++++++++++---- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go b/packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go index f4ddcec6a9..8186568144 100644 --- a/packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go +++ b/packages/orchestrator/internal/sandbox/block/cache_copyfromprocess_test.go @@ -93,6 +93,23 @@ func TestCopyFromProcess_HelperProcess(t *testing.T) { } } +// waitForProcess consistently checks for process existence until it is confirmed or timeout is reached. +func waitForProcess(pid int, maxWait time.Duration) error { + start := time.Now() + for { + // Try sending signal 0 to the process: if it exists, this succeeds (unless permission is denied). + err := syscall.Kill(pid, 0) + if err == nil || err == syscall.EPERM { + // Process exists (but maybe we lack permission, still good enough for tests) + return nil + } + if time.Since(start) > maxWait { + return fmt.Errorf("process %d not available after %.2fs", pid, maxWait.Seconds()) + } + time.Sleep(10 * time.Millisecond) + } +} + func TestCopyFromProcess_Success(t *testing.T) { t.Parallel() @@ -127,8 +144,8 @@ func TestCopyFromProcess_Success(t *testing.T) { require.NoError(t, err) stdout.Close() - // Wait a bit for the process to be ready - time.Sleep(100 * time.Millisecond) + // Wait until the process is up and running before copying its memory. + require.NoError(t, waitForProcess(cmd.Process.Pid, 2*time.Second)) // Test copying a single range ranges := []Range{ @@ -187,8 +204,7 @@ func TestCopyFromProcess_MultipleRanges(t *testing.T) { require.NoError(t, err) stdout.Close() - // Wait a bit for the process to be ready - time.Sleep(100 * time.Millisecond) + require.NoError(t, waitForProcess(cmd.Process.Pid, 2*time.Second)) // Test copying multiple non-contiguous ranges // Order: 0th, 2nd, then 1st segment in memory @@ -262,8 +278,7 @@ func TestCopyFromProcess_ContextCancellation(t *testing.T) { require.NoError(t, err) stdout.Close() - // Wait a bit for the process to be ready - time.Sleep(100 * time.Millisecond) + require.NoError(t, waitForProcess(cmd.Process.Pid, 2*time.Second)) // Cancel context immediately cancel() @@ -290,6 +305,10 @@ func TestCopyFromProcess_InvalidPID(t *testing.T) { } tmpFile := t.TempDir() + "/cache" + + // Add a wait here, but since the PID doesn't exist, it will timeout (this is fine for this test). + _ = waitForProcess(invalidPID, 10*time.Millisecond) + _, err := NewCacheFromProcessMemory(ctx, header.PageSize, tmpFile, invalidPID, ranges) require.Error(t, err) assert.Contains(t, err.Error(), "failed to read memory") @@ -307,6 +326,9 @@ func TestCopyFromProcess_InvalidAddress(t *testing.T) { } tmpFile := t.TempDir() + "/cache" + + require.NoError(t, waitForProcess(os.Getpid(), 2*time.Second)) // Make sure our process is alive + _, err := NewCacheFromProcessMemory(ctx, header.PageSize, tmpFile, os.Getpid(), ranges) require.Error(t, err) assert.Contains(t, err.Error(), "failed to read memory") @@ -352,8 +374,7 @@ func TestCopyFromProcess_LargeRanges(t *testing.T) { require.NoError(t, err) stdout.Close() - // Wait a bit for the process to be ready - time.Sleep(100 * time.Millisecond) + require.NoError(t, waitForProcess(cmd.Process.Pid, 2*time.Second)) // Create many small ranges that exceed IOV_MAX ranges := make([]Range, numRanges) @@ -403,6 +424,7 @@ func TestEmptyRanges(t *testing.T) { ranges := []Range{} tmpFile := t.TempDir() + "/cache" + require.NoError(t, waitForProcess(os.Getpid(), 2*time.Second)) // Make sure our process is alive c, err := NewCacheFromProcessMemory(ctx, header.PageSize, tmpFile, os.Getpid(), ranges) require.NoError(t, err) From 6fc9d9639f67be355ce14eabb04243b776b1ffc7 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 08:08:54 -0800 Subject: [PATCH 36/40] Remove unused env var --- packages/orchestrator/Makefile | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/orchestrator/Makefile b/packages/orchestrator/Makefile index 78afc911cf..9d76ae49d8 100644 --- a/packages/orchestrator/Makefile +++ b/packages/orchestrator/Makefile @@ -49,7 +49,6 @@ run-debug: GCP_DOCKER_REPOSITORY_NAME=$(GCP_DOCKER_REPOSITORY_NAME) \ GOOGLE_SERVICE_ACCOUNT_BASE64=$(GOOGLE_SERVICE_ACCOUNT_BASE64) \ OTEL_COLLECTOR_GRPC_ENDPOINT=$(OTEL_COLLECTOR_GRPC_ENDPOINT) \ - MAX_PARALLEL_MEMFILE_SNAPSHOTTING=$(MAX_PARALLEL_MEMFILE_SNAPSHOTTING) \ ./bin/orchestrator define setup_local_env From f7ff3b26e71d4c15ab4efa7ef8d45798d2aa1ea0 Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 08:09:10 -0800 Subject: [PATCH 37/40] Remove unused env var --- .github/actions/start-services/action.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/actions/start-services/action.yml b/.github/actions/start-services/action.yml index e2c01cd3b6..5365085344 100644 --- a/.github/actions/start-services/action.yml +++ b/.github/actions/start-services/action.yml @@ -101,7 +101,6 @@ runs: STORAGE_PROVIDER: "Local" ENVIRONMENT: "local" OTEL_COLLECTOR_GRPC_ENDPOINT: "localhost:4317" - MAX_PARALLEL_MEMFILE_SNAPSHOTTING: "2" SHARED_CHUNK_CACHE_PATH: "./.e2b-chunk-cache" EDGE_TOKEN: "abdcdefghijklmnop" run: | From 44414c224daf3769006d14f6675665809ea0a3bc Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 13:17:07 -0800 Subject: [PATCH 38/40] WIP --- .../internal/sandbox/uffd/fdexit/fdexit.go | 2 + .../sandbox/uffd/testutils/diff_byte.go | 29 +- .../internal/sandbox/uffd/uffd.go | 5 +- .../userfaultfd/cross_process_helpers_test.go | 171 +++- .../internal/sandbox/uffd/userfaultfd/fd.go | 37 +- .../uffd/userfaultfd/fd_helpers_test.go | 82 ++ .../sandbox/uffd/userfaultfd/helpers_test.go | 36 +- .../sandbox/uffd/userfaultfd/missing_test.go | 76 +- .../uffd/userfaultfd/missing_write_test.go | 186 +++-- .../sandbox/uffd/userfaultfd/userfaultfd.go | 133 ++- .../uffd/userfaultfd/userfaultfd_test.go | 771 ++++++++++++++++++ 11 files changed, 1375 insertions(+), 153 deletions(-) create mode 100644 packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go diff --git a/packages/orchestrator/internal/sandbox/uffd/fdexit/fdexit.go b/packages/orchestrator/internal/sandbox/uffd/fdexit/fdexit.go index 688cff5331..f4b20ad5fe 100644 --- a/packages/orchestrator/internal/sandbox/uffd/fdexit/fdexit.go +++ b/packages/orchestrator/internal/sandbox/uffd/fdexit/fdexit.go @@ -7,6 +7,8 @@ import ( "sync" ) +var ErrFdExit = errors.New("fd exit signal") + // FdExit is a wrapper around a pipe that allows to signal the exit of the uffd. type FdExit struct { r *os.File diff --git a/packages/orchestrator/internal/sandbox/uffd/testutils/diff_byte.go b/packages/orchestrator/internal/sandbox/uffd/testutils/diff_byte.go index 68298ea6ea..1acc10a8ed 100644 --- a/packages/orchestrator/internal/sandbox/uffd/testutils/diff_byte.go +++ b/packages/orchestrator/internal/sandbox/uffd/testutils/diff_byte.go @@ -1,20 +1,31 @@ package testutils +import ( + "errors" + "fmt" +) + // FirstDifferentByte returns the first byte index where a and b differ. // It also returns the differing byte values (want, got). // If slices are identical, it returns idx -1. -func FirstDifferentByte(a, b []byte) (idx int, want, got byte) { - smallerSize := min(len(a), len(b)) +func ErrorFromByteSlicesDifference(expected, actual []byte) error { + var errs []error - for i := range smallerSize { - if a[i] != b[i] { - return i, b[i], a[i] - } + if len(expected) > len(actual) { + errs = append(errs, fmt.Errorf("expected slice (%d bytes) is longer than actual slice (%d bytes)", len(expected), len(actual))) + } else if len(expected) < len(actual) { + errs = append(errs, fmt.Errorf("actual slice (%d bytes) is longer than expected slice (%d bytes)", len(actual), len(expected))) } - if len(a) != len(b) { - return smallerSize, 0, 0 + smallerSize := min(len(expected), len(actual)) + + for i := range smallerSize { + if expected[i] != actual[i] { + errs = append(errs, fmt.Errorf("first different byte: want '%x', got '%x' at index %d", expected[i], actual[i], i)) + + break + } } - return -1, 0, 0 + return errors.Join(errs...) } diff --git a/packages/orchestrator/internal/sandbox/uffd/uffd.go b/packages/orchestrator/internal/sandbox/uffd/uffd.go index c44c9795e3..9383d1abe0 100644 --- a/packages/orchestrator/internal/sandbox/uffd/uffd.go +++ b/packages/orchestrator/internal/sandbox/uffd/uffd.go @@ -140,7 +140,7 @@ func (u *Uffd) handle(ctx context.Context, sandboxId string) error { m := memory.NewMapping(regions) uffd, err := userfaultfd.NewUserfaultfdFromFd( - uintptr(fds[0]), + userfaultfd.Fd(fds[0]), u.memfile, m, logger.L().With(logger.WithSandboxID(sandboxId)), @@ -164,6 +164,9 @@ func (u *Uffd) handle(ctx context.Context, sandboxId string) error { ctx, u.fdExit, ) + if errors.Is(err, fdexit.ErrFdExit) { + return nil + } if err != nil { return fmt.Errorf("failed handling uffd: %w", err) } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go index 3921ee429c..ecb27bd2c2 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go @@ -16,7 +16,6 @@ import ( "os/exec" "os/signal" "strconv" - "strings" "syscall" "testing" @@ -53,10 +52,12 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error err = configureApi(uffdFd, tt.pagesize) require.NoError(t, err) - err = register(uffdFd, memoryStart, uint64(size), UFFDIO_REGISTER_MODE_MISSING) + err = register(Fd(uffdFd), memoryStart, uint64(size), UFFDIO_REGISTER_MODE_MISSING) require.NoError(t, err) - cmd := exec.CommandContext(t.Context(), os.Args[0], "-test.run=TestHelperServingProcess") + // We don't use t.Context() here, because we want to be able to kill the process manually and listen to the correct exit code, + // while also handling the cleanup of the uffd. The t.Context seems to trigger before the test cleanup is started. + cmd := exec.CommandContext(context.Background(), os.Args[0], "-test.run=TestHelperServingProcess") cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS=1") cmd.Env = append(cmd.Env, fmt.Sprintf("GO_MMAP_START=%d", memoryStart)) cmd.Env = append(cmd.Env, fmt.Sprintf("GO_MMAP_PAGE_SIZE=%d", tt.pagesize)) @@ -81,11 +82,18 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error assert.NoError(t, closeErr) }() - offsetsReader, offsetsWriter, err := os.Pipe() + accessedOffsetsReader, accessedOffsetsWriter, err := os.Pipe() require.NoError(t, err) t.Cleanup(func() { - offsetsReader.Close() + accessedOffsetsReader.Close() + }) + + dirtyOffsetsReader, dirtyOffsetsWriter, err := os.Pipe() + require.NoError(t, err) + + t.Cleanup(func() { + dirtyOffsetsReader.Close() }) readyReader, readyWriter, err := os.Pipe() @@ -106,8 +114,9 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error cmd.ExtraFiles = []*os.File{ uffdFile, contentReader, - offsetsWriter, + accessedOffsetsWriter, readyWriter, + dirtyOffsetsWriter, } cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -116,36 +125,59 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error require.NoError(t, err) contentReader.Close() - offsetsWriter.Close() + accessedOffsetsWriter.Close() readyWriter.Close() uffdFile.Close() + dirtyOffsetsWriter.Close() + + go func() { + waitErr := cmd.Wait() + assert.NoError(t, waitErr) + + assert.NotEqual(t, -1, cmd.ProcessState.ExitCode(), "process was not terminated gracefully") + assert.NotEqual(t, 2, cmd.ProcessState.ExitCode(), "fd exit prematurely terminated the serve loop") + assert.NotEqual(t, 1, cmd.ProcessState.ExitCode(), "process exited with unexpected exit code") + + assert.Equal(t, 0, cmd.ProcessState.ExitCode()) + }() t.Cleanup(func() { - signalErr := cmd.Process.Signal(syscall.SIGUSR1) + // We are using SIGHUP to actually get exit code, not -1. + signalErr := cmd.Process.Signal(syscall.SIGTERM) assert.NoError(t, signalErr) - - waitErr := cmd.Wait() - // It can be either nil, an ExitError, a context.Canceled error, or "signal: killed" - assert.True(t, - (waitErr != nil && func(err error) bool { - var exitErr *exec.ExitError - - return errors.As(err, &exitErr) - }(waitErr)) || - errors.Is(waitErr, context.Canceled) || - (waitErr != nil && strings.Contains(waitErr.Error(), "signal: killed")) || - waitErr == nil, - "unexpected error: %v", waitErr, - ) }) - offsetsOnce := func() ([]uint, error) { + accessedOffsetsOnce := func() ([]uint, error) { err := cmd.Process.Signal(syscall.SIGUSR2) if err != nil { return nil, err } - offsetsBytes, err := io.ReadAll(offsetsReader) + offsetsBytes, err := io.ReadAll(accessedOffsetsReader) + if err != nil { + return nil, err + } + + var offsetList []uint + + if len(offsetsBytes)%8 != 0 { + return nil, fmt.Errorf("invalid offsets bytes length: %d", len(offsetsBytes)) + } + + for i := 0; i < len(offsetsBytes); i += 8 { + offsetList = append(offsetList, uint(binary.LittleEndian.Uint64(offsetsBytes[i:i+8]))) + } + + return offsetList, nil + } + + dirtyOffsetsOnce := func() ([]uint, error) { + err := cmd.Process.Signal(syscall.SIGUSR1) + if err != nil { + return nil, err + } + + offsetsBytes, err := io.ReadAll(dirtyOffsetsReader) if err != nil { return nil, err } @@ -169,23 +201,38 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error case <-readySignal: } + mapping := memory.NewMapping([]memory.Region{ + { + BaseHostVirtAddr: memoryStart, + Size: uintptr(size), + Offset: 0, + PageSize: uintptr(tt.pagesize), + }, + }) + return &testHandler{ - memoryArea: &memoryArea, - pagesize: tt.pagesize, - data: data, - offsetsOnce: offsetsOnce, + memoryArea: &memoryArea, + pagesize: tt.pagesize, + data: data, + accessedOffsetsOnce: accessedOffsetsOnce, + mapping: mapping, + dirtyOffsetsOnce: dirtyOffsetsOnce, }, nil } -// Secondary process, orchestrator in our case +// Secondary process, orchestrator in our case. func TestHelperServingProcess(t *testing.T) { if os.Getenv("GO_TEST_HELPER_PROCESS") != "1" { t.Skip("this is a helper process, skipping direct execution") } err := crossProcessServe() + if errors.Is(err, fdexit.ErrFdExit) { + os.Exit(2) + } + if err != nil { - fmt.Println("exit serving process", err) + fmt.Fprintf(os.Stderr, "error serving: %v", err) os.Exit(1) } @@ -240,29 +287,61 @@ func crossProcessServe() error { return fmt.Errorf("exit creating logger: %w", err) } - uffd, err := NewUserfaultfdFromFd(uffdFd, data, m, l) + uffd, err := NewUserfaultfdFromFd(Fd(uffdFd), data, m, l) if err != nil { return fmt.Errorf("exit creating uffd: %w", err) } - offsetsFile := os.NewFile(uintptr(5), "offsets") + accessedOffsetsFile := os.NewFile(uintptr(5), "accessed-offsets") - offsetsSignal := make(chan os.Signal, 1) - signal.Notify(offsetsSignal, syscall.SIGUSR2) - defer signal.Stop(offsetsSignal) + accessedOffsestsSignal := make(chan os.Signal, 1) + signal.Notify(accessedOffsestsSignal, syscall.SIGUSR2) + defer signal.Stop(accessedOffsestsSignal) go func() { - defer offsetsFile.Close() + defer accessedOffsetsFile.Close() for { select { case <-ctx.Done(): return - case <-offsetsSignal: + case <-accessedOffsestsSignal: + for offset := range accessed(uffd).Offsets() { + writeErr := binary.Write(accessedOffsetsFile, binary.LittleEndian, uint64(offset)) + if writeErr != nil { + msg := fmt.Errorf("error writing accessed offsets to file: %w", writeErr) + + fmt.Fprint(os.Stderr, msg.Error()) + + cancel(msg) + + return + } + } + + return + } + } + }() + + dirtyOffsetsFile := os.NewFile(uintptr(7), "dirty-offsets") + + dirtyOffsetsSignal := make(chan os.Signal, 1) + signal.Notify(dirtyOffsetsSignal, syscall.SIGUSR1) + defer signal.Stop(dirtyOffsetsSignal) + + go func() { + defer dirtyOffsetsFile.Close() + + for { + select { + case <-ctx.Done(): + return + case <-dirtyOffsetsSignal: for offset := range uffd.Dirty().Offsets() { - writeErr := binary.Write(offsetsFile, binary.LittleEndian, uint64(offset)) + writeErr := binary.Write(dirtyOffsetsFile, binary.LittleEndian, uint64(offset)) if writeErr != nil { - msg := fmt.Errorf("error writing offsets to file: %w", writeErr) + msg := fmt.Errorf("error writing dirty offsets to file: %w", writeErr) fmt.Fprint(os.Stderr, msg.Error()) @@ -289,6 +368,14 @@ func crossProcessServe() error { }() serverErr := uffd.Serve(ctx, fdExit) + if errors.Is(serverErr, fdexit.ErrFdExit) { + err := fmt.Errorf("serving finished via fd exit: %w", serverErr) + + cancel(err) + + return + } + if serverErr != nil { msg := fmt.Errorf("error serving: %w", serverErr) @@ -298,6 +385,8 @@ func crossProcessServe() error { return } + + fmt.Fprint(os.Stderr, "serving finished") }() cleanup := func() { @@ -318,7 +407,7 @@ func crossProcessServe() error { defer cleanup() exitSignal := make(chan os.Signal, 1) - signal.Notify(exitSignal, syscall.SIGUSR1) + signal.Notify(exitSignal, syscall.SIGTERM) defer signal.Stop(exitSignal) readyFile := os.NewFile(uintptr(6), "ready") @@ -330,7 +419,7 @@ func crossProcessServe() error { select { case <-ctx.Done(): - return fmt.Errorf("context done: %w: %w", ctx.Err(), context.Cause(ctx)) + return context.Cause(ctx) case <-exitSignal: return nil } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go index 1a122d01d5..c25c05af26 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go @@ -32,12 +32,19 @@ const ( UFFD_EVENT_PAGEFAULT = C.UFFD_EVENT_PAGEFAULT UFFDIO_REGISTER_MODE_MISSING = C.UFFDIO_REGISTER_MODE_MISSING + UFFDIO_REGISTER_MODE_WP = C.UFFDIO_REGISTER_MODE_WP - UFFDIO_API = C.UFFDIO_API - UFFDIO_REGISTER = C.UFFDIO_REGISTER - UFFDIO_COPY = C.UFFDIO_COPY + UFFDIO_WRITEPROTECT_MODE_WP = C.UFFDIO_WRITEPROTECT_MODE_WP + UFFDIO_COPY_MODE_WP = C.UFFDIO_COPY_MODE_WP + + UFFDIO_API = C.UFFDIO_API + UFFDIO_REGISTER = C.UFFDIO_REGISTER + UFFDIO_UNREGISTER = C.UFFDIO_UNREGISTER + UFFDIO_COPY = C.UFFDIO_COPY + UFFDIO_WRITEPROTECT = C.UFFDIO_WRITEPROTECT UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE + UFFD_PAGEFAULT_FLAG_WP = C.UFFD_PAGEFAULT_FLAG_WP UFFD_FEATURE_MISSING_HUGETLBFS = C.UFFD_FEATURE_MISSING_HUGETLBFS ) @@ -78,6 +85,13 @@ func newUffdioRegister(start, length, mode CULong) UffdioRegister { } } +func newUffdioWriteProtect(start, length, mode CULong) UffdioWriteProtect { + return UffdioWriteProtect{ + _range: newUffdioRange(start, length), + mode: mode, + } +} + func newUffdioCopy(b []byte, address CULong, pagesize CULong, mode CULong, bytesCopied CLong) UffdioCopy { return UffdioCopy{ src: CULong(uintptr(unsafe.Pointer(&b[0]))), @@ -120,6 +134,23 @@ func (f Fd) copy(addr, pagesize uintptr, data []byte, mode CULong) error { return nil } +// mode: UFFDIO_WRITEPROTECT_MODE_WP +// Passing 0 as the mode will remove the write protection. +func (f Fd) writeProtect(addr, size uintptr, mode CULong) error { + register := newUffdioWriteProtect(CULong(addr), CULong(size), mode) + + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_WRITEPROTECT, uintptr(unsafe.Pointer(®ister))) + if errno != 0 { + return fmt.Errorf("UFFDIO_WRITEPROTECT ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} + func (f Fd) close() error { return syscall.Close(int(f)) } + +func (f Fd) fd() int32 { + return int32(f) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go index f9f8d953cb..33036f346e 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go @@ -8,6 +8,71 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) +// mockFd is a mock implementation of the Fd interface. +// It allows us to test the handling methods separately from the actual uffd serve loop. +type mockFd struct { + // The channels send back the info about the uffd handled operations + // and also allows us to block the methods to test the flow. + copyCh chan *blockedEvent[UffdioCopy] + writeProtectCh chan *blockedEvent[UffdioWriteProtect] +} + +func newMockFd() *mockFd { + return &mockFd{ + copyCh: make(chan *blockedEvent[UffdioCopy]), + writeProtectCh: make(chan *blockedEvent[UffdioWriteProtect]), + } +} + +func (m *mockFd) register(_ uintptr, _ uint64, _ CULong) error { + return nil +} + +func (m *mockFd) unregister(_, _ uintptr) error { + return nil +} + +func (m *mockFd) copy(addr, pagesize uintptr, _ []byte, mode CULong) error { + // Don't use the uffdioCopy constructor as it unsafely checks slice address and fails for arbitrary pointer. + e := newBlockedEvent(UffdioCopy{ + src: 0, + dst: CULong(addr), + len: CULong(pagesize), + mode: mode, + copy: 0, + }) + + m.copyCh <- e + + <-e.resolved + + return nil +} + +func (m *mockFd) writeProtect(addr, size uintptr, mode CULong) error { + e := newBlockedEvent(UffdioWriteProtect{ + _range: newUffdioRange( + CULong(addr), + CULong(size), + ), + mode: mode, + }) + + m.writeProtectCh <- e + + <-e.resolved + + return nil +} + +func (m *mockFd) close() error { + return nil +} + +func (m *mockFd) fd() int32 { + return 0 +} + // Used for testing. // flags: syscall.O_CLOEXEC|syscall.O_NONBLOCK func newFd(flags uintptr) (Fd, error) { @@ -52,3 +117,20 @@ func register(f Fd, addr uintptr, size uint64, mode CULong) error { return nil } + +// This wrapped event allows us to simulate the finish of processing of events by FC on FC API Pause. +type blockedEvent[T UffdioCopy | UffdioWriteProtect] struct { + event T + resolved chan struct{} +} + +func newBlockedEvent[T UffdioCopy | UffdioWriteProtect](event T) *blockedEvent[T] { + return &blockedEvent[T]{ + event: event, + resolved: make(chan struct{}), + } +} + +func (e *blockedEvent[T]) resolve() { + close(e.resolved) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/helpers_test.go index 1a65664b86..8631457a49 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/helpers_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/helpers_test.go @@ -9,6 +9,8 @@ import ( "github.com/bits-and-blooms/bitset" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/testutils" ) @@ -41,8 +43,26 @@ type testHandler struct { data *testutils.MemorySlicer // Returns offsets of the pages that were faulted. // It can only be called once. - offsetsOnce func() ([]uint, error) - mutex sync.Mutex + // Sorted in ascending order. + accessedOffsetsOnce func() ([]uint, error) + // Returns offsets of the pages that were dirtied. + // It can only be called once. + // Sorted in ascending order. + dirtyOffsetsOnce func() ([]uint, error) + + mutex sync.Mutex + mapping *memory.Mapping +} + +func (h *testHandler) executeOperation(ctx context.Context, op operation) error { + switch op.mode { + case operationModeRead: + return h.executeRead(ctx, op) + case operationModeWrite: + return h.executeWrite(ctx, op) + default: + return fmt.Errorf("invalid operation mode: %d", op.mode) + } } func (h *testHandler) executeRead(ctx context.Context, op operation) error { @@ -55,9 +75,7 @@ func (h *testHandler) executeRead(ctx context.Context, op operation) error { // The bytes.Equal is the first place in this flow that actually touches the uffd managed memory and triggers the pagefault, so any deadlocks will manifest here. if !bytes.Equal(readBytes, expectedBytes) { - idx, want, got := testutils.FirstDifferentByte(readBytes, expectedBytes) - - return fmt.Errorf("content mismatch: want '%x, got %x at index %d", want, got, idx) + return fmt.Errorf("content mismatch: %w", testutils.ErrorFromByteSlicesDifference(expectedBytes, readBytes)) } return nil @@ -82,6 +100,7 @@ func (h *testHandler) executeWrite(ctx context.Context, op operation) error { } // Get a bitset of the offsets of the operations for the given mode. +// Sorted in ascending order. func getOperationsOffsets(ops []operation, m operationMode) []uint { b := bitset.New(0) @@ -93,3 +112,10 @@ func getOperationsOffsets(ops []operation, m operationMode) []uint { return slices.Collect(b.EachSet()) } + +func accessed(u *Userfaultfd) *block.Tracker { + u.settleRequests.Lock() + defer u.settleRequests.Unlock() + + return u.missingRequests.Clone() +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_test.go index 20c8ddeeb3..cf24e55d51 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_test.go @@ -62,6 +62,33 @@ func TestMissing(t *testing.T) { }, }, }, + { + name: "standard 4k page, reads, varying offsets", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 4 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 5 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 2 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + }, + }, { name: "hugepage, operation at start", pagesize: header.HugepageSize, @@ -110,6 +137,33 @@ func TestMissing(t *testing.T) { }, }, }, + { + name: "hugepage, reads, varying offsets", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 4 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 5 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 2 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + }, + }, } for _, tt := range tests { @@ -120,15 +174,13 @@ func TestMissing(t *testing.T) { require.NoError(t, err) for _, operation := range tt.operations { - if operation.mode == operationModeRead { - err := h.executeRead(t.Context(), operation) - require.NoError(t, err, "for operation %+v", operation) - } + err := h.executeOperation(t.Context(), operation) + assert.NoError(t, err, "for operation %+v", operation) //nolint:testifylint } expectedAccessedOffsets := getOperationsOffsets(tt.operations, operationModeRead|operationModeWrite) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") @@ -158,7 +210,7 @@ func TestParallelMissing(t *testing.T) { for range parallelOperations { verr.Go(func() error { - return h.executeRead(t.Context(), readOp) + return h.executeOperation(t.Context(), readOp) }) } @@ -167,7 +219,7 @@ func TestParallelMissing(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{readOp}, operationModeRead) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") @@ -191,14 +243,14 @@ func TestParallelMissingWithPrefault(t *testing.T) { mode: operationModeRead, } - err = h.executeRead(t.Context(), readOp) + err = h.executeOperation(t.Context(), readOp) require.NoError(t, err) var verr errgroup.Group for range parallelOperations { verr.Go(func() error { - return h.executeRead(t.Context(), readOp) + return h.executeOperation(t.Context(), readOp) }) } @@ -207,7 +259,7 @@ func TestParallelMissingWithPrefault(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{readOp}, operationModeRead) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") @@ -232,13 +284,13 @@ func TestSerialMissing(t *testing.T) { } for range serialOperations { - err := h.executeRead(t.Context(), readOp) + err := h.executeOperation(t.Context(), readOp) require.NoError(t, err) } expectedAccessedOffsets := getOperationsOffsets([]operation{readOp}, operationModeRead) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_write_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_write_test.go index 52a73e5fa2..5b987fd95c 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_write_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_write_test.go @@ -10,7 +10,7 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) -func TestMissingWrite(t *testing.T) { +func TestWriteProtection(t *testing.T) { t.Parallel() tests := []testConfig{ @@ -19,6 +19,10 @@ func TestMissingWrite(t *testing.T) { pagesize: header.PageSize, numberOfPages: 32, operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, { offset: 0, mode: operationModeWrite, @@ -30,6 +34,10 @@ func TestMissingWrite(t *testing.T) { pagesize: header.PageSize, numberOfPages: 32, operations: []operation{ + { + offset: 15 * header.PageSize, + mode: operationModeRead, + }, { offset: 15 * header.PageSize, mode: operationModeWrite, @@ -41,6 +49,10 @@ func TestMissingWrite(t *testing.T) { pagesize: header.PageSize, numberOfPages: 32, operations: []operation{ + { + offset: 31 * header.PageSize, + mode: operationModeRead, + }, { offset: 31 * header.PageSize, mode: operationModeWrite, @@ -48,18 +60,50 @@ func TestMissingWrite(t *testing.T) { }, }, { - name: "standard 4k page, read after write", + name: "standard 4k page, writes after reads, varying offsets", pagesize: header.PageSize, numberOfPages: 32, operations: []operation{ { - offset: 0, - mode: operationModeWrite, + offset: 4 * header.PageSize, + mode: operationModeRead, }, { - offset: 0, + offset: 5 * header.PageSize, mode: operationModeRead, }, + { + offset: 2 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 4 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 5 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 2 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.PageSize, + mode: operationModeWrite, + }, }, }, { @@ -67,6 +111,10 @@ func TestMissingWrite(t *testing.T) { pagesize: header.HugepageSize, numberOfPages: 8, operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, { offset: 0, mode: operationModeWrite, @@ -78,6 +126,10 @@ func TestMissingWrite(t *testing.T) { pagesize: header.HugepageSize, numberOfPages: 8, operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeRead, + }, { offset: 3 * header.HugepageSize, mode: operationModeWrite, @@ -89,6 +141,10 @@ func TestMissingWrite(t *testing.T) { pagesize: header.HugepageSize, numberOfPages: 8, operations: []operation{ + { + offset: 7 * header.HugepageSize, + mode: operationModeRead, + }, { offset: 7 * header.HugepageSize, mode: operationModeWrite, @@ -96,18 +152,50 @@ func TestMissingWrite(t *testing.T) { }, }, { - name: "hugepage, read after write", + name: "hugepage, writes after reads, varying offsets", pagesize: header.HugepageSize, numberOfPages: 8, operations: []operation{ { - offset: 0, - mode: operationModeWrite, + offset: 4 * header.HugepageSize, + mode: operationModeRead, }, { - offset: 0, + offset: 5 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 2 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, mode: operationModeRead, }, + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 4 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 5 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 2 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, }, }, } @@ -120,28 +208,28 @@ func TestMissingWrite(t *testing.T) { require.NoError(t, err) for _, operation := range tt.operations { - if operation.mode == operationModeRead { - err := h.executeRead(t.Context(), operation) - require.NoError(t, err, "for operation %+v", operation) - } - - if operation.mode == operationModeWrite { - err := h.executeWrite(t.Context(), operation) - require.NoError(t, err, "for operation %+v", operation) - } + err := h.executeOperation(t.Context(), operation) + assert.NoError(t, err, "for operation %+v", operation) //nolint:testifylint } expectedAccessedOffsets := getOperationsOffsets(tt.operations, operationModeRead|operationModeWrite) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets(tt.operations, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") }) } } -func TestParallelMissingWrite(t *testing.T) { +func TestParallelWriteProtection(t *testing.T) { t.Parallel() parallelOperations := 1_000_000 @@ -154,41 +242,12 @@ func TestParallelMissingWrite(t *testing.T) { h, err := configureCrossProcessTest(t, tt) require.NoError(t, err) - writeOp := operation{ + readOp := operation{ offset: 0, - mode: operationModeWrite, + mode: operationModeRead, } - var verr errgroup.Group - - for range parallelOperations { - verr.Go(func() error { - return h.executeWrite(t.Context(), writeOp) - }) - } - - err = verr.Wait() - require.NoError(t, err) - - expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) - - accessedOffsets, err := h.offsetsOnce() - require.NoError(t, err) - - assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") -} - -func TestParallelMissingWriteWithPrefault(t *testing.T) { - t.Parallel() - - parallelOperations := 10_000 - - tt := testConfig{ - pagesize: header.PageSize, - numberOfPages: 2, - } - - h, err := configureCrossProcessTest(t, tt) + err = h.executeOperation(t.Context(), readOp) require.NoError(t, err) writeOp := operation{ @@ -196,9 +255,6 @@ func TestParallelMissingWriteWithPrefault(t *testing.T) { mode: operationModeWrite, } - err = h.executeWrite(t.Context(), writeOp) - require.NoError(t, err) - var verr errgroup.Group for range parallelOperations { @@ -212,13 +268,20 @@ func TestParallelMissingWriteWithPrefault(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets([]operation{writeOp}, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") } -func TestSerialMissingWrite(t *testing.T) { +func TestSerialWriteProtection(t *testing.T) { t.Parallel() serialOperations := 10_000 @@ -243,8 +306,15 @@ func TestSerialMissingWrite(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets([]operation{writeOp}, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go index e248cfe581..59345c121b 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go @@ -1,5 +1,10 @@ package userfaultfd +// flowchart TD +// A[missing page] -- write (WRITE flag) --> B(COPY) --> C[dirty page] +// A -- read (MISSING flag) --> D(COPY + MODE_WP) --> E[faulted page] +// E -- write (WP+[WRITE] flag) --> F(remove MODE_WP) --> C + import ( "context" "errors" @@ -22,15 +27,23 @@ const maxRequestsInProgress = 4096 var ErrUnexpectedEventType = errors.New("unexpected event type") +type uffdio interface { + copy(addr, pagesize uintptr, data []byte, mode CULong) error + writeProtect(addr, size uintptr, mode CULong) error + close() error + fd() int32 +} + type Userfaultfd struct { - fd Fd + uffd uffdio src block.Slicer - ma *memory.Mapping + m *memory.Mapping // We don't skip the already mapped pages, because if the memory is swappable the page *might* under some conditions be mapped out. // For hugepages this should not be a problem, but might theoretically happen to normal pages with swap missingRequests *block.Tracker + writeRequests *block.Tracker // We use the settleRequests to guard the missingRequests so we can access a consistent state of the missingRequests after the requests are finished. settleRequests sync.RWMutex @@ -40,20 +53,30 @@ type Userfaultfd struct { } // NewUserfaultfdFromFd creates a new userfaultfd instance with optional configuration. -func NewUserfaultfdFromFd(fd uintptr, src block.Slicer, m *memory.Mapping, logger logger.Logger) (*Userfaultfd, error) { +func NewUserfaultfdFromFd(uffd uffdio, src block.Slicer, m *memory.Mapping, logger logger.Logger) (*Userfaultfd, error) { blockSize := src.BlockSize() for _, region := range m.Regions { if region.PageSize != uintptr(blockSize) { return nil, fmt.Errorf("block size mismatch: %d != %d for region %d", region.PageSize, blockSize, region.BaseHostVirtAddr) } + + // Register the WP for the regions. + // The memory region is already registered (with missing pages in FC), but registering it again with bigger flag subset should merge these registration flags. + // - https://github.com/firecracker-microvm/firecracker/blob/f335a0adf46f0680a141eb1e76fe31ac258918c5/src/vmm/src/persist.rs#L477 + // - https://github.com/bytecodealliance/userfaultfd-rs/blob/main/src/builder.rs + err := register(uffd, region.BaseHostVirtAddr, uint64(region.Size), UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING) + if err != nil { + return nil, fmt.Errorf("failed to reregister memory region with write protection %d-%d: %w", region.Offset, region.Offset+region.Size, err) + } } u := &Userfaultfd{ - fd: Fd(fd), + uffd: uffd, src: src, missingRequests: block.NewTracker(blockSize), - ma: m, + writeRequests: block.NewTracker(blockSize), + m: m, logger: logger, } @@ -66,15 +89,17 @@ func NewUserfaultfdFromFd(fd uintptr, src block.Slicer, m *memory.Mapping, logge } func (u *Userfaultfd) Close() error { - return u.fd.close() + return u.uffd.close() } func (u *Userfaultfd) Serve( ctx context.Context, fdExit *fdexit.FdExit, ) error { + uffd := u.uffd.fd() + pollFds := []unix.PollFd{ - {Fd: int32(u.fd), Events: unix.POLLIN}, + {Fd: uffd, Events: unix.POLLIN}, {Fd: fdExit.Reader(), Events: unix.POLLIN}, } @@ -113,7 +138,7 @@ outerLoop: return fmt.Errorf("failed to handle uffd: %w", errMsg) } - return nil + return fdexit.ErrFdExit } uffdFd := pollFds[0] @@ -135,7 +160,7 @@ outerLoop: buf := make([]byte, unsafe.Sizeof(UffdMsg{})) for { - _, err := syscall.Read(int(u.fd), buf) + _, err := syscall.Read(int(uffd), buf) if err == syscall.EINTR { u.logger.Debug(ctx, "uffd: interrupted read, reading again") @@ -176,21 +201,27 @@ outerLoop: addr := getPagefaultAddress(&pagefault) - offset, pagesize, err := u.ma.GetOffset(addr) + offset, pagesize, err := u.m.GetOffset(addr) if err != nil { u.logger.Error(ctx, "UFFD serve get mapping error", zap.Error(err)) return fmt.Errorf("failed to map: %w", err) } + // Handle write to write protected page (WP flag) + // The documentation does not clearly mention if the WRITE flag must be present with the WP flag, even though we saw it being present in the events. + // - https://docs.kernel.org/admin-guide/mm/userfaultfd.html#write-protect-notifications + if flags&UFFD_PAGEFAULT_FLAG_WP != 0 { + u.handleWriteProtected(ctx, fdExit.SignalExit, addr, pagesize, offset) + + continue + } + // Handle write to missing page (WRITE flag) // If the event has WRITE flag, it was a write to a missing page. // For the write to be executed, we first need to copy the page from the source to the guest memory. if flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 { - err := u.handleMissing(ctx, fdExit.SignalExit, addr, pagesize, offset) - if err != nil { - return fmt.Errorf("failed to handle missing write: %w", err) - } + u.handleMissing(ctx, fdExit.SignalExit, addr, pagesize, offset, true) continue } @@ -198,10 +229,7 @@ outerLoop: // Handle read to missing page ("MISSING" flag) // If the event has no flags, it was a read to a missing page and we need to copy the page from the source to the guest memory. if flags == 0 { - err := u.handleMissing(ctx, fdExit.SignalExit, addr, pagesize, offset) - if err != nil { - return fmt.Errorf("failed to handle missing: %w", err) - } + u.handleMissing(ctx, fdExit.SignalExit, addr, pagesize, offset, false) continue } @@ -217,18 +245,24 @@ func (u *Userfaultfd) handleMissing( addr, pagesize uintptr, offset int64, -) error { + write bool, +) { u.wg.Go(func() error { // The RLock must be called inside the goroutine to ensure RUnlock runs via defer, // even if the errgroup is cancelled or the goroutine returns early. // This check protects us against race condition between marking the request as missing and accessing the missingRequests tracker. - // The Firecracker pause should return only after the requested memory is faulted in, so we don't need to guard the pagefault from the moment it is created. + // The Firecracker pause should return only after the requested memory is copied to the guest memory, so we don't need to guard the pagefault from the moment it is created. u.settleRequests.RLock() defer u.settleRequests.RUnlock() defer func() { if r := recover(); r != nil { u.logger.Error(ctx, "UFFD serve panic", zap.Any("pagesize", pagesize), zap.Any("panic", r)) + + signalErr := onFailure() + if signalErr != nil { + u.logger.Error(ctx, "UFFD handle missing failure error", zap.Error(signalErr)) + } } }() @@ -245,7 +279,12 @@ func (u *Userfaultfd) handleMissing( var copyMode CULong - copyErr := u.fd.copy(addr, pagesize, b, copyMode) + // If the event is not WRITE, we need to add WP to the page, so we can catch the next WRITE+WP and mark the page as dirty. + if !write { + copyMode |= UFFDIO_COPY_MODE_WP + } + + copyErr := u.uffd.copy(addr, pagesize, b, copyMode) if errors.Is(copyErr, unix.EEXIST) { // Page is already mapped @@ -259,18 +298,64 @@ func (u *Userfaultfd) handleMissing( u.logger.Error(ctx, "UFFD serve uffdio copy error", zap.Error(joinedErr)) - return fmt.Errorf("failed uffdio copy %w", joinedErr) + return fmt.Errorf("failed to copy page %d-%d %w", offset, offset+int64(pagesize), joinedErr) } // Add the offset to the missing requests tracker. u.missingRequests.Add(offset) + if write { + // Add the offset to the write requests tracker. + u.writeRequests.Add(offset) + } + return nil }) +} - return nil +// Userfaultfd write-protect mode currently behave differently on none ptes (when e.g. page is missing) over different types of memories (hugepages file backed, etc.). +// - https://docs.kernel.org/admin-guide/mm/userfaultfd.html#write-protect-notifications - "there will be a userfaultfd write fault message generated when writing to a missing page" +// This should not affect the handling we have in place as all events are being handled. +func (u *Userfaultfd) handleWriteProtected(ctx context.Context, onFailure func() error, addr, pagesize uintptr, offset int64) { + u.wg.Go(func() error { + // The RLock must be called inside the goroutine to ensure RUnlock runs via defer, + // even if the errgroup is cancelled or the goroutine returns early. + // This check protects us against race condition between marking the request as dirty and accessing the writeRequests tracker. + // The Firecracker pause should return only after the requested memory is copied to the guest memory, so we don't need to guard the pagefault from the moment it is created. + u.settleRequests.RLock() + defer u.settleRequests.RUnlock() + + defer func() { + if r := recover(); r != nil { + u.logger.Error(ctx, "UFFD remove write protection panic", zap.Any("offset", offset), zap.Any("pagesize", pagesize), zap.Any("panic", r)) + + signalErr := onFailure() + if signalErr != nil { + u.logger.Error(ctx, "UFFD handle write protected failure error", zap.Error(signalErr)) + } + } + }() + + // Passing 0 as the mode removes the write protection. + wpErr := u.uffd.writeProtect(addr, pagesize, 0) + if wpErr != nil { + signalErr := onFailure() + + joinedErr := errors.Join(wpErr, signalErr) + + u.logger.Error(ctx, "UFFD serve write protect error", zap.Error(joinedErr)) + + return fmt.Errorf("failed to remove write protection from page %d-%d %w", offset, offset+int64(pagesize), joinedErr) + } + + // Add the offset to the write requests tracker. + u.writeRequests.Add(offset) + + return nil + }) } +// Dirty returns the dirty pages. func (u *Userfaultfd) Dirty() *block.Tracker { // This will be at worst cancelled when the uffd is closed. u.settleRequests.Lock() @@ -278,5 +363,5 @@ func (u *Userfaultfd) Dirty() *block.Tracker { // so it is consistent even if there is a another uffd call after. defer u.settleRequests.Unlock() - return u.missingRequests.Clone() + return u.writeRequests.Clone() } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go new file mode 100644 index 0000000000..16333261d3 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go @@ -0,0 +1,771 @@ +package userfaultfd + +import ( + "context" + "fmt" + "maps" + "math/rand" + "slices" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" + "github.com/e2b-dev/infra/packages/shared/pkg/logger" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" +) + +func TestNoOperations(t *testing.T) { + t.Parallel() + + pagesize := uint64(header.PageSize) + numberOfPages := uint64(512) + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + logger, err := logger.NewDevelopmentLogger() + require.NoError(t, err) + + // Placeholder uffd that does not serve anything + uffd, err := NewUserfaultfdFromFd(newMockFd(), h.data, &memory.Mapping{}, logger) + require.NoError(t, err) + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.Empty(t, accessedOffsets, "checking which pages were faulted") + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Empty(t, dirtyOffsets, "checking which pages were dirty") + + dirty := uffd.Dirty() + assert.Empty(t, slices.Collect(dirty.Offsets()), "checking dirty pages") +} + +func TestRandomOperations(t *testing.T) { + t.Parallel() + + pagesize := uint64(header.PageSize) + numberOfPages := uint64(4096) + numberOfOperations := 2048 + repetitions := 8 + + for i := range repetitions { + t.Run(fmt.Sprintf("Run_%d_of_%d", i+1, repetitions), func(t *testing.T) { + t.Parallel() + + // Use time-based seed for each run to ensure different random sequences + // This increases the chance of catching bugs that only manifest with specific sequences + seed := time.Now().UnixNano() + int64(i) + rng := rand.New(rand.NewSource(seed)) + + t.Logf("Using random seed: %d", seed) + + // Randomly operations on the data + operations := make([]operation, 0, numberOfOperations) + for range numberOfOperations { + operations = append(operations, operation{ + offset: int64(rng.Intn(int(numberOfPages-1)) * int(pagesize)), + mode: operationMode(rng.Intn(2) + 1), + }) + } + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + operations: operations, + }) + require.NoError(t, err) + + for _, operation := range operations { + err := h.executeOperation(t.Context(), operation) + require.NoError(t, err, "for operation %+v", operation) + } + + expectedAccessedOffsets := getOperationsOffsets(operations, operationModeRead|operationModeWrite) + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted (seed: %d)", seed) + + expectedDirtyOffsets := getOperationsOffsets(operations, operationModeWrite) + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty (seed: %d)", seed) + }) + } +} + +// Badly configured uffd panic recovery caused silent close of uffd—first operation (or parallel first operations) were always handled ok, but subsequent operations would freeze. +// This behavior was flaky with the rest of the tests, because it was racy. +func TestUffdNotClosingAfterOperation(t *testing.T) { + t.Parallel() + + pagesize := uint64(header.PageSize) + numberOfPages := uint64(4096) + + t.Run("missing write", func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + write1 := operation{ + offset: 0, + mode: operationModeWrite, + } + + // We need different offset, because kernel would cache the same page for read. + write2 := operation{ + offset: int64(2 * header.PageSize), + mode: operationModeWrite, + } + + err = h.executeOperation(t.Context(), write1) + require.NoError(t, err) + + // Use the offset helpers to settle the requests. + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{0}, accessedOffsets, "checking which pages were faulted") + + err = h.executeOperation(t.Context(), write2) + require.NoError(t, err) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{uint(2 * header.PageSize), 0}, dirtyOffsets, "checking which pages were dirty") + }) + + t.Run("missing read", func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + read1 := operation{ + offset: 0, + mode: operationModeRead, + } + + // We need different offset, because kernel would cache the same page for read. + read2 := operation{ + offset: 2 * header.PageSize, + mode: operationModeRead, + } + + err = h.executeOperation(t.Context(), read1) + require.NoError(t, err) + + // Use the offset helpers to settle the requests. + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Empty(t, dirtyOffsets, "checking which pages were dirty") + + err = h.executeOperation(t.Context(), read2) + require.NoError(t, err) + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{uint(2 * header.PageSize), 0}, accessedOffsets, "checking which pages were faulted") + }) + + t.Run("write protected", func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + read1 := operation{ + offset: 0, + mode: operationModeRead, + } + + read2 := operation{ + offset: 1 * header.PageSize, + mode: operationModeRead, + } + + // We need at least 2 wp events to check the wp handler, so we need to write to 2 different pages. + + write1 := operation{ + offset: 0, + mode: operationModeWrite, + } + + write2 := operation{ + offset: 1 * header.PageSize, + mode: operationModeWrite, + } + + err = h.executeOperation(t.Context(), read1) + require.NoError(t, err) + + err = h.executeOperation(t.Context(), read2) + require.NoError(t, err) + + // Use the offset helpers to settle the requests. + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{0, 1 * header.PageSize}, accessedOffsets, "checking which pages were faulted") + + err = h.executeOperation(t.Context(), write1) + require.NoError(t, err) + + err = h.executeOperation(t.Context(), write2) + require.NoError(t, err) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{0, 1 * header.PageSize}, dirtyOffsets, "checking which pages were dirty") + }) +} + +func TestUffdEvents(t *testing.T) { + pagesize := uint64(header.PageSize) + numberOfPages := uint64(32) + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + logger, err := logger.NewDevelopmentLogger() + require.NoError(t, err) + + mockFd := newMockFd() + + // Placeholder uffd that does not serve anything + uffd, err := NewUserfaultfdFromFd(mockFd, h.data, &memory.Mapping{}, logger) + require.NoError(t, err) + + events := []event{ + // Same operation and offset, repeated (copies at 0), with both mode 0 and UFFDIO_COPY_MODE_WP + { + UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, + offset: 0, + }, + { + UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: 0, + }, + { + UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, + offset: 0, + }, + { + UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: 0, + }, + + // WriteProtect at same offset, repeated + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, + offset: 0, + }, + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, + offset: 0, + }, + + // Copy at next offset, include both mode 0 and UFFDIO_COPY_MODE_WP + { + UffdioCopy: &UffdioCopy{ + dst: header.PageSize, + len: header.PageSize, + mode: 0, + }, + offset: int64(header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: header.PageSize, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: int64(header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: header.PageSize, + len: header.PageSize, + mode: 0, + }, + offset: int64(header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: header.PageSize, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: int64(header.PageSize), + }, + + // WriteProtect at next offset, repeated + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: header.PageSize, + len: header.PageSize, + }, + }, + offset: int64(header.PageSize), + }, + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: header.PageSize, + len: header.PageSize, + }, + }, + offset: int64(header.PageSize), + }, + + // Copy at another offset, include both mode 0 and UFFDIO_COPY_MODE_WP + { + UffdioCopy: &UffdioCopy{ + dst: 2 * header.PageSize, + len: header.PageSize, + mode: 0, + }, + offset: int64(2 * header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: 2 * header.PageSize, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: int64(2 * header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: 2 * header.PageSize, + len: header.PageSize, + mode: 0, + }, + offset: int64(2 * header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: 2 * header.PageSize, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: int64(2 * header.PageSize), + }, + + // WriteProtect at another offset, repeated + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 2 * header.PageSize, + len: header.PageSize, + }, + }, + offset: int64(2 * header.PageSize), + }, + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 2 * header.PageSize, + len: header.PageSize, + }, + }, + offset: int64(2 * header.PageSize), + }, + } + + for _, event := range events { + err := event.trigger(t.Context(), uffd) + require.NoError(t, err, "for event %+v", event) + } + + receivedEvents := make([]event, 0, len(events)) + + for range events { + select { + case copyEvent, ok := <-mockFd.copyCh: + if !ok { + t.FailNow() + } + + copyEvent.resolve() + + // We don't add the offset here, because it is propagated only through the "accessed" and "dirty" sets. + // When later comparing the events, we will compare the events without the offset. + receivedEvents = append(receivedEvents, event{UffdioCopy: ©Event.event}) + case writeProtectEvent, ok := <-mockFd.writeProtectCh: + if !ok { + t.FailNow() + } + + writeProtectEvent.resolve() + + // We don't add the offset here, because it is propagated only through the "accessed" and "dirty" sets. + // When later comparing the events, we will compare the events without the offset. + receivedEvents = append(receivedEvents, event{UffdioWriteProtect: &writeProtectEvent.event}) + case <-t.Context().Done(): + t.FailNow() + } + } + + assert.Len(t, receivedEvents, len(events), "checking received events") + assert.ElementsMatch(t, zeroOffsets(events), receivedEvents, "checking received events") + + select { + case <-mockFd.copyCh: + t.Fatalf("copy channel should not have any events") + case <-mockFd.writeProtectCh: + t.Fatalf("write protect channel should not have any events") + case <-t.Context().Done(): + t.FailNow() + default: + } + + dirty := uffd.Dirty() + + expectedDirtyOffsets := make(map[int64]struct{}) + expectedAccessedOffsets := make(map[int64]struct{}) + + for _, event := range events { + if event.UffdioWriteProtect != nil { + expectedDirtyOffsets[event.offset] = struct{}{} + } + if event.UffdioCopy != nil { + if event.UffdioCopy.mode != UFFDIO_COPY_MODE_WP { + expectedDirtyOffsets[event.offset] = struct{}{} + } + + expectedAccessedOffsets[event.offset] = struct{}{} + } + } + + assert.ElementsMatch(t, slices.Collect(maps.Keys(expectedDirtyOffsets)), slices.Collect(dirty.Offsets()), "checking dirty pages") + + accessed := accessed(uffd) + assert.ElementsMatch(t, slices.Collect(maps.Keys(expectedAccessedOffsets)), slices.Collect(accessed.Offsets()), "checking accessed pages") +} + +func TestUffdSettleRequests(t *testing.T) { + t.Parallel() + + pagesize := uint64(header.PageSize) + numberOfPages := uint64(32) + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + logger, err := logger.NewDevelopmentLogger() + require.NoError(t, err) + + testEventsSettle := func(t *testing.T, events []event) { + t.Helper() + + mockFd := newMockFd() + + // Placeholder uffd that does not serve anything + uffd, err := NewUserfaultfdFromFd(mockFd, h.data, &memory.Mapping{}, logger) + require.NoError(t, err) + + for _, e := range events { + err = e.trigger(t.Context(), uffd) + require.NoError(t, err, "for event %+v", e) + } + + var blockedCopyEvents []*blockedEvent[UffdioCopy] + var blockedWriteProtectEvents []*blockedEvent[UffdioWriteProtect] + + for range events { + // Wait until the event is blocked + select { + case copyEvent, ok := <-mockFd.copyCh: + if !ok { + t.FailNow() + } + + require.NotNil(t, copyEvent.event, "copy event should not be nil") + assert.Contains(t, zeroOffsets(events), event{UffdioCopy: ©Event.event}, "checking copy event") + + blockedCopyEvents = append(blockedCopyEvents, copyEvent) + case writeProtectEvent, ok := <-mockFd.writeProtectCh: + if !ok { + t.FailNow() + } + + require.NotNil(t, writeProtectEvent.event, "write protect event should not be nil") + assert.Contains(t, zeroOffsets(events), event{UffdioWriteProtect: &writeProtectEvent.event}, "checking write protect event") + + blockedWriteProtectEvents = append(blockedWriteProtectEvents, writeProtectEvent) + case <-t.Context().Done(): + t.FailNow() + } + } + + require.Len(t, events, len(blockedCopyEvents)+len(blockedWriteProtectEvents), "checking blocked events") + + simulatedFCPause := make(chan struct{}) + + d := make(chan *block.Tracker) + + go func() { + acquired := uffd.settleRequests.TryLock() + assert.False(t, acquired, "settleRequests write lock should not be acquired") + + simulatedFCPause <- struct{}{} + + // This should block, until the events are resolved. + dirty := uffd.Dirty() + + select { + case d <- dirty: + case <-t.Context().Done(): + return + } + }() + + // This would be the place where the FC API Pause would return. + <-simulatedFCPause + + // Resolve the events to unblock getting the dirty pages in the goroutine. + for _, e := range blockedCopyEvents { + e.resolve() + } + + for _, e := range blockedWriteProtectEvents { + e.resolve() + } + + select { + case <-mockFd.copyCh: + t.Fatalf("copy channel should not have any events") + case <-mockFd.writeProtectCh: + t.Fatalf("write protect channel should not have any events") + case <-t.Context().Done(): + t.FailNow() + case dirty, ok := <-d: + if !ok { + t.FailNow() + } + + assert.ElementsMatch(t, dirtyOffsets(events), slices.Collect(dirty.Offsets()), "checking dirty pages") + } + } + + t.Run("missing", func(t *testing.T) { + t.Parallel() + + events := []event{ + {UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, offset: 2 * int64(header.PageSize)}, + } + + testEventsSettle(t, events) + }) + + t.Run("write protect", func(t *testing.T) { + t.Parallel() + + events := []event{ + {UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, offset: 2 * int64(header.PageSize)}, + } + + testEventsSettle(t, events) + }) + + t.Run("missing write", func(t *testing.T) { + t.Parallel() + + events := []event{ + {UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, offset: 2 * int64(header.PageSize)}, + } + + testEventsSettle(t, events) + }) + + t.Run("event mix", func(t *testing.T) { + t.Parallel() + + events := []event{ + {UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, offset: 2 * int64(header.PageSize)}, + {UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, offset: 2 * int64(header.PageSize)}, + {UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, offset: 2 * int64(header.PageSize)}, + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, offset: 0, + }, + } + + testEventsSettle(t, events) + }) +} + +type event struct { + *UffdioCopy + *UffdioWriteProtect + + offset int64 +} + +func (e event) trigger(ctx context.Context, uffd *Userfaultfd) error { + switch { + case e.UffdioCopy != nil: + triggerMissing(ctx, uffd, *e.UffdioCopy, e.offset) + case e.UffdioWriteProtect != nil: + triggerWriteProtected(ctx, uffd, *e.UffdioWriteProtect, e.offset) + default: + return fmt.Errorf("invalid event: %+v", e) + } + + return nil +} + +// Return the event copy without the offset, because the offset is propagated only through the "accessed" and "dirty" sets, so direct comparisons would fail. +func (e event) withoutOffset() event { + return event{ + UffdioCopy: e.UffdioCopy, + UffdioWriteProtect: e.UffdioWriteProtect, + } +} + +// Creates a new slice of events with the offset set to 0, so we can compare the events without the offset. +func zeroOffsets(events []event) []event { + return utils.Map(events, func(e event) event { + return e.withoutOffset() + }) +} + +func triggerMissing(ctx context.Context, uffd *Userfaultfd, c UffdioCopy, offset int64) { + var write bool + + if c.mode != UFFDIO_COPY_MODE_WP { + write = true + } + + uffd.handleMissing( + ctx, + func() error { return nil }, + uintptr(c.dst), + uintptr(uffd.src.BlockSize()), + offset, + write, + ) +} + +func triggerWriteProtected(ctx context.Context, uffd *Userfaultfd, c UffdioWriteProtect, offset int64) { + uffd.handleWriteProtected( + ctx, + func() error { return nil }, + uintptr(c._range.start), + uintptr(uffd.src.BlockSize()), + offset, + ) +} + +func dirtyOffsets(events []event) []int64 { + offsets := make(map[int64]struct{}) + + for _, e := range events { + if e.UffdioWriteProtect != nil { + offsets[e.offset] = struct{}{} + } + + if e.UffdioCopy != nil { + if e.UffdioCopy.mode != UFFDIO_COPY_MODE_WP { + offsets[e.offset] = struct{}{} + } + } + } + + return slices.Collect(maps.Keys(offsets)) +} From b971e46be5e8bd3d77ce1f2a36200671ff6b392d Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Mon, 22 Dec 2025 13:19:09 -0800 Subject: [PATCH 39/40] Cleanup --- packages/orchestrator/benchmark_test.go | 5 ++--- .../userfaultfd/cross_process_helpers_test.go | 2 +- .../internal/sandbox/uffd/userfaultfd/fd.go | 14 ++++++++++++++ .../uffd/userfaultfd/fd_helpers_test.go | 18 ------------------ .../sandbox/uffd/userfaultfd/userfaultfd.go | 3 ++- 5 files changed, 19 insertions(+), 23 deletions(-) diff --git a/packages/orchestrator/benchmark_test.go b/packages/orchestrator/benchmark_test.go index baeecbb781..1f12ce528f 100644 --- a/packages/orchestrator/benchmark_test.go +++ b/packages/orchestrator/benchmark_test.go @@ -300,9 +300,8 @@ func BenchmarkBaseImageLaunch(b *testing.B) { type testCycle string const ( - onlyStart testCycle = "only-start" - startAndPause testCycle = "start-and-pause" - startPauseResume testCycle = "start-pause-resume" + onlyStart testCycle = "only-start" + startAndPause testCycle = "start-and-pause" ) type testContainer struct { diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go index ecb27bd2c2..81ddfbf6ae 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go @@ -52,7 +52,7 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error err = configureApi(uffdFd, tt.pagesize) require.NoError(t, err) - err = register(Fd(uffdFd), memoryStart, uint64(size), UFFDIO_REGISTER_MODE_MISSING) + err = uffdFd.register(memoryStart, uint64(size), UFFDIO_REGISTER_MODE_MISSING) require.NoError(t, err) // We don't use t.Context() here, because we want to be able to kill the process manually and listen to the correct exit code, diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go index c25c05af26..e5145ad060 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go @@ -134,6 +134,20 @@ func (f Fd) copy(addr, pagesize uintptr, data []byte, mode CULong) error { return nil } +// mode: UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING +// This is already called by the FC, but only with the UFFDIO_REGISTER_MODE_MISSING +// We need to call it with UFFDIO_REGISTER_MODE_WP when we use both missing and wp +func (f Fd) register(addr uintptr, size uint64, mode CULong) error { + register := newUffdioRegister(CULong(addr), CULong(size), mode) + + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_REGISTER, uintptr(unsafe.Pointer(®ister))) + if errno != 0 { + return fmt.Errorf("UFFDIO_REGISTER ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} + // mode: UFFDIO_WRITEPROTECT_MODE_WP // Passing 0 as the mode will remove the write protection. func (f Fd) writeProtect(addr, size uintptr, mode CULong) error { diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go index 33036f346e..cfe7619399 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go @@ -28,10 +28,6 @@ func (m *mockFd) register(_ uintptr, _ uint64, _ CULong) error { return nil } -func (m *mockFd) unregister(_, _ uintptr) error { - return nil -} - func (m *mockFd) copy(addr, pagesize uintptr, _ []byte, mode CULong) error { // Don't use the uffdioCopy constructor as it unsafely checks slice address and fails for arbitrary pointer. e := newBlockedEvent(UffdioCopy{ @@ -104,20 +100,6 @@ func configureApi(f Fd, pagesize uint64) error { return nil } -// mode: UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING -// This is already called by the FC, but only with the UFFDIO_REGISTER_MODE_MISSING -// We need to call it with UFFDIO_REGISTER_MODE_WP when we use both missing and wp -func register(f Fd, addr uintptr, size uint64, mode CULong) error { - register := newUffdioRegister(CULong(addr), CULong(size), mode) - - ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_REGISTER, uintptr(unsafe.Pointer(®ister))) - if errno != 0 { - return fmt.Errorf("UFFDIO_REGISTER ioctl failed: %w (ret=%d)", errno, ret) - } - - return nil -} - // This wrapped event allows us to simulate the finish of processing of events by FC on FC API Pause. type blockedEvent[T UffdioCopy | UffdioWriteProtect] struct { event T diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go index 59345c121b..d3b129aa6f 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go @@ -29,6 +29,7 @@ var ErrUnexpectedEventType = errors.New("unexpected event type") type uffdio interface { copy(addr, pagesize uintptr, data []byte, mode CULong) error + register(addr uintptr, size uint64, mode CULong) error writeProtect(addr, size uintptr, mode CULong) error close() error fd() int32 @@ -65,7 +66,7 @@ func NewUserfaultfdFromFd(uffd uffdio, src block.Slicer, m *memory.Mapping, logg // The memory region is already registered (with missing pages in FC), but registering it again with bigger flag subset should merge these registration flags. // - https://github.com/firecracker-microvm/firecracker/blob/f335a0adf46f0680a141eb1e76fe31ac258918c5/src/vmm/src/persist.rs#L477 // - https://github.com/bytecodealliance/userfaultfd-rs/blob/main/src/builder.rs - err := register(uffd, region.BaseHostVirtAddr, uint64(region.Size), UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING) + err := uffd.register(region.BaseHostVirtAddr, uint64(region.Size), UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING) if err != nil { return nil, fmt.Errorf("failed to reregister memory region with write protection %d-%d: %w", region.Offset, region.Offset+region.Size, err) } From d1804e7737184384276b9459141574cac883d46e Mon Sep 17 00:00:00 2001 From: ValentaTomas Date: Tue, 30 Dec 2025 08:26:54 -0800 Subject: [PATCH 40/40] Cleanup --- .../internal/sandbox/block/cache_test.go | 27 +++++-------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/packages/orchestrator/internal/sandbox/block/cache_test.go b/packages/orchestrator/internal/sandbox/block/cache_test.go index ed0b6c5ed2..f972bba906 100644 --- a/packages/orchestrator/internal/sandbox/block/cache_test.go +++ b/packages/orchestrator/internal/sandbox/block/cache_test.go @@ -1,9 +1,7 @@ package block import ( - "bytes" "crypto/rand" - "fmt" "os" "testing" @@ -56,7 +54,7 @@ func TestCopyFromProcess_FullRange(t *testing.T) { require.NoError(t, err) require.Equal(t, int(size), n) - require.NoError(t, compareData(data[:n], mem[:n])) + require.NoError(t, testutils.ErrorFromByteSlicesDifference(mem[:n], data[:n])) } func TestCopyFromProcess_LargeRanges(t *testing.T) { @@ -90,19 +88,19 @@ func TestCopyFromProcess_LargeRanges(t *testing.T) { n, err := cache.ReadAt(data1, 0) require.NoError(t, err) require.Equal(t, int(pageSize), n) - require.NoError(t, compareData(data1[:n], mem[0:pageSize])) + require.NoError(t, testutils.ErrorFromByteSlicesDifference(mem[0:pageSize], data1[:n])) data2 := make([]byte, pageSize) n, err = cache.ReadAt(data2, int64(pageSize)) require.NoError(t, err) require.Equal(t, int(pageSize), n) - require.NoError(t, compareData(data2[:n], mem[pageSize*3:pageSize*4])) + require.NoError(t, testutils.ErrorFromByteSlicesDifference(mem[pageSize*3:pageSize*4], data2[:n])) data3 := make([]byte, pageSize) n, err = cache.ReadAt(data3, int64(pageSize*2)) require.NoError(t, err) require.Equal(t, int(pageSize), n) - require.NoError(t, compareData(data3[:n], mem[pageSize:pageSize*2])) + require.NoError(t, testutils.ErrorFromByteSlicesDifference(mem[pageSize:pageSize*2], data3[:n])) } func TestCopyFromProcess_MultipleRanges(t *testing.T) { @@ -148,7 +146,7 @@ func TestCopyFromProcess_MultipleRanges(t *testing.T) { require.NoError(t, err) require.Equal(t, int(pageSize), n) - require.NoError(t, compareData(data[:n], mem[alignedOffset:alignedOffset+int64(pageSize)])) + require.NoError(t, testutils.ErrorFromByteSlicesDifference(mem[alignedOffset:alignedOffset+int64(pageSize)], data[:n])) } } @@ -188,13 +186,13 @@ func TestCopyFromProcess_HugepageToRegularPage(t *testing.T) { n, err = cache.ReadAt(data, 0) require.NoError(t, err) require.Equal(t, int(pageSize*2), n) - require.NoError(t, compareData(data[:n], mem[0:pageSize*2])) + require.NoError(t, testutils.ErrorFromByteSlicesDifference(mem[0:pageSize*2], data[:n])) data = make([]byte, pageSize*4) n, err = cache.ReadAt(data, pageSize*2) require.NoError(t, err) require.Equal(t, int(pageSize*4), n) - require.NoError(t, compareData(data[:n], mem[pageSize*4:pageSize*8])) + require.NoError(t, testutils.ErrorFromByteSlicesDifference(mem[pageSize*4:pageSize*8], data[:n])) } func TestEmptyRanges(t *testing.T) { @@ -213,14 +211,3 @@ func TestEmptyRanges(t *testing.T) { c.Close() }) } - -func compareData(readBytes []byte, expectedBytes []byte) error { - // The bytes.Equal is the first place in this flow that actually touches the uffd managed memory and triggers the pagefault, so any deadlocks will manifest here. - if !bytes.Equal(readBytes, expectedBytes) { - idx, want, got := testutils.FirstDifferentByte(readBytes, expectedBytes) - - return fmt.Errorf("content mismatch: want '%x, got %x at index %d", want, got, idx) - } - - return nil -}