Skip to content

Commit ef26e7a

Browse files
committed
fix(gpu): fix overflow sub and comparison issues
1 parent 0da79f9 commit ef26e7a

File tree

2 files changed

+164
-22
lines changed

2 files changed

+164
-22
lines changed

backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4342,6 +4342,11 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
43424342
CudaRadixCiphertextFFI *q2; // single block
43434343
CudaRadixCiphertextFFI *q3; // single block
43444344

4345+
Torus **first_indexes_for_overflow_sub;
4346+
Torus **second_indexes_for_overflow_sub;
4347+
Torus **scalars_for_overflow_sub;
4348+
uint32_t max_indexes_to_erase;
4349+
43454350
// allocate and initialize if needed, temporary arrays used to calculate
43464351
// cuda integer div_rem_2_2 operation
43474352
void init_temporary_buffers(cudaStream_t const *streams,
@@ -4610,6 +4615,12 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
46104615
overflow_sub_mem_3 = new int_borrow_prop_memory<Torus>(
46114616
streams, gpu_indexes, gpu_count, params, num_blocks, compute_overflow,
46124617
allocate_gpu_memory, size_tracker);
4618+
uint32_t group_size = overflow_sub_mem_1->group_size;
4619+
bool use_seq = overflow_sub_mem_1->prop_simu_group_carries_mem
4620+
->use_sequential_algorithm_to_resolve_group_carries;
4621+
create_indexes_for_overflow_sub(streams, gpu_indexes, num_blocks,
4622+
group_size, use_seq, allocate_gpu_memory,
4623+
size_tracker);
46134624
comparison_buffer_1 = new int_comparison_buffer<Torus>(
46144625
streams, gpu_indexes, gpu_count, COMPARISON_TYPE::EQ, params,
46154626
num_blocks, false, allocate_gpu_memory, size_tracker);
@@ -4665,6 +4676,102 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
46654676
}
46664677
}
46674678

