Skip to content

Commit b2624d1

Browse files
committed
chore(gpu): refactor the indexing logic for the LWE expand
1 parent 9fb7b56 commit b2624d1

File tree

7 files changed

+219
-177
lines changed

7 files changed

+219
-177
lines changed

backends/tfhe-cuda-backend/cuda/include/zk/zk.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,6 @@
66
#include <stdint.h>
77

88
extern "C" {
9-
10-
void cuda_lwe_expand_64(void *const stream, uint32_t gpu_index,
11-
void *lwe_array_out, const void *lwe_compact_array_in,
12-
uint32_t lwe_dimension, uint32_t num_lwe,
13-
const uint32_t *lwe_compact_input_indexes,
14-
const uint32_t *output_body_id_per_compact_list);
15-
169
uint64_t scratch_cuda_expand_without_verification_64(
1710
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
1811
int8_t **mem_ptr, uint32_t glwe_dimension, uint32_t polynomial_size,

backends/tfhe-cuda-backend/cuda/include/zk/zk_utilities.h

Lines changed: 116 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,96 @@
55
#include "integer/integer.cuh"
66
#include <cstdint>
77

8+
////////////////////////////////////
9+
// Helper structures used in expand
10+
template <typename Torus> struct lwe_mask {
11+
Torus *mask;
12+
13+
lwe_mask(Torus *mask) : mask{mask} {}
14+
};
15+
16+
template <typename Torus> struct compact_lwe_body {
17+
Torus *body;
18+
uint64_t monomial_degree;
19+
20+
/* Body id is the index of the body in the compact ciphertext list.
21+
* It's used to compute the rotation.
22+
*/
23+
compact_lwe_body(Torus *body, const uint64_t body_id)
24+
: body{body}, monomial_degree{body_id} {}
25+
};
26+
27+
template <typename Torus> struct compact_lwe_list {
28+
Torus *ptr;
29+
uint32_t lwe_dimension;
30+
uint32_t total_num_lwes;
31+
32+
compact_lwe_list(Torus *ptr, uint32_t lwe_dimension, uint32_t total_num_lwes)
33+
: ptr{ptr}, lwe_dimension{lwe_dimension}, total_num_lwes{total_num_lwes} {
34+
}
35+
36+
lwe_mask<Torus> get_mask() { return lwe_mask(ptr); }
37+
38+
// Returns the index-th body
39+
compact_lwe_body<Torus> get_body(uint32_t index) {
40+
if (index >= total_num_lwes) {
41+
PANIC("index out of range in compact_lwe_list::get_body");
42+
}
43+
44+
return compact_lwe_body(&ptr[lwe_dimension + index], uint64_t(index));
45+
}
46+
};
47+
48+
template <typename Torus> struct flattened_compact_lwe_lists {
49+
Torus *d_ptr;
50+
Torus **d_ptr_to_compact_list;
51+
const uint32_t *h_num_lwes_per_compact_list;
52+
uint32_t num_compact_lists;
53+
uint32_t lwe_dimension;
54+
uint32_t total_num_lwes;
55+
56+
flattened_compact_lwe_lists(Torus *d_ptr,
57+
const uint32_t *h_num_lwes_per_compact_list,
58+
uint32_t num_compact_lists,
59+
uint32_t lwe_dimension)
60+
: d_ptr(d_ptr), h_num_lwes_per_compact_list(h_num_lwes_per_compact_list),
61+
num_compact_lists(num_compact_lists), lwe_dimension(lwe_dimension) {
62+
d_ptr_to_compact_list =
63+
static_cast<Torus **>(malloc(num_compact_lists * sizeof(Torus **)));
64+
total_num_lwes = 0;
65+
auto curr_list = d_ptr;
66+
for (auto i = 0; i < num_compact_lists; ++i) {
67+
total_num_lwes += h_num_lwes_per_compact_list[i];
68+
d_ptr_to_compact_list[i] = curr_list;
69+
curr_list += lwe_dimension + h_num_lwes_per_compact_list[i];
70+
}
71+
}
72+
73+
compact_lwe_list<Torus> get_device_compact_list(uint32_t compact_list_index) {
74+
if (compact_list_index >= num_compact_lists) {
75+
PANIC("index out of range in flattened_compact_lwe_lists::get");
76+
}
77+
78+
return compact_lwe_list(d_ptr_to_compact_list[compact_list_index],
79+
lwe_dimension,
80+
h_num_lwes_per_compact_list[compact_list_index]);
81+
}
82+
};
83+
84+
/*
85+
* A expand_job tells the expand kernel exactly which input mask and body to use
86+
* and what rotation to apply
87+
*/
88+
template <typename Torus> struct expand_job {
89+
lwe_mask<Torus> mask_to_use;
90+
compact_lwe_body<Torus> body_to_use;
91+
92+
expand_job(lwe_mask<Torus> mask_to_use, compact_lwe_body<Torus> body_to_use)
93+
: mask_to_use{mask_to_use}, body_to_use{body_to_use} {}
94+
};
95+
96+
////////////////////////////////////
97+
898
template <typename Torus> struct zk_expand_mem {
999
int_radix_params computing_params;
10100
int_radix_params casting_params;
@@ -17,11 +107,12 @@ template <typename Torus> struct zk_expand_mem {
17107
Torus *tmp_expanded_lwes;
18108
Torus *tmp_ksed_small_to_big_expanded_lwes;
19109

20-
uint32_t *d_lwe_compact_input_indexes;
21-
22-
uint32_t *d_body_id_per_compact_list;
23110
bool gpu_memory_allocated;
24111

112+
uint32_t *num_lwes_per_compact_list;
113+
expand_job<Torus> *d_expand_jobs;
114+
expand_job<Torus> *h_expand_jobs;
115+
25116
zk_expand_mem(cudaStream_t const *streams, uint32_t const *gpu_indexes,
26117
uint32_t gpu_count, int_radix_params computing_params,
27118
int_radix_params casting_params, KS_TYPE casting_key_type,
@@ -33,9 +124,17 @@ template <typename Torus> struct zk_expand_mem {
33124
casting_key_type(casting_key_type) {
34125

35126
gpu_memory_allocated = allocate_gpu_memory;
127+
128+
// We copy num_lwes_per_compact_list so we get protection against
129+
// num_lwes_per_compact_list being freed while this buffer is still in use
130+
this->num_lwes_per_compact_list =
131+
(uint32_t *)malloc(num_compact_lists * sizeof(uint32_t));
132+
memcpy(this->num_lwes_per_compact_list, num_lwes_per_compact_list,
133+
num_compact_lists * sizeof(uint32_t));
134+
36135
num_lwes = 0;
37136
for (int i = 0; i < num_compact_lists; i++) {
38-
num_lwes += num_lwes_per_compact_list[i];
137+
num_lwes += this->num_lwes_per_compact_list[i];
39138
}
40139

41140
if (computing_params.carry_modulus != computing_params.message_modulus) {
@@ -121,49 +220,14 @@ template <typename Torus> struct zk_expand_mem {
121220
malloc(num_packed_msgs * num_lwes * sizeof(Torus)));
122221
auto h_lut_indexes = static_cast<Torus *>(
123222
malloc(num_packed_msgs * num_lwes * sizeof(Torus)));
124-
auto h_body_id_per_compact_list =
125-
static_cast<uint32_t *>(malloc(num_lwes * sizeof(uint32_t)));
126-
auto h_lwe_compact_input_indexes =
127-
static_cast<uint32_t *>(malloc(num_lwes * sizeof(uint32_t)));
128-
129-
d_body_id_per_compact_list =
130-
static_cast<uint32_t *>(cuda_malloc_with_size_tracking_async(
131-
num_lwes * sizeof(uint32_t), streams[0], gpu_indexes[0],
132-
size_tracker, allocate_gpu_memory));
133-
d_lwe_compact_input_indexes =
134-
static_cast<uint32_t *>(cuda_malloc_with_size_tracking_async(
135-
num_lwes * sizeof(uint32_t), streams[0], gpu_indexes[0],
223+
224+
d_expand_jobs =
225+
static_cast<expand_job<Torus> *>(cuda_malloc_with_size_tracking_async(
226+
num_lwes * sizeof(expand_job<Torus>), streams[0], gpu_indexes[0],
136227
size_tracker, allocate_gpu_memory));
137228

138-
auto compact_list_id = 0;
139-
auto idx = 0;
140-
auto count = 0;
141-
// During flattening, all num_lwes LWEs from all compact lists are stored
142-
// sequentially on a Torus array. h_lwe_compact_input_indexes stores the
143-
// index of the first LWE related to the compact list that contains the i-th
144-
// LWE
145-
for (int i = 0; i < num_lwes; i++) {
146-
h_lwe_compact_input_indexes[i] = idx;
147-
count++;
148-
if (count == num_lwes_per_compact_list[compact_list_id]) {
149-
compact_list_id++;
150-
idx += casting_params.big_lwe_dimension + count;
151-
count = 0;
152-
}
153-
}
154-
155-
// Stores the index of the i-th LWE (within each compact list) related to
156-
// the k-th compact list.
157-
auto offset = 0;
158-
for (int k = 0; k < num_compact_lists; k++) {
159-
auto num_lwes_in_kth_compact_list = num_lwes_per_compact_list[k];
160-
uint32_t body_count = 0;
161-
for (int i = 0; i < num_lwes_in_kth_compact_list; i++) {
162-
h_body_id_per_compact_list[i + offset] = body_count;
163-
body_count++;
164-
}
165-
offset += num_lwes_in_kth_compact_list;
166-
}
229+
h_expand_jobs = static_cast<expand_job<Torus> *>(
230+
malloc(num_lwes * sizeof(expand_job<Torus>)));
167231

168232
/*
169233
* Each LWE contains encrypted data in both carry and message spaces
@@ -198,9 +262,9 @@ template <typename Torus> struct zk_expand_mem {
198262
* num_packed_msgs to use the sanitization LUT (which ensures output is
199263
* exactly 0 or 1).
200264
*/
201-
offset = 0;
265+
auto offset = 0;
202266
for (int k = 0; k < num_compact_lists; k++) {
203-
auto num_lwes_in_kth = num_lwes_per_compact_list[k];
267+
auto num_lwes_in_kth = this->num_lwes_per_compact_list[k];
204268
for (int i = 0; i < num_packed_msgs * num_lwes_in_kth; i++) {
205269
auto lwe_index = i + num_packed_msgs * offset;
206270
auto lwe_index_in_list = i % num_lwes_in_kth;
@@ -220,17 +284,9 @@ template <typename Torus> struct zk_expand_mem {
220284
streams[0], gpu_indexes[0], h_indexes_in, h_indexes_out);
221285
auto lut_indexes = message_and_carry_extract_luts->get_lut_indexes(0, 0);
222286

223-
cuda_memcpy_with_size_tracking_async_to_gpu(
224-
d_lwe_compact_input_indexes, h_lwe_compact_input_indexes,
225-
num_lwes * sizeof(uint32_t), streams[0], gpu_indexes[0],
226-
allocate_gpu_memory);
227287
cuda_memcpy_with_size_tracking_async_to_gpu(
228288
lut_indexes, h_lut_indexes, num_packed_msgs * num_lwes * sizeof(Torus),
229289
streams[0], gpu_indexes[0], allocate_gpu_memory);
230-
cuda_memcpy_with_size_tracking_async_to_gpu(
231-
d_body_id_per_compact_list, h_body_id_per_compact_list,
232-
num_lwes * sizeof(uint32_t), streams[0], gpu_indexes[0],
233-
allocate_gpu_memory);
234290

235291
auto active_gpu_count = get_active_gpu_count(2 * num_lwes, gpu_count);
236292
message_and_carry_extract_luts->broadcast_lut(streams, gpu_indexes,
@@ -253,8 +309,6 @@ template <typename Torus> struct zk_expand_mem {
253309
free(h_indexes_in);
254310
free(h_indexes_out);
255311
free(h_lut_indexes);
256-
free(h_body_id_per_compact_list);
257-
free(h_lwe_compact_input_indexes);
258312
}
259313

260314
void release(cudaStream_t const *streams, uint32_t const *gpu_indexes,
@@ -263,15 +317,16 @@ template <typename Torus> struct zk_expand_mem {
263317
message_and_carry_extract_luts->release(streams, gpu_indexes, gpu_count);
264318
delete message_and_carry_extract_luts;
265319

266-
cuda_drop_with_size_tracking_async(d_body_id_per_compact_list, streams[0],
267-
gpu_indexes[0], gpu_memory_allocated);
268-
cuda_drop_with_size_tracking_async(d_lwe_compact_input_indexes, streams[0],
269-
gpu_indexes[0], gpu_memory_allocated);
270320
cuda_drop_with_size_tracking_async(tmp_expanded_lwes, streams[0],
271321
gpu_indexes[0], gpu_memory_allocated);
272322
cuda_drop_with_size_tracking_async(tmp_ksed_small_to_big_expanded_lwes,
273323
streams[0], gpu_indexes[0],
274324
gpu_memory_allocated);
325+
cuda_drop_with_size_tracking_async(d_expand_jobs, streams[0],
326+
gpu_indexes[0], gpu_memory_allocated);
327+
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
328+
free(num_lwes_per_compact_list);
329+
free(h_expand_jobs);
275330
}
276331
};
277332

backends/tfhe-cuda-backend/cuda/src/zk/expand.cu

Lines changed: 0 additions & 65 deletions
This file was deleted.

backends/tfhe-cuda-backend/cuda/src/zk/expand.cuh

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,35 @@
55
#include "polynomial/functions.cuh"
66
#include "polynomial/polynomial_math.cuh"
77
#include "zk/zk.h"
8+
#include "zk/zk_utilities.h"
89
#include <cstdint>
910

10-
#include "utils/helper.cuh"
11-
1211
// Expand a LweCompactCiphertextList into a LweCiphertextList
1312
// - Each x-block computes one output ciphertext
1413
template <typename Torus, class params>
15-
__global__ void lwe_expand(Torus const *lwe_compact_array_in,
16-
Torus *lwe_array_out,
17-
const uint32_t *lwe_compact_input_indexes,
18-
const uint32_t *output_body_id_per_compact_list) {
14+
__global__ void lwe_expand(const expand_job<Torus> *jobs,
15+
Torus *lwe_array_out) {
1916
const auto lwe_output_id = blockIdx.x;
2017
const auto lwe_dimension = params::degree;
2118

22-
const auto body_id = output_body_id_per_compact_list[lwe_output_id];
19+
const auto job = jobs[lwe_output_id];
2320

24-
const auto input_mask =
25-
&lwe_compact_array_in[lwe_compact_input_indexes[lwe_output_id]];
26-
const auto input_body = &input_mask[lwe_dimension + body_id];
21+
const lwe_mask<Torus> input_mask = job.mask_to_use;
22+
const compact_lwe_body<Torus> input_body = job.body_to_use;
2723

2824
auto output_mask = &lwe_array_out[(lwe_dimension + 1) * lwe_output_id];
2925
auto output_body = &output_mask[lwe_dimension];
3026

3127
// We rotate the input mask by i to calculate the mask related to the i-th
3228
// output
33-
const auto monomial_degree = body_id;
29+
const auto monomial_degree = input_body.monomial_degree;
3430
polynomial_accumulate_monic_monomial_mul<Torus>(
35-
output_mask, input_mask, monomial_degree, threadIdx.x, lwe_dimension,
31+
output_mask, input_mask.mask, monomial_degree, threadIdx.x, lwe_dimension,
3632
params::opt, true);
3733

3834
// The output body is just copied
3935
if (threadIdx.x == 0)
40-
*output_body = *input_body;
36+
*output_body = *input_body.body;
4137
}
4238

4339
template <typename Torus> bool is_power_of_2(Torus value) {
@@ -46,9 +42,7 @@ template <typename Torus> bool is_power_of_2(Torus value) {
4642

4743
template <typename Torus, class params>
4844
void host_lwe_expand(cudaStream_t stream, int gpu_index, Torus *lwe_array_out,
49-
const Torus *lwe_compact_array_in, uint32_t num_lwes,
50-
const uint32_t *lwe_compact_input_indexes,
51-
const uint32_t *output_body_id_per_compact_list) {
45+
const expand_job<Torus> *d_jobs, uint32_t num_lwes) {
5246
// Set the GPU device
5347
cudaSetDevice(gpu_index);
5448

@@ -63,9 +57,8 @@ void host_lwe_expand(cudaStream_t stream, int gpu_index, Torus *lwe_array_out,
6357
PANIC("Error: lwe_dimension must be a power of 2");
6458

6559
// Launch the `lwe_expand` kernel
66-
lwe_expand<Torus, params><<<num_blocks, threads_per_block, 0, stream>>>(
67-
lwe_compact_array_in, lwe_array_out, lwe_compact_input_indexes,
68-
output_body_id_per_compact_list);
60+
lwe_expand<Torus, params>
61+
<<<num_blocks, threads_per_block, 0, stream>>>(d_jobs, lwe_array_out);
6962
check_cuda_error(cudaGetLastError());
7063
}
7164
#endif // EXPAND_CUH

0 commit comments

Comments
 (0)