Skip to content

Commit cec377d

Browse files
committed
Add support for "SHA256CTR" filters
1 parent b5469a2 commit cec377d

File tree

3 files changed

+77
-17
lines changed

3 files changed

+77
-17
lines changed

filtercascade/__init__.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
log = logging.getLogger(__name__)
1616

1717

18+
def byte_length(bit_length):
19+
return (bit_length + 7) // 8
20+
21+
1822
class InvertedLogicException(Exception):
1923
def __init__(self, *, depth, exclude_count, include_len):
2024
self.message = (
@@ -54,12 +58,14 @@ def __init__(
5458
nHashFuncs,
5559
level,
5660
hashAlg=fileformats.HashAlgorithm.MURMUR3,
61+
hashOffset=0,
5762
salt=None,
5863
):
5964
self.nHashFuncs = nHashFuncs
6065
self.size = size
6166
self.level = level
6267
self.hashAlg = fileformats.HashAlgorithm(hashAlg)
68+
self.hashOffset = hashOffset
6369
self.salt = salt
6470

6571
self.bitarray = bitarray.bitarray(self.size, endian="little")
@@ -99,6 +105,23 @@ def hash(self, *, hash_no, key):
99105
)
100106
return h
101107

108+
if self.hashAlg == fileformats.HashAlgorithm.SHA256CTR:
109+
b = []
110+
bytes_needed = byte_length(self.size.bit_length())
111+
offset = self.hashOffset + hash_no * bytes_needed
112+
while len(b) < bytes_needed:
113+
m = hashlib.sha256()
114+
m.update(fileformats.bloomer_sha256ctr_hash_struct.pack(offset // 32))
115+
m.update(self.salt)
116+
m.update(key)
117+
digest = m.digest()
118+
i = offset % 32
119+
x = digest[i : i + bytes_needed - len(b)]
120+
b.extend(x)
121+
offset += len(x)
122+
h = int.from_bytes(b, byteorder="little", signed=False) % self.size
123+
return h
124+
102125
raise Exception(f"Unknown hash algorithm: {self.hashAlg}")
103126

104127
def add(self, key):
@@ -136,13 +159,19 @@ def filter_with_characteristics(
136159
elements,
137160
falsePositiveRate,
138161
hashAlg=fileformats.HashAlgorithm.MURMUR3,
162+
hashOffset=0,
139163
salt=None,
140164
level=1,
141165
):
142166
nHashFuncs = Bloomer.calc_n_hashes(falsePositiveRate)
143167
size = Bloomer.calc_size(nHashFuncs, elements, falsePositiveRate)
144168
return Bloomer(
145-
size=size, nHashFuncs=nHashFuncs, level=level, hashAlg=hashAlg, salt=salt
169+
size=size,
170+
nHashFuncs=nHashFuncs,
171+
level=level,
172+
hashAlg=hashAlg,
173+
hashOffset=hashOffset,
174+
salt=salt,
146175
)
147176

148177
@classmethod
@@ -161,7 +190,7 @@ def calc_size(cls, nHashFuncs, elements, falsePositiveRate):
161190
min_bits = math.ceil(1.44 * elements * math.log2(1 / falsePositiveRate))
162191
assert min_bits > 0, "Always must have a positive number of bits"
163192
# Ensure the result is divisible by 8 for full bytes
164-
return 8 * math.ceil(min_bits / 8)
193+
return 8 * byte_length(min_bits)
165194

166195
@classmethod
167196
def from_buf(cls, buf, salt=None):
@@ -206,10 +235,10 @@ def __init__(
206235
invertedLogic=None,
207236
):
208237
"""
209-
Construct a FilterCascade.
210-
error_rates: If not supplied, defaults will be calculated
211-
invertedLogic: If not supplied (or left as None), it will be auto-
212-
detected.
238+
Construct a FilterCascade.
239+
error_rates: If not supplied, defaults will be calculated
240+
invertedLogic: If not supplied (or left as None), it will be auto-
241+
detected.
213242
"""
214243
self.filters = filters or []
215244
self.growth_factor = growth_factor
@@ -250,10 +279,10 @@ def set_crlite_error_rates(self, *, include_len, exclude_len):
250279

251280
def initialize(self, *, include, exclude):
252281
"""
253-
Arg "exclude" is potentially larger than main memory, so it should
254-
be assumed to be passed as a lazy-loading iterator. If it isn't,
255-
that's fine. The "include" arg must fit in memory and should be
256-
assumed to be a set.
282+
Arg "exclude" is potentially larger than main memory, so it should
283+
be assumed to be passed as a lazy-loading iterator. If it isn't,
284+
that's fine. The "include" arg must fit in memory and should be
285+
assumed to be a set.
257286
"""
258287
try:
259288
iter(exclude)
@@ -286,6 +315,13 @@ def initialize(self, *, include, exclude):
286315
er = self.error_rates[depth - 1]
287316

288317
if depth > len(self.filters):
318+
if len(self.filters) == 0:
319+
hashOffset = 0
320+
else:
321+
prev = self.filters[-1]
322+
hashOffset = prev.hashOffset + prev.nHashFuncs * byte_length(
323+
prev.size.bit_length()
324+
)
289325
self.filters.append(
290326
Bloomer.filter_with_characteristics(
291327
elements=max(
@@ -296,10 +332,15 @@ def initialize(self, *, include, exclude):
296332
falsePositiveRate=er,
297333
level=depth,
298334
hashAlg=self.defaultHashAlg,
335+
hashOffset=hashOffset,
299336
)
300337
)
301338
else:
302339
# Filter already created for this layer. Check size and resize if needed.
340+
prev = self.filters[depth - 1]
341+
hashOffset = prev.hashOffset + prev.nHashFuncs * byte_length(
342+
prev.size.bit_length()
343+
)
303344
required_size = Bloomer.calc_size(
304345
self.filters[depth - 1].nHashFuncs, include_len, er
305346
)
@@ -310,6 +351,7 @@ def initialize(self, *, include, exclude):
310351
falsePositiveRate=er,
311352
level=depth,
312353
hashAlg=self.defaultHashAlg,
354+
hashOffset=hashOffset,
313355
)
314356
log.info(
315357
f"Resized filter at {depth}-depth layer to {self.filters[depth - 1].size}"

filtercascade/fileformats.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
class HashAlgorithm(IntEnum):
88
MURMUR3 = 1
99
SHA256 = 2
10+
SHA256CTR = 3
1011

1112

1213
# The header for each Bloom filter level
@@ -24,6 +25,11 @@ class HashAlgorithm(IntEnum):
2425
# byte 4: layer number of this bloom filter, as an unsigned char
2526
bloomer_sha256_hash_struct = struct.Struct(b"<IB")
2627

28+
# This struct packs a single counter in 4 bytes for the SHA256CTR mode
29+
# Little endian (<)
30+
# byte 0-3: hash iteration counter, as an unsigned int
31+
bloomer_sha256ctr_hash_struct = struct.Struct(b"<I")
32+
2733
# The version struct is a simple 2-byte short indicating version number
2834
# Little endian (<)
2935
# bytes 0-1: The version number of this filter, as an unsigned short

filtercascade/test_filtercascade.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def assertBloomerEqual(self, b1, b2):
5959
self.assertEqual(b1.size, b2.size)
6060
self.assertEqual(b1.level, b2.level)
6161
self.assertEqual(b1.hashAlg, b2.hashAlg)
62-
self.assertEqual(b1.bitarray.length(), b2.bitarray.length())
62+
self.assertEqual(len(b1.bitarray), len(b2.bitarray))
6363
self.assertEqual(b1.bitarray, b2.bitarray)
6464

6565
def assertFilterCascadeEqual(self, f1, f2):
@@ -375,7 +375,8 @@ class TestFilterCascadeSalts(unittest.TestCase):
375375
def test_non_byte_salt(self):
376376
with self.assertRaises(ValueError):
377377
filtercascade.FilterCascade(
378-
defaultHashAlg=filtercascade.fileformats.HashAlgorithm.SHA256, salt=64,
378+
defaultHashAlg=filtercascade.fileformats.HashAlgorithm.SHA256,
379+
salt=64,
379380
)
380381

381382
def test_murmur_with_salt(self):
@@ -470,8 +471,8 @@ def test_increased_false_positive_rate_in_deeper_layer(self):
470471

471472

472473
class TestFilterCascadeAlgorithms(unittest.TestCase):
473-
def verify_minimum_sets(self, *, hashAlg):
474-
fc = filtercascade.FilterCascade(defaultHashAlg=hashAlg)
474+
def verify_minimum_sets(self, *, hashAlg, salt):
475+
fc = filtercascade.FilterCascade(defaultHashAlg=hashAlg, salt=salt)
475476

476477
iterator, small_set = get_serial_iterator_and_set(num_iterator=10, num_set=1)
477478
fc.initialize(include=small_set, exclude=iterator)
@@ -481,19 +482,30 @@ def verify_minimum_sets(self, *, hashAlg):
481482

482483
f = MockFile()
483484
fc.tofile(f)
484-
self.assertEqual(len(f.data), 1030)
485+
self.assertEqual(
486+
len(f.data), 1030 + (len(salt) if isinstance(salt, bytes) else 0)
487+
)
485488

486489
fc2 = filtercascade.FilterCascade.from_buf(f)
487490
iterator2, small_set2 = get_serial_iterator_and_set(num_iterator=10, num_set=1)
488491
fc2.verify(include=small_set2, exclude=iterator2)
489492

490493
def test_murmurhash3(self):
491494
self.verify_minimum_sets(
492-
hashAlg=filtercascade.fileformats.HashAlgorithm.MURMUR3
495+
hashAlg=filtercascade.fileformats.HashAlgorithm.MURMUR3,
496+
salt=None,
493497
)
494498

495499
def test_sha256(self):
496-
self.verify_minimum_sets(hashAlg=filtercascade.fileformats.HashAlgorithm.SHA256)
500+
self.verify_minimum_sets(
501+
hashAlg=filtercascade.fileformats.HashAlgorithm.SHA256,
502+
salt=None,
503+
)
504+
505+
def test_sha256ctr(self):
506+
self.verify_minimum_sets(
507+
hashAlg=filtercascade.fileformats.HashAlgorithm.SHA256CTR, salt=b"salt"
508+
)
497509

498510

499511
if __name__ == "__main__":

0 commit comments

Comments
 (0)