@@ -27,6 +27,8 @@ static constexpr int MaxNumTopExperts = 8;
27
27
static constexpr int MaxNumExperts = 128 ;
28
28
static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads;
29
29
static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps;
30
+ static constexpr int NumThreadsSingleBlock = MaxNumExperts;
31
+ static constexpr int BlockKernelMaxNumTokens = 4 ;
30
32
31
33
template <typename DataType, typename InputType, int VecSize, bool DoSoftmaxBeforeTopK>
32
34
__forceinline__ __device__ void routingTopKExperts (cg::thread_block_tile<WarpSize> const & warp,
@@ -75,6 +77,156 @@ __forceinline__ __device__ void routingTopKExperts(cg::thread_block_tile<WarpSiz
75
77
}
76
78
}
77
79
80
+ template <typename KernelParams, bool DoSoftmaxBeforeTopK = false >
81
+ __global__ void __launch_bounds__ (NumThreadsSingleBlock) routingIndicesBlockKernel(KernelParams params)
82
+ {
83
+ // types used in this kernel
84
+ using OutputT = typename KernelParams::OutputT;
85
+ using InputT = typename KernelParams::InputT;
86
+ using BaseType = std::conditional_t <KernelParams::DoSoftmaxBeforeTopK, float , InputT>;
87
+ using TypePacked = PackedScoreIdx<BaseType>;
88
+
89
+ int32_t const warpIdx = __shfl_sync (0xffffffff , threadIdx .x / WarpSize, 0 );
90
+ int32_t const laneIdx = cutlass::arch::LaneId ();
91
+ int32_t const expert = threadIdx .x ;
92
+ auto scoreOffset = warpIdx * params.mNumExperts ;
93
+ bool validToken = warpIdx < params.mNumTokens ;
94
+
95
+ static constexpr int VecSize = MaxNumExperts / WarpSize;
96
+ static constexpr int totalExpertCounts = BlockKernelMaxNumTokens * MaxNumExperts;
97
+ __shared__ int8_t __attribute ((aligned (128 ))) smemOffset[totalExpertCounts];
98
+ __shared__ int8_t __attribute ((aligned (128 ))) smemKIdx[totalExpertCounts];
99
+
100
+ using Scan = cub::BlockScan<int32_t , NumThreadsSingleBlock, cub::BLOCK_SCAN_WARP_SCANS>;
101
+ __shared__ typename Scan::TempStorage tempStorage;
102
+
103
+ auto block = cg::this_thread_block ();
104
+ auto warp = cg::tiled_partition<WarpSize>(block);
105
+
106
+ for (int i = threadIdx .x ; i < totalExpertCounts; i += blockDim .x )
107
+ {
108
+ smemOffset[i] = int8_t {-1 };
109
+ smemKIdx[i] = int8_t {-1 };
110
+ }
111
+ __syncthreads ();
112
+
113
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
114
+ // then wait on primary grid
115
+ if constexpr (KernelParams::UsePdl)
116
+ {
117
+ cudaGridDependencySynchronize ();
118
+ }
119
+ #endif
120
+
121
+ if (params.mPtrScores != nullptr )
122
+ {
123
+ // in this case, each warp represents a token
124
+ BaseType score[VecSize];
125
+ int32_t idx[VecSize];
126
+
127
+ BaseType warpTopKScore[MaxNumTopExperts];
128
+ int32_t warpTopKExpertIdx[MaxNumTopExperts];
129
+
130
+ BaseType minScore = BaseType{-INFINITY};
131
+ if (validToken)
132
+ {
133
+ routingTopKExperts<BaseType, InputT, VecSize, KernelParams::DoSoftmaxBeforeTopK>(warp, score, idx,
134
+ warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts , params.mTopK ,
135
+ params.mPtrScores + scoreOffset, params.mNormTopkProb );
136
+
137
+ if (laneIdx < params.mTopK )
138
+ {
139
+ int offset = warpIdx * MaxNumExperts + warpTopKExpertIdx[laneIdx];
140
+ smemKIdx[offset] = static_cast <int8_t >(laneIdx);
141
+ if (params.mPtrExpertWeights != nullptr )
142
+ {
143
+ params.mPtrExpertWeights [warpIdx * params.mTopK + laneIdx] = OutputT{warpTopKScore[laneIdx]};
144
+ }
145
+ }
146
+ } // end if (validToken)
147
+ }
148
+ __syncthreads ();
149
+
150
+ // set local experts
151
+ auto localExpertIdx = expert - params.mLocalExpertsStartIdx ;
152
+ auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < params.mNumLocalExperts
153
+ && (localExpertIdx & params.mLocalExpertsStrideLog2 ) == 0 ;
154
+ // Get the count of each expert and the offset for each token
155
+ int accExpertCount = 0 ;
156
+
157
+ if (isLocalExpert)
158
+ {
159
+ int offset = expert;
160
+ for (int j = 0 ; j < BlockKernelMaxNumTokens; j++)
161
+ {
162
+ if (smemKIdx[offset] >= 0 )
163
+ {
164
+ smemOffset[offset] = static_cast <int8_t >(accExpertCount);
165
+ accExpertCount++;
166
+ }
167
+ offset += MaxNumExperts;
168
+ }
169
+ }
170
+ __syncthreads ();
171
+ // Get the number of CTAs and the offset for each CTA
172
+ const int32_t numCta = divUpLog2<int32_t >(accExpertCount, params.mPaddingLog2 );
173
+ int32_t ctaOffset = 0 ;
174
+ int32_t numNonExitingCtas;
175
+ Scan (tempStorage).ExclusiveSum (numCta, ctaOffset, numNonExitingCtas);
176
+
177
+ int32_t expertScanCounts = 0 ;
178
+ Scan (tempStorage).ExclusiveSum (divUpMulLog2 (accExpertCount, params.mPaddingLog2 ), expertScanCounts);
179
+ __syncthreads ();
180
+
181
+ if (isLocalExpert)
182
+ {
183
+ for (int cta = 0 ; cta < numCta; ++cta)
184
+ {
185
+ const int32_t localExpertIdx = (expert - params.mLocalExpertsStartIdx ) >> params.mLocalExpertsStrideLog2 ;
186
+ params.mPtrCtaIdxXyToBatchIdx [ctaOffset + cta] = localExpertIdx;
187
+ params.mPtrCtaIdxXyToMnLimit [ctaOffset + cta]
188
+ = min (mulLog2<int32_t >(ctaOffset + cta + 1 , params.mPaddingLog2 ),
189
+ mulLog2<int32_t >(ctaOffset, params.mPaddingLog2 ) + accExpertCount);
190
+ }
191
+ }
192
+
193
+ // at this point, we can write out padded count
194
+ if (threadIdx .x == 0 )
195
+ {
196
+ const int32_t permutedIdxSize = mulLog2<int32_t >(numNonExitingCtas, params.mPaddingLog2 );
197
+ params.mPtrPermutedIdxSize [0 ] = permutedIdxSize;
198
+ params.mPtrNumNonExitingCtas [0 ] = numNonExitingCtas;
199
+ }
200
+
201
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
202
+ #if !defined(FDL_PROFILE) || FDL_PROFILE == 0
203
+ // we can trigger the next kernel at this point
204
+ if constexpr (KernelParams::UsePdl)
205
+ {
206
+ cudaTriggerProgrammaticLaunchCompletion ();
207
+ }
208
+ #endif
209
+ #endif
210
+
211
+ for (int tokenIdx = 0 ; tokenIdx < params.mNumTokens ; tokenIdx++)
212
+ {
213
+ int offset = tokenIdx * MaxNumExperts + threadIdx .x ;
214
+ if (smemKIdx[offset] >= 0 )
215
+ {
216
+ int const expandedIdx = tokenIdx * params.mTopK + smemKIdx[offset];
217
+ int const offsetWithinExpert = static_cast <int >(smemOffset[offset]);
218
+ int const offsetForExpert = expertScanCounts;
219
+ int const permutedIdx = isLocalExpert ? offsetForExpert + offsetWithinExpert : int32_t {-1 };
220
+
221
+ params.mPtrExpandedIdxToPermutedIdx [expandedIdx] = permutedIdx;
222
+ if (isLocalExpert)
223
+ {
224
+ params.mPtrPermutedIdxToTokenIdx [permutedIdx] = tokenIdx;
225
+ }
226
+ }
227
+ }
228
+ }
229
+
78
230
template <typename KernelParams>
79
231
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
80
232
__global__ void __cluster_dims__ (NumBlocksPerCluster, 1 , 1 ) __launch_bounds__(NumThreads)
@@ -234,18 +386,27 @@ void run(Data const& data, void* stream)
234
386
data.mNumExperts % 4 == 0 , " Routing kernel expects #experts %d to be a multiple of 4." , data.mNumExperts );
235
387
TLLM_CHECK_WITH_INFO (data.mPaddingLog2 < 8 , " Routing kernel expects padding log2 < 8, got %d" , data.mPaddingLog2 );
236
388
389
+ bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens;
237
390
bool const useSingleCluster
238
391
= data.mNumTokens <= (data.mPtrScores != nullptr ? MaxNumTokensSingleClusterScores : MaxNumTokensSingleCluster);
239
392
240
- if (!useSingleCluster)
393
+ if (!useSingleCluster && !useSingleBlock )
241
394
{
242
395
TLLM_CHECK_WITH_INFO (
243
396
data.mPtrExpertIdx != nullptr , " When #tokens is large, `mPtrExpertIdx` is a required input." );
244
397
TLLM_CHECK_WITH_INFO (
245
398
data.mPtrExpertCounts != nullptr , " When #tokens is large, `mPtrExpertCounts` is a required input." );
246
399
}
247
400
248
- if (useSingleCluster)
401
+ if (useSingleBlock)
402
+ {
403
+ // @TODO: For now we use the single block kernel for cases with token number no larger than 4.
404
+ // We will future tune this threshold based on the performance.
405
+ LAUNCH_ROUTING_WITH_EXTRA_FLAG (data, false , routingIndicesBlockKernel, 1 , NumThreadsSingleBlock,
406
+ /* smemSize=*/ 0 , // No dynamic smem
407
+ stream, data.mDoSoftmaxBeforeTopK , /* forceFloatInput=*/ false );
408
+ }
409
+ else if (useSingleCluster)
249
410
{
250
411
LAUNCH_ROUTING_WITH_EXTRA_FLAG (data, false , routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads,
251
412
/* smemSize=*/ 0 , // No dynamic smem
0 commit comments