Skip to content

Commit 91c8de9

Browse files
committed
chore(gpu): add a benchmark for 128-bit multi-bit noise squashing
- Also, remove the lut indexes concept from the 128-bit multi-bit pbs. It's assumed not to exist by the entire backend (as it doesn't for classical PBS). So to keep it here would be a bit error prone.
1 parent 1d35e27 commit 91c8de9

File tree

13 files changed

+260
-214
lines changed

13 files changed

+260
-214
lines changed

backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,8 +1066,18 @@ template <typename InputTorus> struct int_noise_squashing_lut {
10661066
release_radix_ciphertext_async(streams[0], gpu_indexes[0],
10671067
tmp_lwe_before_ks, gpu_memory_allocated);
10681068
for (int i = 0; i < pbs_buffer.size(); i++) {
1069-
cleanup_cuda_programmable_bootstrap_128(streams[i], gpu_indexes[i],
1070-
&pbs_buffer[i]);
1069+
switch (params.pbs_type) {
1070+
case MULTI_BIT:
1071+
cleanup_cuda_multi_bit_programmable_bootstrap_128(
1072+
streams[i], gpu_indexes[i], &pbs_buffer[i]);
1073+
break;
1074+
case CLASSICAL:
1075+
cleanup_cuda_programmable_bootstrap_128(streams[i], gpu_indexes[i],
1076+
&pbs_buffer[i]);
1077+
break;
1078+
default:
1079+
PANIC("Cuda error (PBS): unknown PBS type. ")
1080+
}
10711081
cuda_synchronize_stream(streams[i], gpu_indexes[i]);
10721082
}
10731083
if (lwe_aligned_gather_vec.size() > 0) {

backends/tfhe-cuda-backend/cuda/include/pbs/programmable_bootstrap_multibit.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,11 @@ uint64_t scratch_cuda_multi_bit_programmable_bootstrap_128_vector_64(
4747
void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_128(
4848
void *stream, uint32_t gpu_index, void *lwe_array_out,
4949
void const *lwe_output_indexes, void const *lut_vector,
50-
void const *lut_vector_indexes, void const *lwe_array_in,
51-
void const *lwe_input_indexes, void const *bootstrapping_key,
52-
int8_t *mem_ptr, uint32_t lwe_dimension, uint32_t glwe_dimension,
53-
uint32_t polynomial_size, uint32_t grouping_factor, uint32_t base_log,
54-
uint32_t level_count, uint32_t num_samples, uint32_t num_many_lut,
55-
uint32_t lut_stride);
50+
void const *lwe_array_in, void const *lwe_input_indexes,
51+
void const *bootstrapping_key, int8_t *mem_ptr, uint32_t lwe_dimension,
52+
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor,
53+
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
54+
uint32_t num_many_lut, uint32_t lut_stride);
5655

5756
void cleanup_cuda_multi_bit_programmable_bootstrap_128(void *stream,
5857
const uint32_t gpu_index,

backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -347,18 +347,12 @@ void execute_pbs_async(
347347
auto current_lwe_input_indexes =
348348
get_variant_element(lwe_input_indexes, i);
349349

350-
int gpu_offset =
351-
get_gpu_offset(input_lwe_ciphertext_count, i, gpu_count);
352-
auto d_lut_vector_indexes =
353-
lut_indexes_vec[i] + (ptrdiff_t)(gpu_offset);
354-
355350
cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_128(
356351
streams[i], gpu_indexes[i], current_lwe_array_out,
357-
current_lwe_output_indexes, lut_vec[i], d_lut_vector_indexes,
358-
current_lwe_array_in, current_lwe_input_indexes,
359-
bootstrapping_keys[i], pbs_buffer[i], lwe_dimension, glwe_dimension,
360-
polynomial_size, grouping_factor, base_log, level_count,
361-
num_inputs_on_gpu, num_many_lut, lut_stride);
352+
current_lwe_output_indexes, lut_vec[i], current_lwe_array_in,
353+
current_lwe_input_indexes, bootstrapping_keys[i], pbs_buffer[i],
354+
lwe_dimension, glwe_dimension, polynomial_size, grouping_factor,
355+
base_log, level_count, num_inputs_on_gpu, num_many_lut, lut_stride);
362356
}
363357
break;
364358
case CLASSICAL:

backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap_multibit_128.cu

Lines changed: 49 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ template <typename InputTorus>
120120
void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_128(
121121
void *stream, uint32_t gpu_index, __uint128_t *lwe_array_out,
122122
InputTorus const *lwe_output_indexes, __uint128_t const *lut_vector,
123-
InputTorus const *lut_vector_indexes, InputTorus const *lwe_array_in,
124-
InputTorus const *lwe_input_indexes, __uint128_t const *bootstrapping_key,
123+
InputTorus const *lwe_array_in, InputTorus const *lwe_input_indexes,
124+
__uint128_t const *bootstrapping_key,
125125
pbs_buffer_128<InputTorus, MULTI_BIT> *pbs_buffer, uint32_t lwe_dimension,
126126
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor,
127127
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
@@ -131,45 +131,45 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_128(
131131
case 256:
132132
host_multi_bit_programmable_bootstrap_128<InputTorus, AmortizedDegree<256>>(
133133
static_cast<cudaStream_t>(stream), gpu_index, lwe_array_out,
134-
lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in,
135-
lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension,
136-
lwe_dimension, polynomial_size, grouping_factor, base_log, level_count,
137-
num_samples, num_many_lut, lut_stride);
134+
lwe_output_indexes, lut_vector, lwe_array_in, lwe_input_indexes,
135+
bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension,
136+
polynomial_size, grouping_factor, base_log, level_count, num_samples,
137+
num_many_lut, lut_stride);
138138
break;
139139
case 512:
140140
host_multi_bit_programmable_bootstrap_128<InputTorus, AmortizedDegree<512>>(
141141
static_cast<cudaStream_t>(stream), gpu_index, lwe_array_out,
142-
lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in,
143-
lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension,
144-
lwe_dimension, polynomial_size, grouping_factor, base_log, level_count,
145-
num_samples, num_many_lut, lut_stride);
142+
lwe_output_indexes, lut_vector, lwe_array_in, lwe_input_indexes,
143+
bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension,
144+
polynomial_size, grouping_factor, base_log, level_count, num_samples,
145+
num_many_lut, lut_stride);
146146
break;
147147
case 1024:
148148
host_multi_bit_programmable_bootstrap_128<InputTorus,
149149
AmortizedDegree<1024>>(
150150
static_cast<cudaStream_t>(stream), gpu_index, lwe_array_out,
151-
lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in,
152-
lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension,
153-
lwe_dimension, polynomial_size, grouping_factor, base_log, level_count,
154-
num_samples, num_many_lut, lut_stride);
151+
lwe_output_indexes, lut_vector, lwe_array_in, lwe_input_indexes,
152+
bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension,
153+
polynomial_size, grouping_factor, base_log, level_count, num_samples,
154+
num_many_lut, lut_stride);
155155
break;
156156
case 2048:
157157
host_multi_bit_programmable_bootstrap_128<InputTorus,
158158
AmortizedDegree<2048>>(
159159
static_cast<cudaStream_t>(stream), gpu_index, lwe_array_out,
160-
lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in,
161-
lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension,
162-
lwe_dimension, polynomial_size, grouping_factor, base_log, level_count,
163-
num_samples, num_many_lut, lut_stride);
160+
lwe_output_indexes, lut_vector, lwe_array_in, lwe_input_indexes,
161+
bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension,
162+
polynomial_size, grouping_factor, base_log, level_count, num_samples,
163+
num_many_lut, lut_stride);
164164
break;
165165
case 4096:
166166
host_multi_bit_programmable_bootstrap_128<InputTorus,
167167
AmortizedDegree<4096>>(
168168
static_cast<cudaStream_t>(stream), gpu_index, lwe_array_out,
169-
lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in,
170-
lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension,
171-
lwe_dimension, polynomial_size, grouping_factor, base_log, level_count,
172-
num_samples, num_many_lut, lut_stride);
169+
lwe_output_indexes, lut_vector, lwe_array_in, lwe_input_indexes,
170+
bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension,
171+
polynomial_size, grouping_factor, base_log, level_count, num_samples,
172+
num_many_lut, lut_stride);
173173
break;
174174
default:
175175
PANIC("Cuda error (multi-bit PBS): unsupported polynomial size. Supported "
@@ -182,8 +182,8 @@ template <typename InputTorus>
182182
void cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_128(
183183
void *stream, uint32_t gpu_index, __uint128_t *lwe_array_out,
184184
InputTorus const *lwe_output_indexes, __uint128_t const *lut_vector,
185-
InputTorus const *lut_vector_indexes, InputTorus const *lwe_array_in,
186-
InputTorus const *lwe_input_indexes, __uint128_t const *bootstrapping_key,
185+
InputTorus const *lwe_array_in, InputTorus const *lwe_input_indexes,
186+
__uint128_t const *bootstrapping_key,
187187
pbs_buffer_128<InputTorus, MULTI_BIT> *pbs_buffer, uint32_t lwe_dimension,
188188
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor,
189189
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
@@ -194,46 +194,46 @@ void cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_128(
194194
host_cg_multi_bit_programmable_bootstrap_128<InputTorus,
195195
AmortizedDegree<256>>(
196196
static_cast<cudaStream_t>(stream), gpu_index, lwe_array_out,
197-
lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in,
198-
lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension,
199-
lwe_dimension, polynomial_size, grouping_factor, base_log, level_count,
200-
num_samples, num_many_lut, lut_stride);
197+
lwe_output_indexes, lut_vector, lwe_array_in, lwe_input_indexes,
198+
bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension,
199+
polynomial_size, grouping_factor, base_log, level_count, num_samples,
200+
num_many_lut, lut_stride);
201201
break;
202202
case 512:
203203
host_cg_multi_bit_programmable_bootstrap_128<InputTorus,
204204
AmortizedDegree<512>>(
205205
static_cast<cudaStream_t>(stream), gpu_index, lwe_array_out,
206-
lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in,
207-
lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension,
208-
lwe_dimension, polynomial_size, grouping_factor, base_log, level_count,
209-
num_samples, num_many_lut, lut_stride);
206+
lwe_output_indexes, lut_vector, lwe_array_in, lwe_input_indexes,
207+
bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension,
208+
polynomial_size, grouping_factor, base_log, level_count, num_samples,
209+
num_many_lut, lut_stride);
210210
break;
211211
case 1024:
212212
host_cg_multi_bit_programmable_bootstrap_128<InputTorus,
213213
AmortizedDegree<1024>>(
214214
static_cast<cudaStream_t>(stream), gpu_index, lwe_array_out,
215-
lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in,
216-
lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension,
217-
lwe_dimension, polynomial_size, grouping_factor, base_log, level_count,
218-
num_samples, num_many_lut, lut_stride);
215+
lwe_output_indexes, lut_vector, lwe_array_in, lwe_input_indexes,
216+
bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension,
217+
polynomial_size, grouping_factor, base_log, level_count, num_samples,
218+
num_many_lut, lut_stride);
219219
break;
220220
case 2048:
221221
host_cg_multi_bit_programmable_bootstrap_128<InputTorus,
222222
AmortizedDegree<2048>>(
223223
static_cast<cudaStream_t>(stream), gpu_index, lwe_array_out,
224-
lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in,
225-
lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension,
226-
lwe_dimension, polynomial_size, grouping_factor, base_log, level_count,
227-
num_samples, num_many_lut, lut_stride);
224+
lwe_output_indexes, lut_vector, lwe_array_in, lwe_input_indexes,
225+
bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension,
226+
polynomial_size, grouping_factor, base_log, level_count, num_samples,
227+
num_many_lut, lut_stride);
228228
break;
229229
case 4096:
230230
host_cg_multi_bit_programmable_bootstrap_128<InputTorus,
231231
AmortizedDegree<4096>>(
232232
static_cast<cudaStream_t>(stream), gpu_index, lwe_array_out,
233-
lwe_output_indexes, lut_vector, lut_vector_indexes, lwe_array_in,
234-
lwe_input_indexes, bootstrapping_key, pbs_buffer, glwe_dimension,
235-
lwe_dimension, polynomial_size, grouping_factor, base_log, level_count,
236-
num_samples, num_many_lut, lut_stride);
233+
lwe_output_indexes, lut_vector, lwe_array_in, lwe_input_indexes,
234+
bootstrapping_key, pbs_buffer, glwe_dimension, lwe_dimension,
235+
polynomial_size, grouping_factor, base_log, level_count, num_samples,
236+
num_many_lut, lut_stride);
237237
break;
238238
default:
239239
PANIC("Cuda error (multi-bit PBS): unsupported polynomial size. Supported "
@@ -245,12 +245,11 @@ void cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_128(
245245
void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_128(
246246
void *stream, uint32_t gpu_index, void *lwe_array_out,
247247
void const *lwe_output_indexes, void const *lut_vector,
248-
void const *lut_vector_indexes, void const *lwe_array_in,
249-
void const *lwe_input_indexes, void const *bootstrapping_key,
250-
int8_t *mem_ptr, uint32_t lwe_dimension, uint32_t glwe_dimension,
251-
uint32_t polynomial_size, uint32_t grouping_factor, uint32_t base_log,
252-
uint32_t level_count, uint32_t num_samples, uint32_t num_many_lut,
253-
uint32_t lut_stride) {
248+
void const *lwe_array_in, void const *lwe_input_indexes,
249+
void const *bootstrapping_key, int8_t *mem_ptr, uint32_t lwe_dimension,
250+
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor,
251+
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
252+
uint32_t num_many_lut, uint32_t lut_stride) {
254253

255254
if (base_log > 64)
256255
PANIC("Cuda error (multi-bit PBS): base log should be <= 64")
@@ -263,7 +262,6 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_128(
263262
uint64_t>(stream, gpu_index, static_cast<__uint128_t *>(lwe_array_out),
264263
static_cast<const uint64_t *>(lwe_output_indexes),
265264
static_cast<const __uint128_t *>(lut_vector),
266-
static_cast<const uint64_t *>(lut_vector_indexes),
267265
static_cast<const uint64_t *>(lwe_array_in),
268266
static_cast<const uint64_t *>(lwe_input_indexes),
269267
static_cast<const __uint128_t *>(bootstrapping_key), buffer,
@@ -276,7 +274,6 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_128(
276274
stream, gpu_index, static_cast<__uint128_t *>(lwe_array_out),
277275
static_cast<const uint64_t *>(lwe_output_indexes),
278276
static_cast<const __uint128_t *>(lut_vector),
279-
static_cast<const uint64_t *>(lut_vector_indexes),
280277
static_cast<const uint64_t *>(lwe_array_in),
281278
static_cast<const uint64_t *>(lwe_input_indexes),
282279
static_cast<const __uint128_t *>(bootstrapping_key), buffer,

0 commit comments

Comments
 (0)