4679+
void create_indexes_for_overflow_sub(cudaStream_t const *streams,
4680+
uint32_t const *gpu_indexes,
4681+
uint32_t num_blocks, uint32_t group_size,
4682+
bool use_seq, bool allocate_gpu_memory,
4683+
uint64_t &size_tracker) {
4684+
max_indexes_to_erase = num_blocks;
4685+
4686+
first_indexes_for_overflow_sub =
4687+
(Torus **)malloc(num_blocks * sizeof(Torus *));
4688+
second_indexes_for_overflow_sub =
4689+
(Torus **)malloc(num_blocks * sizeof(Torus *));
4690+
scalars_for_overflow_sub = (Torus **)malloc(num_blocks * sizeof(Torus *));
4691+
4692+
Torus *h_lut_indexes = (Torus *)malloc(num_blocks * sizeof(Torus));
4693+
Torus *h_scalar = (Torus *)malloc(num_blocks * sizeof(Torus));
4694+
4695+
// Extra indexes for the luts in first step
4696+
for (int nb = 1; nb <= num_blocks; nb++) {
4697+
first_indexes_for_overflow_sub[nb - 1] =
4698+
(Torus *)cuda_malloc_with_size_tracking_async(
4699+
nb * sizeof(Torus), streams[0], gpu_indexes[0], size_tracker,
4700+
allocate_gpu_memory);
4701+
for (int index = 0; index < nb; index++) {
4702+
uint32_t grouping_index = index / group_size;
4703+
bool is_in_first_grouping = (grouping_index == 0);
4704+
uint32_t index_in_grouping = index % group_size;
4705+
bool is_last_index = (index == (nb - 1));
4706+
if (is_last_index) {
4707+
if (nb == 1) {
4708+
h_lut_indexes[index] = 2 * group_size;
4709+
} else {
4710+
h_lut_indexes[index] = 2;
4711+
}
4712+
} else if (is_in_first_grouping) {
4713+
h_lut_indexes[index] = index_in_grouping;
4714+
} else {
4715+
h_lut_indexes[index] = index_in_grouping + group_size;
4716+
}
4717+
}
4718+
cuda_memcpy_with_size_tracking_async_to_gpu(
4719+
first_indexes_for_overflow_sub[nb - 1], h_lut_indexes,
4720+
nb * sizeof(Torus), streams[0], gpu_indexes[0], allocate_gpu_memory);
4721+
}
4722+
// Extra indexes for the luts in second step
4723+
for (int nb = 1; nb <= num_blocks; nb++) {
4724+
second_indexes_for_overflow_sub[nb - 1] =
4725+
(Torus *)cuda_malloc_with_size_tracking_async(
4726+
nb * sizeof(Torus), streams[0], gpu_indexes[0], size_tracker,
4727+
allocate_gpu_memory);
4728+
scalars_for_overflow_sub[nb - 1] =
4729+
(Torus *)cuda_malloc_with_size_tracking_async(
4730+
nb * sizeof(Torus), streams[0], gpu_indexes[0], size_tracker,
4731+
allocate_gpu_memory);
4732+
4733+
for (int index = 0; index < nb; index++) {
4734+
uint32_t grouping_index = index / group_size;
4735+
bool is_in_first_grouping = (grouping_index == 0);
4736+
uint32_t index_in_grouping = index % group_size;
4737+
4738+
if (is_in_first_grouping) {
4739+
h_lut_indexes[index] = index_in_grouping;
4740+
} else if (index_in_grouping == (group_size - 1)) {
4741+
if (use_seq) {
4742+
int inner_index = (grouping_index - 1) % (group_size - 1);
4743+
h_lut_indexes[index] = inner_index + 2 * group_size;
4744+
} else {
4745+
h_lut_indexes[index] = 2 * group_size;
4746+
}
4747+
} else {
4748+
h_lut_indexes[index] = index_in_grouping + group_size;
4749+
}
4750+
4751+
bool may_have_its_padding_bit_set =
4752+
!is_in_first_grouping && (index_in_grouping == group_size - 1);
4753+
4754+
if (may_have_its_padding_bit_set) {
4755+
if (use_seq) {
4756+
h_scalar[index] = 1 << ((grouping_index - 1) % (group_size - 1));
4757+
} else {
4758+
h_scalar[index] = 1;
4759+
}
4760+
} else {
4761+
h_scalar[index] = 0;
4762+
}
4763+
}
4764+
cuda_memcpy_with_size_tracking_async_to_gpu(
4765+
second_indexes_for_overflow_sub[nb - 1], h_lut_indexes,
4766+
nb * sizeof(Torus), streams[0], gpu_indexes[0], allocate_gpu_memory);
4767+
cuda_memcpy_with_size_tracking_async_to_gpu(
4768+
scalars_for_overflow_sub[nb - 1], h_scalar, nb * sizeof(Torus),
4769+
streams[0], gpu_indexes[0], allocate_gpu_memory);
4770+
}
4771+
free(h_lut_indexes);
4772+
free(h_scalar);
4773+
};
4774+
46684775
void release(cudaStream_t const *streams, uint32_t const *gpu_indexes,
46694776
uint32_t gpu_count) {
46704777
// release and delete integer ops memory objects
@@ -4793,6 +4900,21 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
47934900
delete q1;
47944901
delete q2;
47954902
delete q3;
4903+
4904+
for (int i = 0; i < max_indexes_to_erase; i++) {
4905+
cuda_drop_with_size_tracking_async(first_indexes_for_overflow_sub[i],
4906+
streams[0], gpu_indexes[0],
4907+
gpu_memory_allocated);
4908+
cuda_drop_with_size_tracking_async(second_indexes_for_overflow_sub[i],
4909+
streams[0], gpu_indexes[0],
4910+
gpu_memory_allocated);
4911+
cuda_drop_with_size_tracking_async(scalars_for_overflow_sub[i],
4912+
streams[0], gpu_indexes[0],
4913+
gpu_memory_allocated);
4914+
}
4915+
free(first_indexes_for_overflow_sub);
4916+
free(second_indexes_for_overflow_sub);
4917+
free(scalars_for_overflow_sub);
47964918
}
47974919
};
47984920

