Skip to content

Commit 9fb7b56

Browse files
committed
feat(integer): add KVStore
The KVStore is a Hash Table, with homomorphic capabilities The keys are meant to be clear integers, values are meant to be Radix/SignedRadix The ServerKey now has functions to be able to do operations that modify an existing key,value pair using an encrypted key.
1 parent 24feeb8 commit 9fb7b56

File tree

6 files changed

+874
-15
lines changed

6 files changed

+874
-15
lines changed

tfhe/src/integer/server_key/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub(crate) mod crt;
77
mod crt_parallel;
88
pub(crate) mod radix;
99
pub(crate) mod radix_parallel;
10+
pub use radix_parallel::kv_store::KVStore;
1011

1112
use super::backward_compatibility::server_key::{CompressedServerKeyVersions, ServerKeyVersions};
1213
use crate::conformance::ParameterSetConformant;
Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
use crate::integer::block_decomposition::{Decomposable, DecomposableInto};
2+
use crate::integer::prelude::ServerKeyDefaultCMux;
3+
use crate::integer::{BooleanBlock, IntegerRadixCiphertext, ServerKey};
4+
use crate::prelude::CastInto;
5+
use rayon::prelude::*;
6+
use std::collections::HashMap;
7+
use std::hash::Hash;
8+
use std::num::NonZeroUsize;
9+
10+
pub struct KVStore<Key, Ct> {
11+
data: HashMap<Key, Ct>,
12+
block_count: Option<NonZeroUsize>,
13+
}
14+
15+
impl<Key, Ct> KVStore<Key, Ct> {
16+
pub fn new() -> Self {
17+
Self {
18+
data: HashMap::new(),
19+
block_count: None,
20+
}
21+
}
22+
23+
pub fn get(&self, key: &Key) -> Option<&Ct>
24+
where
25+
Key: Eq + Hash,
26+
{
27+
self.data.get(key)
28+
}
29+
30+
/// Inserts the value for the key
31+
///
32+
/// Returns the previous value stored for the key if there was any
33+
///
34+
/// # Notes
35+
///
36+
/// If the value does not contain blocks, nothing is inserted and None is returned
37+
///
38+
/// # Panics
39+
///
40+
/// Panics if the number of blocks of the value is not the same as all other
41+
/// values stored
42+
pub fn insert(&mut self, key: Key, value: Ct) -> Option<Ct>
43+
where
44+
Key: PartialEq + Ord + Eq + Hash,
45+
Ct: IntegerRadixCiphertext,
46+
{
47+
let n_blocks = value.blocks().len();
48+
if n_blocks == 0 {
49+
return None;
50+
}
51+
52+
let n = self
53+
.block_count
54+
.get_or_insert_with(|| NonZeroUsize::new(n_blocks).unwrap());
55+
56+
assert_eq!(
57+
n.get(),
58+
n_blocks,
59+
"All ciphertexts must have the same number of blocks"
60+
);
61+
self.data.insert(key, value)
62+
}
63+
64+
pub fn len(&self) -> usize {
65+
self.data.len()
66+
}
67+
68+
pub fn is_empty(&self) -> bool {
69+
self.data.is_empty()
70+
}
71+
72+
pub fn iter(&self) -> impl Iterator<Item = (&Key, &Ct)>
73+
where
74+
Key: Eq + Hash + Sync,
75+
Ct: Send,
76+
{
77+
self.data.iter()
78+
}
79+
80+
fn par_iter_keys(&self) -> impl ParallelIterator<Item = &Key>
81+
where
82+
Key: Send + Sync + Hash + Eq,
83+
Ct: Send + Sync,
84+
{
85+
self.data.par_iter().map(|(k, _)| k)
86+
}
87+
}
88+
89+
impl<Key, Ct> Default for KVStore<Key, Ct>
90+
where
91+
Self: Sized,
92+
{
93+
fn default() -> Self {
94+
Self::new()
95+
}
96+
}
97+
98+
impl ServerKey {
99+
/// Internal function used to perform a binary operation
100+
/// on an entry.
101+
///
102+
/// `encrypted_key`: The key of the slot
103+
/// `func`: function that receives to arguments:
104+
/// * A boolean block that encrypts `true` if the corresponding key is the same as the
105+
/// `encrypted_key`
106+
/// * a `& mut` to the ciphertext which stores the value
107+
fn kv_store_binary_op_to_slot<Key, Ct, F>(
108+
&self,
109+
map: &mut KVStore<Key, Ct>,
110+
encrypted_key: &Ct,
111+
func: F,
112+
) where
113+
Ct: IntegerRadixCiphertext,
114+
Key: Decomposable + CastInto<usize> + Hash + Eq,
115+
F: Fn(&BooleanBlock, &mut Ct) + Sync + Send,
116+
{
117+
let kv_vec: Vec<(&Key, &mut Ct)> = map.data.iter_mut().collect();
118+
119+
// For each clear key, get a boolean ciphertext that tells if it's
120+
// equal to the encrypted key
121+
let selectors =
122+
self.compute_equality_selectors(encrypted_key, kv_vec.par_iter().map(|(k, _v)| **k));
123+
124+
kv_vec
125+
.into_par_iter()
126+
.zip(selectors.par_iter())
127+
.for_each(|((_k, current_ct), selector)| func(selector, current_ct));
128+
}
129+
130+
/// Performs an addition on an entry of the store
131+
///
132+
/// `map[encrypted_key] += value`
133+
///
134+
/// This finds the value that corresponds to the given `encrypted_key `
135+
/// and adds `value` to it.
136+
pub fn kv_store_add_to_slot<Key, Ct>(
137+
&self,
138+
map: &mut KVStore<Key, Ct>,
139+
encrypted_key: &Ct,
140+
value: &Ct,
141+
) where
142+
Ct: IntegerRadixCiphertext,
143+
Key: Decomposable + CastInto<usize> + Hash + Eq,
144+
{
145+
self.kv_store_binary_op_to_slot(map, encrypted_key, |selector, v| {
146+
let mut ct_to_add = value.clone();
147+
self.zero_out_if_condition_is_false(&mut ct_to_add, &selector.0);
148+
self.add_assign_parallelized(v, &ct_to_add);
149+
});
150+
}
151+
152+
/// Performs an addition by a clear on an entry of the store
153+
///
154+
/// `map[encrypted_key] += value`
155+
///
156+
/// This finds the value that corresponds to the given `encrypted_key `
157+
/// and adds `value` to it.
158+
pub fn kv_store_scalar_add_to_slot<Key, Ct, Clear>(
159+
&self,
160+
map: &mut KVStore<Key, Ct>,
161+
encrypted_key: &Ct,
162+
value: Clear,
163+
) where
164+
Ct: IntegerRadixCiphertext,
165+
Key: Decomposable + CastInto<usize> + Hash + Eq,
166+
Clear: DecomposableInto<u64>,
167+
{
168+
self.kv_store_binary_op_to_slot(map, encrypted_key, |selector, v| {
169+
let ct_to_add =
170+
self.scalar_cmux_parallelized(selector, value, Clear::ZERO, v.blocks().len());
171+
self.add_assign_parallelized(v, &ct_to_add);
172+
});
173+
}
174+
175+
/// Performs a subtraction on an entry of the store
176+
///
177+
/// `map[encrypted_key] -= value`
178+
///
179+
/// This finds the value that corresponds to the given `encrypted_key`,
180+
/// and subtracts `value` to it.
181+
pub fn kv_store_sub_to_slot<Key, Ct>(
182+
&self,
183+
map: &mut KVStore<Key, Ct>,
184+
encrypted_key: &Ct,
185+
value: &Ct,
186+
) where
187+
Ct: IntegerRadixCiphertext,
188+
Key: Decomposable + CastInto<usize> + Hash + Eq,
189+
{
190+
self.kv_store_binary_op_to_slot(map, encrypted_key, |selector, v| {
191+
let mut ct_to_sub = value.clone();
192+
self.zero_out_if_condition_is_false(&mut ct_to_sub, &selector.0);
193+
self.sub_assign_parallelized(v, &ct_to_sub);
194+
});
195+
}
196+
197+
/// Performs a multiplication on an entry of the store
198+
///
199+
/// `map[encrypted_key] *= value`
200+
///
201+
/// This finds the value that corresponds to the given `encrypted_key`,
202+
/// and multiplies it by `value`.
203+
pub fn kv_store_mul_to_slot<Key, Ct>(
204+
&self,
205+
map: &mut KVStore<Key, Ct>,
206+
encrypted_key: &Ct,
207+
value: &Ct,
208+
) where
209+
Ct: IntegerRadixCiphertext,
210+
Key: Decomposable + CastInto<usize> + Hash + Eq,
211+
Self: for<'a> ServerKeyDefaultCMux<u64, &'a Ct, Output = Ct>,
212+
{
213+
self.kv_store_binary_op_to_slot(map, encrypted_key, |selector, v| {
214+
let selector = self.boolean_bitnot(selector);
215+
let ct_to_mul = self.if_then_else_parallelized(&selector, 1u64, value);
216+
self.mul_assign_parallelized(v, &ct_to_mul);
217+
});
218+
}
219+
220+
/// Implementation of the get function that additionally returns the Vec of selectors
221+
/// so it can be reused to avoid re-computing it.
222+
fn kv_store_get_impl<Key, Ct>(
223+
&self,
224+
map: &KVStore<Key, Ct>,
225+
encrypted_key: &Ct,
226+
) -> (Ct, BooleanBlock, Vec<BooleanBlock>)
227+
where
228+
Ct: IntegerRadixCiphertext,
229+
Key: Decomposable + CastInto<usize> + Hash + Eq,
230+
{
231+
let selectors =
232+
self.compute_equality_selectors(encrypted_key, map.par_iter_keys().copied());
233+
234+
let (result, check_block) = rayon::join(
235+
|| {
236+
let kv_vec: Vec<(&Key, &Ct)> = map.data.iter().collect();
237+
let one_hot = kv_vec
238+
.into_par_iter()
239+
.zip(selectors.par_iter())
240+
.map(|((_, v), s)| {
241+
let mut result = v.clone();
242+
self.zero_out_if_condition_is_false(&mut result, &s.0);
243+
result
244+
})
245+
.collect::<Vec<_>>();
246+
247+
self.aggregate_one_hot_vector(one_hot)
248+
},
249+
|| {
250+
let selectors = selectors.iter().map(|s| s.0.clone()).collect::<Vec<_>>();
251+
BooleanBlock::new_unchecked(self.is_at_least_one_comparisons_block_true(selectors))
252+
},
253+
);
254+
255+
(result, check_block, selectors)
256+
}
257+
258+
/// Returns the value at the given key
259+
///
260+
/// `return map[encrypted_key]`
261+
///
262+
/// This finds the value that corresponds to the given `encrypted_key`,
263+
/// and returns it.
264+
/// It also returns a boolean block that encrypts `true` if an entry for
265+
/// the `encrypted_key` was found.
266+
///
267+
/// If the key was not found, the returned value is an encryption of zero
268+
pub fn kv_store_get<Key, Ct>(
269+
&self,
270+
map: &KVStore<Key, Ct>,
271+
encrypted_key: &Ct,
272+
) -> (Ct, BooleanBlock)
273+
where
274+
Ct: IntegerRadixCiphertext,
275+
Key: Decomposable + CastInto<usize> + Hash + Eq,
276+
{
277+
let (result, check_block, _selectors) = self.kv_store_get_impl(map, encrypted_key);
278+
(result, check_block)
279+
}
280+
281+
/// Updates the value at the given key by the given value
282+
///
283+
/// `map[encrypted_key] = new_value`
284+
///
285+
/// This finds the value that corresponds to the given `encrypted_key`,
286+
/// then updates the value stored with the `new_value`.
287+
///
288+
/// Returns a boolean block that encrypts `true` if an entry for
289+
/// the `encrypted_key` was found, and thus the update was done
290+
pub fn kv_store_update<Key, Ct>(
291+
&self,
292+
map: &mut KVStore<Key, Ct>,
293+
encrypted_key: &Ct,
294+
new_value: &Ct,
295+
) -> BooleanBlock
296+
where
297+
Ct: IntegerRadixCiphertext,
298+
Key: Decomposable + CastInto<usize> + Hash + Eq,
299+
{
300+
let selectors =
301+
self.compute_equality_selectors(encrypted_key, map.par_iter_keys().copied());
302+
303+
rayon::join(
304+
|| {
305+
let kv_vec: Vec<(&Key, &mut Ct)> = map.data.iter_mut().collect();
306+
kv_vec
307+
.into_par_iter()
308+
.zip(selectors.par_iter())
309+
.for_each(|((_, old_value), s)| {
310+
*old_value = self.if_then_else_parallelized(s, new_value, old_value);
311+
});
312+
},
313+
|| {
314+
let selectors = selectors.iter().map(|s| s.0.clone()).collect::<Vec<_>>();
315+
BooleanBlock::new_unchecked(self.is_at_least_one_comparisons_block_true(selectors))
316+
},
317+
)
318+
.1
319+
}
320+
321+
/// Updates the value at the given key by applying a function
322+
///
323+
/// `map[encrypted_key] = func(map[encrypted_value])`
324+
///
325+
/// This finds the value that corresponds to the given `encrypted_key`, then
326+
/// calls `func` then updates the value stored with the one returned by the `func`.
327+
///
328+
/// Returns the new value and a boolean block that encrypts `true` if an entry for
329+
/// the `encrypted_key` was found.
330+
pub fn kv_store_map<Key, Ct, F>(
331+
&self,
332+
map: &mut KVStore<Key, Ct>,
333+
encrypted_key: &Ct,
334+
func: F,
335+
) -> (Ct, BooleanBlock)
336+
where
337+
Ct: IntegerRadixCiphertext,
338+
Key: Decomposable + CastInto<usize> + Hash + Eq,
339+
F: Fn(Ct) -> Ct,
340+
{
341+
let (result, check_block, selectors) = self.kv_store_get_impl(map, encrypted_key);
342+
let new_value = func(result);
343+
344+
let kv_vec: Vec<(&Key, &mut Ct)> = map.data.iter_mut().collect();
345+
kv_vec
346+
.into_par_iter()
347+
.zip(selectors.par_iter())
348+
.for_each(|((_, old_value), s)| {
349+
*old_value = self.if_then_else_parallelized(s, &new_value, old_value);
350+
});
351+
352+
(new_value, check_block)
353+
}
354+
}

tfhe/src/integer/server_key/radix_parallel/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ mod sum;
2424

2525
mod count_zeros_ones;
2626
pub(crate) mod ilog2;
27+
pub(crate) mod kv_store;
2728
mod reverse_bits;
2829
mod scalar_dot_prod;
2930
mod slice;

tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pub(crate) mod test_comparison;
88
mod test_count_zeros_ones;
99
pub(crate) mod test_div_mod;
1010
pub(crate) mod test_ilog2;
11+
pub(crate) mod test_kv_store;
1112
pub(crate) mod test_mul;
1213
pub(crate) mod test_neg;
1314
pub(crate) mod test_rotate;

0 commit comments

Comments
 (0)