Skip to content

Commit 2e19d8e

Browse files
committed
Add support for "SHA256CTR" filters
1 parent b5469a2 commit 2e19d8e

File tree

3 files changed

+76
-18
lines changed

3 files changed

+76
-18
lines changed

filtercascade/__init__.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
log = logging.getLogger(__name__)
1616

1717

18+
def byte_length(bit_length):
19+
return (bit_length + 7) // 8
20+
21+
1822
class InvertedLogicException(Exception):
1923
def __init__(self, *, depth, exclude_count, include_len):
2024
self.message = (
@@ -54,12 +58,14 @@ def __init__(
5458
nHashFuncs,
5559
level,
5660
hashAlg=fileformats.HashAlgorithm.MURMUR3,
61+
hashOffset=0,
5762
salt=None,
5863
):
5964
self.nHashFuncs = nHashFuncs
6065
self.size = size
6166
self.level = level
6267
self.hashAlg = fileformats.HashAlgorithm(hashAlg)
68+
self.hashOffset = hashOffset
6369
self.salt = salt
6470

6571
self.bitarray = bitarray.bitarray(self.size, endian="little")
@@ -91,14 +97,33 @@ def hash(self, *, hash_no, key):
9197
m = hashlib.sha256()
9298
if self.salt:
9399
m.update(self.salt)
94-
m.update(fileformats.bloomer_sha256_hash_struct.pack(hash_no, self.level))
100+
m.update(
101+
fileformats.bloomer_sha256_hash_struct.pack(hash_no, self.level)
102+
)
95103
m.update(key)
96104
h = (
97105
int.from_bytes(m.digest()[:4], byteorder="little", signed=False)
98106
% self.size
99107
)
100108
return h
101109

110+
if self.hashAlg == fileformats.HashAlgorithm.SHA256CTR:
111+
b = []
112+
bytes_needed = byte_length(self.size.bit_length())
113+
offset = self.hashOffset + hash_no * bytes_needed
114+
while len(b) < bytes_needed:
115+
m = hashlib.sha256()
116+
m.update(fileformats.bloomer_sha256ctr_hash_struct.pack(offset // 32))
117+
m.update(self.salt)
118+
m.update(key)
119+
digest = m.digest()
120+
i = offset % 32
121+
x = digest[i : i + bytes_needed - len(b)]
122+
b.extend(x)
123+
offset += len(x)
124+
h = int.from_bytes(b, byteorder="little", signed=False) % self.size
125+
return h
126+
102127
raise Exception(f"Unknown hash algorithm: {self.hashAlg}")
103128

104129
def add(self, key):
@@ -136,13 +161,19 @@ def filter_with_characteristics(
136161
elements,
137162
falsePositiveRate,
138163
hashAlg=fileformats.HashAlgorithm.MURMUR3,
164+
hashOffset=0,
139165
salt=None,
140166
level=1,
141167
):
142168
nHashFuncs = Bloomer.calc_n_hashes(falsePositiveRate)
143169
size = Bloomer.calc_size(nHashFuncs, elements, falsePositiveRate)
144170
return Bloomer(
145-
size=size, nHashFuncs=nHashFuncs, level=level, hashAlg=hashAlg, salt=salt
171+
size=size,
172+
nHashFuncs=nHashFuncs,
173+
level=level,
174+
hashAlg=hashAlg,
175+
hashOffset=hashOffset,
176+
salt=salt,
146177
)
147178

148179
@classmethod
@@ -161,7 +192,7 @@ def calc_size(cls, nHashFuncs, elements, falsePositiveRate):
161192
min_bits = math.ceil(1.44 * elements * math.log2(1 / falsePositiveRate))
162193
assert min_bits > 0, "Always must have a positive number of bits"
163194
# Ensure the result is divisible by 8 for full bytes
164-
return 8 * math.ceil(min_bits / 8)
195+
return 8 * byte_length(min_bits)
165196

166197
@classmethod
167198
def from_buf(cls, buf, salt=None):
@@ -206,10 +237,10 @@ def __init__(
206237
invertedLogic=None,
207238
):
208239
"""
209-
Construct a FilterCascade.
210-
error_rates: If not supplied, defaults will be calculated
211-
invertedLogic: If not supplied (or left as None), it will be auto-
212-
detected.
240+
Construct a FilterCascade.
241+
error_rates: If not supplied, defaults will be calculated
242+
invertedLogic: If not supplied (or left as None), it will be auto-
243+
detected.
213244
"""
214245
self.filters = filters or []
215246
self.growth_factor = growth_factor
@@ -250,10 +281,10 @@ def set_crlite_error_rates(self, *, include_len, exclude_len):
250281

251282
def initialize(self, *, include, exclude):
252283
"""
253-
Arg "exclude" is potentially larger than main memory, so it should
254-
be assumed to be passed as a lazy-loading iterator. If it isn't,
255-
that's fine. The "include" arg must fit in memory and should be
256-
assumed to be a set.
284+
Arg "exclude" is potentially larger than main memory, so it should
285+
be assumed to be passed as a lazy-loading iterator. If it isn't,
286+
that's fine. The "include" arg must fit in memory and should be
287+
assumed to be a set.
257288
"""
258289
try:
259290
iter(exclude)
@@ -286,6 +317,13 @@ def initialize(self, *, include, exclude):
286317
er = self.error_rates[depth - 1]
287318

288319
if depth > len(self.filters):
320+
if len(self.filters) == 0:
321+
hashOffset = 0
322+
else:
323+
prev = self.filters[-1]
324+
hashOffset = prev.hashOffset + prev.nHashFuncs * byte_length(
325+
prev.size.bit_length()
326+
)
289327
self.filters.append(
290328
Bloomer.filter_with_characteristics(
291329
elements=max(
@@ -296,10 +334,15 @@ def initialize(self, *, include, exclude):
296334
falsePositiveRate=er,
297335
level=depth,
298336
hashAlg=self.defaultHashAlg,
337+
hashOffset=hashOffset,
299338
)
300339
)
301340
else:
302341
# Filter already created for this layer. Check size and resize if needed.
342+
prev = self.filters[depth - 1]
343+
hashOffset = prev.hashOffset + prev.nHashFuncs * byte_length(
344+
prev.size.bit_length()
345+
)
303346
required_size = Bloomer.calc_size(
304347
self.filters[depth - 1].nHashFuncs, include_len, er
305348
)
@@ -310,6 +353,7 @@ def initialize(self, *, include, exclude):
310353
falsePositiveRate=er,
311354
level=depth,
312355
hashAlg=self.defaultHashAlg,
356+
hashOffset=hashOffset,
313357
)
314358
log.info(
315359
f"Resized filter at {depth}-depth layer to {self.filters[depth - 1].size}"

filtercascade/fileformats.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
class HashAlgorithm(IntEnum):
88
MURMUR3 = 1
99
SHA256 = 2
10+
SHA256CTR = 3
1011

1112

1213
# The header for each Bloom filter level
@@ -24,6 +25,11 @@ class HashAlgorithm(IntEnum):
2425
# byte 4: layer number of this bloom filter, as an unsigned char
2526
bloomer_sha256_hash_struct = struct.Struct(b"<IB")
2627

28+
# This struct packs a single counter in 4 bytes for the SHA256CTR mode
29+
# Little endian (<)
30+
# byte 0-3: hash iteration counter, as an unsigned int
31+
bloomer_sha256ctr_hash_struct = struct.Struct(b"<I")
32+
2733
# The version struct is a simple 2-byte short indicating version number
2834
# Little endian (<)
2935
# bytes 0-1: The version number of this filter, as an unsigned short

filtercascade/test_filtercascade.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def assertBloomerEqual(self, b1, b2):
5959
self.assertEqual(b1.size, b2.size)
6060
self.assertEqual(b1.level, b2.level)
6161
self.assertEqual(b1.hashAlg, b2.hashAlg)
62-
self.assertEqual(b1.bitarray.length(), b2.bitarray.length())
62+
self.assertEqual(len(b1.bitarray), len(b2.bitarray))
6363
self.assertEqual(b1.bitarray, b2.bitarray)
6464

6565
def assertFilterCascadeEqual(self, f1, f2):
@@ -375,7 +375,8 @@ class TestFilterCascadeSalts(unittest.TestCase):
375375
def test_non_byte_salt(self):
376376
with self.assertRaises(ValueError):
377377
filtercascade.FilterCascade(
378-
defaultHashAlg=filtercascade.fileformats.HashAlgorithm.SHA256, salt=64,
378+
defaultHashAlg=filtercascade.fileformats.HashAlgorithm.SHA256,
379+
salt=64,
379380
)
380381

381382
def test_murmur_with_salt(self):
@@ -470,8 +471,8 @@ def test_increased_false_positive_rate_in_deeper_layer(self):
470471

471472

472473
class TestFilterCascadeAlgorithms(unittest.TestCase):
473-
def verify_minimum_sets(self, *, hashAlg):
474-
fc = filtercascade.FilterCascade(defaultHashAlg=hashAlg)
474+
def verify_minimum_sets(self, *, hashAlg, salt):
475+
fc = filtercascade.FilterCascade(defaultHashAlg=hashAlg, salt=salt)
475476

476477
iterator, small_set = get_serial_iterator_and_set(num_iterator=10, num_set=1)
477478
fc.initialize(include=small_set, exclude=iterator)
@@ -481,19 +482,26 @@ def verify_minimum_sets(self, *, hashAlg):
481482

482483
f = MockFile()
483484
fc.tofile(f)
484-
self.assertEqual(len(f.data), 1030)
485+
self.assertEqual(len(f.data), 1030 + (len(salt) if isinstance(salt, bytes) else 0))
485486

486487
fc2 = filtercascade.FilterCascade.from_buf(f)
487488
iterator2, small_set2 = get_serial_iterator_and_set(num_iterator=10, num_set=1)
488489
fc2.verify(include=small_set2, exclude=iterator2)
489490

490491
def test_murmurhash3(self):
491492
self.verify_minimum_sets(
492-
hashAlg=filtercascade.fileformats.HashAlgorithm.MURMUR3
493+
hashAlg=filtercascade.fileformats.HashAlgorithm.MURMUR3,
494+
salt=None,
493495
)
494496

495497
def test_sha256(self):
496-
self.verify_minimum_sets(hashAlg=filtercascade.fileformats.HashAlgorithm.SHA256)
498+
self.verify_minimum_sets(
499+
hashAlg=filtercascade.fileformats.HashAlgorithm.SHA256,
500+
salt=None,
501+
)
502+
503+
def test_sha256ctr(self):
504+
self.verify_minimum_sets(hashAlg=filtercascade.fileformats.HashAlgorithm.SHA256CTR, salt=b"salt")
497505

498506

499507
if __name__ == "__main__":

0 commit comments

Comments
 (0)