diff --git a/internal/quick_select.go b/internal/quick_select.go index ab53a17..360f280 100644 --- a/internal/quick_select.go +++ b/internal/quick_select.go @@ -61,3 +61,51 @@ func partition[T cmp.Ordered](arr []T, lo int, hi int) int { arr[lo], arr[j] = arr[j], arr[lo] return j } + +// QuickSelectFunc finds the k-th smallest element in a slice using the Quickselect algorithm with a custom comparator. +// The slice is partially partitioned and may not maintain full order. +// It modifies the input slice in-place. +// T is a generic type, and the comparison logic is provided by the `compare` function. +// The `lo` and `hi` parameters define the range in the slice to consider for the selection. +func QuickSelectFunc[T any](arr []T, lo int, hi int, pivot int, compare func(a, b T) int) T { + for hi > lo { + j := partitionFunc(arr, lo, hi, compare) + if j == pivot { + return arr[pivot] + } + if j > pivot { + hi = j - 1 + } else { + lo = j + 1 + } + } + return arr[pivot] +} + +func partitionFunc[T any](arr []T, lo int, hi int, compare func(a, b T) int) int { + i := lo + j := hi + 1 + v := arr[lo] + for { + for compare(arr[i+1], v) < 0 { + i++ + if i == hi { + break + } + } + i++ + for compare(v, arr[j-1]) < 0 { + j-- + if j == lo { + break + } + } + j-- + if i >= j { + break + } + arr[i], arr[j] = arr[j], arr[i] + } + arr[lo], arr[j] = arr[j], arr[lo] + return j +} diff --git a/internal/quick_select_test.go b/internal/quick_select_test.go index f90781e..72f8553 100644 --- a/internal/quick_select_test.go +++ b/internal/quick_select_test.go @@ -149,3 +149,93 @@ func TestQuickSelectString(t *testing.T) { assert.Equal(t, expected, result, "want: %v\ngot: %v", expected, result) } + +type testEntry struct { + hash uint64 + summary any +} + +func TestQuickSelectFunc(t *testing.T) { + testCases := []struct { + name string + arr []testEntry + lo int + hi int + pivot int + expected uint64 + }{ + { + name: "two elements first smaller", + arr: []testEntry{{hash: 50}, {hash: 100}}, + lo: 0, + hi: 1, + pivot: 1, + expected: 100, + }, + { + name: "find median", + arr: []testEntry{{hash: 3}, {hash: 1}, {hash: 4}, {hash: 1}, {hash: 5}, {hash: 9}, {hash: 2}, {hash: 6}}, + lo: 0, + hi: 7, + pivot: 4, + expected: 4, + }, + { + name: "find minimum", + arr: []testEntry{{hash: 3}, {hash: 1}, {hash: 4}, {hash: 1}, {hash: 5}, {hash: 9}, {hash: 2}, {hash: 6}}, + lo: 0, + hi: 7, + pivot: 0, + expected: 1, + }, + { + name: "find maximum", + arr: []testEntry{{hash: 3}, {hash: 1}, {hash: 4}, {hash: 1}, {hash: 5}, {hash: 9}, {hash: 2}, {hash: 6}}, + lo: 0, + hi: 7, + pivot: 7, + expected: 9, + }, + { + name: "single element", + arr: []testEntry{{hash: 42}}, + lo: 0, + hi: 0, + pivot: 0, + expected: 42, + }, + { + name: "two elements descending", + arr: []testEntry{{hash: 5}, {hash: 3}}, + lo: 0, + hi: 1, + pivot: 0, + expected: 3, + }, + { + name: "with summary data", + arr: []testEntry{{hash: 30, summary: "a"}, {hash: 10, summary: "b"}, {hash: 20, summary: "c"}}, + lo: 0, + hi: 2, + pivot: 1, + expected: 20, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + arrCopy := make([]testEntry, len(tc.arr)) + copy(arrCopy, tc.arr) + + result := QuickSelectFunc(arrCopy, tc.lo, tc.hi, tc.pivot, func(a, b testEntry) int { + if a.hash < b.hash { + return -1 + } else if a.hash > b.hash { + return 1 + } + return 0 + }) + + assert.Equal(t, tc.expected, result.hash, "want: %v\ngot: %v", tc.expected, result.hash) + }) + } +} diff --git a/tuple/hashtable.go b/tuple/hashtable.go new file mode 100644 index 0000000..e0b5fef --- /dev/null +++ b/tuple/hashtable.go @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tuple + +import ( + "errors" + "fmt" + "math" + + "github.com/apache/datasketches-go/internal" + "github.com/apache/datasketches-go/theta" +) + +const ( + resizeThreshold = 0.5 + rebuildThreshold = 15.0 / 16.0 +) + +const ( + strideHashBits = 7 + strideMask = (1 << strideHashBits) - 1 +) + +var ( + ErrKeyNotFound = errors.New("key not found") + ErrKeyNotFoundAndNoEmptySlots = errors.New("key not found and no empty slots") + // ErrZeroHashValue is used to indicate that the hash value is zero. + // Zero is a reserved value for empty slots in the hash table. + ErrZeroHashValue = errors.New("zero hash value") +) + +type entry[S Summary] struct { + Hash uint64 + Summary S +} + +func (e *entry[S]) reset() { + if e.Hash != 0 { + e.Summary.Reset() + } + e.Hash = 0 +} + +type hashtable[S Summary] struct { + entries []entry[S] + entryLessFunc func(a, b entry[S]) int + theta uint64 + seed uint64 + numEntries uint32 + p float32 + lgCurSize uint8 + lgNomSize uint8 + rf theta.ResizeFactor + isEmpty bool +} + +func newHashtable[S Summary](lgCurSize, lgNomSize uint8, rf theta.ResizeFactor, p float32, theta, seed uint64, isEmpty bool) *hashtable[S] { + sketch := &hashtable[S]{ + isEmpty: isEmpty, + lgCurSize: lgCurSize, + lgNomSize: lgNomSize, + rf: rf, + p: p, + numEntries: 0, + theta: theta, + seed: seed, + entries: nil, + entryLessFunc: func(a, b entry[S]) int { + if a.Hash < b.Hash { + return -1 + } else if a.Hash > b.Hash { + return 1 + } + return 0 + }, + } + + if lgCurSize > 0 { + size := 1 << lgCurSize + sketch.entries = make([]entry[S], size) + } + + return sketch +} + +// HashStringAndScreen computes the hash of string and checks if it passes theta threshold +func (t *hashtable[S]) HashStringAndScreen(data string) (uint64, error) { + t.isEmpty = false + h1, _ := internal.HashCharSliceMurmur3([]byte(data), 0, len(data), t.seed) + hash := h1 >> 1 + if hash >= t.theta { + return 0, fmt.Errorf("hash %d is greater than or equal to theta %d", hash, t.theta) + } + if hash == 0 { + return 0, ErrZeroHashValue + } + return hash, nil +} + +// HashInt32AndScreen computes the hash of int32 and checks if it passes theta threshold +func (t *hashtable[S]) HashInt32AndScreen(data int32) (uint64, error) { + t.isEmpty = false + h1, _ := internal.HashInt32SliceMurmur3([]int32{data}, 0, 1, t.seed) + hash := h1 >> 1 + if hash >= t.theta { + return 0, fmt.Errorf("hash %d is greater than or equal to theta %d", hash, t.theta) + } + if hash == 0 { + return 0, ErrZeroHashValue + } + return hash, nil +} + +// HashInt64AndScreen computes the hash of int64 and checks if it passes theta threshold +func (t *hashtable[S]) HashInt64AndScreen(data int64) (uint64, error) { + t.isEmpty = false + h1, _ := internal.HashInt64SliceMurmur3([]int64{data}, 0, 1, t.seed) + hash := h1 >> 1 + if hash >= t.theta { + return 0, fmt.Errorf("hash %d is greater than or equal to theta %d", hash, t.theta) + } + if hash == 0 { + return 0, ErrZeroHashValue + } + return hash, nil +} + +// HashBytesAndScreen computes the hash of bytes and checks if it passes theta threshold +func (t *hashtable[S]) HashBytesAndScreen(data []byte) (uint64, error) { + t.isEmpty = false + h1, _ := internal.HashByteArrMurmur3(data, 0, len(data), t.seed) + hash := h1 >> 1 + if hash >= t.theta { + return 0, fmt.Errorf("hash %d is greater than or equal to theta %d", hash, t.theta) + } + if hash == 0 { + return 0, ErrZeroHashValue + } + return hash, nil +} + +// Find searches for an entry in the hash table and returns the index if found, +// or an error if not found +func (t *hashtable[S]) Find(key uint64) (int, error) { + return find(t.entries, t.lgCurSize, key) +} + +func find[S Summary](entries []entry[S], lgSize uint8, key uint64) (int, error) { + size := uint32(1 << lgSize) + mask := size - 1 + stride := computeStride(key, lgSize) + index := uint32(key) & mask + + loopIndex := index + for { + probe := entries[index] + if probe.Hash == 0 { + return int(index), ErrKeyNotFound + } else if probe.Hash == key { + return int(index), nil + } + + index = (index + stride) & mask + if index == loopIndex { + return 0, ErrKeyNotFoundAndNoEmptySlots + } + } +} + +// computeStride computes the stride for probing +func computeStride(key uint64, lgSize uint8) uint32 { + // odd and independent of the index assuming lg_size lowest bits of the key were used for the index + return (2 * uint32((key>>lgSize)&strideMask)) + 1 +} + +// Insert inserts an entry at the given index +func (t *hashtable[S]) Insert(index int, entry entry[S]) { + t.entries[index] = entry + t.numEntries++ + + if t.numEntries > computeCapacity(t.lgCurSize, t.lgNomSize) { + if t.lgCurSize <= t.lgNomSize { + t.resize() + } else { + t.rebuild() + } + } +} + +func computeCapacity(lgCurSize, lgNomSize uint8) uint32 { + var fraction float64 + if lgCurSize <= lgNomSize { + fraction = resizeThreshold + } else { + fraction = rebuildThreshold + } + return uint32(math.Floor(fraction * float64(uint32(1)< uint32(1<= 0 && sketch.entries[index] == insertedEntry { + foundCount++ + } + } + + assert.Greater(t, foundCount, 0, "Some entries should still be accessible after rebuild") + }) +} + +func TestHashtable_Trim(t *testing.T) { + t.Run("rebuild", func(t *testing.T) { + lgNomSize := uint8(3) + lgCurSize := uint8(5) + sketch := newHashtable[*float64Summary](lgCurSize, lgNomSize, theta.ResizeX2, 1.0, theta.MaxTheta, theta.DefaultSeed, true) + + // Insert entries exceeding nominal size + numToInsert := 20 + for i := 0; i < numToInsert; i++ { + e := entry[*float64Summary]{ + Hash: uint64(i + 5000), + } + index, err := sketch.Find(e.Hash) + if err == nil { + continue + } + + sketch.entries[index] = e + sketch.numEntries++ + } + + initialNumEntries := sketch.numEntries + nominalSize := uint32(1 << lgNomSize) + + assert.Greater(t, initialNumEntries, nominalSize, "numEntries should exceed nominal size before Trim") + + sketch.Trim() + + assert.Equal(t, nominalSize, sketch.numEntries, "After Trim, numEntries should equal nominal size") + assert.Less(t, sketch.theta, theta.MaxTheta, "Theta should decrease after Trim") + }) + + t.Run("no op", func(t *testing.T) { + lgNomSize := uint8(4) + lgCurSize := uint8(4) + sketch := newHashtable[*float64Summary](lgCurSize, lgNomSize, theta.ResizeX2, 1.0, theta.MaxTheta, theta.DefaultSeed, true) + + // Insert fewer entries than the nominal size + numToInsert := 5 + for i := 0; i < numToInsert; i++ { + e := entry[*float64Summary]{ + Hash: uint64(i + 6000), + } + index, err := sketch.Find(e.Hash) + if err == nil { + continue + } + + sketch.entries[index] = e + sketch.numEntries++ + } + + initialNumEntries := sketch.numEntries + initialTheta := sketch.theta + nominalSize := uint32(1 << lgNomSize) + + assert.Less(t, initialNumEntries, nominalSize, "numEntries should be less than nominal size") + + sketch.Trim() + + assert.Equal(t, initialNumEntries, sketch.numEntries, "numEntries should not change when less than nominal size") + assert.Equal(t, initialTheta, sketch.theta, "Theta should not change when entries <= nominal size") + }) +} + +func TestHashtable_Reset(t *testing.T) { + sketch := newHashtable[*float64Summary](4, 4, theta.ResizeX1, 0.5, theta.MaxTheta, theta.DefaultSeed, false) + + sketch.entries[0] = entry[*float64Summary]{ + Hash: uint64(100), + } + sketch.entries[5] = entry[*float64Summary]{ + Hash: uint64(200), + } + sketch.numEntries = 2 + sketch.isEmpty = false + + sketch.Reset() + + assert.True(t, sketch.isEmpty) + assert.Zero(t, sketch.numEntries) + // Verify all entries are zero + for i, e := range sketch.entries { + assert.Zero(t, e, "entry at index %d should be zero after reset", i) + } + + expectedTheta := startingThetaFromP(sketch.p) + assert.Equal(t, expectedTheta, sketch.theta, "theta should be %d after reset", expectedTheta) +} diff --git a/tuple/sketch.go b/tuple/sketch.go new file mode 100644 index 0000000..822fe62 --- /dev/null +++ b/tuple/sketch.go @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tuple + +import ( + "iter" +) + +// Summary is the base interface for all summary types used in tuple sketches. +// A summary holds aggregate data associated with each retained hash key. +type Summary interface { + // Reset clears the content of the summary, restoring it to its initial state. + Reset() + // Clone creates and returns a deep copy of the current Summary instance. + Clone() Summary +} + +// Sketch is the base interface for tuple sketches. +// It extends Theta sketch to associate arbitrary summaries with each retained key. +type Sketch[S Summary] interface { + // IsEmpty reports whether this sketch represents an empty set. + // Note: this is not the same as having no retained hashes. + IsEmpty() bool + + // Estimate returns the estimated distinct count of the input stream. + Estimate() float64 + + // LowerBoundFromSubset returns the approximate lower error bound for + // the given number of standard deviations over a subset of retained hashes. + // numStdDevs specifies the confidence level (1, 2, or 3) corresponding to + // approximately 67%, 95%, or 99% confidence intervals. + // numSubsetEntries specifies number of items from {0, 1, ..., get_num_retained()} + // over which to estimate the bound. + LowerBoundFromSubset(numStdDevs uint8, numSubsetEntries uint32) (float64, error) + + // LowerBound returns the approximate lower error bound for the given + // number of standard deviations. numStdDevs should be 1, 2, or 3 for + // approximately 67%, 95%, or 99% confidence intervals. + LowerBound(numStdDevs uint8) (float64, error) + + // UpperBoundFromSubset returns the approximate upper error bound for + // the given number of standard deviations over a subset of retained hashes. + // numStdDevs specifies the confidence level (1, 2, or 3) corresponding to + // approximately 67%, 95%, or 99% confidence intervals. + // numSubsetEntries specifies number of items from {0, 1, ..., get_num_retained()} + // over which to estimate the bound. + UpperBoundFromSubset(numStdDevs uint8, numSubsetEntries uint32) (float64, error) + + // UpperBound returns the approximate upper error bound for the given + // number of standard deviations. numStdDevs should be 1, 2, or 3 for + // approximately 67%, 95%, or 99% confidence intervals. + UpperBound(numStdDevs uint8) (float64, error) + + // IsEstimationMode reports whether the sketch is in estimation mode, + // as opposed to exact mode. + IsEstimationMode() bool + + // Theta returns theta as a fraction from 0 to 1, representing the + // effective sampling rate. + Theta() float64 + + // Theta64 returns theta as a positive integer between 0 and math.MaxUint64. + Theta64() uint64 + + // NumRetained returns the number of hashes retained in the sketch. + NumRetained() uint32 + + // SeedHash returns the hash of the seed used to hash the input. + SeedHash() (uint16, error) + + // IsOrdered reports whether retained hashes are sorted by hash value. + IsOrdered() bool + + // String returns a human-readable summary of this sketch. + // If printItems is true, the output includes all retained hashes. + String(shouldPrintItems bool) string + + // All returns an iterator over all hash-summary pairs in the sketch. + All() iter.Seq2[uint64, S] +} diff --git a/tuple/testing.go b/tuple/testing.go new file mode 100644 index 0000000..ff05207 --- /dev/null +++ b/tuple/testing.go @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tuple + +type int32Summary struct { + value int32 +} + +func (s *int32Summary) Reset() { + s.value = 0 +} + +func (s *int32Summary) Clone() Summary { + return &int32Summary{ + value: s.value, + } +} + +func (s *int32Summary) Update(value int32) { + s.value += value +} + +func newInt32Summary() *int32Summary { + return &int32Summary{} +} + +type float64Summary struct { + value float64 +} + +func (s *float64Summary) Reset() { + s.value = 0 +} + +func (s *float64Summary) Clone() Summary { + return &float64Summary{value: s.value} +} + +func (s *float64Summary) Update(value float64) { + s.value += value +} + +func newFloat64Summary() *float64Summary { + return &float64Summary{} +}