Skip to content

Commit 4f9fa9f

Browse files
authored
feat: MoE trtllm backend kernel update (#5183)
Signed-off-by: Anthony Chang <[email protected]>
1 parent 1d2b0d3 commit 4f9fa9f

File tree

88 files changed

+3650
-1459
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+3650
-1459
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp

Lines changed: 139 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,138 @@ namespace kernels
2727
{
2828

2929
using namespace batchedGemm::batchedGemm;
30+
using namespace batchedGemm::gemm;
31+
using namespace batchedGemm::trtllm::gen;
32+
33+
std::vector<int64_t> prioritizePredefinedConfigs(int m, int n, int k, std::vector<int64_t> const& sortedIndices,
34+
batchedGemm::batchedGemm::BatchedGemmConfig const* configs)
35+
{
36+
37+
// Function to bubble up the pre-determined config.
38+
auto bubbleUpConfig = [&configs](std::vector<int64_t> const& sortedIndices, auto&& pred) -> std::vector<int64_t>
39+
{
40+
std::vector<int64_t> prioritizedIndices_;
41+
// Copy matching configs to new vector
42+
std::copy_if(sortedIndices.begin(), sortedIndices.end(), std::back_inserter(prioritizedIndices_),
43+
[&configs, &pred](int idx)
44+
{
45+
BatchedGemmConfig const& config = configs[idx];
46+
return (pred(config));
47+
});
48+
// Copy the rest of the configs to new vector, if not already copied
49+
std::copy_if(sortedIndices.begin(), sortedIndices.end(), std::back_inserter(prioritizedIndices_),
50+
[&prioritizedIndices_](int idx) {
51+
return std::find(prioritizedIndices_.begin(), prioritizedIndices_.end(), idx)
52+
== prioritizedIndices_.end();
53+
});
54+
return prioritizedIndices_;
55+
};
56+
57+
// Init empty vector
58+
std::vector<int64_t> prioritizedIndices;
59+
60+
//
61+
// Qwen3
62+
//
63+
64+
// Qwen3_235B_TP1_EP8_MoE_FC1 m=3072 k=4096
65+
if (n /* out_dim */ == 3072 && k /* in_dim */ == 4096)
66+
{
67+
auto pred = [](BatchedGemmConfig const& config)
68+
{
69+
BatchedGemmOptions const& options = config.mOptions;
70+
return options.mNumStages == 4 && options.mNumStagesMma == 1 && options.mTileK == 512
71+
&& options.mTileScheduler == TileScheduler::Static;
72+
};
73+
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
74+
}
75+
// Qwen3_235B_TP1_EP8_MoE_FC2 m=4096 k=1536
76+
else if (n /* out_dim */ == 4096 && k /* in_dim */ == 1536)
77+
{
78+
auto pred = [](BatchedGemmConfig const& config)
79+
{
80+
BatchedGemmOptions const& options = config.mOptions;
81+
return options.mNumStages == 4 && options.mNumStagesMma == 1 && options.mTileK == 512
82+
&& options.mTileScheduler == TileScheduler::Static;
83+
};
84+
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
85+
}
86+
// Qwen3_235B_TP2_EP4_MoE_FC1 m=1536 k=4096
87+
else if (n /* out_dim */ == 1536 && k /* in_dim */ == 4096)
88+
{
89+
auto pred = [](BatchedGemmConfig const& config)
90+
{
91+
BatchedGemmOptions const& options = config.mOptions;
92+
return options.mNumStages == 4 && options.mNumStagesMma == 1 && options.mTileK == 512
93+
&& options.mTileScheduler == TileScheduler::Static;
94+
};
95+
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
96+
}
97+
// Qwen3_235B_TP2_EP4_MoE_FC2 m=4096 k=768
98+
else if (n /* out_dim */ == 4096 && k /* in_dim */ == 768)
99+
{
100+
auto pred = [](BatchedGemmConfig const& config)
101+
{
102+
BatchedGemmOptions const& options = config.mOptions;
103+
return options.mNumStages == 4 && options.mNumStagesMma == 2 && options.mTileK == 512
104+
&& options.mTileScheduler == TileScheduler::Persistent;
105+
};
106+
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
107+
}
108+
// Qwen3_235B_TP4_EP2_MoE_FC1 m=768 k=4096
109+
else if (n /* out_dim */ == 768 && k /* in_dim */ == 4096)
110+
{
111+
auto pred = [](BatchedGemmConfig const& config)
112+
{
113+
BatchedGemmOptions const& options = config.mOptions;
114+
return options.mNumStages == 4 && options.mNumStagesMma == 1 && options.mTileK == 512
115+
&& options.mTileScheduler == TileScheduler::Static;
116+
};
117+
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
118+
}
119+
// Qwen3_235B_TP4_EP2_MoE_FC2 m=4096 k=384
120+
else if (n /* out_dim */ == 4096 && k /* in_dim */ == 384)
121+
{
122+
auto pred = [](BatchedGemmConfig const& config)
123+
{
124+
BatchedGemmOptions const& options = config.mOptions;
125+
return options.mNumStages == 4 && options.mNumStagesMma == 2 && options.mTileK == 512
126+
&& options.mTileScheduler == TileScheduler::Persistent;
127+
};
128+
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
129+
}
130+
// Qwen3_235B_TP8_EP1_MoE_FC1 m=384 k=4096
131+
else if (n /* out_dim */ == 384 && k /* in_dim */ == 4096)
132+
{
133+
auto pred = [](BatchedGemmConfig const& config)
134+
{
135+
BatchedGemmOptions const& options = config.mOptions;
136+
return options.mNumStages == 4 && options.mNumStagesMma == 1 && options.mTileK == 512
137+
&& options.mTileScheduler == TileScheduler::Static;
138+
};
139+
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
140+
}
141+
// Qwen3_235B_TP8_EP1_MoE_FC2 m=4096 k=192
142+
else if (n /* out_dim */ == 4096 && k /* in_dim */ == 192)
143+
{
144+
auto pred = [](BatchedGemmConfig const& config)
145+
{
146+
BatchedGemmOptions const& options = config.mOptions;
147+
return options.mNumStages == 4 && options.mNumStagesMma == 2 && options.mTileK == 256
148+
&& options.mTileScheduler == TileScheduler::Persistent;
149+
};
150+
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
151+
}
152+
//
153+
// Fall back
154+
//
155+
else
156+
{
157+
prioritizedIndices = sortedIndices;
158+
}
159+
160+
return prioritizedIndices;
161+
}
30162

31163
TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunnerOptions const& options_)
32164
: mOptions(options_)
@@ -44,7 +176,8 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
44176
// When we include low-latency kernels we can set transposeMmaOutput via constructor
45177
if (options.mDtypeA == mOptions.eltType && options.mDtypeC == mOptions.outputType
46178
&& options.mUseDeepSeekFp8 == mOptions.deepSeekFp8
47-
&& options.mTransposeMmaOutput == mOptions.transposeMmaOutput && options.mRouteAct == mOptions.routeAct
179+
&& options.mTransposeMmaOutput == mOptions.transposeMmaOutput
180+
&& (!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct
48181
&& options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch
49182
&& tileSize == mOptions.tileSize)
50183
{
@@ -227,9 +360,9 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m
227360
gemmData.mProblemDimensions.mWorldSize = 1;
228361
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
229362
// Sort configs by options
230-
std::vector<int32_t> sortedIndices = mPassingConfigIndices;
363+
std::vector<int64_t> sortedIndices = mPassingConfigIndices;
231364
std::sort(sortedIndices.begin(), sortedIndices.end(),
232-
[&configs](int32_t idx0, int32_t idx1)
365+
[&configs](int64_t idx0, int64_t idx1)
233366
{
234367
auto const& optionsA = configs[idx0].mOptions;
235368
auto const& optionsB = configs[idx1].mOptions;
@@ -247,16 +380,17 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m
247380
}
248381

249382
// Then by tile scheduler (persistent scheduler is better for FC2 in MoE)
250-
if (!optionsA.mRouteAct)
383+
if (doesRouteImplUseNoRoute(optionsA.mRouteImpl))
251384
{
252385
return optionsA.mTileScheduler == batchedGemm::gemm::TileScheduler::Persistent;
253386
}
254387

255388
return optionsA.mTileM > optionsB.mTileM;
256389
});
257390

391+
std::vector<int64_t> prioritizedIndices = prioritizePredefinedConfigs(m, n, k, sortedIndices, configs);
258392
std::vector<int64_t> validConfigIndices;
259-
for (auto const& configIndex : sortedIndices)
393+
for (auto const& configIndex : prioritizedIndices)
260394
{
261395
auto const& config = configs[configIndex];
262396
auto isValidConfig = bmm.isValidConfig(config, gemmData);

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,26 @@ class TrtllmGenBatchedGemmRunner
5050
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
5151
std::optional<int32_t> configIndex = std::nullopt);
5252

53+
// Generic GEMM interface
5354
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
5455
int32_t numBatches, int32_t maxNumCtasInBatchDim, void const* a, void const* sfA, void const* b,
5556
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC,
5657
float const* scaleGateC, void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
5758
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
5859
void* workspace, CUstream stream, int device, std::optional<int32_t> configIndex = std::nullopt);
5960

61+
// NVFP4 per-block scaling GEMM
6062
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* sfA,
6163
void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device,
6264
std::optional<int32_t> configIndex = std::nullopt);
6365

