@@ -56,27 +56,28 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
5656 auto extend_2xd_f = [&](cudaStream_t const *streams,
5757 uint32_t const *gpu_indexes, uint32_t gpu_count) {
5858 // d2 is allocated with num_blocks + 1; so we extend with 1.
59- host_extend_radix_with_trivial_zero_blocks_msb<Torus>(mem_ptr->d2 , divisor, streams,
60- gpu_indexes);
59+ host_extend_radix_with_trivial_zero_blocks_msb<Torus>(mem_ptr->d2 , divisor,
60+ streams, gpu_indexes);
6161 host_integer_radix_logical_scalar_shift_kb_inplace<Torus>(
62- streams, gpu_indexes, gpu_count, mem_ptr->d2 , 1 , mem_ptr->shift_mem , bsks,
63- ksks, ms_noise_reduction_key, mem_ptr->d2 ->num_radix_blocks );
62+ streams, gpu_indexes, gpu_count, mem_ptr->d2 , 1 , mem_ptr->shift_mem ,
63+ bsks, ksks, ms_noise_reduction_key, mem_ptr->d2 ->num_radix_blocks );
6464 };
6565
6666 // Computes 3*d = 4*d - d using block shift and subtraction
6767 auto extend_3xd_f = [&](cudaStream_t const *streams,
6868 uint32_t const *gpu_indexes, uint32_t gpu_count) {
6969 // d1 is allocated with num_blocks + 1; so we extend with 1.
70- host_extend_radix_with_trivial_zero_blocks_msb<Torus>(mem_ptr->d1 , divisor, streams,
71- gpu_indexes);
72- host_radix_blocks_rotate_right<Torus>(streams, gpu_indexes, gpu_count, mem_ptr->d3 ,
73- mem_ptr->d1 , 1 , mem_ptr->d1 ->num_radix_blocks );
74- set_zero_radix_ciphertext_slice_async<Torus>(streams[0 ], gpu_indexes[0 ], mem_ptr->d3 ,
75- 0 , 1 );
70+ host_extend_radix_with_trivial_zero_blocks_msb<Torus>(mem_ptr->d1 , divisor,
71+ streams, gpu_indexes);
72+ host_radix_blocks_rotate_right<Torus>(streams, gpu_indexes, gpu_count,
73+ mem_ptr->d3 , mem_ptr->d1 , 1 ,
74+ mem_ptr->d1 ->num_radix_blocks );
75+ set_zero_radix_ciphertext_slice_async<Torus>(streams[0 ], gpu_indexes[0 ],
76+ mem_ptr->d3 , 0 , 1 );
7677 host_sub_and_propagate_single_carry (
77- streams, gpu_indexes, gpu_count, mem_ptr->d3 , mem_ptr->d1 , nullptr , nullptr ,
78- mem_ptr->sub_and_propagate_mem , bsks, ksks, ms_noise_reduction_key ,
79- outputFlag::FLAG_NONE, 0 );
78+ streams, gpu_indexes, gpu_count, mem_ptr->d3 , mem_ptr->d1 , nullptr ,
79+ nullptr , mem_ptr->sub_and_propagate_mem , bsks, ksks,
80+ ms_noise_reduction_key, outputFlag::FLAG_NONE, 0 );
8081 // trim d1 by one msb block
8182 mem_ptr->d1 ->num_radix_blocks -= 1 ;
8283 };
@@ -100,15 +101,18 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
100101 mem_ptr->low2 ->num_radix_blocks = slice_len;
101102 mem_ptr->low3 ->num_radix_blocks = slice_len;
102103 mem_ptr->rem ->num_radix_blocks = slice_len;
103- copy_radix_ciphertext_slice_async<Torus>(streams[0 ], gpu_indexes[0 ], mem_ptr->low1 ,
104- 0 , slice_len, mem_ptr->d1 , 0 , slice_len);
105- copy_radix_ciphertext_slice_async<Torus>(streams[0 ], gpu_indexes[0 ], mem_ptr->low2 ,
106- 0 , slice_len, mem_ptr->d2 , 0 , slice_len);
107- copy_radix_ciphertext_slice_async<Torus>(streams[0 ], gpu_indexes[0 ], mem_ptr->low3 ,
108- 0 , slice_len, mem_ptr->d3 , 0 , slice_len);
109- copy_radix_ciphertext_slice_async<Torus>(streams[0 ], gpu_indexes[0 ], mem_ptr->rem , 0 ,
110- slice_len, remainder, block_index,
111- num_blocks);
104+ copy_radix_ciphertext_slice_async<Torus>(streams[0 ], gpu_indexes[0 ],
105+ mem_ptr->low1 , 0 , slice_len,
106+ mem_ptr->d1 , 0 , slice_len);
107+ copy_radix_ciphertext_slice_async<Torus>(streams[0 ], gpu_indexes[0 ],
108+ mem_ptr->low2 , 0 , slice_len,
109+ mem_ptr->d2 , 0 , slice_len);
110+ copy_radix_ciphertext_slice_async<Torus>(streams[0 ], gpu_indexes[0 ],
111+ mem_ptr->low3 , 0 , slice_len,
112+ mem_ptr->d3 , 0 , slice_len);
113+ copy_radix_ciphertext_slice_async<Torus>(
114+ streams[0 ], gpu_indexes[0 ], mem_ptr->rem , 0 , slice_len, remainder,
115+ block_index, num_blocks);
112116 uint32_t compute_borrow = 1 ;
113117 uint32_t uses_input_borrow = 0 ;
114118 auto sub_result_f = [&](cudaStream_t const *streams,
@@ -119,9 +123,10 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
119123 CudaRadixCiphertextFFI *low) {
120124 sub_result->num_radix_blocks = low->num_radix_blocks ;
121125 host_integer_overflowing_sub<uint64_t >(
122- streams, gpu_indexes, gpu_count, sub_result, mem_ptr->rem , low, sub_overflowed,
123- (const CudaRadixCiphertextFFI *)nullptr , overflow_sub_mem, bsks, ksks,
124- ms_noise_reduction_key, compute_borrow, uses_input_borrow);
126+ streams, gpu_indexes, gpu_count, sub_result, mem_ptr->rem , low,
127+ sub_overflowed, (const CudaRadixCiphertextFFI *)nullptr ,
128+ overflow_sub_mem, bsks, ksks, ms_noise_reduction_key, compute_borrow,
129+ uses_input_borrow);
125130 };
126131
127132 auto cmp_f = [&](cudaStream_t const *streams, uint32_t const *gpu_indexes,
@@ -153,8 +158,6 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
153158 streams[0 ], gpu_indexes[0 ], (Torus *)out_boolean_block->ptr ,
154159 (Torus *)out_boolean_block->ptr , encoded_scalar,
155160 radix_params.big_lwe_dimension , 1 );
156- release_radix_ciphertext_async (streams[0 ], gpu_indexes[0 ], d_msb,
157- true );
158161 delete d_msb;
159162 };
160163
@@ -300,7 +303,8 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
300303 conditional_update (mem_ptr->sub_streams_3 , gpu_indexes, gpu_count,
301304 mem_ptr->c1 , r1, mem_ptr->zero_out_if_not_2_lut_2 , 3 );
302305 conditional_update (mem_ptr->sub_streams_4 , gpu_indexes, gpu_count,
303- mem_ptr->c0 , mem_ptr->rem , mem_ptr->zero_out_if_not_1_lut_2 , 2 );
306+ mem_ptr->c0 , mem_ptr->rem ,
307+ mem_ptr->zero_out_if_not_1_lut_2 , 2 );
304308
305309 calculate_quotient_bits (mem_ptr->sub_streams_5 , gpu_indexes, 1 , mem_ptr->q1 ,
306310 mem_ptr->c1 , mem_ptr->quotient_lut_1 );
@@ -319,12 +323,12 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
319323 cuda_synchronize_stream (mem_ptr->sub_streams_7 [j], gpu_indexes[j]);
320324 }
321325
322- host_addition<Torus>(streams[0 ], gpu_indexes[0 ], mem_ptr->rem , mem_ptr->rem , r3,
323- mem_ptr->rem ->num_radix_blocks , 4 , 4 );
324- host_addition<Torus>(streams[0 ], gpu_indexes[0 ], mem_ptr->rem , mem_ptr->rem , r2,
325- mem_ptr->rem ->num_radix_blocks , 4 , 4 );
326- host_addition<Torus>(streams[0 ], gpu_indexes[0 ], mem_ptr->rem , mem_ptr->rem , r1,
327- mem_ptr->rem ->num_radix_blocks , 4 , 4 );
326+ host_addition<Torus>(streams[0 ], gpu_indexes[0 ], mem_ptr->rem , mem_ptr->rem ,
327+ r3, mem_ptr->rem ->num_radix_blocks , 4 , 4 );
328+ host_addition<Torus>(streams[0 ], gpu_indexes[0 ], mem_ptr->rem , mem_ptr->rem ,
329+ r2, mem_ptr->rem ->num_radix_blocks , 4 , 4 );
330+ host_addition<Torus>(streams[0 ], gpu_indexes[0 ], mem_ptr->rem , mem_ptr->rem ,
331+ r1, mem_ptr->rem ->num_radix_blocks , 4 , 4 );
328332
329333 host_addition<Torus>(streams[0 ], gpu_indexes[0 ], mem_ptr->q1 , mem_ptr->q1 ,
330334 mem_ptr->q2 , 1 , 4 , 4 );
@@ -335,9 +339,9 @@ __host__ void host_unsigned_integer_div_rem_kb_block_by_block_2_2(
335339 cuda_synchronize_stream (streams[j], gpu_indexes[j]);
336340 }
337341 integer_radix_apply_univariate_lookup_table_kb<Torus>(
338- mem_ptr->sub_streams_1 , gpu_indexes, gpu_count, mem_ptr->rem , mem_ptr-> rem , bsks, ksks,
339- ms_noise_reduction_key, mem_ptr->message_extract_lut_1 ,
340- mem_ptr->rem ->num_radix_blocks );
342+ mem_ptr->sub_streams_1 , gpu_indexes, gpu_count, mem_ptr->rem ,
343+ mem_ptr->rem , bsks, ksks, ms_noise_reduction_key ,
344+ mem_ptr->message_extract_lut_1 , mem_ptr-> rem ->num_radix_blocks );
341345 integer_radix_apply_univariate_lookup_table_kb<Torus>(
342346 mem_ptr->sub_streams_2 , gpu_indexes, gpu_count, mem_ptr->q1 ,
343347 mem_ptr->q1 , bsks, ksks, ms_noise_reduction_key,
@@ -383,6 +387,7 @@ __host__ void host_unsigned_integer_div_rem_kb(
383387 host_unsigned_integer_div_rem_kb_block_by_block_2_2<Torus>(
384388 streams, gpu_indexes, gpu_count, quotient, remainder, numerator,
385389 divisor, bsks, ksks, ms_noise_reduction_key, mem_ptr->div_rem_2_2_mem );
390+ return ;
386391 }
387392 auto radix_params = mem_ptr->params ;
388393 auto num_blocks = quotient->num_radix_blocks ;
0 commit comments