5
5
#include " integer/integer.cuh"
6
6
#include < cstdint>
7
7
8
+ // //////////////////////////////////
9
+ // Helper structures used in expand
10
+ template <typename Torus> struct lwe_mask {
11
+ Torus *mask;
12
+
13
+ lwe_mask (Torus *mask) : mask{mask} {}
14
+ };
15
+
16
+ template <typename Torus> struct compact_lwe_body {
17
+ Torus *body;
18
+ uint64_t monomial_degree;
19
+
20
+ /* Body id is the index of the body in the compact ciphertext list.
21
+ * It's used to compute the rotation.
22
+ */
23
+ compact_lwe_body (Torus *body, const uint64_t body_id)
24
+ : body{body}, monomial_degree{body_id} {}
25
+ };
26
+
27
+ template <typename Torus> struct compact_lwe_list {
28
+ Torus *ptr;
29
+ uint32_t lwe_dimension;
30
+ uint32_t total_num_lwes;
31
+
32
+ compact_lwe_list (Torus *ptr, uint32_t lwe_dimension, uint32_t total_num_lwes)
33
+ : ptr{ptr}, lwe_dimension{lwe_dimension}, total_num_lwes{total_num_lwes} {
34
+ }
35
+
36
+ lwe_mask<Torus> get_mask () { return lwe_mask (ptr); }
37
+
38
+ // Returns the index-th body
39
+ compact_lwe_body<Torus> get_body (uint32_t index) {
40
+ if (index >= total_num_lwes) {
41
+ PANIC (" index out of range in compact_lwe_list::get_body" );
42
+ }
43
+
44
+ return compact_lwe_body (&ptr[lwe_dimension + index], uint64_t (index));
45
+ }
46
+ };
47
+
48
+ template <typename Torus> struct flattened_compact_lwe_lists {
49
+ Torus *d_ptr;
50
+ Torus **d_ptr_to_compact_list;
51
+ const uint32_t *h_num_lwes_per_compact_list;
52
+ uint32_t num_compact_lists;
53
+ uint32_t lwe_dimension;
54
+ uint32_t total_num_lwes;
55
+
56
+ flattened_compact_lwe_lists (Torus *d_ptr,
57
+ const uint32_t *h_num_lwes_per_compact_list,
58
+ uint32_t num_compact_lists,
59
+ uint32_t lwe_dimension)
60
+ : d_ptr(d_ptr), h_num_lwes_per_compact_list(h_num_lwes_per_compact_list),
61
+ num_compact_lists (num_compact_lists), lwe_dimension(lwe_dimension) {
62
+ d_ptr_to_compact_list =
63
+ static_cast <Torus **>(malloc (num_compact_lists * sizeof (Torus **)));
64
+ total_num_lwes = 0 ;
65
+ auto curr_list = d_ptr;
66
+ for (auto i = 0 ; i < num_compact_lists; ++i) {
67
+ total_num_lwes += h_num_lwes_per_compact_list[i];
68
+ d_ptr_to_compact_list[i] = curr_list;
69
+ curr_list += lwe_dimension + h_num_lwes_per_compact_list[i];
70
+ }
71
+ }
72
+
73
+ compact_lwe_list<Torus> get_device_compact_list (uint32_t compact_list_index) {
74
+ if (compact_list_index >= num_compact_lists) {
75
+ PANIC (" index out of range in flattened_compact_lwe_lists::get" );
76
+ }
77
+
78
+ return compact_lwe_list (d_ptr_to_compact_list[compact_list_index],
79
+ lwe_dimension,
80
+ h_num_lwes_per_compact_list[compact_list_index]);
81
+ }
82
+ };
83
+
84
+ /*
85
+ * A expand_job tells the expand kernel exactly which input mask and body to use
86
+ * and what rotation to apply
87
+ */
88
+ template <typename Torus> struct expand_job {
89
+ lwe_mask<Torus> mask_to_use;
90
+ compact_lwe_body<Torus> body_to_use;
91
+
92
+ expand_job (lwe_mask<Torus> mask_to_use, compact_lwe_body<Torus> body_to_use)
93
+ : mask_to_use{mask_to_use}, body_to_use{body_to_use} {}
94
+ };
95
+
96
+ // //////////////////////////////////
97
+
8
98
template <typename Torus> struct zk_expand_mem {
9
99
int_radix_params computing_params;
10
100
int_radix_params casting_params;
@@ -17,11 +107,12 @@ template <typename Torus> struct zk_expand_mem {
17
107
Torus *tmp_expanded_lwes;
18
108
Torus *tmp_ksed_small_to_big_expanded_lwes;
19
109
20
- uint32_t *d_lwe_compact_input_indexes;
21
-
22
- uint32_t *d_body_id_per_compact_list;
23
110
bool gpu_memory_allocated;
24
111
112
+ uint32_t *num_lwes_per_compact_list;
113
+ expand_job<Torus> *d_expand_jobs;
114
+ expand_job<Torus> *h_expand_jobs;
115
+
25
116
zk_expand_mem (cudaStream_t const *streams, uint32_t const *gpu_indexes,
26
117
uint32_t gpu_count, int_radix_params computing_params,
27
118
int_radix_params casting_params, KS_TYPE casting_key_type,
@@ -33,9 +124,17 @@ template <typename Torus> struct zk_expand_mem {
33
124
casting_key_type(casting_key_type) {
34
125
35
126
gpu_memory_allocated = allocate_gpu_memory;
127
+
128
+ // We copy num_lwes_per_compact_list so we get protection against
129
+ // num_lwes_per_compact_list being freed while this buffer is still in use
130
+ this ->num_lwes_per_compact_list =
131
+ (uint32_t *)malloc (num_compact_lists * sizeof (uint32_t ));
132
+ memcpy (this ->num_lwes_per_compact_list , num_lwes_per_compact_list,
133
+ num_compact_lists * sizeof (uint32_t ));
134
+
36
135
num_lwes = 0 ;
37
136
for (int i = 0 ; i < num_compact_lists; i++) {
38
- num_lwes += num_lwes_per_compact_list[i];
137
+ num_lwes += this -> num_lwes_per_compact_list [i];
39
138
}
40
139
41
140
if (computing_params.carry_modulus != computing_params.message_modulus ) {
@@ -121,49 +220,14 @@ template <typename Torus> struct zk_expand_mem {
121
220
malloc (num_packed_msgs * num_lwes * sizeof (Torus)));
122
221
auto h_lut_indexes = static_cast <Torus *>(
123
222
malloc (num_packed_msgs * num_lwes * sizeof (Torus)));
124
- auto h_body_id_per_compact_list =
125
- static_cast <uint32_t *>(malloc (num_lwes * sizeof (uint32_t )));
126
- auto h_lwe_compact_input_indexes =
127
- static_cast <uint32_t *>(malloc (num_lwes * sizeof (uint32_t )));
128
-
129
- d_body_id_per_compact_list =
130
- static_cast <uint32_t *>(cuda_malloc_with_size_tracking_async (
131
- num_lwes * sizeof (uint32_t ), streams[0 ], gpu_indexes[0 ],
132
- size_tracker, allocate_gpu_memory));
133
- d_lwe_compact_input_indexes =
134
- static_cast <uint32_t *>(cuda_malloc_with_size_tracking_async (
135
- num_lwes * sizeof (uint32_t ), streams[0 ], gpu_indexes[0 ],
223
+
224
+ d_expand_jobs =
225
+ static_cast <expand_job<Torus> *>(cuda_malloc_with_size_tracking_async (
226
+ num_lwes * sizeof (expand_job<Torus>), streams[0 ], gpu_indexes[0 ],
136
227
size_tracker, allocate_gpu_memory));
137
228
138
- auto compact_list_id = 0 ;
139
- auto idx = 0 ;
140
- auto count = 0 ;
141
- // During flattening, all num_lwes LWEs from all compact lists are stored
142
- // sequentially on a Torus array. h_lwe_compact_input_indexes stores the
143
- // index of the first LWE related to the compact list that contains the i-th
144
- // LWE
145
- for (int i = 0 ; i < num_lwes; i++) {
146
- h_lwe_compact_input_indexes[i] = idx;
147
- count++;
148
- if (count == num_lwes_per_compact_list[compact_list_id]) {
149
- compact_list_id++;
150
- idx += casting_params.big_lwe_dimension + count;
151
- count = 0 ;
152
- }
153
- }
154
-
155
- // Stores the index of the i-th LWE (within each compact list) related to
156
- // the k-th compact list.
157
- auto offset = 0 ;
158
- for (int k = 0 ; k < num_compact_lists; k++) {
159
- auto num_lwes_in_kth_compact_list = num_lwes_per_compact_list[k];
160
- uint32_t body_count = 0 ;
161
- for (int i = 0 ; i < num_lwes_in_kth_compact_list; i++) {
162
- h_body_id_per_compact_list[i + offset] = body_count;
163
- body_count++;
164
- }
165
- offset += num_lwes_in_kth_compact_list;
166
- }
229
+ h_expand_jobs = static_cast <expand_job<Torus> *>(
230
+ malloc (num_lwes * sizeof (expand_job<Torus>)));
167
231
168
232
/*
169
233
* Each LWE contains encrypted data in both carry and message spaces
@@ -198,9 +262,9 @@ template <typename Torus> struct zk_expand_mem {
198
262
* num_packed_msgs to use the sanitization LUT (which ensures output is
199
263
* exactly 0 or 1).
200
264
*/
201
- offset = 0 ;
265
+ auto offset = 0 ;
202
266
for (int k = 0 ; k < num_compact_lists; k++) {
203
- auto num_lwes_in_kth = num_lwes_per_compact_list[k];
267
+ auto num_lwes_in_kth = this -> num_lwes_per_compact_list [k];
204
268
for (int i = 0 ; i < num_packed_msgs * num_lwes_in_kth; i++) {
205
269
auto lwe_index = i + num_packed_msgs * offset;
206
270
auto lwe_index_in_list = i % num_lwes_in_kth;
@@ -220,17 +284,9 @@ template <typename Torus> struct zk_expand_mem {
220
284
streams[0 ], gpu_indexes[0 ], h_indexes_in, h_indexes_out);
221
285
auto lut_indexes = message_and_carry_extract_luts->get_lut_indexes (0 , 0 );
222
286
223
- cuda_memcpy_with_size_tracking_async_to_gpu (
224
- d_lwe_compact_input_indexes, h_lwe_compact_input_indexes,
225
- num_lwes * sizeof (uint32_t ), streams[0 ], gpu_indexes[0 ],
226
- allocate_gpu_memory);
227
287
cuda_memcpy_with_size_tracking_async_to_gpu (
228
288
lut_indexes, h_lut_indexes, num_packed_msgs * num_lwes * sizeof (Torus),
229
289
streams[0 ], gpu_indexes[0 ], allocate_gpu_memory);
230
- cuda_memcpy_with_size_tracking_async_to_gpu (
231
- d_body_id_per_compact_list, h_body_id_per_compact_list,
232
- num_lwes * sizeof (uint32_t ), streams[0 ], gpu_indexes[0 ],
233
- allocate_gpu_memory);
234
290
235
291
auto active_gpu_count = get_active_gpu_count (2 * num_lwes, gpu_count);
236
292
message_and_carry_extract_luts->broadcast_lut (streams, gpu_indexes,
@@ -253,8 +309,6 @@ template <typename Torus> struct zk_expand_mem {
253
309
free (h_indexes_in);
254
310
free (h_indexes_out);
255
311
free (h_lut_indexes);
256
- free (h_body_id_per_compact_list);
257
- free (h_lwe_compact_input_indexes);
258
312
}
259
313
260
314
void release (cudaStream_t const *streams, uint32_t const *gpu_indexes,
@@ -263,15 +317,16 @@ template <typename Torus> struct zk_expand_mem {
263
317
message_and_carry_extract_luts->release (streams, gpu_indexes, gpu_count);
264
318
delete message_and_carry_extract_luts;
265
319
266
- cuda_drop_with_size_tracking_async (d_body_id_per_compact_list, streams[0 ],
267
- gpu_indexes[0 ], gpu_memory_allocated);
268
- cuda_drop_with_size_tracking_async (d_lwe_compact_input_indexes, streams[0 ],
269
- gpu_indexes[0 ], gpu_memory_allocated);
270
320
cuda_drop_with_size_tracking_async (tmp_expanded_lwes, streams[0 ],
271
321
gpu_indexes[0 ], gpu_memory_allocated);
272
322
cuda_drop_with_size_tracking_async (tmp_ksed_small_to_big_expanded_lwes,
273
323
streams[0 ], gpu_indexes[0 ],
274
324
gpu_memory_allocated);
325
+ cuda_drop_with_size_tracking_async (d_expand_jobs, streams[0 ],
326
+ gpu_indexes[0 ], gpu_memory_allocated);
327
+ cuda_synchronize_stream (streams[0 ], gpu_indexes[0 ]);
328
+ free (num_lwes_per_compact_list);
329
+ free (h_expand_jobs);
275
330
}
276
331
};
277
332
0 commit comments