66+
// FP8 per-tensor scaling GEMM
6467
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* b,
6568
float const* scaleC, float const* scaleGateC, void* c, void* workspace, CUstream stream, int device,
6669
std::optional<int32_t> configIndex = std::nullopt);
6770

6871
// Get the list of configs that passed the validation based on the constructor options
69-
[[nodiscard]] std::vector<int32_t> getPassingConfigIndices() const
72+
[[nodiscard]] std::vector<int64_t> getPassingConfigIndices() const
7073
{
7174
return mPassingConfigIndices;
7275
}
@@ -88,8 +91,8 @@ class TrtllmGenBatchedGemmRunner
8891

8992
private:
9093
TrtllmGenBatchedGemmRunnerOptions mOptions;
91-
std::vector<int32_t> mPassingConfigIndices;
92-
std::optional<int32_t> mSelectedConfigIndex;
94+
std::vector<int64_t> mPassingConfigIndices;
95+
std::optional<int64_t> mSelectedConfigIndex;
9396
};
9497
} // namespace kernels
9598
} // namespace tensorrt_llm
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION &
3+
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
#pragma once
18+
19+
#include <cassert>
20+
#include <string>
21+
22+
namespace batchedGemm
23+
{
24+
25+
namespace batchedGemm
26+
{
27+
28+
////////////////////////////////////////////////////////////////////////////////////////////////////
29+
30+
enum class RouteImpl
31+
{
32+
// No Routing
33+
NoRoute = 0,
34+
// Use LDGSTS to do the routing
35+
Ldgsts = 1,
36+
// Use UTMALDG.GATHER4 to do the routing
37+
Tma = 2
38+
};
39+
40+
////////////////////////////////////////////////////////////////////////////////////////////////////
41+
42+
inline bool doesRouteImplUseNoRoute(RouteImpl mode)
43+
{
44+
return (mode == RouteImpl::NoRoute);
45+
}
46+
47+
////////////////////////////////////////////////////////////////////////////////////////////////////
48+
49+
inline bool doesRouteImplUseLdgsts(RouteImpl mode)
50+
{
51+
return (mode == RouteImpl::Ldgsts);
52+
}
53+
54+
////////////////////////////////////////////////////////////////////////////////////////////////////
55+
56+
inline bool doesRouteImplUseTma(RouteImpl mode)
57+
{
58+
return (mode == RouteImpl::Tma);
59+
}
60+
61+
////////////////////////////////////////////////////////////////////////////////////////////////////
62+
63+
} // namespace batchedGemm
64+
65+
////////////////////////////////////////////////////////////////////////////////////////////////////
66+
67+
} // namespace batchedGemm

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,16 @@ struct BatchedGemmData
9696
// Logical strides are [K, 1].
9797
//
9898
// If batchN:
99-
// If transposeMatrixA is false
99+
// If layoutA is MatrixLayout::MajorK
100100
// Logical shape is [B, divUpMul(M, tileM), K].
101101
// Logical strides are [divUpMul(M, tileM) * K, K, 1].
102-
// If transposeMatrixA is true
102+
// If layoutA is MatrixLayout::MajorMn
103103
// Logical shape is [B, K, divUpMul(M, tileM)].
104104
// Logical strides are [K * divUpMul(M, tileM), divUpMul(M, tileM), 1].
105+
// If layoutA is MatrixLayout::BlockMajorK
106+
// Logical shape is [B, K / blockK, divUpMul(M, tileM), blockK].
107+
// Logical strides are [K * divUpMul(M, tileM), divUpMul(M, tileM) * blockK, blockK, 1].
108+
// where blockK is 128B.
105109
void const* mPtrA{nullptr};
106110

