Skip to content

Commit 022cb3b

Browse files
fix(gpu): avoid out of memory when benchmarking throughput
1 parent c4feabb commit 022cb3b

File tree

2 files changed

+53
-23
lines changed

2 files changed

+53
-23
lines changed

backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -756,18 +756,20 @@ template <typename Torus> struct int_radix_lut {
756756
CudaStreams streams, uint64_t max_num_radix_blocks,
757757
uint64_t &size_tracker, bool allocate_gpu_memory) {
758758
// We need to create the auxiliary array only in GPU 0
759-
lwe_aligned_vec.resize(active_streams.count());
760-
for (uint i = 0; i < active_streams.count(); i++) {
761-
uint64_t size_tracker_on_array_i = 0;
762-
auto inputs_on_gpu = std::max(
763-
THRESHOLD_MULTI_GPU, get_num_inputs_on_gpu(max_num_radix_blocks, i,
764-
active_streams.count()));
765-
Torus *d_array = (Torus *)cuda_malloc_with_size_tracking_async(
766-
inputs_on_gpu * (params.big_lwe_dimension + 1) * sizeof(Torus),
767-
streams.stream(0), streams.gpu_index(0), size_tracker_on_array_i,
768-
allocate_gpu_memory);
769-
lwe_aligned_vec[i] = d_array;
770-
size_tracker += size_tracker_on_array_i;
759+
if (active_streams.count() > 1) {
760+
lwe_aligned_vec.resize(active_streams.count());
761+
for (uint i = 0; i < active_streams.count(); i++) {
762+
uint64_t size_tracker_on_array_i = 0;
763+
auto inputs_on_gpu = std::max(
764+
THRESHOLD_MULTI_GPU, get_num_inputs_on_gpu(max_num_radix_blocks, i,
765+
active_streams.count()));
766+
Torus *d_array = (Torus *)cuda_malloc_with_size_tracking_async(
767+
inputs_on_gpu * (params.big_lwe_dimension + 1) * sizeof(Torus),
768+
streams.stream(0), streams.gpu_index(0), size_tracker_on_array_i,
769+
allocate_gpu_memory);
770+
lwe_aligned_vec[i] = d_array;
771+
size_tracker += size_tracker_on_array_i;
772+
}
771773
}
772774
}
773775

@@ -1632,8 +1634,19 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
16321634
luts_message_carry = new int_radix_lut<Torus>(
16331635
streams, params, 2, pbs_count, true, size_tracker);
16341636
allocated_luts_message_carry = true;
1637+
uint64_t message_modulus_bits =
1638+
(uint64_t)std::log2(params.message_modulus);
1639+
uint64_t carry_modulus_bits = (uint64_t)std::log2(params.carry_modulus);
1640+
uint64_t total_bits_per_block =
1641+
message_modulus_bits + carry_modulus_bits;
1642+
uint64_t denominator =
1643+
(uint64_t)std::ceil((pow(2, total_bits_per_block) - 1) /
1644+
(pow(2, message_modulus_bits) - 1));
1645+
1646+
uint64_t upper_bound_num_blocks =
1647+
max_total_blocks_in_vec * 2 / denominator;
16351648
luts_message_carry->allocate_lwe_vector_for_non_trivial_indexes(
1636-
streams, this->max_total_blocks_in_vec, size_tracker, true);
1649+
streams, upper_bound_num_blocks, size_tracker, true);
16371650
}
16381651
}
16391652
if (allocated_luts_message_carry) {
@@ -1731,9 +1744,17 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
17311744
this->current_blocks = current_blocks;
17321745
this->small_lwe_vector = small_lwe_vector;
17331746
this->luts_message_carry = reused_lut;
1747+
1748+
uint64_t message_modulus_bits = (uint64_t)std::log2(params.message_modulus);
1749+
uint64_t carry_modulus_bits = (uint64_t)std::log2(params.carry_modulus);
1750+
uint64_t total_bits_per_block = message_modulus_bits + carry_modulus_bits;
1751+
uint64_t denominator =
1752+
(uint64_t)std::ceil((pow(2, total_bits_per_block) - 1) /
1753+
(pow(2, message_modulus_bits) - 1));
1754+
1755+
uint64_t upper_bound_num_blocks = max_total_blocks_in_vec * 2 / denominator;
17341756
this->luts_message_carry->allocate_lwe_vector_for_non_trivial_indexes(
1735-
streams, this->max_total_blocks_in_vec, size_tracker,
1736-
allocate_gpu_memory);
1757+
streams, upper_bound_num_blocks, size_tracker, allocate_gpu_memory);
17371758
setup_index_buffers(streams, size_tracker);
17381759
}
17391760

tfhe-benchmark/src/utilities.rs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -421,23 +421,32 @@ pub fn throughput_num_threads(num_block: usize, op_pbs_count: u64) -> u64 {
421421
let block_multiplicator = (ref_block_count as f64 / num_block as f64).ceil().min(1.0);
422422
// Some operations with a high serial workload (e.g. division) would yield an operation
423423
// loading value so low that the number of elements in the end wouldn't be meaningful.
424-
let minimum_loading = if num_block < 64 { 0.2 } else { 0.01 };
424+
let minimum_loading = if num_block < 64 { 1.0 } else { 0.015 };
425425

426426
#[cfg(feature = "gpu")]
427427
{
428428
let num_sms_per_gpu = get_number_of_sms();
429429
let total_num_sm = num_sms_per_gpu * get_number_of_gpus();
430430

431-
let total_blocks_per_sm = 4u32; // Assume each SM can handle 4 blocks concurrently
432-
let total_num_sm = total_blocks_per_sm * total_num_sm;
431+
let total_blocks_per_sm = 4u64; // Assume each SM can handle 4 blocks concurrently
433432
let min_num_waves = 4u64; //Enforce at least 4 waves in the GPU
434-
let elements_per_wave = total_num_sm as u64 / (num_block as u64);
435-
433+
let block_factor = ((2.0f64 * num_block as f64) / 4.0f64).ceil() as u64;
434+
let elements_per_wave = total_blocks_per_sm * total_num_sm as u64 / block_factor;
435+
// We need to enable the new load for pbs benches and for sizes larger than 16 blocks in
436+
// demanding operations for the rest of operations we maintain a minimum of 200
437+
// elements
438+
let min_elements = if op_pbs_count == 1
439+
|| (op_pbs_count > (num_block * num_block) as u64 && num_block >= 16)
440+
{
441+
elements_per_wave * min_num_waves
442+
} else {
443+
200u64
444+
};
436445
let operation_loading = ((total_num_sm as u64 / op_pbs_count) as f64).max(minimum_loading);
437446
let elements = (total_num_sm as f64 * block_multiplicator * operation_loading) as u64;
438-
elements.min(elements_per_wave * min_num_waves) // This threshold is useful for operation
439-
// with both a small number of
440-
// block and low PBs count.
447+
elements.min(min_elements) // This threshold is useful for operation
448+
// with both a small number of
449+
// block and low PBs count.
441450
}
442451
#[cfg(feature = "hpu")]
443452
{

0 commit comments

Comments
 (0)