Skip to content

Commit 6488677

Browse files
authored
Merge branch 'main' into syr/attn_tp_config
2 parents 56d66e3 + e0253ee commit 6488677

File tree

84 files changed

+2356
-1045
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

84 files changed

+2356
-1045
lines changed

.gitattributes

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@ triton_backend/tools/gpt/input_data.json filter=lfs diff=lfs merge=lfs -text
99
docs/source/blogs/media/tech_blog3_mla_absorb.png filter=lfs diff=lfs merge=lfs -text
1010
tests/integration/test_input_files/*.png filter=lfs diff=lfs merge=lfs -text
1111
tests/integration/test_input_files/*.jpg filter=lfs diff=lfs merge=lfs -text
12+
docs/source/blogs/media/tech_blog10_baseline_performance_detail.png filter=lfs diff=lfs merge=lfs -text
13+
docs/source/blogs/media/tech_blog10_full_strategy_performance.png filter=lfs diff=lfs merge=lfs -text
14+
docs/source/blogs/media/tech_blog10_context_wait_performance.png filter=lfs diff=lfs merge=lfs -text

.github/workflows/pr-check.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,5 @@ jobs:
6969
- name: Validate PR Checklist
7070
env:
7171
PR_BODY: ${{ github.event.pull_request.body }}
72-
ENFORCE_PR_HAS_CHECKLIST: true
72+
ENFORCE_PR_HAS_CHECKLIST: false
7373
run: python .github/scripts/pr_checklist_check.py

cpp/include/tensorrt_llm/deep_gemm/scheduler.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ struct GroupedMaskedScheduler
379379
}
380380
};
381381

382-
// Need to keep the same as the one in tests/unittest/_torch/thop/deep_gemm_tests.py
382+
// Need to keep the same as the one in tests/unittest/_torch/thop/parallel/deep_gemm_tests.py
383383
template <typename T_offset, typename T_index>
384384
__host__ __device__ __forceinline__ T_offset compute_padded_offset(T_offset offset, T_index problem_idx)
385385
{

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,22 @@ struct ScaledAccPerRowBias
373373
static constexpr bool IsPerRowBiasSupported = true;
374374
};
375375

376+
template<
377+
class ElementOutput_,
378+
class ElementCompute_,
379+
class ElementBias_ = ElementOutput_,
380+
class ElementScalar_ = ElementCompute_,
381+
int AlignmentBias_ = 128 / cute::sizeof_bits_v<ElementBias_>,
382+
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
383+
>
384+
struct ScaledAccPerColBias
385+
: ScaledAcc<ElementOutput_, ElementCompute_, ElementScalar_, RoundStyle_>
386+
{
387+
using ElementBias = ElementBias_;
388+
static constexpr int AlignmentBias = AlignmentBias_;
389+
static constexpr bool IsPerColBiasSupported = true;
390+
};
391+
376392
template<
377393
class GmemLayoutTagOut,
378394
class ElementOutput,
@@ -393,6 +409,26 @@ struct ScaledAccPerRowBiasPerColScaleScatter
393409
static constexpr bool IsAuxOutSupported = true;
394410
};
395411

412+
template<
413+
class GmemLayoutTagOut,
414+
class ElementOutput,
415+
class ElementCompute,
416+
class ElementBias = ElementOutput,
417+
class ElementScale = ElementCompute,
418+
class ElementScalar = ElementCompute,
419+
int AlignmentBias = 128 / cute::sizeof_bits_v<ElementBias>,
420+
int AlignmentOutput = 128 / cute::sizeof_bits_v<ElementOutput>,
421+
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
422+
>
423+
struct ScaledAccPerColBiasPerRowScaleScatter
424+
: ScaledAccPerColBias<ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>
425+
{
426+
using ElementAux = ElementOutput;
427+
using GmemLayoutTagAux = GmemLayoutTagOut;
428+
static constexpr int AlignmentAux = AlignmentOutput;
429+
static constexpr bool IsAuxOutSupported = true;
430+
};
431+
396432
// D = alpha * acc + per-row bias
397433
template<
398434
class CtaTileShapeMNK,
@@ -410,6 +446,22 @@ using Sm90ScaledAccPerRowBiasPtrArray =
410446
Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias *, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias
411447
>;
412448

449+
template<
450+
class CtaTileShapeMNK,
451+
class ElementOutput,
452+
class ElementCompute,
453+
class ElementBias = ElementOutput,
454+
class ElementScalar = ElementCompute,
455+
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
456+
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
457+
>
458+
using Sm90ScaledAccPerColBiasPtrArray =
459+
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // alpha * acc + bias
460+
Sm90ScalarBroadcastPtrArray<ElementScalar, Stride<_0,_0,int64_t>>, // alpha
461+
Sm90AccFetch, // acc
462+
Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias *, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias
463+
>;
464+
413465
template<
414466
class CtaTileShapeMNK,
415467
class EpilogueTile,
@@ -433,6 +485,29 @@ using Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray =
433485
>
434486
>;
435487

488+
template<
489+
class CtaTileShapeMNK,
490+
class EpilogueTile,
491+
class StrideOutput,
492+
class SmemLayoutAtom,
493+
class CopyOpR2S,
494+
class ElementOutput,
495+
class ElementCompute,
496+
class ElementBias = ElementOutput,
497+
class ElementScale = ElementCompute,
498+
class ElementScalar = ElementCompute,
499+
int AlignmentBias = 128 / cute::sizeof_bits_v<ElementBias>,
500+
int AlignmentOutput = 128 / cute::sizeof_bits_v<ElementOutput>,
501+
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
502+
>
503+
using Sm90ScaledAccPerColBiasPerRowScaleScatterPtrArray =
504+
Sm90EVT<Sm90ScatterPtrArray<EpilogueTile, StrideOutput, SmemLayoutAtom, CopyOpR2S, ElementOutput, AlignmentOutput, RoundStyle>, // scatter store
505+
Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // scale * (alpha * acc + bias)
506+
Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar *, ElementCompute, Stride<_1,_0,int64_t>, 1>, // scale
507+
Sm90ScaledAccPerColBiasPtrArray<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle> // alpha * acc + bias
508+
>
509+
>;
510+
436511
template <
437512
int StagesC,
438513
int StagesD,
@@ -556,6 +631,129 @@ struct FusionCallbacks<
556631

557632
};
558633

634+
template <
635+
int StagesC,
636+
int StagesD,
637+
int FragmentSize,
638+
bool ReuseSmemC,
639+
bool DelayTmaStore,
640+
int NumEpilogueWarpGroups,
641+
class GmemLayoutTagOut,
642+
class ElementOutput,
643+
class ElementCompute,
644+
class ElementBias,
645+
class ElementScale,
646+
class ElementScalar,
647+
int AlignmentBias,
648+
int AlignmentOutput,
649+
FloatRoundStyle RoundStyle,
650+
class CtaTileShapeMNK,
651+
class EpilogueTile,
652+
class SmemLayoutAtom,
653+
class CopyOpR2S
654+
>
655+
struct FusionCallbacks<
656+
epilogue::Sm90PtrArrayTmaWarpSpecialized<StagesC,
657+
StagesD,
658+
FragmentSize,
659+
ReuseSmemC,
660+
DelayTmaStore,
661+
NumEpilogueWarpGroups
662+
>,
663+
fusion::ScaledAccPerColBiasPerRowScaleScatter<GmemLayoutTagOut,
664+
ElementOutput,
665+
ElementCompute,
666+
ElementBias,
667+
ElementScale,
668+
ElementScalar,
669+
AlignmentBias,
670+
AlignmentOutput,
671+
RoundStyle>,
672+
CtaTileShapeMNK,
673+
EpilogueTile,
674+
SmemLayoutAtom,
675+
CopyOpR2S
676+
> : Sm90ScaledAccPerColBiasPerRowScaleScatterPtrArray<
677+
CtaTileShapeMNK,
678+
EpilogueTile,
679+
cutlass::gemm::TagToStrideC_t<GmemLayoutTagOut>,
680+
SmemLayoutAtom, CopyOpR2S,
681+
ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar,
682+
AlignmentBias, AlignmentOutput, RoundStyle
683+
> {
684+
685+
using StrideOutput = cutlass::gemm::TagToStrideC_t<GmemLayoutTagOut>;
686+
687+
using Impl = Sm90ScaledAccPerColBiasPerRowScaleScatterPtrArray<
688+
CtaTileShapeMNK,
689+
EpilogueTile,
690+
StrideOutput,
691+
SmemLayoutAtom, CopyOpR2S,
692+
ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar,
693+
AlignmentBias, AlignmentOutput, RoundStyle
694+
>;
695+
using Operation = fusion::ScaledAccPerColBiasPerRowScaleScatter<
696+
GmemLayoutTagOut,
697+
ElementOutput,
698+
ElementCompute,
699+
ElementBias,
700+
ElementScale,
701+
ElementScalar,
702+
AlignmentBias,
703+
AlignmentOutput,
704+
RoundStyle>;
705+
706+
struct Arguments {
707+
708+
using StrideAlpha = Stride<_0,_0,int64_t>;
709+
ElementScalar alpha = ElementScalar(1);
710+
ElementScalar const* alpha_ptr{};
711+
ElementScalar const* const* alpha_ptr_array{};
712+
StrideAlpha dAlpha{};
713+
714+
using StrideBias = Stride<_0,_1,int64_t>;
715+
ElementBias const* const* bias_ptr{};
716+
StrideBias dBias{};
717+
718+
using StrideScale = Stride<_1,_0,int64_t>;
719+
ElementScalar const* const* scale_ptr_array{};
720+
StrideScale dScale{};
721+
722+
// Nested args not usable due to a compiler bug with constexpr evaluation
723+
// using ScatterArguments = typename Sm90ScatterPtrArray<EpilogueTile, StrideOutput, SmemLayoutAtom, CopyOpR2S, ElementOutput, AlignmentOutput, RoundStyle>::Arguments;
724+
// ScatterArguments scatter{};
725+
726+
ElementOutput* ptr_out = nullptr;
727+
StrideOutput dOut = {};
728+
int const* const* ptr_index{}; // per-group pointer to the scatter index
729+
int index_modulo{}; // modulo used to transform the index before store
730+
int shape_override = -1; // override value for contiguous output tensor mode
731+
bool use_reduction = true;
732+
733+
operator typename Impl::Arguments() const {
734+
return
735+
{ // unary op: reduce(scale * (beta * C + (alpha * acc)))
736+
{ // binary op: scale * (beta * C + (alpha * acc))
737+
{ scale_ptr_array, ElementScalar(1), dScale }, // leaf args : scale broadcast
738+
{ // ternary op : alpha * acc + bias
739+
{{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha
740+
{}, // leaf args : acc
741+
{bias_ptr, ElementBias(0), dBias}, // leaf args : bias
742+
{} // ternary args : multiply_add
743+
}, // end binary op
744+
{} // binary args: multiply
745+
}, // end binary op
746+
//scatter // unary args: reduce
747+
{ ptr_out, dOut, ptr_index, index_modulo, shape_override, use_reduction }
748+
}; // end unary op
749+
}
750+
};
751+
752+
// Ctor inheritance
753+
using Impl::Impl;
754+
755+
};
756+
559757
} // namespace cutlass::epilogue::fusion
560758

561759
// clang-format on

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
2+
* Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -440,6 +440,7 @@ struct CutlassGemmConfig
440440
};
441441

442442
EpilogueFusionType epilogue_fusion_type = EpilogueFusionType::NONE;
443+
bool swap_ab = false;
443444

444445
CutlassGemmConfig() = default;
445446

@@ -511,7 +512,8 @@ struct CutlassGemmConfig
511512
<< "\n\tcluster shape ID: " << (int) cluster_shape
512513
<< "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule
513514
<< "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false")
514-
<< "\n\tepilogue fusion type: " << (int) epilogue_fusion_type;
515+
<< "\n\tepilogue fusion type: " << (int) epilogue_fusion_type
516+
<< "\n\tswap_ab: " << (swap_ab ? "true" : "false");
515517
}
516518
else if (tile_config_sm80 != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
517519
{
@@ -544,7 +546,8 @@ inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& conf
544546
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule)
545547
<< ", cluster_shape_enum: " << int(config.cluster_shape)
546548
<< ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false")
547-
<< ", epilogue_fusion_type: " << int(config.epilogue_fusion_type);
549+
<< ", epilogue_fusion_type: " << int(config.epilogue_fusion_type)
550+
<< ", swap_ab: " << (config.swap_ab ? "true" : "false");
548551
}
549552
else
550553
{

cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,24 @@ struct TmaWarpSpecializedGroupedGemmInput
7474
static_assert(std::is_same_v<cutlass::layout::RowMajor, TransposeLayoutTag<cutlass::layout::ColumnMajor>>);
7575
static_assert(std::is_same_v<cutlass::layout::ColumnMajor, TransposeLayoutTag<cutlass::layout::RowMajor>>);
7676

77-
// Layout for A and B is transposed and then swapped in the implementation
78-
// This uses B^T * A^T = (A * B)^T to get a better layout for the GEMM
79-
using LayoutA = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for A matrix operand
80-
using LayoutB = TransposeLayoutTag<cutlass::layout::ColumnMajor>; // Layout type for B matrix operand
81-
using LayoutC = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for C matrix operand
82-
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand
77+
// These are always the layout of A & B matrices, activations and weights will be assigned to either A or B based on
78+
// swap_ab
79+
using LayoutA = cutlass::layout::RowMajor;
80+
using LayoutB = cutlass::layout::ColumnMajor;
81+
82+
// When using Swap A&B we need to transpose the output matrix
83+
using LayoutC = cutlass::layout::RowMajor;
84+
using LayoutD = cutlass::layout::RowMajor;
85+
using LayoutC_T = TransposeLayoutTag<LayoutC>;
86+
using LayoutD_T = TransposeLayoutTag<LayoutD>;
87+
88+
using StrideA = std::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
89+
using StrideB = std::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;
90+
91+
using StrideC = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC*>>;
92+
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
93+
using StrideC_T = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC_T*>>;
94+
using StrideD_T = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD_T*>>;
8395

8496
constexpr static int NVFP4BlockScaleVectorSize = 16;
8597
constexpr static int MXFPXBlockScaleVectorSize = 32;
@@ -110,13 +122,6 @@ struct TmaWarpSpecializedGroupedGemmInput
110122
return (dim + alignment - 1) / alignment * alignment;
111123
}
112124

113-
using StrideA
114-
= std::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutA*>>; // Use B because they will be swapped
115-
using StrideB
116-
= std::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutB*>>; // Use A because they will be swapped
117-
using StrideC = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC*>>;
118-
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
119-
120125
#ifdef ENABLE_FP8
121126
template <class T>
122127
constexpr static bool IsFP8_v = std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
@@ -131,26 +136,29 @@ struct TmaWarpSpecializedGroupedGemmInput
131136

132137
using ProblemShape = cutlass::gemm::GroupProblemShape<cute::Shape<int64_t, int64_t, int64_t>>;
133138

139+
bool swap_ab = false;
134140
ProblemShape shape_info{};
135-
StrideA* stride_a = nullptr;
136-
StrideB* stride_b = nullptr;
141+
void* stride_act = nullptr;
142+
void* stride_weight = nullptr;
137143

138-
void const** ptr_a = nullptr;
139-
void const** ptr_b = nullptr;
144+
void const** ptr_act = nullptr;
145+
void const** ptr_weight = nullptr;
140146

141147
// C is currently the same in both epilogues
142-
StrideC* stride_c = nullptr;
148+
void* stride_c = nullptr;
143149
void const** ptr_c = nullptr;
144150

145151
// D is used in all cases except fused finalize
146-
StrideD* stride_d = nullptr;
152+
void* stride_d = nullptr;
147153
void** ptr_d = nullptr;
148154

149155
struct FusedFinalizeEpilogue
150156
{
157+
using StrideFinalOutput_T = cutlass::detail::TagToStrideC_t<LayoutD_T>;
151158
using StrideFinalOutput = cutlass::detail::TagToStrideC_t<LayoutD>;
152159

153160
void* ptr_final_output = nullptr;
161+
StrideFinalOutput_T stride_final_output_transposed{};
154162
StrideFinalOutput stride_final_output{};
155163

156164
void const** ptr_bias = nullptr;
@@ -179,11 +187,11 @@ struct TmaWarpSpecializedGroupedGemmInput
179187
using ElementSF = uint8_t;
180188
using MXFPXElementSF = ElementSF; // Just an alias for now
181189
using NVFP4ElementSF = ElementSF; // Just an alias for now
182-
ElementSF const** fpX_block_scaling_factors_A = nullptr;
183-
ElementSF const** fpX_block_scaling_factors_B = nullptr;
190+
ElementSF const** fpX_block_scaling_factors_act = nullptr;
191+
ElementSF const** fpX_block_scaling_factors_weight = nullptr;
184192

185-
void* fpX_block_scaling_factors_stride_A = nullptr;
186-
void* fpX_block_scaling_factors_stride_B = nullptr;
193+
void* fpX_block_scaling_factors_stride_act = nullptr;
194+
void* fpX_block_scaling_factors_stride_weight = nullptr;
187195

188196
enum class FpXBlockScalingType
189197
{
@@ -229,7 +237,7 @@ struct TmaWarpSpecializedGroupedGemmInput
229237

230238
bool isValid() const
231239
{
232-
return stride_a != nullptr && ptr_a != nullptr;
240+
return stride_act != nullptr && ptr_act != nullptr;
233241
}
234242

235243
void setFinalizeFusionParams(void* final_output, int hidden_size, int num_output_tokens, bool use_reduction);

0 commit comments

Comments
 (0)