backends/tfhe-cuda-backend/cuda/src/integer/div_rem.cuh

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -113,20 +113,32 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
113113
copy_radix_ciphertext_slice_async<Torus>(
114114
streams[0], gpu_indexes[0], mem_ptr->rem, 0, slice_len, remainder,
115115
block_index, num_blocks);
116-
uint32_t compute_borrow = 1;
116+
uint32_t compute_overflow = 1;
117117
uint32_t uses_input_borrow = 0;
118+
auto first_indexes =
119+
mem_ptr->first_indexes_for_overflow_sub[mem_ptr->rem->num_radix_blocks -
120+
1];
121+
auto second_indexes =
122+
mem_ptr
123+
->second_indexes_for_overflow_sub[mem_ptr->rem->num_radix_blocks -
124+
1];
125+
auto scalar_indexes =
126+
mem_ptr->scalars_for_overflow_sub[mem_ptr->rem->num_radix_blocks - 1];
118127
auto sub_result_f = [&](cudaStream_t const *streams,
119128
uint32_t const *gpu_indexes, uint32_t gpu_count,
120129
CudaRadixCiphertextFFI *sub_result,
121130
CudaRadixCiphertextFFI *sub_overflowed,
122131
int_borrow_prop_memory<Torus> *overflow_sub_mem,
123132
CudaRadixCiphertextFFI *low) {
124133
sub_result->num_radix_blocks = low->num_radix_blocks;
134+
overflow_sub_mem->update_lut_indexes(streams, gpu_indexes, first_indexes,
135+
second_indexes, scalar_indexes,
136+
mem_ptr->rem->num_radix_blocks);
125137
host_integer_overflowing_sub<uint64_t>(
126138
streams, gpu_indexes, gpu_count, sub_result, mem_ptr->rem, low,
127139
sub_overflowed, (const CudaRadixCiphertextFFI *)nullptr,
128-
overflow_sub_mem, bsks, ksks, ms_noise_reduction_key, compute_borrow,
129-
uses_input_borrow);
140+
overflow_sub_mem, bsks, ksks, ms_noise_reduction_key,
141+
compute_overflow, uses_input_borrow);
130142
};
131143

132144
auto cmp_f = [&](cudaStream_t const *streams, uint32_t const *gpu_indexes,
@@ -139,25 +151,33 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
139151
uint32_t slice_start = num_blocks - block_index;
140152
uint32_t slice_end = d->num_radix_blocks;
141153
as_radix_ciphertext_slice<Torus>(d_msb, d, slice_start, slice_end);
142-
host_compare_blocks_with_zero<Torus>(
143-
streams, gpu_indexes, gpu_count, comparison_blocks, d_msb,
144-
comparison_buffer, bsks, ksks, ms_noise_reduction_key,
145-
d_msb->num_radix_blocks, comparison_buffer->is_zero_lut);
146-
are_all_comparisons_block_true(
147-
streams, gpu_indexes, gpu_count, out_boolean_block, comparison_blocks,
148-
comparison_buffer, bsks, ksks, ms_noise_reduction_key,
149-
comparison_blocks->num_radix_blocks);
150-
151-
host_negation<Torus>(
152-
streams[0], gpu_indexes[0], (Torus *)out_boolean_block->ptr,
153-
(Torus *)out_boolean_block->ptr, radix_params.big_lwe_dimension, 1);
154-
// we calculate encoding because this block works only for message_modulus
155-
// = 4 and carry_modulus = 4.
156-
const Torus encoded_scalar = 1ULL << (sizeof(Torus) * 8 - 5);
157-
host_addition_plaintext_scalar<Torus>(
158-
streams[0], gpu_indexes[0], (Torus *)out_boolean_block->ptr,
159-
(Torus *)out_boolean_block->ptr, encoded_scalar,
160-
radix_params.big_lwe_dimension, 1);
154+
comparison_blocks->num_radix_blocks = d_msb->num_radix_blocks;
155+
if (d_msb->num_radix_blocks == 0) {
156+
cuda_memset_async((Torus *)out_boolean_block->ptr, 0,
157+
sizeof(Torus) *
158+
(out_boolean_block->lwe_dimension + 1),
159+
streams[0], gpu_indexes[0]);
160+
} else {
161+
host_compare_blocks_with_zero<Torus>(
162+
streams, gpu_indexes, gpu_count, comparison_blocks, d_msb,
163+
comparison_buffer, bsks, ksks, ms_noise_reduction_key,
164+
d_msb->num_radix_blocks, comparison_buffer->is_zero_lut);
165+
are_all_comparisons_block_true(
166+
streams, gpu_indexes, gpu_count, out_boolean_block,
167+
comparison_blocks, comparison_buffer, bsks, ksks,
168+
ms_noise_reduction_key, comparison_blocks->num_radix_blocks);
169+
170+
host_negation<Torus>(
171+
streams[0], gpu_indexes[0], (Torus *)out_boolean_block->ptr,
172+
(Torus *)out_boolean_block->ptr, radix_params.big_lwe_dimension, 1);
173+
// we calculate encoding because this block works only for
174+
// message_modulus = 4 and carry_modulus = 4.
175+
const Torus encoded_scalar = 1ULL << (sizeof(Torus) * 8 - 5);
176+
host_addition_plaintext_scalar<Torus>(
177+
streams[0], gpu_indexes[0], (Torus *)out_boolean_block->ptr,
178+
(Torus *)out_boolean_block->ptr, encoded_scalar,
179+
radix_params.big_lwe_dimension, 1);
180+
}
161181
delete d_msb;
162182
};
163183

0 commit comments

Comments
 (0)