diff --git a/include/binaryfusefilter.h b/include/binaryfusefilter.h index 14dbf4e..05d0a2c 100644 --- a/include/binaryfusefilter.h +++ b/include/binaryfusefilter.h @@ -279,6 +279,10 @@ static inline uint8_t binary_fuse_mod3(uint8_t x) { // many duplicated keys. static inline bool binary_fuse8_populate(uint64_t *keys, uint32_t size, binary_fuse8_t *filter) { + if (size != filter->Size) { + return false; + } + uint64_t rng_counter = 0x726b2b9d438b9d4d; filter->Seed = binary_fuse_rng_splitmix64(&rng_counter); uint64_t *reverseOrder = (uint64_t *)calloc((size + 1), sizeof(uint64_t)); @@ -569,6 +573,10 @@ static inline void binary_fuse16_free(binary_fuse16_t *filter) { // many duplicated keys. static inline bool binary_fuse16_populate(uint64_t *keys, uint32_t size, binary_fuse16_t *filter) { + if (size != filter->Size) { + return false; + } + uint64_t rng_counter = 0x726b2b9d438b9d4d; filter->Seed = binary_fuse_rng_splitmix64(&rng_counter); uint64_t *reverseOrder = (uint64_t *)calloc((size + 1), sizeof(uint64_t)); @@ -747,14 +755,14 @@ static inline bool binary_fuse16_populate(uint64_t *keys, uint32_t size, } static inline size_t binary_fuse16_serialization_bytes(binary_fuse16_t *filter) { - return sizeof(filter->Seed) + sizeof(filter->SegmentLength) + + return sizeof(filter->Seed) + sizeof(filter->Size) + sizeof(filter->SegmentLength) + sizeof(filter->SegmentLengthMask) + sizeof(filter->SegmentCount) + sizeof(filter->SegmentCountLength) + sizeof(filter->ArrayLength) + sizeof(uint16_t) * filter->ArrayLength; } static inline size_t binary_fuse8_serialization_bytes(const binary_fuse8_t *filter) { - return sizeof(filter->Seed) + sizeof(filter->SegmentLength) + + return sizeof(filter->Seed) + sizeof(filter->Size) + sizeof(filter->SegmentLength) + sizeof(filter->SegmentCount) + sizeof(filter->SegmentCountLength) + sizeof(filter->ArrayLength) + sizeof(uint8_t) * filter->ArrayLength; @@ -766,6 +774,8 @@ static inline size_t binary_fuse8_serialization_bytes(const binary_fuse8_t *filt static inline void binary_fuse16_serialize(const binary_fuse16_t *filter, char *buffer) { memcpy(buffer, &filter->Seed, sizeof(filter->Seed)); buffer += sizeof(filter->Seed); + memcpy(buffer, &filter->Size, sizeof(filter->Size)); + buffer += sizeof(filter->Size); memcpy(buffer, &filter->SegmentLength, sizeof(filter->SegmentLength)); buffer += sizeof(filter->SegmentLength); memcpy(buffer, &filter->SegmentCount, sizeof(filter->SegmentCount)); @@ -783,6 +793,8 @@ static inline void binary_fuse16_serialize(const binary_fuse16_t *filter, char * static inline void binary_fuse8_serialize(const binary_fuse8_t *filter, char *buffer) { memcpy(buffer, &filter->Seed, sizeof(filter->Seed)); buffer += sizeof(filter->Seed); + memcpy(buffer, &filter->Size, sizeof(filter->Size)); + buffer += sizeof(filter->Size); memcpy(buffer, &filter->SegmentLength, sizeof(filter->SegmentLength)); buffer += sizeof(filter->SegmentLength); memcpy(buffer, &filter->SegmentCount, sizeof(filter->SegmentCount)); @@ -802,6 +814,8 @@ static inline void binary_fuse8_serialize(const binary_fuse8_t *filter, char *bu static inline const char* binary_fuse16_deserialize_header(binary_fuse16_t* filter, const char* buffer) { memcpy(&filter->Seed, buffer, sizeof(filter->Seed)); buffer += sizeof(filter->Seed); + memcpy(&filter->Size, buffer, sizeof(filter->Size)); + buffer += sizeof(filter->Size); memcpy(&filter->SegmentLength, buffer, sizeof(filter->SegmentLength)); buffer += sizeof(filter->SegmentLength); filter->SegmentLengthMask = filter->SegmentLength - 1; @@ -837,6 +851,8 @@ static inline bool binary_fuse16_deserialize(binary_fuse16_t * filter, const cha static inline const char* binary_fuse8_deserialize_header(binary_fuse8_t* filter, const char* buffer) { memcpy(&filter->Seed, buffer, sizeof(filter->Seed)); buffer += sizeof(filter->Seed); + memcpy(&filter->Size, buffer, sizeof(filter->Size)); + buffer += sizeof(filter->Size); memcpy(&filter->SegmentLength, buffer, sizeof(filter->SegmentLength)); buffer += sizeof(filter->SegmentLength); filter->SegmentLengthMask = filter->SegmentLength - 1; diff --git a/tests/unit.c b/tests/unit.c index cd7ad0c..abe5725 100644 --- a/tests/unit.c +++ b/tests/unit.c @@ -191,6 +191,35 @@ bool testbufferedxor16(size_t size) { bool testbinaryfuse8(size_t size, size_t repeated_size) { printf("testing binary fuse8 with size %zu and %zu duplicates\n", size, repeated_size); binary_fuse8_t filter; + + // size serialization test + binary_fuse8_allocate((uint32_t)size, &filter); + uint64_t *big_set = (uint64_t *)malloc(sizeof(uint64_t) * size); + for (size_t i = 0; i < size; i++) { + big_set[i] = i; + } + if (!binary_fuse8_populate(big_set, (uint32_t)size, &filter)) { + return false; + } + free(big_set); + + size_t buffer_size = binary_fuse8_serialization_bytes(&filter); + char *buffer = (char *)malloc(buffer_size); + binary_fuse8_serialize(&filter, buffer); + binary_fuse8_free(&filter); + binary_fuse8_deserialize(&filter, buffer); + + if (filter.Size != size) { + printf("size not (de-)serialized correctly, found %d, expected %zu.", + filter.Size, size); + free(buffer); + binary_fuse8_free(&filter); + return false; + } + free(buffer); + binary_fuse8_free(&filter); + // end of size serialization test + return test(size, repeated_size, &filter, binary_fuse8_allocate_gen, binary_fuse8_free_gen, @@ -205,6 +234,36 @@ bool testbinaryfuse8(size_t size, size_t repeated_size) { bool testbinaryfuse16(size_t size, size_t repeated_size) { printf("testing binary fuse16 with size %zu and %zu duplicates\n", size, repeated_size); binary_fuse16_t filter; + + // size serialization test + binary_fuse16_allocate((uint32_t)size, &filter); + uint64_t *big_set = (uint64_t *)malloc(sizeof(uint64_t) * size); + for (size_t i = 0; i < size; i++) { + big_set[i] = i; + } + if (!binary_fuse16_populate(big_set, (uint32_t)size, &filter)) { + return false; + } + free(big_set); + + size_t buffer_size = binary_fuse16_serialization_bytes(&filter); + char *buffer = (char *)malloc(buffer_size); + binary_fuse16_serialize(&filter, buffer); + binary_fuse16_free(&filter); + binary_fuse16_deserialize(&filter, buffer); + + if (filter.Size != size) { + printf("size not (de-)serialized correctly, found %d, expected %zu.", + filter.Size, size); + free(buffer); + binary_fuse16_free(&filter); + return false; + } + + free(buffer); + binary_fuse16_free(&filter); + // end of size serialization test + return test(size, repeated_size, &filter, binary_fuse16_allocate_gen, binary_fuse16_free_gen,