@@ -27,6 +27,138 @@ namespace kernels
2727{
2828
2929using namespace batchedGemm ::batchedGemm;
30+ using namespace batchedGemm ::gemm;
31+ using namespace batchedGemm ::trtllm::gen;
32+
33+ std::vector<int64_t > prioritizePredefinedConfigs (int m, int n, int k, std::vector<int64_t > const & sortedIndices,
34+ batchedGemm::batchedGemm::BatchedGemmConfig const * configs)
35+ {
36+
37+ // Function to bubble up the pre-determined config.
38+ auto bubbleUpConfig = [&configs](std::vector<int64_t > const & sortedIndices, auto && pred) -> std::vector<int64_t >
39+ {
40+ std::vector<int64_t > prioritizedIndices_;
41+ // Copy matching configs to new vector
42+ std::copy_if (sortedIndices.begin (), sortedIndices.end (), std::back_inserter (prioritizedIndices_),
43+ [&configs, &pred](int idx)
44+ {
45+ BatchedGemmConfig const & config = configs[idx];
46+ return (pred (config));
47+ });
48+ // Copy the rest of the configs to new vector, if not already copied
49+ std::copy_if (sortedIndices.begin (), sortedIndices.end (), std::back_inserter (prioritizedIndices_),
50+ [&prioritizedIndices_](int idx) {
51+ return std::find (prioritizedIndices_.begin (), prioritizedIndices_.end (), idx)
52+ == prioritizedIndices_.end ();
53+ });
54+ return prioritizedIndices_;
55+ };
56+
57+ // Init empty vector
58+ std::vector<int64_t > prioritizedIndices;
59+
60+ //
61+ // Qwen3
62+ //
63+
64+ // Qwen3_235B_TP1_EP8_MoE_FC1 m=3072 k=4096
65+ if (n /* out_dim */ == 3072 && k /* in_dim */ == 4096 )
66+ {
67+ auto pred = [](BatchedGemmConfig const & config)
68+ {
69+ BatchedGemmOptions const & options = config.mOptions ;
70+ return options.mNumStages == 4 && options.mNumStagesMma == 1 && options.mTileK == 512
71+ && options.mTileScheduler == TileScheduler::Static;
72+ };
73+ prioritizedIndices = bubbleUpConfig (sortedIndices, pred);
74+ }
75+ // Qwen3_235B_TP1_EP8_MoE_FC2 m=4096 k=1536
76+ else if (n /* out_dim */ == 4096 && k /* in_dim */ == 1536 )
77+ {
78+ auto pred = [](BatchedGemmConfig const & config)
79+ {
80+ BatchedGemmOptions const & options = config.mOptions ;
81+ return options.mNumStages == 4 && options.mNumStagesMma == 1 && options.mTileK == 512
82+ && options.mTileScheduler == TileScheduler::Static;
83+ };
84+ prioritizedIndices = bubbleUpConfig (sortedIndices, pred);
85+ }
86+ // Qwen3_235B_TP2_EP4_MoE_FC1 m=1536 k=4096
87+ else if (n /* out_dim */ == 1536 && k /* in_dim */ == 4096 )
88+ {
89+ auto pred = [](BatchedGemmConfig const & config)
90+ {
91+ BatchedGemmOptions const & options = config.mOptions ;
92+ return options.mNumStages == 4 && options.mNumStagesMma == 1 && options.mTileK == 512
93+ && options.mTileScheduler == TileScheduler::Static;
94+ };
95+ prioritizedIndices = bubbleUpConfig (sortedIndices, pred);
96+ }
97+ // Qwen3_235B_TP2_EP4_MoE_FC2 m=4096 k=768
98+ else if (n /* out_dim */ == 4096 && k /* in_dim */ == 768 )
99+ {
100+ auto pred = [](BatchedGemmConfig const & config)
101+ {
102+ BatchedGemmOptions const & options = config.mOptions ;
103+ return options.mNumStages == 4 && options.mNumStagesMma == 2 && options.mTileK == 512
104+ && options.mTileScheduler == TileScheduler::Persistent;
105+ };
106+ prioritizedIndices = bubbleUpConfig (sortedIndices, pred);
107+ }
108+ // Qwen3_235B_TP4_EP2_MoE_FC1 m=768 k=4096
109+ else if (n /* out_dim */ == 768 && k /* in_dim */ == 4096 )
110+ {
111+ auto pred = [](BatchedGemmConfig const & config)
112+ {
113+ BatchedGemmOptions const & options = config.mOptions ;
114+ return options.mNumStages == 4 && options.mNumStagesMma == 1 && options.mTileK == 512
115+ && options.mTileScheduler == TileScheduler::Static;
116+ };
117+ prioritizedIndices = bubbleUpConfig (sortedIndices, pred);
118+ }
119+ // Qwen3_235B_TP4_EP2_MoE_FC2 m=4096 k=384
120+ else if (n /* out_dim */ == 4096 && k /* in_dim */ == 384 )
121+ {
122+ auto pred = [](BatchedGemmConfig const & config)
123+ {
124+ BatchedGemmOptions const & options = config.mOptions ;
125+ return options.mNumStages == 4 && options.mNumStagesMma == 2 && options.mTileK == 512
126+ && options.mTileScheduler == TileScheduler::Persistent;
127+ };
128+ prioritizedIndices = bubbleUpConfig (sortedIndices, pred);
129+ }
130+ // Qwen3_235B_TP8_EP1_MoE_FC1 m=384 k=4096
131+ else if (n /* out_dim */ == 384 && k /* in_dim */ == 4096 )
132+ {
133+ auto pred = [](BatchedGemmConfig const & config)
134+ {
135+ BatchedGemmOptions const & options = config.mOptions ;
136+ return options.mNumStages == 4 && options.mNumStagesMma == 1 && options.mTileK == 512
137+ && options.mTileScheduler == TileScheduler::Static;
138+ };
139+ prioritizedIndices = bubbleUpConfig (sortedIndices, pred);
140+ }
141+ // Qwen3_235B_TP8_EP1_MoE_FC2 m=4096 k=192
142+ else if (n /* out_dim */ == 4096 && k /* in_dim */ == 192 )
143+ {
144+ auto pred = [](BatchedGemmConfig const & config)
145+ {
146+ BatchedGemmOptions const & options = config.mOptions ;
147+ return options.mNumStages == 4 && options.mNumStagesMma == 2 && options.mTileK == 256
148+ && options.mTileScheduler == TileScheduler::Persistent;
149+ };
150+ prioritizedIndices = bubbleUpConfig (sortedIndices, pred);
151+ }
152+ //
153+ // Fall back
154+ //
155+ else
156+ {
157+ prioritizedIndices = sortedIndices;
158+ }
159+
160+ return prioritizedIndices;
161+ }
30162
31163TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner (TrtllmGenBatchedGemmRunnerOptions const & options_)
32164 : mOptions (options_)
@@ -44,7 +176,8 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
44176 // When we include low-latency kernels we can set transposeMmaOutput via constructor
45177 if (options.mDtypeA == mOptions .eltType && options.mDtypeC == mOptions .outputType
46178 && options.mUseDeepSeekFp8 == mOptions .deepSeekFp8
47- && options.mTransposeMmaOutput == mOptions .transposeMmaOutput && options.mRouteAct == mOptions .routeAct
179+ && options.mTransposeMmaOutput == mOptions .transposeMmaOutput
180+ && (!doesRouteImplUseNoRoute (options.mRouteImpl )) == mOptions .routeAct
48181 && options.mFusedAct == mOptions .fusedAct && options.mIsStaticBatch == mOptions .staticBatch
49182 && tileSize == mOptions .tileSize )
50183 {
@@ -227,9 +360,9 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m
227360 gemmData.mProblemDimensions .mWorldSize = 1 ;
228361 gemmData.mProblemDimensions .mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
229362 // Sort configs by options
230- std::vector<int32_t > sortedIndices = mPassingConfigIndices ;
363+ std::vector<int64_t > sortedIndices = mPassingConfigIndices ;
231364 std::sort (sortedIndices.begin (), sortedIndices.end (),
232- [&configs](int32_t idx0, int32_t idx1)
365+ [&configs](int64_t idx0, int64_t idx1)
233366 {
234367 auto const & optionsA = configs[idx0].mOptions ;
235368 auto const & optionsB = configs[idx1].mOptions ;
@@ -247,16 +380,17 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m
247380 }
248381
249382 // Then by tile scheduler (persistent scheduler is better for FC2 in MoE)
250- if (! optionsA.mRouteAct )
383+ if (doesRouteImplUseNoRoute ( optionsA.mRouteImpl ) )
251384 {
252385 return optionsA.mTileScheduler == batchedGemm::gemm::TileScheduler::Persistent;
253386 }
254387
255388 return optionsA.mTileM > optionsB.mTileM ;
256389 });
257390
391+ std::vector<int64_t > prioritizedIndices = prioritizePredefinedConfigs (m, n, k, sortedIndices, configs);
258392 std::vector<int64_t > validConfigIndices;
259- for (auto const & configIndex : sortedIndices )
393+ for (auto const & configIndex : prioritizedIndices )
260394 {
261395 auto const & config = configs[configIndex];
262396 auto isValidConfig = bmm.isValidConfig (config, gemmData);
0 commit comments