Skip to content

Commit 93a0fd0

Browse files
authored
[TRTLLM-6445] feat: Enable AllReduce-associated fusion patterns in Llama3/4. (#6205)
Signed-off-by: Yukun He <[email protected]>
1 parent 2dd3186 commit 93a0fd0

File tree

2 files changed

+203
-32
lines changed

2 files changed

+203
-32
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ __global__ void __launch_bounds__(1024) allreduce_fusion_kernel_oneshot_lamport(
520520
}
521521

522522
template <AllReduceFusionPattern Pattern, typename DType, int NRanks, bool Fp32Acc>
523-
__global__ void allreduce_fusion_kernel_twoshot_sync(
523+
__global__ void __launch_bounds__(1024) allreduce_fusion_kernel_twoshot_sync(
524524
AllReduceFusionParams params, std::array<int, NRanks> begin_tokens, std::array<int, NRanks> token_num_per_ranks)
525525
{
526526
IndexHelper<DType> index_helper(params);

0 commit comments

Comments
 (0)