Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 52 additions & 10 deletions filtercascade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
log = logging.getLogger(__name__)


def byte_length(bit_length):
return (bit_length + 7) // 8


class InvertedLogicException(Exception):
def __init__(self, *, depth, exclude_count, include_len):
self.message = (
Expand Down Expand Up @@ -54,12 +58,14 @@ def __init__(
nHashFuncs,
level,
hashAlg=fileformats.HashAlgorithm.MURMUR3,
hashOffset=0,
salt=None,
):
self.nHashFuncs = nHashFuncs
self.size = size
self.level = level
self.hashAlg = fileformats.HashAlgorithm(hashAlg)
self.hashOffset = hashOffset
self.salt = salt

self.bitarray = bitarray.bitarray(self.size, endian="little")
Expand Down Expand Up @@ -99,6 +105,23 @@ def hash(self, *, hash_no, key):
)
return h

if self.hashAlg == fileformats.HashAlgorithm.SHA256CTR:
b = []
bytes_needed = byte_length(self.size.bit_length())
offset = self.hashOffset + hash_no * bytes_needed
while len(b) < bytes_needed:
m = hashlib.sha256()
m.update(fileformats.bloomer_sha256ctr_hash_struct.pack(offset // 32))
m.update(self.salt)
m.update(key)
digest = m.digest()
i = offset % 32
x = digest[i : i + bytes_needed - len(b)]
b.extend(x)
offset += len(x)
h = int.from_bytes(b, byteorder="little", signed=False) % self.size
return h

raise Exception(f"Unknown hash algorithm: {self.hashAlg}")

def add(self, key):
Expand Down Expand Up @@ -136,13 +159,19 @@ def filter_with_characteristics(
elements,
falsePositiveRate,
hashAlg=fileformats.HashAlgorithm.MURMUR3,
hashOffset=0,
salt=None,
level=1,
):
nHashFuncs = Bloomer.calc_n_hashes(falsePositiveRate)
size = Bloomer.calc_size(nHashFuncs, elements, falsePositiveRate)
return Bloomer(
size=size, nHashFuncs=nHashFuncs, level=level, hashAlg=hashAlg, salt=salt
size=size,
nHashFuncs=nHashFuncs,
level=level,
hashAlg=hashAlg,
hashOffset=hashOffset,
salt=salt,
)

@classmethod
Expand All @@ -161,7 +190,7 @@ def calc_size(cls, nHashFuncs, elements, falsePositiveRate):
min_bits = math.ceil(1.44 * elements * math.log2(1 / falsePositiveRate))
assert min_bits > 0, "Always must have a positive number of bits"
# Ensure the result is divisible by 8 for full bytes
return 8 * math.ceil(min_bits / 8)
return 8 * byte_length(min_bits)

@classmethod
def from_buf(cls, buf, salt=None):
Expand Down Expand Up @@ -206,10 +235,10 @@ def __init__(
invertedLogic=None,
):
"""
Construct a FilterCascade.
error_rates: If not supplied, defaults will be calculated
invertedLogic: If not supplied (or left as None), it will be auto-
detected.
Construct a FilterCascade.
error_rates: If not supplied, defaults will be calculated
invertedLogic: If not supplied (or left as None), it will be auto-
detected.
"""
self.filters = filters or []
self.growth_factor = growth_factor
Expand Down Expand Up @@ -250,10 +279,10 @@ def set_crlite_error_rates(self, *, include_len, exclude_len):

def initialize(self, *, include, exclude):
"""
Arg "exclude" is potentially larger than main memory, so it should
be assumed to be passed as a lazy-loading iterator. If it isn't,
that's fine. The "include" arg must fit in memory and should be
assumed to be a set.
Arg "exclude" is potentially larger than main memory, so it should
be assumed to be passed as a lazy-loading iterator. If it isn't,
that's fine. The "include" arg must fit in memory and should be
assumed to be a set.
"""
try:
iter(exclude)
Expand Down Expand Up @@ -286,6 +315,13 @@ def initialize(self, *, include, exclude):
er = self.error_rates[depth - 1]

if depth > len(self.filters):
if len(self.filters) == 0:
hashOffset = 0
else:
prev = self.filters[-1]
hashOffset = prev.hashOffset + prev.nHashFuncs * byte_length(
prev.size.bit_length()
)
self.filters.append(
Bloomer.filter_with_characteristics(
elements=max(
Expand All @@ -296,10 +332,15 @@ def initialize(self, *, include, exclude):
falsePositiveRate=er,
level=depth,
hashAlg=self.defaultHashAlg,
hashOffset=hashOffset,
)
)
else:
# Filter already created for this layer. Check size and resize if needed.
prev = self.filters[depth - 1]
hashOffset = prev.hashOffset + prev.nHashFuncs * byte_length(
prev.size.bit_length()
)
required_size = Bloomer.calc_size(
self.filters[depth - 1].nHashFuncs, include_len, er
)
Expand All @@ -310,6 +351,7 @@ def initialize(self, *, include, exclude):
falsePositiveRate=er,
level=depth,
hashAlg=self.defaultHashAlg,
hashOffset=hashOffset,
)
log.info(
f"Resized filter at {depth}-depth layer to {self.filters[depth - 1].size}"
Expand Down
6 changes: 6 additions & 0 deletions filtercascade/fileformats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
class HashAlgorithm(IntEnum):
MURMUR3 = 1
SHA256 = 2
SHA256CTR = 3


# The header for each Bloom filter level
Expand All @@ -24,6 +25,11 @@ class HashAlgorithm(IntEnum):
# byte 4: layer number of this bloom filter, as an unsigned char
bloomer_sha256_hash_struct = struct.Struct(b"<IB")

# This struct packs a single counter in 4 bytes for the SHA256CTR mode
# Little endian (<)
# byte 0-3: hash iteration counter, as an unsigned int
bloomer_sha256ctr_hash_struct = struct.Struct(b"<I")

# The version struct is a simple 2-byte short indicating version number
# Little endian (<)
# bytes 0-1: The version number of this filter, as an unsigned short
Expand Down
26 changes: 19 additions & 7 deletions filtercascade/test_filtercascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def assertBloomerEqual(self, b1, b2):
self.assertEqual(b1.size, b2.size)
self.assertEqual(b1.level, b2.level)
self.assertEqual(b1.hashAlg, b2.hashAlg)
self.assertEqual(b1.bitarray.length(), b2.bitarray.length())
self.assertEqual(len(b1.bitarray), len(b2.bitarray))
self.assertEqual(b1.bitarray, b2.bitarray)

def assertFilterCascadeEqual(self, f1, f2):
Expand Down Expand Up @@ -375,7 +375,8 @@ class TestFilterCascadeSalts(unittest.TestCase):
def test_non_byte_salt(self):
with self.assertRaises(ValueError):
filtercascade.FilterCascade(
defaultHashAlg=filtercascade.fileformats.HashAlgorithm.SHA256, salt=64,
defaultHashAlg=filtercascade.fileformats.HashAlgorithm.SHA256,
salt=64,
)

def test_murmur_with_salt(self):
Expand Down Expand Up @@ -470,8 +471,8 @@ def test_increased_false_positive_rate_in_deeper_layer(self):


class TestFilterCascadeAlgorithms(unittest.TestCase):
def verify_minimum_sets(self, *, hashAlg):
fc = filtercascade.FilterCascade(defaultHashAlg=hashAlg)
def verify_minimum_sets(self, *, hashAlg, salt):
fc = filtercascade.FilterCascade(defaultHashAlg=hashAlg, salt=salt)

iterator, small_set = get_serial_iterator_and_set(num_iterator=10, num_set=1)
fc.initialize(include=small_set, exclude=iterator)
Expand All @@ -481,19 +482,30 @@ def verify_minimum_sets(self, *, hashAlg):

f = MockFile()
fc.tofile(f)
self.assertEqual(len(f.data), 1030)
self.assertEqual(
len(f.data), 1030 + (len(salt) if isinstance(salt, bytes) else 0)
)

fc2 = filtercascade.FilterCascade.from_buf(f)
iterator2, small_set2 = get_serial_iterator_and_set(num_iterator=10, num_set=1)
fc2.verify(include=small_set2, exclude=iterator2)

def test_murmurhash3(self):
self.verify_minimum_sets(
hashAlg=filtercascade.fileformats.HashAlgorithm.MURMUR3
hashAlg=filtercascade.fileformats.HashAlgorithm.MURMUR3,
salt=None,
)

def test_sha256(self):
self.verify_minimum_sets(hashAlg=filtercascade.fileformats.HashAlgorithm.SHA256)
self.verify_minimum_sets(
hashAlg=filtercascade.fileformats.HashAlgorithm.SHA256,
salt=None,
)

def test_sha256ctr(self):
self.verify_minimum_sets(
hashAlg=filtercascade.fileformats.HashAlgorithm.SHA256CTR, salt=b"salt"
)


if __name__ == "__main__":
Expand Down