@@ -756,18 +756,20 @@ template <typename Torus> struct int_radix_lut {
756756 CudaStreams streams, uint64_t max_num_radix_blocks,
757757 uint64_t &size_tracker, bool allocate_gpu_memory) {
758758 // We need to create the auxiliary array only in GPU 0
759- lwe_aligned_vec.resize (active_streams.count ());
760- for (uint i = 0 ; i < active_streams.count (); i++) {
761- uint64_t size_tracker_on_array_i = 0 ;
762- auto inputs_on_gpu = std::max (
763- THRESHOLD_MULTI_GPU, get_num_inputs_on_gpu (max_num_radix_blocks, i,
764- active_streams.count ()));
765- Torus *d_array = (Torus *)cuda_malloc_with_size_tracking_async (
766- inputs_on_gpu * (params.big_lwe_dimension + 1 ) * sizeof (Torus),
767- streams.stream (0 ), streams.gpu_index (0 ), size_tracker_on_array_i,
768- allocate_gpu_memory);
769- lwe_aligned_vec[i] = d_array;
770- size_tracker += size_tracker_on_array_i;
759+ if (active_streams.count () > 1 ) {
760+ lwe_aligned_vec.resize (active_streams.count ());
761+ for (uint i = 0 ; i < active_streams.count (); i++) {
762+ uint64_t size_tracker_on_array_i = 0 ;
763+ auto inputs_on_gpu = std::max (
764+ THRESHOLD_MULTI_GPU, get_num_inputs_on_gpu (max_num_radix_blocks, i,
765+ active_streams.count ()));
766+ Torus *d_array = (Torus *)cuda_malloc_with_size_tracking_async (
767+ inputs_on_gpu * (params.big_lwe_dimension + 1 ) * sizeof (Torus),
768+ streams.stream (0 ), streams.gpu_index (0 ), size_tracker_on_array_i,
769+ allocate_gpu_memory);
770+ lwe_aligned_vec[i] = d_array;
771+ size_tracker += size_tracker_on_array_i;
772+ }
771773 }
772774 }
773775
@@ -1632,8 +1634,19 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
16321634 luts_message_carry = new int_radix_lut<Torus>(
16331635 streams, params, 2 , pbs_count, true , size_tracker);
16341636 allocated_luts_message_carry = true ;
1637+ uint64_t message_modulus_bits =
1638+ (uint64_t )std::log2 (params.message_modulus );
1639+ uint64_t carry_modulus_bits = (uint64_t )std::log2 (params.carry_modulus );
1640+ uint64_t total_bits_per_block =
1641+ message_modulus_bits + carry_modulus_bits;
1642+ uint64_t denominator =
1643+ (uint64_t )std::ceil ((pow (2 , total_bits_per_block) - 1 ) /
1644+ (pow (2 , message_modulus_bits) - 1 ));
1645+
1646+ uint64_t upper_bound_num_blocks =
1647+ max_total_blocks_in_vec * 2 / denominator;
16351648 luts_message_carry->allocate_lwe_vector_for_non_trivial_indexes (
1636- streams, this -> max_total_blocks_in_vec , size_tracker, true );
1649+ streams, upper_bound_num_blocks , size_tracker, true );
16371650 }
16381651 }
16391652 if (allocated_luts_message_carry) {
@@ -1731,9 +1744,17 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
17311744 this ->current_blocks = current_blocks;
17321745 this ->small_lwe_vector = small_lwe_vector;
17331746 this ->luts_message_carry = reused_lut;
1747+
1748+ uint64_t message_modulus_bits = (uint64_t )std::log2 (params.message_modulus );
1749+ uint64_t carry_modulus_bits = (uint64_t )std::log2 (params.carry_modulus );
1750+ uint64_t total_bits_per_block = message_modulus_bits + carry_modulus_bits;
1751+ uint64_t denominator =
1752+ (uint64_t )std::ceil ((pow (2 , total_bits_per_block) - 1 ) /
1753+ (pow (2 , message_modulus_bits) - 1 ));
1754+
1755+ uint64_t upper_bound_num_blocks = max_total_blocks_in_vec * 2 / denominator;
17341756 this ->luts_message_carry ->allocate_lwe_vector_for_non_trivial_indexes (
1735- streams, this ->max_total_blocks_in_vec , size_tracker,
1736- allocate_gpu_memory);
1757+ streams, upper_bound_num_blocks, size_tracker, allocate_gpu_memory);
17371758 setup_index_buffers (streams, size_tracker);
17381759 }
17391760
0 commit comments