Skip to content

Commit 3e73e88

Browse files
committed
fix 2_2 mem bugs after separation
1 parent 24627d2 commit 3e73e88

File tree

2 files changed

+63
-50
lines changed

2 files changed

+63
-50
lines changed

backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4597,6 +4597,7 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
45974597
uint32_t num_blocks, bool allocate_gpu_memory,
45984598
uint64_t &size_tracker) {
45994599
gpu_memory_allocated = allocate_gpu_memory;
4600+
active_gpu_count = get_active_gpu_count(2 * num_blocks, gpu_count);
46004601
this->params = params;
46014602

46024603
uint32_t compute_overflow = 1;
@@ -4666,6 +4667,8 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
46664667

46674668
void release(cudaStream_t const *streams, uint32_t const *gpu_indexes,
46684669
uint32_t gpu_count) {
4670+
4671+
printf("2_2 release\n");
46694672
// release and delete integer ops memory objects
46704673
overflow_sub_mem_1->release(streams, gpu_indexes, gpu_count);
46714674
overflow_sub_mem_2->release(streams, gpu_indexes, gpu_count);
@@ -4739,12 +4742,12 @@ template <typename Torus> struct unsigned_int_div_rem_2_2_memory {
47394742
gpu_memory_allocated);
47404743
release_radix_ciphertext_async(streams[0], gpu_indexes[0], sub_3_overflowed,
47414744
gpu_memory_allocated);
4742-
release_radix_ciphertext_async(streams[0], gpu_indexes[0], comparison_blocks_1,
4743-
gpu_memory_allocated);
4744-
release_radix_ciphertext_async(streams[0], gpu_indexes[0], comparison_blocks_2,
4745-
gpu_memory_allocated);
4746-
release_radix_ciphertext_async(streams[0], gpu_indexes[0], comparison_blocks_3,
4747-
gpu_memory_allocated);
4745+
release_radix_ciphertext_async(streams[0], gpu_indexes[0],
4746+
comparison_blocks_1, gpu_memory_allocated);
4747+
release_radix_ciphertext_async(streams[0], gpu_indexes[0],
4748+
comparison_blocks_2, gpu_memory_allocated);
4749+
release_radix_ciphertext_async(streams[0], gpu_indexes[0],
4750+
comparison_blocks_3, gpu_memory_allocated);
47484751
release_radix_ciphertext_async(streams[0], gpu_indexes[0], cmp_1,
47494752
gpu_memory_allocated);
47504753
release_radix_ciphertext_async(streams[0], gpu_indexes[0], cmp_2,
@@ -5097,14 +5100,17 @@ template <typename Torus> struct unsigned_int_div_rem_memory {
50975100
int_radix_params params, uint32_t num_blocks,
50985101
bool allocate_gpu_memory,
50995102
uint64_t &size_tracker) {
5103+
gpu_memory_allocated = allocate_gpu_memory;
5104+
active_gpu_count = get_active_gpu_count(2 * num_blocks, gpu_count);
5105+
this->params = params;
5106+
51005107
if (params.message_modulus == 4 && params.carry_modulus == 4) {
51015108
div_rem_2_2_mem = new unsigned_int_div_rem_2_2_memory<Torus>(
51025109
streams, gpu_indexes, gpu_count, params, num_blocks,
51035110
allocate_gpu_memory, size_tracker);
5111+
return;
51045112
}
5105-
gpu_memory_allocated = allocate_gpu_memory;
5106-
active_gpu_count = get_active_gpu_count(2 * num_blocks, gpu_count);
5107-
this->params = params;
5113+
51085114
shift_mem_1 = new int_logical_scalar_shift_buffer<Torus>(
51095115
streams, gpu_indexes, gpu_count, SHIFT_OR_ROTATE_TYPE::LEFT_SHIFT,
51105116
params, 2 * num_blocks, allocate_gpu_memory, size_tracker);
@@ -5247,6 +5253,12 @@ template <typename Torus> struct unsigned_int_div_rem_memory {
52475253

52485254
void release(cudaStream_t const *streams, uint32_t const *gpu_indexes,
52495255
uint32_t gpu_count) {
5256+
5257+
if (params.message_modulus == 4 && params.carry_modulus == 4) {
5258+
div_rem_2_2_mem->release(streams, gpu_indexes, gpu_count);
5259+
delete div_rem_2_2_mem;
5260+
return;
5261+
}
52505262
uint32_t num_bits_in_message = 31 - __builtin_clz(params.message_modulus);
52515263

52525264
// release and delete other operation memory objects
@@ -5260,10 +5272,6 @@ template <typename Torus> struct unsigned_int_div_rem_memory {
52605272
delete overflow_sub_mem;
52615273
delete comparison_buffer;
52625274

5263-
if (params.message_modulus == 4 && params.carry_modulus == 4) {
5264-
div_rem_2_2_mem->release(streams, gpu_indexes, gpu_count);
5265-
delete div_rem_2_2_mem;
5266-
}
52675275
// drop temporary buffers
52685276
release_radix_ciphertext_async(streams[0], gpu_indexes[0], remainder1,
52695277
gpu_memory_allocated);

backends/tfhe-cuda-backend/cuda/src/integer/div_rem.cuh

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -56,27 +56,28 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
5656
auto extend_2xd_f = [&](cudaStream_t const *streams,
5757
uint32_t const *gpu_indexes, uint32_t gpu_count) {
5858
// d2 is allocated with num_blocks + 1; so we extend with 1.
59-
host_extend_radix_with_trivial_zero_blocks_msb<Torus>(mem_ptr->d2, divisor, streams,
60-
gpu_indexes);
59+
host_extend_radix_with_trivial_zero_blocks_msb<Torus>(mem_ptr->d2, divisor,
60+
streams, gpu_indexes);
6161
host_integer_radix_logical_scalar_shift_kb_inplace<Torus>(
62-
streams, gpu_indexes, gpu_count, mem_ptr->d2, 1, mem_ptr->shift_mem, bsks,
63-
ksks, ms_noise_reduction_key, mem_ptr->d2->num_radix_blocks);
62+
streams, gpu_indexes, gpu_count, mem_ptr->d2, 1, mem_ptr->shift_mem,
63+
bsks, ksks, ms_noise_reduction_key, mem_ptr->d2->num_radix_blocks);
6464
};
6565

6666
// Computes 3*d = 4*d - d using block shift and subtraction
6767
auto extend_3xd_f = [&](cudaStream_t const *streams,
6868
uint32_t const *gpu_indexes, uint32_t gpu_count) {
6969
// d1 is allocated with num_blocks + 1; so we extend with 1.
70-
host_extend_radix_with_trivial_zero_blocks_msb<Torus>(mem_ptr->d1, divisor, streams,
71-
gpu_indexes);
72-
host_radix_blocks_rotate_right<Torus>(streams, gpu_indexes, gpu_count, mem_ptr->d3,
73-
mem_ptr->d1, 1, mem_ptr->d1->num_radix_blocks);
74-
set_zero_radix_ciphertext_slice_async<Torus>(streams[0], gpu_indexes[0], mem_ptr->d3,
75-
0, 1);
70+
host_extend_radix_with_trivial_zero_blocks_msb<Torus>(mem_ptr->d1, divisor,
71+
streams, gpu_indexes);
72+
host_radix_blocks_rotate_right<Torus>(streams, gpu_indexes, gpu_count,
73+
mem_ptr->d3, mem_ptr->d1, 1,
74+
mem_ptr->d1->num_radix_blocks);
75+
set_zero_radix_ciphertext_slice_async<Torus>(streams[0], gpu_indexes[0],
76+
mem_ptr->d3, 0, 1);
7677
host_sub_and_propagate_single_carry(
77-
streams, gpu_indexes, gpu_count, mem_ptr->d3, mem_ptr->d1, nullptr, nullptr,
78-
mem_ptr->sub_and_propagate_mem, bsks, ksks, ms_noise_reduction_key,
79-
outputFlag::FLAG_NONE, 0);
78+
streams, gpu_indexes, gpu_count, mem_ptr->d3, mem_ptr->d1, nullptr,
79+
nullptr, mem_ptr->sub_and_propagate_mem, bsks, ksks,
80+
ms_noise_reduction_key, outputFlag::FLAG_NONE, 0);
8081
// trim d1 by one msb block
8182
mem_ptr->d1->num_radix_blocks -= 1;
8283
};
@@ -100,15 +101,18 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
100101
mem_ptr->low2->num_radix_blocks = slice_len;
101102
mem_ptr->low3->num_radix_blocks = slice_len;
102103
mem_ptr->rem->num_radix_blocks = slice_len;
103-
copy_radix_ciphertext_slice_async<Torus>(streams[0], gpu_indexes[0], mem_ptr->low1,
104-
0, slice_len, mem_ptr->d1, 0, slice_len);
105-
copy_radix_ciphertext_slice_async<Torus>(streams[0], gpu_indexes[0], mem_ptr->low2,
106-
0, slice_len, mem_ptr->d2, 0, slice_len);
107-
copy_radix_ciphertext_slice_async<Torus>(streams[0], gpu_indexes[0], mem_ptr->low3,
108-
0, slice_len, mem_ptr->d3, 0, slice_len);
109-
copy_radix_ciphertext_slice_async<Torus>(streams[0], gpu_indexes[0], mem_ptr->rem, 0,
110-
slice_len, remainder, block_index,
111-
num_blocks);
104+
copy_radix_ciphertext_slice_async<Torus>(streams[0], gpu_indexes[0],
105+
mem_ptr->low1, 0, slice_len,
106+
mem_ptr->d1, 0, slice_len);
107+
copy_radix_ciphertext_slice_async<Torus>(streams[0], gpu_indexes[0],
108+
mem_ptr->low2, 0, slice_len,
109+
mem_ptr->d2, 0, slice_len);
110+
copy_radix_ciphertext_slice_async<Torus>(streams[0], gpu_indexes[0],
111+
mem_ptr->low3, 0, slice_len,
112+
mem_ptr->d3, 0, slice_len);
113+
copy_radix_ciphertext_slice_async<Torus>(
114+
streams[0], gpu_indexes[0], mem_ptr->rem, 0, slice_len, remainder,
115+
block_index, num_blocks);
112116
uint32_t compute_borrow = 1;
113117
uint32_t uses_input_borrow = 0;
114118
auto sub_result_f = [&](cudaStream_t const *streams,
@@ -119,9 +123,10 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
119123
CudaRadixCiphertextFFI *low) {
120124
sub_result->num_radix_blocks = low->num_radix_blocks;
121125
host_integer_overflowing_sub<uint64_t>(
122-
streams, gpu_indexes, gpu_count, sub_result, mem_ptr->rem, low, sub_overflowed,
123-
(const CudaRadixCiphertextFFI *)nullptr, overflow_sub_mem, bsks, ksks,
124-
ms_noise_reduction_key, compute_borrow, uses_input_borrow);
126+
streams, gpu_indexes, gpu_count, sub_result, mem_ptr->rem, low,
127+
sub_overflowed, (const CudaRadixCiphertextFFI *)nullptr,
128+
overflow_sub_mem, bsks, ksks, ms_noise_reduction_key, compute_borrow,
129+
uses_input_borrow);
125130
};
126131

127132
auto cmp_f = [&](cudaStream_t const *streams, uint32_t const *gpu_indexes,
@@ -153,8 +158,6 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
153158
streams[0], gpu_indexes[0], (Torus *)out_boolean_block->ptr,
154159
(Torus *)out_boolean_block->ptr, encoded_scalar,
155160
radix_params.big_lwe_dimension, 1);
156-
release_radix_ciphertext_async(streams[0], gpu_indexes[0], d_msb,
157-
true);
158161
delete d_msb;
159162
};
160163

@@ -300,7 +303,8 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
300303
conditional_update(mem_ptr->sub_streams_3, gpu_indexes, gpu_count,
301304
mem_ptr->c1, r1, mem_ptr->zero_out_if_not_2_lut_2, 3);
302305
conditional_update(mem_ptr->sub_streams_4, gpu_indexes, gpu_count,
303-
mem_ptr->c0, mem_ptr->rem, mem_ptr->zero_out_if_not_1_lut_2, 2);
306+
mem_ptr->c0, mem_ptr->rem,
307+
mem_ptr->zero_out_if_not_1_lut_2, 2);
304308

305309
calculate_quotient_bits(mem_ptr->sub_streams_5, gpu_indexes, 1, mem_ptr->q1,
306310
mem_ptr->c1, mem_ptr->quotient_lut_1);
@@ -319,12 +323,12 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
319323
cuda_synchronize_stream(mem_ptr->sub_streams_7[j], gpu_indexes[j]);
320324
}
321325

322-
host_addition<Torus>(streams[0], gpu_indexes[0], mem_ptr->rem, mem_ptr->rem, r3,
323-
mem_ptr->rem->num_radix_blocks, 4, 4);
324-
host_addition<Torus>(streams[0], gpu_indexes[0], mem_ptr->rem, mem_ptr->rem, r2,
325-
mem_ptr->rem->num_radix_blocks, 4, 4);
326-
host_addition<Torus>(streams[0], gpu_indexes[0], mem_ptr->rem, mem_ptr->rem, r1,
327-
mem_ptr->rem->num_radix_blocks, 4, 4);
326+
host_addition<Torus>(streams[0], gpu_indexes[0], mem_ptr->rem, mem_ptr->rem,
327+
r3, mem_ptr->rem->num_radix_blocks, 4, 4);
328+
host_addition<Torus>(streams[0], gpu_indexes[0], mem_ptr->rem, mem_ptr->rem,
329+
r2, mem_ptr->rem->num_radix_blocks, 4, 4);
330+
host_addition<Torus>(streams[0], gpu_indexes[0], mem_ptr->rem, mem_ptr->rem,
331+
r1, mem_ptr->rem->num_radix_blocks, 4, 4);
328332

329333
host_addition<Torus>(streams[0], gpu_indexes[0], mem_ptr->q1, mem_ptr->q1,
330334
mem_ptr->q2, 1, 4, 4);
@@ -335,9 +339,9 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
335339
cuda_synchronize_stream(streams[j], gpu_indexes[j]);
336340
}
337341
integer_radix_apply_univariate_lookup_table_kb<Torus>(
338-
mem_ptr->sub_streams_1, gpu_indexes, gpu_count, mem_ptr->rem, mem_ptr->rem, bsks, ksks,
339-
ms_noise_reduction_key, mem_ptr->message_extract_lut_1,
340-
mem_ptr->rem->num_radix_blocks);
342+
mem_ptr->sub_streams_1, gpu_indexes, gpu_count, mem_ptr->rem,
343+
mem_ptr->rem, bsks, ksks, ms_noise_reduction_key,
344+
mem_ptr->message_extract_lut_1, mem_ptr->rem->num_radix_blocks);
341345
integer_radix_apply_univariate_lookup_table_kb<Torus>(
342346
mem_ptr->sub_streams_2, gpu_indexes, gpu_count, mem_ptr->q1,
343347
mem_ptr->q1, bsks, ksks, ms_noise_reduction_key,
@@ -383,6 +387,7 @@ __host__ void host_unsigned_integer_div_rem_kb(
383387
host_unsigned_integer_div_rem_kb_block_by_block_2_2<Torus>(
384388
streams, gpu_indexes, gpu_count, quotient, remainder, numerator,
385389
divisor, bsks, ksks, ms_noise_reduction_key, mem_ptr->div_rem_2_2_mem);
390+
return;
386391
}
387392
auto radix_params = mem_ptr->params;
388393
auto num_blocks = quotient->num_radix_blocks;

0 commit comments

Comments
 (0)