107111
// The block scaling factors to dequantize A.
@@ -154,12 +158,16 @@ struct BatchedGemmData
154158
// Logical strides are [K, 1].
155159
//
156160
// If batchM:
157-
// If transposeMatrixB is true
161+
// If layoutB is MatrixLayout::MajorK
158162
// Logical shape is [B, divUpMul(N, tileN), K].
159163
// Logical strides are [divUpMul(N, tileN) * K, K, 1].
160-
// If transposeMatrixB is false
164+
// If layoutB is MatrixLayout::MajorMn
161165
// Logical shape is [B, K, divUpMul(N, tileN)].
162166
// Logical strides are [K * divUpMul(N, tileN), divUpMul(N, tileN), 1].
167+
// If layoutB is MatrixLayout::BlockMajorK
168+
// Logical shape is [B, K / blockK, divUpMul(N, tileN), blockK].
169+
// Logical strides are [K * divUpMul(N, tileN), divUpMul(N, tileN) * blockK, blockK, 1].
170+
// where blockK is 128B.
163171
void const* mPtrB{nullptr};
164172

165173
// The scaling factors to dequantize B.
@@ -210,6 +218,21 @@ struct BatchedGemmData
210218
// Logical shape is [sum(divUpMul(N[bi], tileN) for bi in B)]
211219
void const* mPtrPerTokenSfB{nullptr};
212220

