55#include " integer/integer.cuh"
66#include < cstdint>
77
8+ // //////////////////////////////////
9+ // Helper structures used in expand
10+ template <typename Torus> struct lwe_mask {
11+ Torus *mask;
12+
13+ lwe_mask (Torus *mask) : mask{mask} {}
14+ };
15+
16+ template <typename Torus> struct compact_lwe_body {
17+ Torus *body;
18+ uint64_t monomial_degree;
19+
20+ /* Body id is the index of the body in the compact ciphertext list.
21+ * It's used to compute the rotation.
22+ */
23+ compact_lwe_body (Torus *body, const uint64_t body_id)
24+ : body{body}, monomial_degree{body_id} {}
25+ };
26+
27+ template <typename Torus> struct compact_lwe_list {
28+ Torus *ptr;
29+ uint32_t lwe_dimension;
30+ uint32_t total_num_lwes;
31+
32+ compact_lwe_list (Torus *ptr, uint32_t lwe_dimension, uint32_t total_num_lwes)
33+ : ptr{ptr}, lwe_dimension{lwe_dimension}, total_num_lwes{total_num_lwes} {
34+ }
35+
36+ lwe_mask<Torus> get_mask () { return lwe_mask (ptr); }
37+
38+ // Returns the index-th body
39+ compact_lwe_body<Torus> get_body (uint32_t index) {
40+ if (index >= total_num_lwes) {
41+ PANIC (" index out of range in compact_lwe_list::get_body" );
42+ }
43+
44+ return compact_lwe_body (&ptr[lwe_dimension + index], uint64_t (index));
45+ }
46+ };
47+
48+ template <typename Torus> struct flattened_compact_lwe_lists {
49+ Torus *d_ptr;
50+ Torus **d_ptr_to_compact_list;
51+ const uint32_t *h_num_lwes_per_compact_list;
52+ uint32_t num_compact_lists;
53+ uint32_t lwe_dimension;
54+ uint32_t total_num_lwes;
55+
56+ flattened_compact_lwe_lists (Torus *d_ptr,
57+ const uint32_t *h_num_lwes_per_compact_list,
58+ uint32_t num_compact_lists,
59+ uint32_t lwe_dimension)
60+ : d_ptr(d_ptr), h_num_lwes_per_compact_list(h_num_lwes_per_compact_list),
61+ num_compact_lists (num_compact_lists), lwe_dimension(lwe_dimension) {
62+ d_ptr_to_compact_list =
63+ static_cast <Torus **>(malloc (num_compact_lists * sizeof (Torus **)));
64+ total_num_lwes = 0 ;
65+ auto curr_list = d_ptr;
66+ for (auto i = 0 ; i < num_compact_lists; ++i) {
67+ total_num_lwes += h_num_lwes_per_compact_list[i];
68+ d_ptr_to_compact_list[i] = curr_list;
69+ curr_list += lwe_dimension + h_num_lwes_per_compact_list[i];
70+ }
71+ }
72+
73+ compact_lwe_list<Torus> get_device_compact_list (uint32_t compact_list_index) {
74+ if (compact_list_index >= num_compact_lists) {
75+ PANIC (" index out of range in flattened_compact_lwe_lists::get" );
76+ }
77+
78+ return compact_lwe_list (d_ptr_to_compact_list[compact_list_index],
79+ lwe_dimension,
80+ h_num_lwes_per_compact_list[compact_list_index]);
81+ }
82+ };
83+
84+ /*
85+ * A expand_job tells the expand kernel exactly which input mask and body to use
86+ * and what rotation to apply
87+ */
88+ template <typename Torus> struct expand_job {
89+ lwe_mask<Torus> mask_to_use;
90+ compact_lwe_body<Torus> body_to_use;
91+
92+ expand_job (lwe_mask<Torus> mask_to_use, compact_lwe_body<Torus> body_to_use)
93+ : mask_to_use{mask_to_use}, body_to_use{body_to_use} {}
94+ };
95+
96+ // //////////////////////////////////
97+
898template <typename Torus> struct zk_expand_mem {
999 int_radix_params computing_params;
10100 int_radix_params casting_params;
@@ -17,11 +107,12 @@ template <typename Torus> struct zk_expand_mem {
17107 Torus *tmp_expanded_lwes;
18108 Torus *tmp_ksed_small_to_big_expanded_lwes;
19109
20- uint32_t *d_lwe_compact_input_indexes;
21-
22- uint32_t *d_body_id_per_compact_list;
23110 bool gpu_memory_allocated;
24111
112+ uint32_t *num_lwes_per_compact_list;
113+ expand_job<Torus> *d_expand_jobs;
114+ expand_job<Torus> *h_expand_jobs;
115+
25116 zk_expand_mem (cudaStream_t const *streams, uint32_t const *gpu_indexes,
26117 uint32_t gpu_count, int_radix_params computing_params,
27118 int_radix_params casting_params, KS_TYPE casting_key_type,
@@ -33,9 +124,17 @@ template <typename Torus> struct zk_expand_mem {
33124 casting_key_type(casting_key_type) {
34125
35126 gpu_memory_allocated = allocate_gpu_memory;
127+
128+ // We copy num_lwes_per_compact_list so we get protection against
129+ // num_lwes_per_compact_list being freed while this buffer is still in use
130+ this ->num_lwes_per_compact_list =
131+ (uint32_t *)malloc (num_compact_lists * sizeof (uint32_t ));
132+ memcpy (this ->num_lwes_per_compact_list , num_lwes_per_compact_list,
133+ num_compact_lists * sizeof (uint32_t ));
134+
36135 num_lwes = 0 ;
37136 for (int i = 0 ; i < num_compact_lists; i++) {
38- num_lwes += num_lwes_per_compact_list[i];
137+ num_lwes += this -> num_lwes_per_compact_list [i];
39138 }
40139
41140 if (computing_params.carry_modulus != computing_params.message_modulus ) {
@@ -121,49 +220,14 @@ template <typename Torus> struct zk_expand_mem {
121220 malloc (num_packed_msgs * num_lwes * sizeof (Torus)));
122221 auto h_lut_indexes = static_cast <Torus *>(
123222 malloc (num_packed_msgs * num_lwes * sizeof (Torus)));
124- auto h_body_id_per_compact_list =
125- static_cast <uint32_t *>(malloc (num_lwes * sizeof (uint32_t )));
126- auto h_lwe_compact_input_indexes =
127- static_cast <uint32_t *>(malloc (num_lwes * sizeof (uint32_t )));
128-
129- d_body_id_per_compact_list =
130- static_cast <uint32_t *>(cuda_malloc_with_size_tracking_async (
131- num_lwes * sizeof (uint32_t ), streams[0 ], gpu_indexes[0 ],
132- size_tracker, allocate_gpu_memory));
133- d_lwe_compact_input_indexes =
134- static_cast <uint32_t *>(cuda_malloc_with_size_tracking_async (
135- num_lwes * sizeof (uint32_t ), streams[0 ], gpu_indexes[0 ],
223+
224+ d_expand_jobs =
225+ static_cast <expand_job<Torus> *>(cuda_malloc_with_size_tracking_async (
226+ num_lwes * sizeof (expand_job<Torus>), streams[0 ], gpu_indexes[0 ],
136227 size_tracker, allocate_gpu_memory));
137228
138- auto compact_list_id = 0 ;
139- auto idx = 0 ;
140- auto count = 0 ;
141- // During flattening, all num_lwes LWEs from all compact lists are stored
142- // sequentially on a Torus array. h_lwe_compact_input_indexes stores the
143- // index of the first LWE related to the compact list that contains the i-th
144- // LWE
145- for (int i = 0 ; i < num_lwes; i++) {
146- h_lwe_compact_input_indexes[i] = idx;
147- count++;
148- if (count == num_lwes_per_compact_list[compact_list_id]) {
149- compact_list_id++;
150- idx += casting_params.big_lwe_dimension + count;
151- count = 0 ;
152- }
153- }
154-
155- // Stores the index of the i-th LWE (within each compact list) related to
156- // the k-th compact list.
157- auto offset = 0 ;
158- for (int k = 0 ; k < num_compact_lists; k++) {
159- auto num_lwes_in_kth_compact_list = num_lwes_per_compact_list[k];
160- uint32_t body_count = 0 ;
161- for (int i = 0 ; i < num_lwes_in_kth_compact_list; i++) {
162- h_body_id_per_compact_list[i + offset] = body_count;
163- body_count++;
164- }
165- offset += num_lwes_in_kth_compact_list;
166- }
229+ h_expand_jobs = static_cast <expand_job<Torus> *>(
230+ malloc (num_lwes * sizeof (expand_job<Torus>)));
167231
168232 /*
169233 * Each LWE contains encrypted data in both carry and message spaces
@@ -198,9 +262,9 @@ template <typename Torus> struct zk_expand_mem {
198262 * num_packed_msgs to use the sanitization LUT (which ensures output is
199263 * exactly 0 or 1).
200264 */
201- offset = 0 ;
265+ auto offset = 0 ;
202266 for (int k = 0 ; k < num_compact_lists; k++) {
203- auto num_lwes_in_kth = num_lwes_per_compact_list[k];
267+ auto num_lwes_in_kth = this -> num_lwes_per_compact_list [k];
204268 for (int i = 0 ; i < num_packed_msgs * num_lwes_in_kth; i++) {
205269 auto lwe_index = i + num_packed_msgs * offset;
206270 auto lwe_index_in_list = i % num_lwes_in_kth;
@@ -220,17 +284,9 @@ template <typename Torus> struct zk_expand_mem {
220284 streams[0 ], gpu_indexes[0 ], h_indexes_in, h_indexes_out);
221285 auto lut_indexes = message_and_carry_extract_luts->get_lut_indexes (0 , 0 );
222286
223- cuda_memcpy_with_size_tracking_async_to_gpu (
224- d_lwe_compact_input_indexes, h_lwe_compact_input_indexes,
225- num_lwes * sizeof (uint32_t ), streams[0 ], gpu_indexes[0 ],
226- allocate_gpu_memory);
227287 cuda_memcpy_with_size_tracking_async_to_gpu (
228288 lut_indexes, h_lut_indexes, num_packed_msgs * num_lwes * sizeof (Torus),
229289 streams[0 ], gpu_indexes[0 ], allocate_gpu_memory);
230- cuda_memcpy_with_size_tracking_async_to_gpu (
231- d_body_id_per_compact_list, h_body_id_per_compact_list,
232- num_lwes * sizeof (uint32_t ), streams[0 ], gpu_indexes[0 ],
233- allocate_gpu_memory);
234290
235291 auto active_gpu_count = get_active_gpu_count (2 * num_lwes, gpu_count);
236292 message_and_carry_extract_luts->broadcast_lut (streams, gpu_indexes,
@@ -253,8 +309,6 @@ template <typename Torus> struct zk_expand_mem {
253309 free (h_indexes_in);
254310 free (h_indexes_out);
255311 free (h_lut_indexes);
256- free (h_body_id_per_compact_list);
257- free (h_lwe_compact_input_indexes);
258312 }
259313
260314 void release (cudaStream_t const *streams, uint32_t const *gpu_indexes,
@@ -263,15 +317,16 @@ template <typename Torus> struct zk_expand_mem {
263317 message_and_carry_extract_luts->release (streams, gpu_indexes, gpu_count);
264318 delete message_and_carry_extract_luts;
265319
266- cuda_drop_with_size_tracking_async (d_body_id_per_compact_list, streams[0 ],
267- gpu_indexes[0 ], gpu_memory_allocated);
268- cuda_drop_with_size_tracking_async (d_lwe_compact_input_indexes, streams[0 ],
269- gpu_indexes[0 ], gpu_memory_allocated);
270320 cuda_drop_with_size_tracking_async (tmp_expanded_lwes, streams[0 ],
271321 gpu_indexes[0 ], gpu_memory_allocated);
272322 cuda_drop_with_size_tracking_async (tmp_ksed_small_to_big_expanded_lwes,
273323 streams[0 ], gpu_indexes[0 ],
274324 gpu_memory_allocated);
325+ cuda_drop_with_size_tracking_async (d_expand_jobs, streams[0 ],
326+ gpu_indexes[0 ], gpu_memory_allocated);
327+ cuda_synchronize_stream (streams[0 ], gpu_indexes[0 ]);
328+ free (num_lwes_per_compact_list);
329+ free (h_expand_jobs);
275330 }
276331};
277332
0 commit comments