Skip to content

Commit 978011b

Browse files
fix(gpu): avoid out of memory when benchmarking throughput
1 parent 0ed97cf commit 978011b

File tree

2 files changed

+34
-7
lines changed

2 files changed

+34
-7
lines changed

backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,8 +1681,19 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
16811681
new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count, params, 2,
16821682
pbs_count, true, size_tracker);
16831683
allocated_luts_message_carry = true;
1684+
uint64_t message_modulus_bits =
1685+
(uint64_t)std::log2(params.message_modulus);
1686+
uint64_t carry_modulus_bits = (uint64_t)std::log2(params.carry_modulus);
1687+
uint64_t total_bits_per_block =
1688+
message_modulus_bits + carry_modulus_bits;
1689+
uint64_t denominator =
1690+
(uint64_t)std::ceil((pow(2, total_bits_per_block) - 1) /
1691+
(pow(2, message_modulus_bits) - 1));
1692+
1693+
uint64_t upper_bound_num_blocks =
1694+
num_blocks_in_radix * num_blocks_in_radix * 2 / denominator;
16841695
luts_message_carry->allocate_lwe_vector_for_non_trivial_indexes(
1685-
streams, gpu_indexes, gpu_count, this->max_total_blocks_in_vec,
1696+
streams, gpu_indexes, gpu_count, upper_bound_num_blocks,
16861697
size_tracker, true);
16871698
}
16881699
}
@@ -1781,9 +1792,19 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
17811792
this->current_blocks = current_blocks;
17821793
this->small_lwe_vector = small_lwe_vector;
17831794
this->luts_message_carry = reused_lut;
1795+
1796+
uint64_t message_modulus_bits = (uint64_t)std::log2(params.message_modulus);
1797+
uint64_t carry_modulus_bits = (uint64_t)std::log2(params.carry_modulus);
1798+
uint64_t total_bits_per_block = message_modulus_bits + carry_modulus_bits;
1799+
uint64_t denominator =
1800+
(uint64_t)std::ceil((pow(2, total_bits_per_block) - 1) /
1801+
(pow(2, message_modulus_bits) - 1));
1802+
1803+
uint64_t upper_bound_num_blocks =
1804+
num_blocks_in_radix * num_blocks_in_radix * 2 / denominator;
17841805
this->luts_message_carry->allocate_lwe_vector_for_non_trivial_indexes(
1785-
streams, gpu_indexes, gpu_count, this->max_total_blocks_in_vec,
1786-
size_tracker, allocate_gpu_memory);
1806+
streams, gpu_indexes, gpu_count, upper_bound_num_blocks, size_tracker,
1807+
allocate_gpu_memory);
17871808
setup_index_buffers(streams, gpu_indexes, size_tracker);
17881809
}
17891810

tfhe-benchmark/src/utilities.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -437,12 +437,18 @@ pub fn throughput_num_threads(num_block: usize, op_pbs_count: u64) -> u64 {
437437
let total_num_sm = total_blocks_per_sm * total_num_sm;
438438
let min_num_waves = 4u64; //Enforce at least 4 waves in the GPU
439439
let elements_per_wave = total_num_sm as u64 / (num_block as u64);
440-
440+
// This should ensure that operations with PBS count more than the number of blocks
441+
// squared will default to 200 elements.
442+
let min_elements = if op_pbs_count > num_block as u64 * num_block as u64 {
443+
200u64
444+
} else {
445+
elements_per_wave * min_num_waves
446+
};
441447
let operation_loading = ((total_num_sm as u64 / op_pbs_count) as f64).max(minimum_loading);
442448
let elements = (total_num_sm as f64 * block_multiplicator * operation_loading) as u64;
443-
elements.min(elements_per_wave * min_num_waves) // This threshold is useful for operation
444-
// with both a small number of
445-
// block and low PBs count.
449+
elements.min(min_elements) // This threshold is useful for operation
450+
// with both a small number of
451+
// block and low PBs count.
446452
}
447453
#[cfg(feature = "hpu")]
448454
{

0 commit comments

Comments
 (0)