diff --git a/.github/workflows/pr-tests.yml b/.github/workflows/pr-tests.yml index b72c17302a..3bf04ed6d7 100644 --- a/.github/workflows/pr-tests.yml +++ b/.github/workflows/pr-tests.yml @@ -33,6 +33,16 @@ jobs: with: job_key: unit-tests-${{ matrix.package }} + - name: Enable unprivileged uffd mode + run: | + echo 1 | sudo tee /proc/sys/vm/unprivileged_userfaultfd + + - name: Enable hugepages + run: | + sudo mkdir -p /mnt/hugepages + sudo mount -t hugetlbfs none /mnt/hugepages + echo 128 | sudo tee /proc/sys/vm/nr_hugepages + - name: Run tests working-directory: ${{ matrix.package }} run: go test -v ${{ matrix.test_path }} diff --git a/.gitignore b/.gitignore index fb64d4183b..5dc8c13c76 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ tests/periodic-test/build-template/e2b.toml .air go.work.sum .infisical.json +.vscode/mise-tools/ \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 9c63bbb7fc..2167d3d351 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -89,5 +89,11 @@ "editor.codeActionsOnSave": { "source.organizeImports": "explicit" } + }, + "go.goroot": "${workspaceFolder}/.vscode/mise-tools/goRoot", + "go.alternateTools": { + "go": "${workspaceFolder}/.vscode/mise-tools/go", + "dlv": "${workspaceFolder}/.vscode/mise-tools/dlv", + "gopls": "${workspaceFolder}/.vscode/mise-tools/gopls" } } \ No newline at end of file diff --git a/packages/orchestrator/internal/sandbox/block/tracker.go b/packages/orchestrator/internal/sandbox/block/tracker.go index b0caf19411..c69c1488ab 100644 --- a/packages/orchestrator/internal/sandbox/block/tracker.go +++ b/packages/orchestrator/internal/sandbox/block/tracker.go @@ -1,66 +1,70 @@ package block import ( - "context" - "fmt" "sync" - "sync/atomic" "github.com/bits-and-blooms/bitset" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) -type TrackedSliceDevice struct { - data ReadonlyDevice - blockSize int64 +type Tracker struct { + b *bitset.BitSet + mu sync.RWMutex - nilTracking atomic.Bool - dirty *bitset.BitSet - dirtyMu sync.Mutex - empty []byte + blockSize int64 } -func NewTrackedSliceDevice(blockSize int64, device ReadonlyDevice) (*TrackedSliceDevice, error) { - return &TrackedSliceDevice{ - data: device, - empty: make([]byte, blockSize), +func NewTracker(blockSize int64) *Tracker { + return &Tracker{ + // The bitset resizes automatically based on the maximum set bit. + b: bitset.New(0), blockSize: blockSize, - }, nil + } } -func (t *TrackedSliceDevice) Disable() error { - size, err := t.data.Size() - if err != nil { - return fmt.Errorf("failed to get device size: %w", err) - } +func (t *Tracker) Has(off int64) bool { + t.mu.RLock() + defer t.mu.RUnlock() - t.dirty = bitset.New(uint(header.TotalBlocks(size, t.blockSize))) - // We are starting with all being dirty. - t.dirty.FlipRange(0, t.dirty.Len()) + return t.b.Test(uint(header.BlockIdx(off, t.blockSize))) +} + +func (t *Tracker) Add(off int64) bool { + t.mu.Lock() + defer t.mu.Unlock() - t.nilTracking.Store(true) + if t.b.Test(uint(header.BlockIdx(off, t.blockSize))) { + return false + } - return nil + t.b.Set(uint(header.BlockIdx(off, t.blockSize))) + + return true } -func (t *TrackedSliceDevice) Slice(ctx context.Context, off int64, length int64) ([]byte, error) { - if t.nilTracking.Load() { - t.dirtyMu.Lock() - t.dirty.Clear(uint(header.BlockIdx(off, t.blockSize))) - t.dirtyMu.Unlock() +func (t *Tracker) Reset() { + t.mu.Lock() + defer t.mu.Unlock() - return t.empty, nil - } + t.b.ClearAll() +} - return t.data.Slice(ctx, off, length) +// BitSet returns a clone of the bitset and the block size. +func (t *Tracker) BitSet() *bitset.BitSet { + t.mu.RLock() + defer t.mu.RUnlock() + + return t.b.Clone() } -// Return which bytes were not read since Disable. -// This effectively returns the bytes that have been requested after paused vm and are not dirty. -func (t *TrackedSliceDevice) Dirty() *bitset.BitSet { - t.dirtyMu.Lock() - defer t.dirtyMu.Unlock() +func (t *Tracker) BlockSize() int64 { + return t.blockSize +} - return t.dirty.Clone() +func (t *Tracker) Clone() *Tracker { + return &Tracker{ + b: t.BitSet(), + blockSize: t.BlockSize(), + } } diff --git a/packages/orchestrator/internal/sandbox/block/tracker_test.go b/packages/orchestrator/internal/sandbox/block/tracker_test.go new file mode 100644 index 0000000000..75d1f58fd2 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/block/tracker_test.go @@ -0,0 +1,109 @@ +package block + +import ( + "testing" +) + +func TestTracker_AddAndHas(t *testing.T) { + const pageSize = 4096 + tr := NewTracker(pageSize) + + offset := int64(pageSize * 4) + + // Initially should not be marked + if tr.Has(offset) { + t.Errorf("Expected offset %d not to be marked initially", offset) + } + + // After adding, should be marked + tr.Add(offset) + if !tr.Has(offset) { + t.Errorf("Expected offset %d to be marked after Add", offset) + } +} + +func TestTracker_Reset(t *testing.T) { + const pageSize = 4096 + tr := NewTracker(pageSize) + + offset := int64(pageSize * 4) + + // Add offset and verify it's marked + tr.Add(offset) + if !tr.Has(offset) { + t.Errorf("Expected offset %d to be marked after Add", offset) + } + + // After reset, should not be marked + tr.Reset() + if tr.Has(offset) { + t.Errorf("Expected offset %d to be cleared after Reset", offset) + } +} + +func TestTracker_MultipleOffsets(t *testing.T) { + const pageSize = 4096 + tr := NewTracker(pageSize) + + offsets := []int64{0, pageSize, 2 * pageSize, 10 * pageSize} + + // Add multiple offsets + for _, o := range offsets { + tr.Add(o) + } + + // Verify all offsets are marked + for _, o := range offsets { + if !tr.Has(o) { + t.Errorf("Expected offset %d to be marked", o) + } + } +} + +func TestTracker_ResetClearsAll(t *testing.T) { + const pageSize = 4096 + tr := NewTracker(pageSize) + + offsets := []int64{0, pageSize, 2 * pageSize, 10 * pageSize} + + // Add multiple offsets + for _, o := range offsets { + tr.Add(o) + } + + // Reset should clear all + tr.Reset() + + // Verify all offsets are cleared + for _, o := range offsets { + if tr.Has(o) { + t.Errorf("Expected offset %d to be cleared after Reset", o) + } + } +} + +func TestTracker_MisalignedOffset(t *testing.T) { + const pageSize = 4096 + tr := NewTracker(pageSize) + + // Test with misaligned offset + misalignedOffset := int64(123) + tr.Add(misalignedOffset) + + // Should be set for the block containing the offset—that is, block 0 (0..4095) + if !tr.Has(misalignedOffset) { + t.Errorf("Expected misaligned offset %d to be marked (should mark its containing block)", misalignedOffset) + } + + // Now check that any offset in the same block is also considered marked + anotherOffsetInSameBlock := int64(1000) + if !tr.Has(anotherOffsetInSameBlock) { + t.Errorf("Expected offset %d to be marked as in same block as %d", anotherOffsetInSameBlock, misalignedOffset) + } + + // But not for a different block + offsetInNextBlock := int64(pageSize) // convert to int64 to match Has signature + if tr.Has(offsetInNextBlock) { + t.Errorf("Did not expect offset %d to be marked", offsetInNextBlock) + } +} diff --git a/packages/orchestrator/internal/sandbox/sandbox.go b/packages/orchestrator/internal/sandbox/sandbox.go index 0961268089..b325ee5637 100644 --- a/packages/orchestrator/internal/sandbox/sandbox.go +++ b/packages/orchestrator/internal/sandbox/sandbox.go @@ -671,10 +671,18 @@ func (s *Sandbox) Pause( return nil, fmt.Errorf("failed to pause VM: %w", err) } - if err := s.memory.Disable(); err != nil { + err = s.memory.Disable(ctx) + if err != nil { return nil, fmt.Errorf("failed to disable uffd: %w", err) } + dirty, err := s.memory.Dirty(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get dirty pages: %w", err) + } + + dirtyPages := dirty.BitSet() + // Snapfile is not closed as it's returned and cached for later use (like resume) snapfile := template.NewLocalFileLink(snapshotTemplateFiles.CacheSnapfilePath()) // Memfile is also closed on diff creation processing @@ -718,7 +726,7 @@ func (s *Sandbox) Pause( originalMemfile.Header(), &MemoryDiffCreator{ memfile: memfile, - dirtyPages: s.memory.Dirty(), + dirtyPages: dirtyPages, blockSize: originalMemfile.BlockSize(), doneHook: func(context.Context) error { return memfile.Close() @@ -929,7 +937,7 @@ func serveMemory( ctx, span := tracer.Start(ctx, "serve-memory") defer span.End() - fcUffd, err := uffd.New(memfile, socketPath, memfile.BlockSize()) + fcUffd, err := uffd.New(memfile, socketPath) if err != nil { return nil, fmt.Errorf("failed to create uffd: %w", err) } diff --git a/packages/orchestrator/internal/sandbox/uffd/mapping/firecracker.go b/packages/orchestrator/internal/sandbox/uffd/mapping/firecracker.go deleted file mode 100644 index dbff8bbd7a..0000000000 --- a/packages/orchestrator/internal/sandbox/uffd/mapping/firecracker.go +++ /dev/null @@ -1,32 +0,0 @@ -package mapping - -import "fmt" - -type GuestRegionUffdMapping struct { - BaseHostVirtAddr uintptr `json:"base_host_virt_addr"` - Size uintptr `json:"size"` - Offset uintptr `json:"offset"` - // This is actually in bytes. - // This field is deprecated in the newer version of the Firecracer with a new field `page_size`. - PageSize uintptr `json:"page_size_kib"` -} - -func (m *GuestRegionUffdMapping) relativeOffset(addr uintptr) int64 { - return int64(m.Offset + addr - m.BaseHostVirtAddr) -} - -type FcMappings []GuestRegionUffdMapping - -// Returns the relative offset and the page size of the mapped range for a given address -func (m FcMappings) GetRange(addr uintptr) (int64, int64, error) { - for _, m := range m { - if addr < m.BaseHostVirtAddr || m.BaseHostVirtAddr+m.Size <= addr { - // Outside of this mapping - continue - } - - return m.relativeOffset(addr), int64(m.PageSize), nil - } - - return 0, 0, fmt.Errorf("address %d not found in any mapping", addr) -} diff --git a/packages/orchestrator/internal/sandbox/uffd/mapping/mapping.go b/packages/orchestrator/internal/sandbox/uffd/mapping/mapping.go deleted file mode 100644 index ffa8c4a6fa..0000000000 --- a/packages/orchestrator/internal/sandbox/uffd/mapping/mapping.go +++ /dev/null @@ -1,5 +0,0 @@ -package mapping - -type Mappings interface { - GetRange(addr uintptr) (offset int64, pagesize int64, err error) -} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go b/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go new file mode 100644 index 0000000000..3902b7c6b8 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go @@ -0,0 +1,45 @@ +package memory + +import ( + "fmt" +) + +type Mapping struct { + Regions []Region +} + +func NewMapping(regions []Region) *Mapping { + return &Mapping{Regions: regions} +} + +// GetOffset returns the relative offset and the page size of the mapped range for a given address. +func (m *Mapping) GetOffset(hostVirtAddr uintptr) (int64, uint64, error) { + for _, r := range m.Regions { + if hostVirtAddr >= r.BaseHostVirtAddr && hostVirtAddr < r.endHostVirtAddr() { + return r.shiftedOffset(hostVirtAddr), uint64(r.PageSize), nil + } + } + + return 0, 0, fmt.Errorf("address %d not found in any mapping", hostVirtAddr) +} + +// GetHostVirtAddr returns the host virtual address for a given offset. +func (m *Mapping) GetHostVirtAddr(off int64) (int64, uint64, error) { + r, err := m.getHostVirtRegion(off) + if err != nil { + return 0, 0, err + } + + return int64(r.shiftedHostVirtAddr(off)), uint64(r.PageSize), nil +} + +// getHostVirtRegion returns the region that contains the given offset. +func (m *Mapping) getHostVirtRegion(off int64) (*Region, error) { + for _, r := range m.Regions { + if off >= int64(r.Offset) && off < r.endOffset() { + return &r, nil + } + } + + return nil, fmt.Errorf("offset %d not found in any mapping", off) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go b/packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go new file mode 100644 index 0000000000..ee93c360eb --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go @@ -0,0 +1,254 @@ +package memory + +import ( + "testing" +) + +func TestMapping_GetOffset(t *testing.T) { + regions := []Region{ + { + BaseHostVirtAddr: 0x1000, + Size: 0x2000, + Offset: 0x5000, + PageSize: 4096, + }, + { + BaseHostVirtAddr: 0x5000, + Size: 0x1000, + Offset: 0x8000, + PageSize: 4096, + }, + } + mapping := NewMapping(regions) + + tests := []struct { + name string + hostVirtAddr uintptr + expectedOffset int64 + expectedSize uint64 + expectError bool + }{ + { + name: "valid address in first region", + hostVirtAddr: 0x1500, + expectedOffset: 0x5500, // 0x5000 + (0x1500 - 0x1000) + expectedSize: 4096, + expectError: false, + }, + { + name: "valid address in second region", + hostVirtAddr: 0x5500, + expectedOffset: 0x8500, // 0x8000 + (0x5500 - 0x5000) + expectedSize: 4096, + expectError: false, + }, + { + name: "address before first region", + hostVirtAddr: 0x500, + expectError: true, + }, + { + name: "address after last region", + hostVirtAddr: 0x7000, + expectError: true, + }, + { + name: "address in gap between regions", + hostVirtAddr: 0x4000, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + offset, size, err := mapping.GetOffset(tt.hostVirtAddr) + + if tt.expectError { + if err == nil { + t.Errorf("Expected error but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if offset != tt.expectedOffset { + t.Errorf("Expected offset %d, got %d", tt.expectedOffset, offset) + } + + if size != tt.expectedSize { + t.Errorf("Expected size %d, got %d", tt.expectedSize, size) + } + }) + } +} + +func TestMapping_GetHostVirtAddr(t *testing.T) { + regions := []Region{ + { + BaseHostVirtAddr: 0x1000, + Size: 0x2000, + Offset: 0x5000, + PageSize: 4096, + }, + { + BaseHostVirtAddr: 0x5000, + Size: 0x1000, + Offset: 0x8000, + PageSize: 4096, + }, + } + mapping := NewMapping(regions) + + tests := []struct { + name string + offset int64 + expectedAddr int64 + expectedPageSize uint64 + expectError bool + }{ + { + name: "valid offset in first region", + offset: 0x5500, + expectedAddr: 0x1500, // 0x1000 + (0x5500 - 0x5000) + expectedPageSize: 4096, + expectError: false, + }, + { + name: "valid offset in second region", + offset: 0x8500, + expectedAddr: 0x5500, // 0x5000 + (0x8500 - 0x8000) + expectedPageSize: 4096, + expectError: false, + }, + { + name: "offset before first region", + offset: 0x4000, + expectError: true, + }, + { + name: "offset after last region", + offset: 0x10000, + expectError: true, + }, + { + name: "offset in gap between regions", + offset: 0x7000, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr, pageSize, err := mapping.GetHostVirtAddr(tt.offset) + + if tt.expectError { + if err == nil { + t.Errorf("Expected error but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if addr != tt.expectedAddr { + t.Errorf("Expected address %d, got %d", tt.expectedAddr, addr) + } + + if pageSize != tt.expectedPageSize { + t.Errorf("Expected page size %d, got %d", tt.expectedPageSize, pageSize) + } + }) + } +} + +func TestMapping_EmptyRegions(t *testing.T) { + mapping := NewMapping([]Region{}) + + // Test GetOffset with empty regions + _, _, err := mapping.GetOffset(0x1000) + if err == nil { + t.Errorf("Expected error for empty regions, got none") + } + + // Test GetHostVirtAddr with empty regions + _, _, err = mapping.GetHostVirtAddr(0x1000) + if err == nil { + t.Errorf("Expected error for empty regions, got none") + } +} + +func TestMapping_OverlappingRegions(t *testing.T) { + // Test with overlapping regions (edge case) + regions := []Region{ + { + BaseHostVirtAddr: 0x1000, + Size: 0x2000, + Offset: 0x5000, + PageSize: 4096, + }, + { + BaseHostVirtAddr: 0x2000, // Overlaps with first region + Size: 0x1000, + Offset: 0x8000, + PageSize: 4096, + }, + } + mapping := NewMapping(regions) + + // The first matching region should be returned + offset, _, err := mapping.GetOffset(0x2500) // In overlap area + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Should get result from first region + expectedOffset := int64(0x5000 + (0x2500 - 0x1000)) // 0x6500 + if offset != expectedOffset { + t.Errorf("Expected offset %d, got %d", expectedOffset, offset) + } +} + +func TestMapping_BoundaryConditions(t *testing.T) { + regions := []Region{ + { + BaseHostVirtAddr: 0x1000, + Size: 0x2000, + Offset: 0x5000, + PageSize: 4096, + }, + } + mapping := NewMapping(regions) + + // Test exact start boundary + offset, _, err := mapping.GetOffset(0x1000) + if err != nil { + t.Errorf("Unexpected error at start boundary: %v", err) + } + expectedOffset := int64(0x5000) // 0x5000 + (0x1000 - 0x1000) + if offset != expectedOffset { + t.Errorf("Expected offset %d at start boundary, got %d", expectedOffset, offset) + } + + // Test just before end boundary (exclusive) + offset, _, err = mapping.GetOffset(0x2FFF) // 0x1000 + 0x2000 - 1 + if err != nil { + t.Errorf("Unexpected error just before end boundary: %v", err) + } + expectedOffset = int64(0x5000 + (0x2FFF - 0x1000)) // 0x6FFF + if offset != expectedOffset { + t.Errorf("Expected offset %d just before end boundary, got %d", expectedOffset, offset) + } + + // Test exact end boundary (should fail - exclusive) + _, _, err = mapping.GetOffset(0x3000) // 0x1000 + 0x2000 + if err == nil { + t.Errorf("Expected error at end boundary (exclusive), got none") + } +} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/region.go b/packages/orchestrator/internal/sandbox/uffd/memory/region.go new file mode 100644 index 0000000000..84e824cad4 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/memory/region.go @@ -0,0 +1,34 @@ +package memory + +// Region is a mapping of a region of memory of the guest to a region of memory on the host. +// The serialization is based on the Firecracker UFFD protocol communication. +type Region struct { + BaseHostVirtAddr uintptr `json:"base_host_virt_addr"` + Size uintptr `json:"size"` + Offset uintptr `json:"offset"` + // This is actually in bytes. + // This field is deprecated in the newer version of the Firecracer with a new field `page_size`. + PageSize uintptr `json:"page_size_kib"` +} + +// endOffset returns the end offset of the region in bytes. +// The end offset is exclusive. +func (r *Region) endOffset() int64 { + return int64(r.Offset + r.Size) +} + +// endHostVirtAddr returns the end address of the region in host virtual address. +// The end address is exclusive. +func (r *Region) endHostVirtAddr() uintptr { + return r.BaseHostVirtAddr + r.Size +} + +// shiftedOffset returns the offset of the given address in the region. +func (r *Region) shiftedOffset(addr uintptr) int64 { + return int64(addr - r.BaseHostVirtAddr + r.Offset) +} + +// shiftedHostVirtAddr returns the host virtual address of the given offset in the region. +func (r *Region) shiftedHostVirtAddr(off int64) uintptr { + return uintptr(off) + r.BaseHostVirtAddr - r.Offset +} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory_backend.go b/packages/orchestrator/internal/sandbox/uffd/memory_backend.go index 4c65f5d977..84abf75a57 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory_backend.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory_backend.go @@ -3,14 +3,14 @@ package uffd import ( "context" - "github.com/bits-and-blooms/bitset" - + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) type MemoryBackend interface { - Disable() error - Dirty() *bitset.BitSet + Dirty(ctx context.Context) (*block.Tracker, error) + // Disable switch the uffd to start serving empty pages. + Disable(ctx context.Context) error Start(ctx context.Context, sandboxId string) error Stop() error diff --git a/packages/orchestrator/internal/sandbox/uffd/noop.go b/packages/orchestrator/internal/sandbox/uffd/noop.go index 4d08459510..b03d8b257b 100644 --- a/packages/orchestrator/internal/sandbox/uffd/noop.go +++ b/packages/orchestrator/internal/sandbox/uffd/noop.go @@ -3,9 +3,7 @@ package uffd import ( "context" - "github.com/bits-and-blooms/bitset" - - "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) @@ -13,7 +11,7 @@ type NoopMemory struct { size int64 blockSize int64 - dirty *bitset.BitSet + dirty *block.Tracker exit *utils.ErrorOnce } @@ -21,25 +19,20 @@ type NoopMemory struct { var _ MemoryBackend = (*NoopMemory)(nil) func NewNoopMemory(size, blockSize int64) *NoopMemory { - blocks := header.TotalBlocks(size, blockSize) - - dirty := bitset.New(uint(blocks)) - dirty.FlipRange(0, dirty.Len()) - return &NoopMemory{ size: size, blockSize: blockSize, - dirty: dirty, + dirty: block.NewTracker(blockSize), exit: utils.NewErrorOnce(), } } -func (m *NoopMemory) Disable() error { +func (m *NoopMemory) Disable(context.Context) error { return nil } -func (m *NoopMemory) Dirty() *bitset.BitSet { - return m.dirty +func (m *NoopMemory) Dirty(context.Context) (*block.Tracker, error) { + return m.dirty.Clone(), nil } func (m *NoopMemory) Start(context.Context, string) error { @@ -53,6 +46,7 @@ func (m *NoopMemory) Stop() error { func (m *NoopMemory) Ready() chan struct{} { ch := make(chan struct{}) ch <- struct{}{} + return ch } diff --git a/packages/orchestrator/internal/sandbox/uffd/serve.go b/packages/orchestrator/internal/sandbox/uffd/serve.go deleted file mode 100644 index c5866b52dd..0000000000 --- a/packages/orchestrator/internal/sandbox/uffd/serve.go +++ /dev/null @@ -1,200 +0,0 @@ -package uffd - -import ( - "context" - "errors" - "fmt" - "syscall" - "unsafe" - - "go.uber.org/zap" - "golang.org/x/sync/errgroup" - "golang.org/x/sys/unix" - - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/fdexit" - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/mapping" - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/userfaultfd" -) - -var ErrUnexpectedEventType = errors.New("unexpected event type") - -type GuestRegionUffdMapping struct { - BaseHostVirtAddr uintptr `json:"base_host_virt_addr"` - Size uintptr `json:"size"` - Offset uintptr `json:"offset"` - PageSize uintptr `json:"page_size_kib"` -} - -func Serve( - ctx context.Context, - uffd int, - mappings mapping.Mappings, - src block.Slicer, - fdExit *fdexit.FdExit, - logger *zap.Logger, -) error { - pollFds := []unix.PollFd{ - {Fd: int32(uffd), Events: unix.POLLIN}, - {Fd: fdExit.Reader(), Events: unix.POLLIN}, - } - - var eg errgroup.Group - - missingPagesBeingHandled := map[int64]struct{}{} - -outerLoop: - for { - if _, err := unix.Poll( - pollFds, - -1, - ); err != nil { - if err == unix.EINTR { - logger.Debug("uffd: interrupted polling, going back to polling") - - continue - } - - if err == unix.EAGAIN { - logger.Debug("uffd: eagain during polling, going back to polling") - - continue - } - - logger.Error("UFFD serve polling error", zap.Error(err)) - - return fmt.Errorf("failed polling: %w", err) - } - - exitFd := pollFds[1] - if exitFd.Revents&unix.POLLIN != 0 { - errMsg := eg.Wait() - if errMsg != nil { - logger.Warn("UFFD fd exit error while waiting for goroutines to finish", zap.Error(errMsg)) - - return fmt.Errorf("failed to handle uffd: %w", errMsg) - } - - return nil - } - - uffdFd := pollFds[0] - if uffdFd.Revents&unix.POLLIN == 0 { - // Uffd is not ready for reading as there is nothing to read on the fd. - // https://github.com/firecracker-microvm/firecracker/issues/5056 - // https://elixir.bootlin.com/linux/v6.8.12/source/fs/userfaultfd.c#L1149 - // TODO: Check for all the errors - // - https://docs.kernel.org/admin-guide/mm/userfaultfd.html - // - https://elixir.bootlin.com/linux/v6.8.12/source/fs/userfaultfd.c - // - https://man7.org/linux/man-pages/man2/userfaultfd.2.html - // It might be possible to just check for data != 0 in the syscall.Read loop - // but I don't feel confident about doing that. - logger.Debug("uffd: no data in fd, going back to polling") - - continue - } - - buf := make([]byte, unsafe.Sizeof(userfaultfd.UffdMsg{})) - - for { - n, err := syscall.Read(uffd, buf) - if err == syscall.EINTR { - logger.Debug("uffd: interrupted read, reading again") - - continue - } - - if err == nil { - // There is no error so we can proceed. - break - } - - if err == syscall.EAGAIN { - logger.Debug("uffd: eagain error, going back to polling", zap.Error(err), zap.Int("read_bytes", n)) - - // Continue polling the fd. - continue outerLoop - } - - logger.Error("uffd: read error", zap.Error(err)) - - return fmt.Errorf("failed to read: %w", err) - } - - msg := *(*userfaultfd.UffdMsg)(unsafe.Pointer(&buf[0])) - if userfaultfd.GetMsgEvent(&msg) != userfaultfd.UFFD_EVENT_PAGEFAULT { - logger.Error("UFFD serve unexpected event type", zap.Any("event_type", userfaultfd.GetMsgEvent(&msg))) - - return ErrUnexpectedEventType - } - - arg := userfaultfd.GetMsgArg(&msg) - pagefault := (*(*userfaultfd.UffdPagefault)(unsafe.Pointer(&arg[0]))) - - addr := userfaultfd.GetPagefaultAddress(&pagefault) - - offset, pagesize, err := mappings.GetRange(uintptr(addr)) - if err != nil { - logger.Error("UFFD serve get mapping error", zap.Error(err)) - - return fmt.Errorf("failed to map: %w", err) - } - - if _, ok := missingPagesBeingHandled[offset]; ok { - continue - } - - missingPagesBeingHandled[offset] = struct{}{} - - eg.Go(func() error { - defer func() { - if r := recover(); r != nil { - logger.Error("UFFD serve panic", zap.Any("offset", offset), zap.Any("pagesize", pagesize), zap.Any("panic", r)) - } - }() - - b, err := src.Slice(ctx, offset, pagesize) - if err != nil { - signalErr := fdExit.SignalExit() - - joinedErr := errors.Join(err, signalErr) - - logger.Error("UFFD serve slice error", zap.Error(joinedErr)) - - return fmt.Errorf("failed to read from source: %w", joinedErr) - } - - cpy := userfaultfd.NewUffdioCopy( - b, - addr&^userfaultfd.CULong(pagesize-1), - userfaultfd.CULong(pagesize), - 0, - 0, - ) - - if _, _, errno := syscall.Syscall( - syscall.SYS_IOCTL, - uintptr(uffd), - userfaultfd.UFFDIO_COPY, - uintptr(unsafe.Pointer(&cpy)), - ); errno != 0 { - if errno == unix.EEXIST { - logger.Debug("UFFD serve page already mapped", zap.Any("offset", offset), zap.Any("pagesize", pagesize)) - - // Page is already mapped - return nil - } - - signalErr := fdExit.SignalExit() - - joinedErr := errors.Join(errno, signalErr) - - logger.Error("UFFD serve uffdio copy error", zap.Error(joinedErr)) - - return fmt.Errorf("failed uffdio copy %w", joinedErr) - } - - return nil - }) - } -} diff --git a/packages/orchestrator/internal/sandbox/uffd/testutils/diff_byte.go b/packages/orchestrator/internal/sandbox/uffd/testutils/diff_byte.go new file mode 100644 index 0000000000..68298ea6ea --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/testutils/diff_byte.go @@ -0,0 +1,20 @@ +package testutils + +// FirstDifferentByte returns the first byte index where a and b differ. +// It also returns the differing byte values (want, got). +// If slices are identical, it returns idx -1. +func FirstDifferentByte(a, b []byte) (idx int, want, got byte) { + smallerSize := min(len(a), len(b)) + + for i := range smallerSize { + if a[i] != b[i] { + return i, b[i], a[i] + } + } + + if len(a) != len(b) { + return smallerSize, 0, 0 + } + + return -1, 0, 0 +} diff --git a/packages/orchestrator/internal/sandbox/uffd/testutils/logger.go b/packages/orchestrator/internal/sandbox/uffd/testutils/logger.go new file mode 100644 index 0000000000..d2e10ebe49 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/testutils/logger.go @@ -0,0 +1,41 @@ +package testutils + +import ( + "testing" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type testWriter struct { + t *testing.T +} + +func (w *testWriter) Write(p []byte) (n int, err error) { + w.t.Log(string(p)) + + return len(p), nil +} + +// NewTestLogger creates a new zap logger that logs all zap logs to the test output. +func NewTestLogger(t *testing.T) *zap.Logger { + encoderCfg := zap.NewDevelopmentEncoderConfig() + encoderCfg.EncodeLevel = zapcore.CapitalColorLevelEncoder + encoderCfg.CallerKey = zapcore.OmitKey + encoderCfg.ConsoleSeparator = " " + encoderCfg.TimeKey = "" + encoderCfg.MessageKey = "message" + encoderCfg.LevelKey = "level" + encoderCfg.NameKey = "logger" + encoderCfg.StacktraceKey = "stacktrace" + encoderCfg.EncodeTime = zapcore.RFC3339NanoTimeEncoder + encoderCfg.EncodeCaller = zapcore.ShortCallerEncoder + encoderCfg.EncodeDuration = zapcore.StringDurationEncoder + + encoder := zapcore.NewConsoleEncoder(encoderCfg) + + testSyncer := zapcore.AddSync(&testWriter{t}) + core := zapcore.NewCore(encoder, testSyncer, zap.WarnLevel) + + return zap.New(core, zap.AddCaller()) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/testutils/memory_slicer.go b/packages/orchestrator/internal/sandbox/uffd/testutils/memory_slicer.go new file mode 100644 index 0000000000..ba2cac99e1 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/testutils/memory_slicer.go @@ -0,0 +1,47 @@ +package testutils + +import ( + "context" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" +) + +// MemorySlicer exposes byte slice via the Slicer interface. +// This is used for testing purposes. +type MemorySlicer struct { + content []byte + pagesize int64 + + accessed *block.Tracker +} + +var _ block.Slicer = (*MemorySlicer)(nil) + +func newMemorySlicer(content []byte, pagesize int64) *MemorySlicer { + return &MemorySlicer{ + content: content, + pagesize: pagesize, + accessed: block.NewTracker(pagesize), + } +} + +func (s *MemorySlicer) Slice(_ context.Context, offset, size int64) ([]byte, error) { + for i := offset; i < offset+size; i += s.pagesize { + s.accessed.Add(i) + } + + return s.content[offset : offset+size], nil +} + +func (s *MemorySlicer) Size() (int64, error) { + return int64(len(s.content)), nil +} + +func (s *MemorySlicer) Content() []byte { + return s.content +} + +// Offsets returns offsets of the content that were accessed via the Slice method. +func (s *MemorySlicer) Accessed() *block.Tracker { + return s.accessed.Clone() +} diff --git a/packages/orchestrator/internal/sandbox/uffd/testutils/page_mmap.go b/packages/orchestrator/internal/sandbox/uffd/testutils/page_mmap.go new file mode 100644 index 0000000000..bb8837add6 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/testutils/page_mmap.go @@ -0,0 +1,44 @@ +package testutils + +import ( + "fmt" + "math" + "syscall" + "unsafe" + + "golang.org/x/sys/unix" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +func NewPageMmap(size, pagesize uint64) ([]byte, uintptr, func() error, error) { + if pagesize == header.PageSize { + return newMmap(size, header.PageSize, 0) + } + + if pagesize == header.HugepageSize { + return newMmap(size, header.HugepageSize, unix.MAP_HUGETLB|unix.MAP_HUGE_2MB) + } + + return nil, 0, nil, fmt.Errorf("unsupported page size: %d", pagesize) +} + +func newMmap(size, pagesize uint64, flags int) ([]byte, uintptr, func() error, error) { + l := int(math.Ceil(float64(size)/float64(pagesize)) * float64(pagesize)) + b, err := syscall.Mmap( + -1, + 0, + l, + syscall.PROT_READ|syscall.PROT_WRITE, + syscall.MAP_PRIVATE|syscall.MAP_ANONYMOUS|flags, + ) + if err != nil { + return nil, 0, nil, fmt.Errorf("failed to mmap: %w", err) + } + + closeMmap := func() error { + return syscall.Munmap(b) + } + + return b, uintptr(unsafe.Pointer(&b[0])), closeMmap, nil +} diff --git a/packages/orchestrator/internal/sandbox/uffd/testutils/random_data.go b/packages/orchestrator/internal/sandbox/uffd/testutils/random_data.go new file mode 100644 index 0000000000..c8bfe8ed1e --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/testutils/random_data.go @@ -0,0 +1,17 @@ +package testutils + +import ( + "crypto/rand" +) + +func RandomPages(pagesize, numberOfPages uint64) *MemorySlicer { + size := pagesize * numberOfPages + + n := int(size) + buf := make([]byte, n) + if _, err := rand.Read(buf); err != nil { + panic(err) + } + + return newMemorySlicer(buf, int64(pagesize)) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/handler.go b/packages/orchestrator/internal/sandbox/uffd/uffd.go similarity index 60% rename from packages/orchestrator/internal/sandbox/uffd/handler.go rename to packages/orchestrator/internal/sandbox/uffd/uffd.go index 947ff168cf..55b39a219e 100644 --- a/packages/orchestrator/internal/sandbox/uffd/handler.go +++ b/packages/orchestrator/internal/sandbox/uffd/uffd.go @@ -10,13 +10,13 @@ import ( "syscall" "time" - "github.com/bits-and-blooms/bitset" "go.opentelemetry.io/otel" "go.uber.org/zap" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/fdexit" - "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/mapping" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/userfaultfd" "github.com/e2b-dev/infra/packages/shared/pkg/logger" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) @@ -26,29 +26,22 @@ var tracer = otel.Tracer("github.com/e2b-dev/infra/packages/orchestrator/interna const ( uffdMsgListenerTimeout = 10 * time.Second fdSize = 4 - mappingsSize = 1024 + regionMappingsSize = 1024 ) type Uffd struct { - exit *utils.ErrorOnce - readyCh chan struct{} - - fdExit *fdexit.FdExit - - lis *net.UnixListener - - memfile *block.TrackedSliceDevice + exit *utils.ErrorOnce + readyCh chan struct{} + fdExit *fdexit.FdExit + lis *net.UnixListener socketPath string + memfile block.ReadonlyDevice + handler utils.SetOnce[*userfaultfd.Userfaultfd] } var _ MemoryBackend = (*Uffd)(nil) -func New(memfile block.ReadonlyDevice, socketPath string, blockSize int64) (*Uffd, error) { - trackedMemfile, err := block.NewTrackedSliceDevice(blockSize, memfile) - if err != nil { - return nil, fmt.Errorf("failed to create tracked slice device: %w", err) - } - +func New(memfile block.ReadonlyDevice, socketPath string) (*Uffd, error) { fdExit, err := fdexit.New() if err != nil { return nil, fmt.Errorf("failed to create fd exit: %w", err) @@ -58,8 +51,8 @@ func New(memfile block.ReadonlyDevice, socketPath string, blockSize int64) (*Uff exit: utils.NewErrorOnce(), readyCh: make(chan struct{}, 1), fdExit: fdExit, - memfile: trackedMemfile, socketPath: socketPath, + memfile: memfile, }, nil } @@ -106,19 +99,19 @@ func (u *Uffd) handle(ctx context.Context, sandboxId string) error { unixConn := conn.(*net.UnixConn) - mappingsBuf := make([]byte, mappingsSize) + regionMappingsBuf := make([]byte, regionMappingsSize) uffdBuf := make([]byte, syscall.CmsgSpace(fdSize)) - numBytesMappings, numBytesFd, _, _, err := unixConn.ReadMsgUnix(mappingsBuf, uffdBuf) + numBytesMappings, numBytesFd, _, _, err := unixConn.ReadMsgUnix(regionMappingsBuf, uffdBuf) if err != nil { return fmt.Errorf("failed to read unix msg from connection: %w", err) } - mappingsBuf = mappingsBuf[:numBytesMappings] + regionMappingsBuf = regionMappingsBuf[:numBytesMappings] - var m mapping.FcMappings + var regions []memory.Region - err = json.Unmarshal(mappingsBuf, &m) + err = json.Unmarshal(regionMappingsBuf, ®ions) if err != nil { return fmt.Errorf("failed parsing memory mapping data: %w", err) } @@ -141,24 +134,47 @@ func (u *Uffd) handle(ctx context.Context, sandboxId string) error { return fmt.Errorf("expected 1 fd: found %d", len(fds)) } - uffd := fds[0] + m := memory.NewMapping(regions) + + uffd, err := userfaultfd.NewUserfaultfdFromFd( + uintptr(fds[0]), + u.memfile, + u.memfile.BlockSize(), + m, + zap.L().With(logger.WithSandboxID(sandboxId)), + ) + if err != nil { + return fmt.Errorf("failed to create uffd: %w", err) + } + + u.handler.SetValue(uffd) defer func() { - closeErr := syscall.Close(uffd) + closeErr := uffd.Close() if closeErr != nil { zap.L().Error("failed to close uffd", logger.WithSandboxID(sandboxId), zap.String("socket_path", u.socketPath), zap.Error(closeErr)) } }() + for _, region := range m.Regions { + // Register the WP. It is possible that the memory region was already registered (with missing pages in FC), but registering it again with bigger flag subset should merge these. + // - https://github.com/firecracker-microvm/firecracker/blob/f335a0adf46f0680a141eb1e76fe31ac258918c5/src/vmm/src/persist.rs#L477 + // - https://github.com/bytecodealliance/userfaultfd-rs/blob/main/src/builder.rs + err := uffd.Register( + region.BaseHostVirtAddr+region.Offset, + uint64(region.Size), + userfaultfd.UFFDIO_REGISTER_MODE_WP|userfaultfd.UFFDIO_REGISTER_MODE_MISSING, + ) + if err != nil { + return fmt.Errorf("failed to reregister memory region with write protection %d-%d", region.Offset, region.Offset+region.Size) + } + } + u.readyCh <- struct{}{} - err = Serve( + err = uffd.Serve( ctx, - uffd, - m, - u.memfile, u.fdExit, - zap.L().With(logger.WithSandboxID(sandboxId)), ) if err != nil { return fmt.Errorf("failed handling uffd: %w", err) @@ -175,18 +191,28 @@ func (u *Uffd) Ready() chan struct{} { return u.readyCh } -func (u *Uffd) Exit() *utils.ErrorOnce { - return u.exit -} +func (u *Uffd) Disable(ctx context.Context) error { + uffd, err := u.handler.WaitWithContext(ctx) + if err != nil { + return fmt.Errorf("failed to get uffd: %w", err) + } + + uffd.Disable() -func (u *Uffd) TrackAndReturnNil() error { - return u.lis.Close() + return nil } -func (u *Uffd) Disable() error { - return u.memfile.Disable() +func (u *Uffd) Exit() *utils.ErrorOnce { + return u.exit } -func (u *Uffd) Dirty() *bitset.BitSet { - return u.memfile.Dirty() +// Dirty waits for all the requests in flight to be finished and then returns clone of the dirty tracker. +// Call *after* pausing the firecracker process—to let the uffd process all the requests. +func (u *Uffd) Dirty(ctx context.Context) (*block.Tracker, error) { + uffd, err := u.handler.WaitWithContext(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get uffd: %w", err) + } + + return uffd.Dirty(ctx) } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/constants.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/constants.go index 63a3b3e71e..86c8633883 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/constants.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/constants.go @@ -34,14 +34,13 @@ const ( UFFDIO_API = C.UFFDIO_API UFFDIO_REGISTER = C.UFFDIO_REGISTER - UFFDIO_WRITEPROTECT = C.UFFDIO_WRITEPROTECT UFFDIO_COPY = C.UFFDIO_COPY + UFFDIO_WRITEPROTECT = C.UFFDIO_WRITEPROTECT UFFD_PAGEFAULT_FLAG_WP = C.UFFD_PAGEFAULT_FLAG_WP UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE - UFFD_FEATURE_MISSING_HUGETLBFS = C.UFFD_FEATURE_MISSING_HUGETLBFS - UFFD_FEATURE_WP_HUGETLBFS_SHMEM = C.UFFD_FEATURE_WP_HUGETLBFS_SHMEM + UFFD_FEATURE_MISSING_HUGETLBFS = C.UFFD_FEATURE_MISSING_HUGETLBFS ) type ( @@ -79,7 +78,7 @@ func NewUffdioRegister(start, length, mode CULong) UffdioRegister { func NewUffdioCopy(b []byte, address CULong, pagesize CULong, mode CULong, bytesCopied CLong) UffdioCopy { return UffdioCopy{ src: CULong(uintptr(unsafe.Pointer(&b[0]))), - dst: address &^ (pagesize - 1), + dst: address, len: pagesize, mode: mode, copy: bytesCopied, @@ -104,14 +103,6 @@ func GetMsgArg(msg *UffdMsg) [24]byte { return msg.arg } -func GetPagefaultAddress(pagefault *UffdPagefault) CULong { - return pagefault.address -} - -func IsWritePageFault(pagefault *UffdPagefault) bool { - return pagefault.flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 -} - -func IsWriteProtectPageFault(pagefault *UffdPagefault) bool { - return pagefault.flags&UFFD_PAGEFAULT_FLAG_WP != 0 +func GetPagefaultAddress(pagefault *UffdPagefault) uintptr { + return uintptr(pagefault.address) } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/diagram.mermaid b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/diagram.mermaid new file mode 100644 index 0000000000..4daa1238d3 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/diagram.mermaid @@ -0,0 +1,4 @@ +flowchart TD +A[missing page] -- write (WRITE flag) --> B(COPY) --> C[dirty page] +A -- read (MISSING flag) --> D(COPY + MODE_WP) --> E[faulted page] +E -- write (WP+WRITE flag) --> F(remove MODE_WP) --> C \ No newline at end of file diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/serve.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/serve.go new file mode 100644 index 0000000000..ca7abe8121 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/serve.go @@ -0,0 +1,262 @@ +package userfaultfd + +import ( + "context" + "errors" + "fmt" + "syscall" + "unsafe" + + "go.uber.org/zap" + "golang.org/x/sys/unix" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/fdexit" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +func (u *Userfaultfd) Serve( + ctx context.Context, + fdExit *fdexit.FdExit, +) error { + pollFds := []unix.PollFd{ + {Fd: int32(u.fd), Events: unix.POLLIN}, + {Fd: fdExit.Reader(), Events: unix.POLLIN}, + } + +outerLoop: + for { + if _, err := unix.Poll( + pollFds, + -1, + ); err != nil { + if err == unix.EINTR { + u.logger.Debug("uffd: interrupted polling, going back to polling") + + continue + } + + if err == unix.EAGAIN { + u.logger.Debug("uffd: eagain during polling, going back to polling") + + continue + } + + u.logger.Error("UFFD serve polling error", zap.Error(err)) + + return fmt.Errorf("failed polling: %w", err) + } + + exitFd := pollFds[1] + if exitFd.Revents&unix.POLLIN != 0 { + errMsg := u.wg.Wait() + if errMsg != nil { + u.logger.Warn("UFFD fd exit error while waiting for goroutines to finish", zap.Error(errMsg)) + + return fmt.Errorf("failed to handle uffd: %w", errMsg) + } + + return nil + } + + uffdFd := pollFds[0] + if uffdFd.Revents&unix.POLLIN == 0 { + // Uffd is not ready for reading as there is nothing to read on the fd. + // https://github.com/firecracker-microvm/firecracker/issues/5056 + // https://elixir.bootlin.com/linux/v6.8.12/source/fs/userfaultfd.c#L1149 + // TODO: Check for all the errors + // - https://docs.kernel.org/admin-guide/mm/userfaultfd.html + // - https://elixir.bootlin.com/linux/v6.8.12/source/fs/userfaultfd.c + // - https://man7.org/linux/man-pages/man2/userfaultfd.2.html + // It might be possible to just check for data != 0 in the syscall.Read loop + // but I don't feel confident about doing that. + u.logger.Debug("uffd: no data in fd, going back to polling") + + continue + } + + buf := make([]byte, unsafe.Sizeof(UffdMsg{})) + + for { + n, err := syscall.Read(int(u.fd), buf) + if err == syscall.EINTR { + u.logger.Debug("uffd: interrupted read, reading again") + + continue + } + + if err == nil { + // There is no error so we can proceed. + break + } + + if err == syscall.EAGAIN { + u.logger.Debug("uffd: eagain error, going back to polling", zap.Error(err), zap.Int("read_bytes", n)) + + // Continue polling the fd. + continue outerLoop + } + + u.logger.Error("uffd: read error", zap.Error(err)) + + return fmt.Errorf("failed to read: %w", err) + } + + msg := *(*UffdMsg)(unsafe.Pointer(&buf[0])) + if GetMsgEvent(&msg) != UFFD_EVENT_PAGEFAULT { + u.logger.Error("UFFD serve unexpected event type", zap.Any("event_type", GetMsgEvent(&msg))) + + return ErrUnexpectedEventType + } + + arg := GetMsgArg(&msg) + pagefault := (*(*UffdPagefault)(unsafe.Pointer(&arg[0]))) + flags := pagefault.flags + + addr := GetPagefaultAddress(&pagefault) + + offset, pagesize, err := u.ma.GetOffset(addr) + if err != nil { + u.logger.Error("UFFD serve get mapping error", zap.Error(err)) + + return fmt.Errorf("failed to map: %w", err) + } + + // Handle write to write protected page (WP+WRITE flag) + if flags&UFFD_PAGEFAULT_FLAG_WP != 0 && flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 { + u.handleWriteProtection(addr, offset, pagesize) + + continue + } + + // Handle write to missing page (WRITE flag) + if flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 { + u.handleMissing(ctx, fdExit.SignalExit, addr, offset, pagesize, true) + + continue + } + + // Handle read to missing page (MISSING flag) + if flags == 0 { + u.handleMissing(ctx, fdExit.SignalExit, addr, offset, pagesize, false) + + continue + } + + u.logger.Warn("UFFD serve unexpected event type", zap.Any("event_type", flags)) + } +} + +func (u *Userfaultfd) handleMissing( + ctx context.Context, + onFailure func() error, + addr uintptr, + offset int64, + pagesize uint64, + write bool, +) { + if write { + if !u.writeRequests.Add(offset) { + return + } + + u.writeRequestCounter.Add() + } else { + // TODO: We should be able to add the page to the missing map on the write handle as well. + if !u.missingRequests.Add(offset) { + return + } + } + + u.wg.Go(func() error { + defer func() { + if r := recover(); r != nil { + u.logger.Error("UFFD serve panic", zap.Any("pagesize", pagesize), zap.Any("panic", r)) + } + }() + + defer func() { + if write { + u.writeRequestCounter.Done() + } + }() + + var b []byte + + if u.disabled.Load() { + b = header.EmptyHugePage[:pagesize] + } else { + sliceB, sliceErr := u.src.Slice(ctx, offset, int64(pagesize)) + if sliceErr != nil { + signalErr := onFailure() + + joinedErr := errors.Join(sliceErr, signalErr) + + u.logger.Error("UFFD serve slice error", zap.Error(joinedErr)) + + return fmt.Errorf("failed to read from source: %w", joinedErr) + } + + b = sliceB + } + var copyMode CULong + + if !write { + copyMode = copyMode | UFFDIO_COPY_MODE_WP + } + + copyErr := u.copy(addr, b, pagesize, copyMode) + if errors.Is(copyErr, unix.EEXIST) { + u.logger.Debug("UFFD serve page already mapped", zap.Any("offset", addr), zap.Any("pagesize", pagesize)) + + // Page is already mapped + + return nil + } + + if copyErr != nil { + signalErr := onFailure() + + joinedErr := errors.Join(copyErr, signalErr) + + u.logger.Error("UFFD serve uffdio copy error", zap.Error(joinedErr)) + + return fmt.Errorf("failed uffdio copy %w", joinedErr) + } + + // We mark the page as dirty if it was a write to a page that was not already mapped. + if write { + u.dirty.Add(offset) + } + + return nil + }) + +} + +func (u *Userfaultfd) handleWriteProtection(addr uintptr, offset int64, pagesize uint64) { + if !u.writeRequests.Add(offset) { + return + } + + u.writeRequestCounter.Add() + + u.wg.Go(func() error { + defer func() { + if r := recover(); r != nil { + u.logger.Error("UFFD remove write protection panic", zap.Any("offset", offset), zap.Any("pagesize", pagesize), zap.Any("panic", r)) + } + }() + + defer u.writeRequestCounter.Done() + + wpErr := u.RemoveWriteProtection(addr, pagesize) + if wpErr != nil { + return fmt.Errorf("error removing write protection from page %d", addr) + } + + // We mark the page as dirty if it was a write to a page that was already mapped. + u.dirty.Add(offset) + + return nil + }) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/syscalls.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/syscalls.go new file mode 100644 index 0000000000..c5f3684c67 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/syscalls.go @@ -0,0 +1,96 @@ +package userfaultfd + +import ( + "fmt" + "syscall" + "unsafe" + + "go.uber.org/zap" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +// flags: syscall.O_CLOEXEC|syscall.O_NONBLOCK +func newUserfaultfd(flags uintptr, src block.Slicer, pagesize int64, m *memory.Mapping, logger *zap.Logger) (*Userfaultfd, error) { + uffd, _, errno := syscall.Syscall(NR_userfaultfd, flags, 0, 0) + if errno != 0 { + return nil, fmt.Errorf("userfaultfd syscall failed: %w", errno) + } + + return NewUserfaultfdFromFd(uffd, src, pagesize, m, logger) +} + +// features: UFFD_FEATURE_MISSING_HUGETLBFS +// This is already called by the FC +func (u *Userfaultfd) configureApi(pagesize uint64) error { + var features CULong + + // Only set the hugepage feature if we're using hugepages + if pagesize == header.HugepageSize { + features |= UFFD_FEATURE_MISSING_HUGETLBFS + } + + api := NewUffdioAPI(UFFD_API, features) + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, u.fd, UFFDIO_API, uintptr(unsafe.Pointer(&api))) + if errno != 0 { + return fmt.Errorf("UFFDIO_API ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} + +// mode: UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING +// This is already called by the FC, but only with the UFFDIO_REGISTER_MODE_MISSING +// We need to call it with UFFDIO_REGISTER_MODE_WP when we use both missing and wp +func (u *Userfaultfd) Register(addr uintptr, size uint64, mode CULong) error { + register := NewUffdioRegister(CULong(addr), CULong(size), mode) + + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, u.fd, UFFDIO_REGISTER, uintptr(unsafe.Pointer(®ister))) + if errno != 0 { + return fmt.Errorf("UFFDIO_REGISTER ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} + +func (u *Userfaultfd) writeProtect(addr uintptr, size uint64, mode CULong) error { + register := NewUffdioWriteProtect(CULong(addr), CULong(size), mode) + + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, u.fd, UFFDIO_WRITEPROTECT, uintptr(unsafe.Pointer(®ister))) + if errno != 0 { + return fmt.Errorf("UFFDIO_WRITEPROTECT ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} + +func (u *Userfaultfd) RemoveWriteProtection(addr uintptr, size uint64) error { + return u.writeProtect(addr, size, 0) +} + +func (u *Userfaultfd) AddWriteProtection(addr uintptr, size uint64) error { + return u.writeProtect(addr, size, UFFDIO_WRITEPROTECT_MODE_WP) +} + +// mode: UFFDIO_COPY_MODE_WP +// When we use both missing and wp, we need to use UFFDIO_COPY_MODE_WP, otherwise copying would unprotect the page +func (u *Userfaultfd) copy(addr uintptr, data []byte, pagesize uint64, mode CULong) error { + cpy := NewUffdioCopy(data, CULong(addr)&^CULong(pagesize-1), CULong(pagesize), mode, 0) + + if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, u.fd, UFFDIO_COPY, uintptr(unsafe.Pointer(&cpy))); errno != 0 { + return errno + } + + // Check if the copied size matches the requested pagesize + if uint64(cpy.copy) != pagesize { + return fmt.Errorf("UFFDIO_COPY copied %d bytes, expected %d", cpy.copy, pagesize) + } + + return nil +} + +func (u *Userfaultfd) Close() error { + return syscall.Close(int(u.fd)) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go new file mode 100644 index 0000000000..5d14ec021d --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go @@ -0,0 +1,71 @@ +package userfaultfd + +import ( + "context" + "errors" + "fmt" + "sync/atomic" + + "go.uber.org/zap" + "golang.org/x/sync/errgroup" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" +) + +var ErrUnexpectedEventType = errors.New("unexpected event type") + +type Userfaultfd struct { + fd uintptr + + src block.Slicer + ma *memory.Mapping + dirty *block.Tracker + disabled atomic.Bool + + // The maps prevent serving pages multiple times (as we now add WP only once we don't have to remove entries from any map.) + // For normal sized pages with swap on, the behavior seems not to be properly described in docs + // and it's not clear if the missing can be legitimately triggered multiple times. + missingRequests *block.Tracker + writeRequests *block.Tracker + wpRequests *block.Tracker + + writeRequestCounter *utils.SettleCounter + wg errgroup.Group + + logger *zap.Logger +} + +// NewUserfaultfdFromFd creates a new userfaultfd instance with optional configuration. +func NewUserfaultfdFromFd(fd uintptr, src block.Slicer, pagesize int64, m *memory.Mapping, logger *zap.Logger) (*Userfaultfd, error) { + return &Userfaultfd{ + fd: fd, + src: src, + dirty: block.NewTracker(pagesize), + missingRequests: block.NewTracker(pagesize), + writeRequests: block.NewTracker(pagesize), + wpRequests: block.NewTracker(pagesize), + disabled: atomic.Bool{}, + ma: m, + writeRequestCounter: utils.NewZeroSettleCounter(), + logger: logger, + }, nil +} + +func (u *Userfaultfd) Disable() { + u.disabled.Store(true) +} + +func (u *Userfaultfd) Dirty(ctx context.Context) (*block.Tracker, error) { + err := u.writeRequestCounter.Wait(ctx) + if err != nil { + return nil, fmt.Errorf("failed to wait for write requests: %w", err) + } + + u.missingRequests.Reset() + u.writeRequests.Reset() + u.wpRequests.Reset() + + return u.dirty.Clone(), nil +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go new file mode 100644 index 0000000000..6313c46bff --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go @@ -0,0 +1,456 @@ +package userfaultfd + +import ( + "bytes" + "context" + "fmt" + "slices" + "syscall" + "testing" + + "github.com/bits-and-blooms/bitset" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/fdexit" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/testutils" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" +) + +type testConfig struct { + name string + // Page size of the memory area. + pagesize uint64 + // Number of pages in the memory area. + numberOfPages uint64 + // Operations to trigger on the memory area. + operations []operation +} + +type operationMode uint32 + +const ( + operationModeRead operationMode = 1 << iota + operationModeWrite +) + +type operation struct { + // Offset in bytes. Must be smaller than the (numberOfPages-1) * pagesize as it reads a page and it must be aligned to the pagesize from the testConfig. + offset int64 + mode operationMode +} + +func TestUffdMissing(t *testing.T) { + tests := []testConfig{ + { + name: "standard 4k page, operation at start", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + }, + }, + { + name: "standard 4k page, operation at middle", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 15 * header.PageSize, + mode: operationModeRead, + }, + }, + }, + { + name: "standard 4k page, operation at last page", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 31 * header.PageSize, + mode: operationModeRead, + }, + }, + }, + { + name: "hugepage, operation at start", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + }, + }, + { + name: "hugepage, operation at middle", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeRead, + }, + }, + }, + { + name: "hugepage, operation at last page", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 7 * header.HugepageSize, + mode: operationModeRead, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := configureTest(t, tt) + + for _, operation := range tt.operations { + if operation.mode == operationModeRead { + err := h.executeRead(t.Context(), operation) + require.NoError(t, err) + } + } + + err := h.uffd.writeRequestCounter.Wait(t.Context()) + require.NoError(t, err) + + expectedReadOffsets := getOperationsOffsets(tt.operations, operationModeRead|operationModeWrite) + assert.Equal(t, expectedReadOffsets, h.getReadOffsets(), "checking which pages were faulted)") + }) + } +} + +func TestUffdWriteProtection(t *testing.T) { + tests := []testConfig{ + { + name: "standard 4k page, single write", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, single read then write on first page", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, single write then read on first page", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeWrite, + }, + { + offset: 0, + mode: operationModeRead, + }, + }, + }, + { + name: "standard 4k page, single read then write on non-first page", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 15 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 15 * header.PageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, two writes on different pages", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 15 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 16 * header.PageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, single write", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, single read then write on first page", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, single read then write on non-first page", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 3 * header.HugepageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, single write then read on non-first page", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 3 * header.HugepageSize, + mode: operationModeRead, + }, + }, + }, + { + name: "hugepage, two writes on different pages", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 4 * header.HugepageSize, + mode: operationModeWrite, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := configureTest(t, tt) + + for _, operation := range tt.operations { + if operation.mode == operationModeRead { + err := h.executeRead(t.Context(), operation) + assert.NoError(t, err) + } + + if operation.mode == operationModeWrite { + err := h.executeWrite(t.Context(), operation) + assert.NoError(t, err) + } + } + + err := h.uffd.writeRequestCounter.Wait(t.Context()) + require.NoError(t, err) + + expectedWriteOffsets := getOperationsOffsets(tt.operations, operationModeWrite) + assert.Equal(t, expectedWriteOffsets, h.getWriteOffsets(), "checking which pages were written to") + + expectedReadOffsets := getOperationsOffsets(tt.operations, operationModeRead|operationModeWrite) + assert.Equal(t, expectedReadOffsets, h.getReadOffsets(), "checking which pages were faulted)") + }) + } +} + +type testHandler struct { + memoryArea *[]byte + pagesize uint64 + data *testutils.MemorySlicer + memoryMap *memory.Mapping + uffd *Userfaultfd +} + +func (h *testHandler) getReadOffsets() []uint { + return utils.Map(slices.Collect(h.uffd.missingRequests.BitSet().Union(h.uffd.writeRequests.BitSet()).EachSet()), func(offset uint) uint { + return uint(header.BlockOffset(int64(offset), int64(h.pagesize))) + }) +} + +func (h *testHandler) getWriteOffsets() []uint { + return utils.Map(slices.Collect(h.uffd.dirty.BitSet().EachSet()), func(offset uint) uint { + return uint(header.BlockOffset(int64(offset), int64(h.pagesize))) + }) +} + +func (h *testHandler) executeRead(ctx context.Context, op operation) error { + readBytes := (*h.memoryArea)[op.offset : op.offset+int64(h.pagesize)] + + expectedBytes, err := h.data.Slice(ctx, op.offset, int64(h.pagesize)) + if err != nil { + return err + } + + if !bytes.Equal(readBytes, expectedBytes) { + idx, want, got := testutils.FirstDifferentByte(readBytes, expectedBytes) + + return fmt.Errorf("content mismatch: want '%x, got %x at index %d", want, got, idx) + } + + return nil +} + +func (h *testHandler) executeWrite(ctx context.Context, op operation) error { + bytesToWrite, err := h.data.Slice(ctx, int64(op.offset), int64(h.pagesize)) + if err != nil { + return err + } + + n := copy((*h.memoryArea)[op.offset:op.offset+int64(h.pagesize)], bytesToWrite) + if n != int(h.pagesize) { + return fmt.Errorf("copy length mismatch: want %d, got %d", h.pagesize, n) + } + + err = h.uffd.writeRequestCounter.Wait(ctx) + if err != nil { + return fmt.Errorf("failed to wait for write requests finish: %w", err) + } + + if !h.uffd.dirty.Has(op.offset) { + return fmt.Errorf("dirty bit not set for page at offset %d, all dirty offsets: %v", op.offset, h.getWriteOffsets()) + } + + return nil +} + +func configureTest(t *testing.T, tt testConfig) *testHandler { + data := testutils.RandomPages(tt.pagesize, tt.numberOfPages) + + size, err := data.Size() + require.NoError(t, err) + + memoryArea, memoryStart, unmap, err := testutils.NewPageMmap(uint64(size), tt.pagesize) + require.NoError(t, err) + + t.Cleanup(func() { + unmap() + }) + + m := memory.NewMapping([]memory.Region{ + { + BaseHostVirtAddr: uintptr(memoryStart), + Size: uintptr(size), + Offset: uintptr(0), + PageSize: uintptr(tt.pagesize), + }, + }) + + logger := testutils.NewTestLogger(t) + + uffd, err := newUserfaultfd(syscall.O_CLOEXEC|syscall.O_NONBLOCK, data, int64(tt.pagesize), m, logger) + require.NoError(t, err) + + t.Cleanup(func() { + uffd.Close() + }) + + err = uffd.configureApi(tt.pagesize) + require.NoError(t, err) + + err = uffd.Register(memoryStart, uint64(size), UFFDIO_REGISTER_MODE_MISSING|UFFDIO_REGISTER_MODE_WP) + require.NoError(t, err) + + fdExit, err := fdexit.New() + require.NoError(t, err) + + t.Cleanup(func() { + fdExit.SignalExit() + fdExit.Close() + }) + + exitUffd := make(chan struct{}, 1) + + go func() { + err := uffd.Serve(t.Context(), fdExit) + assert.NoError(t, err) + + exitUffd <- struct{}{} + }() + + t.Cleanup(func() { + signalExitErr := fdExit.SignalExit() + require.NoError(t, signalExitErr) + + select { + case <-exitUffd: + case <-t.Context().Done(): + t.Log("context done before exit:", t.Context().Err()) + } + }) + + return &testHandler{ + memoryArea: &memoryArea, + memoryMap: m, + pagesize: tt.pagesize, + data: data, + uffd: uffd, + } +} + +// Get a bitset of the offsets of the operations for the given mode. +func getOperationsOffsets(ops []operation, m operationMode) []uint { + b := bitset.New(0) + + for _, operation := range ops { + if operation.mode&m != 0 { + b.Set(uint(operation.offset)) + } + } + + return slices.Collect(b.EachSet()) +} + +// TODO: Test write protection double registration (with missing) to simulate the FC situation diff --git a/packages/shared/pkg/utils/map_keys.go b/packages/shared/pkg/utils/map_keys.go new file mode 100644 index 0000000000..a2d4b410d2 --- /dev/null +++ b/packages/shared/pkg/utils/map_keys.go @@ -0,0 +1,10 @@ +package utils + +func MapKeys[K comparable](m map[K]struct{}) []K { + keys := make([]K, 0, len(m)) + for key := range m { + keys = append(keys, key) + } + + return keys +} diff --git a/packages/shared/pkg/utils/settle_counter.go b/packages/shared/pkg/utils/settle_counter.go new file mode 100644 index 0000000000..26bfd83733 --- /dev/null +++ b/packages/shared/pkg/utils/settle_counter.go @@ -0,0 +1,68 @@ +package utils + +import ( + "context" + "sync" + "sync/atomic" +) + +type SettleCounter struct { + counter atomic.Int64 + cond sync.Cond + settleValue int64 +} + +// NewZeroSettleCounter creates a new SettleCounter that settles when the counter is zero. +func NewZeroSettleCounter() *SettleCounter { + return &SettleCounter{ + counter: atomic.Int64{}, + cond: *sync.NewCond(&sync.Mutex{}), + settleValue: 0, + } +} + +func (w *SettleCounter) add(delta int64) { + if w.counter.Add(delta) == w.settleValue { + w.cond.Broadcast() + } +} + +func (w *SettleCounter) Add() { + w.add(1) +} + +func (w *SettleCounter) Done() { + w.add(-1) +} + +// Wait waits for the counter to be the settle value. +func (w *SettleCounter) Wait(ctx context.Context) error { + // Ensure we can break out of this Wait when the context is done. + go func() { + <-ctx.Done() + + w.cond.Broadcast() + }() + + for w.counter.Load() != w.settleValue { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + w.cond.L.Lock() + + w.cond.Wait() + + w.cond.L.Unlock() + } + + return nil +} + +func (w *SettleCounter) Close() { + w.counter.Store(w.settleValue) + + w.cond.Broadcast() +}