221+
// The bias applied after the GEMM and before the activation function.
222+
// The bias is applied before applying the global scaling factor. I.e.
223+
// C = act(A * B + bias') * scaleC
224+
// scaleC = dequantA * dequantB * quantC
225+
// Thus, the bias' = bias / (dequantA * dequantB), where the bias is the original bias.
226+
//
227+
// If batchM, BiasType must be N, and bias shape is [B, N].
228+
// The bias is broadcasted along the M dimension.
229+
//
230+
// If batchN BiasType must be M, and bias shape is [B, M].
231+
// The bias is broadcasted along the N dimension.
232+
//
233+
// The dtype is float32.
234+
void const* mPtrBias{nullptr};
235+
213236
// The output tensor scaling factor for MxFp{4,8}, Fp8 and NvFp4 quantization.
214237
// TensorRT-LLM API requires a scaling factor on the device.
215238
// Shape is [B].
@@ -220,6 +243,12 @@ struct BatchedGemmData
220243
// Shape is [B].
221244
float const* mPtrScaleGate{nullptr};
222245

246+
// The alpha and beta for SwiGlu.
247+
// gatedActivation <- (x0 + beta) * sigmoid(alpha * x1)
248+
// Shape is [B]
249+
float const* mPtrSwiGluAlpha{nullptr};
250+
float const* mPtrSwiGluBeta{nullptr};
251+
223252
// Param is used when the kernel is configured with -routeAct true.
224253
// The inputs are not padded, but the outputs are padded to divUpMul(M[bi], tileM) for batchM or
225254
// divUpMul(N[bi], tileN) for batchN.
@@ -609,11 +638,13 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
609638
batchedGemmData.mInputBuffers.mPtrB, batchedGemmData.mOutputBuffers.mPtrC,
610639
batchedGemmData.mInputBuffers.mPtrSfA, batchedGemmData.mInputBuffers.mPtrSfB,
611640
batchedGemmData.mInputBuffers.mPtrPerTokenSfA, batchedGemmData.mInputBuffers.mPtrPerTokenSfB,
612-
batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC,
613-
batchedGemmData.mInputBuffers.mPtrScaleGate, batchedGemmData.mInputBuffers.mPtrRouteMap, dPtrRowMax,
614-
dPtrRowMaxBars, batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas,
615-
batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens, batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx,
616-
batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, maxNumCtasInBatchDim);
641+
batchedGemmData.mInputBuffers.mPtrBias, batchedGemmData.mOutputBuffers.mPtrSfC,
642+
batchedGemmData.mInputBuffers.mPtrScaleC, batchedGemmData.mInputBuffers.mPtrScaleGate,
643+
batchedGemmData.mInputBuffers.mPtrSwiGluAlpha, batchedGemmData.mInputBuffers.mPtrSwiGluBeta,
644+
batchedGemmData.mInputBuffers.mPtrRouteMap, dPtrRowMax, dPtrRowMaxBars,
645+
batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas, batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens,
646+
batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx, batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit,
647+
maxNumCtasInBatchDim);
617648

618649
// The size of the grid.
619650
std::vector<int32_t> grid{numCtaX, numCtaY, numCtaZ};

0 commit comments

Comments
 (0)