diff --git a/packages/orchestrator/internal/sandbox/block/tracker.go b/packages/orchestrator/internal/sandbox/block/tracker.go deleted file mode 100644 index b0caf19411..0000000000 --- a/packages/orchestrator/internal/sandbox/block/tracker.go +++ /dev/null @@ -1,66 +0,0 @@ -package block - -import ( - "context" - "fmt" - "sync" - "sync/atomic" - - "github.com/bits-and-blooms/bitset" - - "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" -) - -type TrackedSliceDevice struct { - data ReadonlyDevice - blockSize int64 - - nilTracking atomic.Bool - dirty *bitset.BitSet - dirtyMu sync.Mutex - empty []byte -} - -func NewTrackedSliceDevice(blockSize int64, device ReadonlyDevice) (*TrackedSliceDevice, error) { - return &TrackedSliceDevice{ - data: device, - empty: make([]byte, blockSize), - blockSize: blockSize, - }, nil -} - -func (t *TrackedSliceDevice) Disable() error { - size, err := t.data.Size() - if err != nil { - return fmt.Errorf("failed to get device size: %w", err) - } - - t.dirty = bitset.New(uint(header.TotalBlocks(size, t.blockSize))) - // We are starting with all being dirty. - t.dirty.FlipRange(0, t.dirty.Len()) - - t.nilTracking.Store(true) - - return nil -} - -func (t *TrackedSliceDevice) Slice(ctx context.Context, off int64, length int64) ([]byte, error) { - if t.nilTracking.Load() { - t.dirtyMu.Lock() - t.dirty.Clear(uint(header.BlockIdx(off, t.blockSize))) - t.dirtyMu.Unlock() - - return t.empty, nil - } - - return t.data.Slice(ctx, off, length) -} - -// Return which bytes were not read since Disable. -// This effectively returns the bytes that have been requested after paused vm and are not dirty. -func (t *TrackedSliceDevice) Dirty() *bitset.BitSet { - t.dirtyMu.Lock() - defer t.dirtyMu.Unlock() - - return t.dirty.Clone() -} diff --git a/packages/orchestrator/internal/sandbox/sandbox.go b/packages/orchestrator/internal/sandbox/sandbox.go index 0961268089..5badd05c48 100644 --- a/packages/orchestrator/internal/sandbox/sandbox.go +++ b/packages/orchestrator/internal/sandbox/sandbox.go @@ -671,7 +671,8 @@ func (s *Sandbox) Pause( return nil, fmt.Errorf("failed to pause VM: %w", err) } - if err := s.memory.Disable(); err != nil { + dirtyPages, err := s.memory.Disable(ctx) + if err != nil { return nil, fmt.Errorf("failed to disable uffd: %w", err) } @@ -718,7 +719,7 @@ func (s *Sandbox) Pause( originalMemfile.Header(), &MemoryDiffCreator{ memfile: memfile, - dirtyPages: s.memory.Dirty(), + dirtyPages: dirtyPages, blockSize: originalMemfile.BlockSize(), doneHook: func(context.Context) error { return memfile.Close() diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/tracker.go b/packages/orchestrator/internal/sandbox/uffd/memory/tracker.go new file mode 100644 index 0000000000..6362b8a3e7 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/memory/tracker.go @@ -0,0 +1,43 @@ +package memory + +import ( + "sync" + + "github.com/bits-and-blooms/bitset" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +type Tracker struct { + bitset *bitset.BitSet + blockSize int64 + mu sync.RWMutex +} + +func NewTracker(size, blockSize int64) *Tracker { + return &Tracker{ + bitset: bitset.New(uint(header.TotalBlocks(size, blockSize))), + blockSize: blockSize, + mu: sync.RWMutex{}, + } +} + +func (t *Tracker) Mark(offset int64) { + t.mu.Lock() + defer t.mu.Unlock() + + t.bitset.Set(uint(header.BlockIdx(offset, t.blockSize))) +} + +func (t *Tracker) BitSet() *bitset.BitSet { + t.mu.RLock() + defer t.mu.RUnlock() + + return t.bitset.Clone() +} + +func (t *Tracker) Check(offset int64) bool { + t.mu.RLock() + defer t.mu.RUnlock() + + return t.bitset.Test(uint(header.BlockIdx(offset, t.blockSize))) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory_backend.go b/packages/orchestrator/internal/sandbox/uffd/memory_backend.go index 4c65f5d977..22524ab55a 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory_backend.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory_backend.go @@ -9,8 +9,12 @@ import ( ) type MemoryBackend interface { - Disable() error - Dirty() *bitset.BitSet + // Dirty returns the dirty bitset. + // It waits for all the requests in flight to be finished. + Dirty(ctx context.Context) (*bitset.BitSet, error) + // Disable switch the uffd to start serving empty pages and returns the dirty bitset. + // It waits for all the requests in flight to be finished. + Disable(ctx context.Context) (*bitset.BitSet, error) Start(ctx context.Context, sandboxId string) error Stop() error diff --git a/packages/orchestrator/internal/sandbox/uffd/noop.go b/packages/orchestrator/internal/sandbox/uffd/noop.go index 4d08459510..3555490ccf 100644 --- a/packages/orchestrator/internal/sandbox/uffd/noop.go +++ b/packages/orchestrator/internal/sandbox/uffd/noop.go @@ -34,12 +34,12 @@ func NewNoopMemory(size, blockSize int64) *NoopMemory { } } -func (m *NoopMemory) Disable() error { - return nil +func (m *NoopMemory) Disable(context.Context) (*bitset.BitSet, error) { + return m.dirty, nil } -func (m *NoopMemory) Dirty() *bitset.BitSet { - return m.dirty +func (m *NoopMemory) Dirty(context.Context) (*bitset.BitSet, error) { + return m.dirty, nil } func (m *NoopMemory) Start(context.Context, string) error { diff --git a/packages/orchestrator/internal/sandbox/uffd/testutils/logger.go b/packages/orchestrator/internal/sandbox/uffd/testutils/logger.go new file mode 100644 index 0000000000..8dd7866392 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/testutils/logger.go @@ -0,0 +1,15 @@ +package testutils + +import "go.uber.org/zap" + +func NewLogger() *zap.Logger { + cfg := zap.NewDevelopmentConfig() + + logger, err := cfg.Build() + + if err != nil { + panic(err) + } + + return logger +} diff --git a/packages/orchestrator/internal/sandbox/uffd/uffd.go b/packages/orchestrator/internal/sandbox/uffd/uffd.go index 93a047456d..40c965294b 100644 --- a/packages/orchestrator/internal/sandbox/uffd/uffd.go +++ b/packages/orchestrator/internal/sandbox/uffd/uffd.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "os" + "sync/atomic" "syscall" "time" @@ -38,16 +39,25 @@ type Uffd struct { lis *net.UnixListener - memfile *block.TrackedSliceDevice + memfile block.ReadonlyDevice + dirty *memory.Tracker socketPath string + + missingMap *userfaultfd.OffsetMap + writeMap *userfaultfd.OffsetMap + wpMap *userfaultfd.OffsetMap + + disabled atomic.Bool + + writeRequestCounter utils.WaitCounter } var _ MemoryBackend = (*Uffd)(nil) func New(memfile block.ReadonlyDevice, socketPath string, blockSize int64) (*Uffd, error) { - trackedMemfile, err := block.NewTrackedSliceDevice(blockSize, memfile) + size, err := memfile.Size() if err != nil { - return nil, fmt.Errorf("failed to create tracked slice device: %w", err) + return nil, fmt.Errorf("failed to get memfile size: %w", err) } fdExit, err := fdexit.New() @@ -59,8 +69,9 @@ func New(memfile block.ReadonlyDevice, socketPath string, blockSize int64) (*Uff exit: utils.NewErrorOnce(), readyCh: make(chan struct{}, 1), fdExit: fdExit, - memfile: trackedMemfile, socketPath: socketPath, + memfile: memfile, + dirty: memory.NewTracker(size, memfile.BlockSize()), }, nil } @@ -151,13 +162,33 @@ func (u *Uffd) handle(ctx context.Context, sandboxId string) error { } }() + for _, region := range m { + // Register the WP. It is possible that the memory region was already registered (with missing pages in FC), but registering it again with bigger flag subset should merge these. + // - https://github.com/firecracker-microvm/firecracker/blob/f335a0adf46f0680a141eb1e76fe31ac258918c5/src/vmm/src/persist.rs#L477 + // - https://github.com/bytecodealliance/userfaultfd-rs/blob/main/src/builder.rs + err := uffd.Register( + region.BaseHostVirtAddr+region.Offset, + uint64(region.Size), + userfaultfd.UFFDIO_REGISTER_MODE_WP|userfaultfd.UFFDIO_REGISTER_MODE_MISSING, + ) + if err != nil { + return fmt.Errorf("failed to reregister memory region with write protection %d-%d", region.Offset, region.Offset+region.Size) + } + } + u.readyCh <- struct{}{} err = uffd.Serve( ctx, m, u.memfile, + u.dirty, + &u.writeRequestCounter, + u.missingMap, + u.writeMap, + u.wpMap, u.fdExit, + &u.disabled, zap.L().With(logger.WithSandboxID(sandboxId)), ) if err != nil { @@ -175,14 +206,32 @@ func (u *Uffd) Ready() chan struct{} { return u.readyCh } +func (u *Uffd) Disable(ctx context.Context) (*bitset.BitSet, error) { + dirty, err := u.Dirty(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get dirty bitset: %w", err) + } + + u.disabled.Store(true) + + return dirty, nil +} + func (u *Uffd) Exit() *utils.ErrorOnce { return u.exit } -func (u *Uffd) Disable() error { - return u.memfile.Disable() -} +// Dirty waits for all the requests in flight to be finished and then returns the dirty bitset. +// Call *after* pausing the firecracker process. +func (u *Uffd) Dirty(ctx context.Context) (*bitset.BitSet, error) { + err := u.writeRequestCounter.Wait(ctx) + if err != nil { + return nil, fmt.Errorf("failed to wait for write requests: %w", err) + } + + u.missingMap.Reset() + u.writeMap.Reset() + u.wpMap.Reset() -func (u *Uffd) Dirty() *bitset.BitSet { - return u.memfile.Dirty() + return u.dirty.BitSet(), nil } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/constants.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/constants.go index 657fb14ec4..fd814515f5 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/constants.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/constants.go @@ -32,6 +32,14 @@ const ( UFFDIO_REGISTER = C.UFFDIO_REGISTER UFFDIO_COPY = C.UFFDIO_COPY + UFFDIO_API = C.UFFDIO_API + UFFDIO_REGISTER = C.UFFDIO_REGISTER + UFFDIO_WRITEPROTECT = C.UFFDIO_WRITEPROTECT + UFFDIO_COPY = C.UFFDIO_COPY + + UFFD_PAGEFAULT_FLAG_WP = C.UFFD_PAGEFAULT_FLAG_WP + UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE + UFFD_FEATURE_MISSING_HUGETLBFS = C.UFFD_FEATURE_MISSING_HUGETLBFS ) diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/diagram.mermaid b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/diagram.mermaid new file mode 100644 index 0000000000..7b4ead1174 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/diagram.mermaid @@ -0,0 +1,4 @@ +flowchart TD +A[missing page] -- write (WRITE flag) --> B(COPY) --> C[dirty page] +A -- read (MISSING flag) --> D(COPY + MODE_WP) --> E[faulted page] +E -- write (WP flag) --> F(remove MODE_WP) --> C \ No newline at end of file diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/offset_map.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/offset_map.go new file mode 100644 index 0000000000..0f14264bd8 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/offset_map.go @@ -0,0 +1,35 @@ +package userfaultfd + +import "sync/atomic" + +// OffsetMap wraps a map that is non-thread-safe map for writes/reads, but make it thread safe to call the reset function. +// The TryAdd on the map is still non-thread-safe. +type OffsetMap struct { + r atomic.Pointer[map[int64]struct{}] +} + +func NewResetMap() *OffsetMap { + m := &OffsetMap{ + r: atomic.Pointer[map[int64]struct{}]{}, + } + + m.r.Store(&map[int64]struct{}{}) + + return m +} + +func (r *OffsetMap) Reset() { + r.r.Store(&map[int64]struct{}{}) +} + +func (r *OffsetMap) TryAdd(key int64) bool { + m := *r.r.Load() + + if _, ok := m[key]; ok { + return false + } + + m[key] = struct{}{} + + return true +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/serve.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/serve.go index ef161b6815..c9307e7000 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/serve.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/serve.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync/atomic" "syscall" "unsafe" @@ -14,13 +15,21 @@ import ( "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/fdexit" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) func (u *userfaultfd) Serve( ctx context.Context, m memory.MemoryMap, src block.Slicer, + dirty *memory.Tracker, + writeRequestCounter *utils.WaitCounter, + missingMap *OffsetMap, + writeMap *OffsetMap, + wpMap *OffsetMap, fdExit *fdexit.FdExit, + disabled *atomic.Bool, logger *zap.Logger, ) error { pollFds := []unix.PollFd{ @@ -30,8 +39,6 @@ func (u *userfaultfd) Serve( var eg errgroup.Group - handledPages := map[int64]struct{}{} - outerLoop: for { if _, err := unix.Poll( @@ -119,6 +126,7 @@ outerLoop: arg := GetMsgArg(&msg) pagefault := (*(*UffdPagefault)(unsafe.Pointer(&arg[0]))) + flags := pagefault.flags addr := GetPagefaultAddress(&pagefault) @@ -129,33 +137,89 @@ outerLoop: return fmt.Errorf("failed to map: %w", err) } - // This prevents serving missing pages multiple times. + // The maps prevent serving pages multiple times (as we now add WP only once we don't have to remove entries from any map.) // For normal sized pages with swap on, the behavior seems not to be properly described in docs // and it's not clear if the missing can be legitimately triggered multiple times. - if _, ok := handledPages[offset]; ok { + + if flags&UFFD_PAGEFAULT_FLAG_WP != 0 { + if !wpMap.TryAdd(offset) { + continue + } + + writeRequestCounter.Add() + + eg.Go(func() error { + defer func() { + if r := recover(); r != nil { + logger.Error("UFFD remove write protection panic", zap.Any("offset", offset), zap.Any("pagesize", pagesize), zap.Any("panic", r)) + } + }() + + defer writeRequestCounter.Done() + + wpErr := u.RemoveWriteProtection(addr, pagesize) + if wpErr != nil { + return fmt.Errorf("error removing write protection from page %d", addr) + } + + // We mark the page as dirty if it was a write to a page that was already mapped. + dirty.Mark(offset) + + return nil + }) + continue } - handledPages[offset] = struct{}{} + if flags == 0 { + if !missingMap.TryAdd(offset) { + continue + } + } + + if flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 { + if !writeMap.TryAdd(offset) { + continue + } + + writeRequestCounter.Add() + } eg.Go(func() error { defer func() { if r := recover(); r != nil { - logger.Error("UFFD serve panic", zap.Any("offset", offset), zap.Any("pagesize", pagesize), zap.Any("panic", r)) + logger.Error("UFFD serve panic", zap.Any("pagesize", pagesize), zap.Any("panic", r)) } }() - var copyMode CULong + defer func() { + if flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 { + writeRequestCounter.Done() + } + }() - b, sliceErr := src.Slice(ctx, offset, int64(pagesize)) - if sliceErr != nil { - signalErr := fdExit.SignalExit() + var b []byte - joinedErr := errors.Join(sliceErr, signalErr) + if disabled.Load() { + b = header.EmptyHugePage[:pagesize] + } else { + sliceB, sliceErr := src.Slice(ctx, offset, int64(pagesize)) + if sliceErr != nil { + signalErr := fdExit.SignalExit() - logger.Error("UFFD serve slice error", zap.Error(joinedErr)) + joinedErr := errors.Join(sliceErr, signalErr) - return fmt.Errorf("failed to read from source: %w", joinedErr) + logger.Error("UFFD serve slice error", zap.Error(joinedErr)) + + return fmt.Errorf("failed to read from source: %w", joinedErr) + } + + b = sliceB + } + var copyMode CULong + + if flags == 0 { + copyMode = copyMode | UFFDIO_COPY_MODE_WP } copyErr := u.copy(addr, b, pagesize, copyMode) @@ -177,6 +241,11 @@ outerLoop: return fmt.Errorf("failed uffdio copy %w", joinedErr) } + // We mark the page as dirty if it was a write to a page that was not already mapped. + if flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 { + dirty.Mark(offset) + } + return nil }) } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go index 0f345b12f2..71bbc3acbd 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go @@ -65,6 +65,25 @@ func (u *userfaultfd) Register(addr uintptr, size uint64, mode CULong) error { return nil } +func (u *userfaultfd) writeProtect(addr uintptr, size uint64, mode CULong) error { + register := NewUffdioWriteProtect(CULong(addr), CULong(size), mode) + + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, u.fd, UFFDIO_WRITEPROTECT, uintptr(unsafe.Pointer(®ister))) + if errno != 0 { + return fmt.Errorf("UFFDIO_WRITEPROTECT ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} + +func (u *userfaultfd) RemoveWriteProtection(addr uintptr, size uint64) error { + return u.writeProtect(addr, size, 0) +} + +func (u *userfaultfd) AddWriteProtection(addr uintptr, size uint64) error { + return u.writeProtect(addr, size, UFFDIO_WRITEPROTECT_MODE_WP) +} + // mode: UFFDIO_COPY_MODE_WP // When we use both missing and wp, we need to use UFFDIO_COPY_MODE_WP, otherwise copying would unprotect the page func (u *userfaultfd) copy(addr uintptr, data []byte, pagesize uint64, mode CULong) error { diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go index 9e9af112b3..63244706ad 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go @@ -2,117 +2,257 @@ package userfaultfd import ( "bytes" + "context" + "fmt" + "sync/atomic" "syscall" "testing" + "github.com/bits-and-blooms/bitset" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/zap" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/fdexit" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/testutils" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) -type pageTest struct { - name string - pagesize uint64 - numberOfPages uint64 - operationOffset uint64 -} - func TestUffdMissing(t *testing.T) { - tests := []pageTest{ + tests := []testConfig{ { - name: "standard 4k page, operation at start", - pagesize: header.PageSize, - numberOfPages: 32, - operationOffset: 0, + name: "standard 4k page, operation at start", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + }, }, { - name: "standard 4k page, operation at middle", - pagesize: header.PageSize, - numberOfPages: 32, - operationOffset: 16 * header.PageSize, + name: "standard 4k page, operation at middle", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 15 * header.PageSize, + mode: operationModeRead, + }, + }, }, { - name: "standard 4k page, operation at last page", - pagesize: header.PageSize, - numberOfPages: 32, - operationOffset: 31 * header.PageSize, + name: "standard 4k page, operation at last page", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 31 * header.PageSize, + mode: operationModeRead, + }, + }, }, { - name: "hugepage, operation at start", - pagesize: header.HugepageSize, - numberOfPages: 8, - operationOffset: 0, + name: "hugepage, operation at start", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + }, }, { - name: "hugepage, operation at middle", - pagesize: header.HugepageSize, - numberOfPages: 8, - operationOffset: 4 * header.HugepageSize, + name: "hugepage, operation at middle", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeRead, + }, + }, }, { - name: "hugepage, operation at last page", - pagesize: header.HugepageSize, - numberOfPages: 8, - operationOffset: 7 * header.HugepageSize, + name: "hugepage, operation at last page", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 7 * header.HugepageSize, + mode: operationModeRead, + }, + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - data, size := testutils.RandomPages(tt.pagesize, tt.numberOfPages) - - uffd, err := newUserfaultfd(syscall.O_CLOEXEC | syscall.O_NONBLOCK) - require.NoError(t, err) - - t.Cleanup(func() { - uffd.Close() - }) - - err = uffd.configureApi(tt.pagesize) - require.NoError(t, err) - - memoryArea, memoryStart, unmap, err := testutils.NewPageMmap(size, tt.pagesize) - require.NoError(t, err) + h, memoryMap, _, exitUffd, fdExit := configureTest(t, tt) - t.Cleanup(func() { - unmap() - }) - - err = uffd.Register(memoryStart, size, UFFDIO_REGISTER_MODE_MISSING) - require.NoError(t, err) - - m := testutils.NewContiguousMap(memoryStart, size, tt.pagesize) + for _, operation := range tt.operations { + if operation.mode == operationModeRead { + err := h.executeRead(t.Context(), operation) + require.NoError(t, err) + } + } - fdExit, err := fdexit.New() - require.NoError(t, err) + memoryMapAccesses := getOperationsOffsets(tt.operations, operationModeRead) + assert.Equal(t, memoryMapAccesses, memoryMap.Map(), "checking which pages were accessed") - t.Cleanup(func() { - fdExit.SignalExit() - fdExit.Close() - }) + signalExitErr := fdExit.SignalExit() + require.NoError(t, signalExitErr) - exitUffd := make(chan struct{}, 1) + select { + case <-exitUffd: + case <-t.Context().Done(): + t.Fatal("context done before exit", t.Context().Err()) + } + }) + } +} - go func() { - err := uffd.Serve(t.Context(), m, data, fdExit, zap.L()) - assert.NoError(t, err) +func TestUffdWriteProtection(t *testing.T) { + tests := []testConfig{ + { + name: "standard 4k page, single write", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, single read then write on first page", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, single read then write on non-first page", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 15 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 15 * header.PageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, two writes on different pages", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 15 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 16 * header.PageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, single write", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, single read then write on first page", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, single read then write on non-first page", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 3 * header.HugepageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, two writes on different pages", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 4 * header.HugepageSize, + mode: operationModeWrite, + }, + }, + }, + } - exitUffd <- struct{}{} - }() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, memoryMap, dirty, exitUffd, fdExit := configureTest(t, tt) - d, err := data.Slice(t.Context(), int64(tt.operationOffset), int64(tt.pagesize)) - require.NoError(t, err) + for _, operation := range tt.operations { + if operation.mode == operationModeRead { + err := h.executeRead(t.Context(), operation) + require.NoError(t, err) + } - if !bytes.Equal(memoryArea[tt.operationOffset:tt.operationOffset+tt.pagesize], d) { - idx, want, got := testutils.DiffByte(memoryArea[tt.operationOffset:tt.operationOffset+tt.pagesize], d) - t.Fatalf("content mismatch: want %q, got %q at index %d", want, got, idx) + if operation.mode == operationModeWrite { + err := h.executeWrite(t.Context(), operation) + require.NoError(t, err) + } } - assert.Equal(t, map[uint64]struct{}{tt.operationOffset: {}}, m.Map()) + writeOperations := getOperationsOffsets(tt.operations, operationModeWrite) + assert.Equal(t, writeOperations, getOffsetsFromBitset(dirty.BitSet(), tt.pagesize), "checking written to pages") + + memoryAccesses := getOperationsOffsets(tt.operations, 0) + assert.Equal(t, memoryAccesses, memoryMap.Map(), "checking which pages were accessed (read and write)") signalExitErr := fdExit.SignalExit() require.NoError(t, signalExitErr) @@ -125,3 +265,148 @@ func TestUffdMissing(t *testing.T) { }) } } + +var logger = testutils.NewLogger() + +type operationMode uint + +const ( + operationModeRead operationMode = iota + 1 + operationModeWrite +) + +type operation struct { + offset uint + mode operationMode +} + +type testConfig struct { + name string + pagesize uint64 + numberOfPages uint64 + operations []operation +} + +type testHandler struct { + memoryArea *[]byte + pagesize uint64 + data block.Slicer + dirty *memory.Tracker +} + +func getOperationsOffsets(operations []operation, mode operationMode) map[uint64]struct{} { + count := map[uint64]struct{}{} + + for _, operation := range operations { + // If mode is 0, we want to get all operations + if operation.mode == mode || mode == 0 { + count[uint64(operation.offset)] = struct{}{} + } + } + + return count +} + +func getOffsetsFromBitset(bitset *bitset.BitSet, pagesize uint64) map[uint64]struct{} { + count := map[uint64]struct{}{} + + for i, e := bitset.NextSet(0); e; i, e = bitset.NextSet(i + 1) { + count[uint64(i)*pagesize] = struct{}{} + } + + return count +} + +func (h *testHandler) executeRead(ctx context.Context, op operation) error { + readBytes := (*h.memoryArea)[op.offset : op.offset+uint(h.pagesize)] + + expectedBytes, err := h.data.Slice(ctx, int64(op.offset), int64(h.pagesize)) + if err != nil { + return err + } + + if !bytes.Equal(readBytes, expectedBytes) { + idx, want, got := testutils.DiffByte(readBytes, expectedBytes) + return fmt.Errorf("content mismatch: want %q, got %q at index %d", want, got, idx) + } + + return nil +} + +func (h *testHandler) executeWrite(ctx context.Context, op operation) error { + bytesToWrite, err := h.data.Slice(ctx, int64(op.offset), int64(h.pagesize)) + if err != nil { + return err + } + + copy((*h.memoryArea)[op.offset:op.offset+uint(h.pagesize)], bytesToWrite) + + if !h.dirty.Check(int64(op.offset)) { + return fmt.Errorf("dirty bit not set for page at offset %d", op.offset) + } + + return nil +} + +func configureTest(t *testing.T, tt testConfig) (*testHandler, *testutils.ContiguousMap, *memory.Tracker, chan struct{}, *fdexit.FdExit) { + data, size := testutils.RandomPages(tt.pagesize, tt.numberOfPages) + + uffd, err := newUserfaultfd(syscall.O_CLOEXEC | syscall.O_NONBLOCK) + require.NoError(t, err) + + t.Cleanup(func() { + uffd.Close() + }) + + err = uffd.configureApi(tt.pagesize) + require.NoError(t, err) + + memoryArea, memoryStart, unmap, err := testutils.NewPageMmap(size, tt.pagesize) + require.NoError(t, err) + + t.Cleanup(func() { + unmap() + }) + + err = uffd.Register(memoryStart, size, UFFDIO_REGISTER_MODE_MISSING|UFFDIO_REGISTER_MODE_WP) + require.NoError(t, err) + + m := testutils.NewContiguousMap(memoryStart, size, tt.pagesize) + + fdExit, err := fdexit.New() + require.NoError(t, err) + + t.Cleanup(func() { + fdExit.SignalExit() + fdExit.Close() + }) + + exitUffd := make(chan struct{}, 1) + + dirty := memory.NewTracker(int64(size), int64(tt.pagesize)) + + writeRequestCounter := utils.WaitCounter{} + missingMap := NewResetMap() + writeMap := NewResetMap() + wpMap := NewResetMap() + disabled := atomic.Bool{} + + go func() { + err := uffd.Serve(t.Context(), m, data, dirty, &writeRequestCounter, missingMap, writeMap, wpMap, fdExit, &disabled, logger) + assert.NoError(t, err) + + exitUffd <- struct{}{} + }() + + return &testHandler{ + memoryArea: &memoryArea, + pagesize: tt.pagesize, + dirty: dirty, + data: data, + }, m, dirty, exitUffd, fdExit +} + +// TODO: Test write protection +// TODO: Test write protection with missing +// TODO: Test async write protection (if we decide for it) +// TODO: Test write protection double registration (with missing) to simulate the FC situation diff --git a/packages/shared/pkg/utils/wait_counter.go b/packages/shared/pkg/utils/wait_counter.go new file mode 100644 index 0000000000..4ae1ea16ab --- /dev/null +++ b/packages/shared/pkg/utils/wait_counter.go @@ -0,0 +1,57 @@ +package utils + +import ( + "context" + "sync" + "sync/atomic" +) + +type WaitCounter struct { + counter atomic.Int64 + cond sync.Cond +} + +func (w *WaitCounter) add(delta int64) { + if w.counter.Add(delta) == 0 { + w.cond.Broadcast() + } +} + +func (w *WaitCounter) Add() { + w.add(1) +} + +func (w *WaitCounter) Done() { + w.add(-1) +} + +func (w *WaitCounter) Wait(ctx context.Context) error { + // Ensure we can break out of the loop when the context is done. + go func() { + <-ctx.Done() + + w.cond.Broadcast() + }() + + for w.counter.Load() != 0 { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + w.cond.L.Lock() + + w.cond.Wait() + + w.cond.L.Unlock() + } + + return nil +} + +func (w *WaitCounter) Close() { + w.counter.Store(0) + + w.cond.Broadcast() +}