Skip to content

Commit 87b991b

Browse files
authored
Merge branch 'main' into user/chenfeiz/fix-70b-acc-drop
2 parents d04d0f3 + e88cb92 commit 87b991b

File tree

11 files changed

+420
-177
lines changed

11 files changed

+420
-177
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu

Lines changed: 163 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ static constexpr int MaxNumTopExperts = 8;
2727
static constexpr int MaxNumExperts = 128;
2828
static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads;
2929
static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps;
30+
static constexpr int NumThreadsSingleBlock = MaxNumExperts;
31+
static constexpr int BlockKernelMaxNumTokens = 4;
3032

3133
template <typename DataType, typename InputType, int VecSize, bool DoSoftmaxBeforeTopK>
3234
__forceinline__ __device__ void routingTopKExperts(cg::thread_block_tile<WarpSize> const& warp,
@@ -75,6 +77,156 @@ __forceinline__ __device__ void routingTopKExperts(cg::thread_block_tile<WarpSiz
7577
}
7678
}
7779

80+
template <typename KernelParams, bool DoSoftmaxBeforeTopK = false>
81+
__global__ void __launch_bounds__(NumThreadsSingleBlock) routingIndicesBlockKernel(KernelParams params)
82+
{
83+
// types used in this kernel
84+
using OutputT = typename KernelParams::OutputT;
85+
using InputT = typename KernelParams::InputT;
86+
using BaseType = std::conditional_t<KernelParams::DoSoftmaxBeforeTopK, float, InputT>;
87+
using TypePacked = PackedScoreIdx<BaseType>;
88+
89+
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
90+
int32_t const laneIdx = cutlass::arch::LaneId();
91+
int32_t const expert = threadIdx.x;
92+
auto scoreOffset = warpIdx * params.mNumExperts;
93+
bool validToken = warpIdx < params.mNumTokens;
94+
95+
static constexpr int VecSize = MaxNumExperts / WarpSize;
96+
static constexpr int totalExpertCounts = BlockKernelMaxNumTokens * MaxNumExperts;
97+
__shared__ int8_t __attribute((aligned(128))) smemOffset[totalExpertCounts];
98+
__shared__ int8_t __attribute((aligned(128))) smemKIdx[totalExpertCounts];
99+
100+
using Scan = cub::BlockScan<int32_t, NumThreadsSingleBlock, cub::BLOCK_SCAN_WARP_SCANS>;
101+
__shared__ typename Scan::TempStorage tempStorage;
102+
103+
auto block = cg::this_thread_block();
104+
auto warp = cg::tiled_partition<WarpSize>(block);
105+
106+
for (int i = threadIdx.x; i < totalExpertCounts; i += blockDim.x)
107+
{
108+
smemOffset[i] = int8_t{-1};
109+
smemKIdx[i] = int8_t{-1};
110+
}
111+
__syncthreads();
112+
113+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
114+
// then wait on primary grid
115+
if constexpr (KernelParams::UsePdl)
116+
{
117+
cudaGridDependencySynchronize();
118+
}
119+
#endif
120+
121+
if (params.mPtrScores != nullptr)
122+
{
123+
// in this case, each warp represents a token
124+
BaseType score[VecSize];
125+
int32_t idx[VecSize];
126+
127+
BaseType warpTopKScore[MaxNumTopExperts];
128+
int32_t warpTopKExpertIdx[MaxNumTopExperts];
129+
130+
BaseType minScore = BaseType{-INFINITY};
131+
if (validToken)
132+
{
133+
routingTopKExperts<BaseType, InputT, VecSize, KernelParams::DoSoftmaxBeforeTopK>(warp, score, idx,
134+
warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, params.mTopK,
135+
params.mPtrScores + scoreOffset, params.mNormTopkProb);
136+
137+
if (laneIdx < params.mTopK)
138+
{
139+
int offset = warpIdx * MaxNumExperts + warpTopKExpertIdx[laneIdx];
140+
smemKIdx[offset] = static_cast<int8_t>(laneIdx);
141+
if (params.mPtrExpertWeights != nullptr)
142+
{
143+
params.mPtrExpertWeights[warpIdx * params.mTopK + laneIdx] = OutputT{warpTopKScore[laneIdx]};
144+
}
145+
}
146+
} // end if (validToken)
147+
}
148+
__syncthreads();
149+
150+
// set local experts
151+
auto localExpertIdx = expert - params.mLocalExpertsStartIdx;
152+
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < params.mNumLocalExperts
153+
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
154+
// Get the count of each expert and the offset for each token
155+
int accExpertCount = 0;
156+
157+
if (isLocalExpert)
158+
{
159+
int offset = expert;
160+
for (int j = 0; j < BlockKernelMaxNumTokens; j++)
161+
{
162+
if (smemKIdx[offset] >= 0)
163+
{
164+
smemOffset[offset] = static_cast<int8_t>(accExpertCount);
165+
accExpertCount++;
166+
}
167+
offset += MaxNumExperts;
168+
}
169+
}
170+
__syncthreads();
171+
// Get the number of CTAs and the offset for each CTA
172+
const int32_t numCta = divUpLog2<int32_t>(accExpertCount, params.mPaddingLog2);
173+
int32_t ctaOffset = 0;
174+
int32_t numNonExitingCtas;
175+
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
176+
177+
int32_t expertScanCounts = 0;
178+
Scan(tempStorage).ExclusiveSum(divUpMulLog2(accExpertCount, params.mPaddingLog2), expertScanCounts);
179+
__syncthreads();
180+
181+
if (isLocalExpert)
182+
{
183+
for (int cta = 0; cta < numCta; ++cta)
184+
{
185+
const int32_t localExpertIdx = (expert - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
186+
params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
187+
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta]
188+
= min(mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2),
189+
mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + accExpertCount);
190+
}
191+
}
192+
193+
// at this point, we can write out padded count
194+
if (threadIdx.x == 0)
195+
{
196+
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
197+
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
198+
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
199+
}
200+
201+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
202+
#if !defined(FDL_PROFILE) || FDL_PROFILE == 0
203+
// we can trigger the next kernel at this point
204+
if constexpr (KernelParams::UsePdl)
205+
{
206+
cudaTriggerProgrammaticLaunchCompletion();
207+
}
208+
#endif
209+
#endif
210+
211+
for (int tokenIdx = 0; tokenIdx < params.mNumTokens; tokenIdx++)
212+
{
213+
int offset = tokenIdx * MaxNumExperts + threadIdx.x;
214+
if (smemKIdx[offset] >= 0)
215+
{
216+
int const expandedIdx = tokenIdx * params.mTopK + smemKIdx[offset];
217+
int const offsetWithinExpert = static_cast<int>(smemOffset[offset]);
218+
int const offsetForExpert = expertScanCounts;
219+
int const permutedIdx = isLocalExpert ? offsetForExpert + offsetWithinExpert : int32_t{-1};
220+
221+
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
222+
if (isLocalExpert)
223+
{
224+
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
225+
}
226+
}
227+
}
228+
}
229+
78230
template <typename KernelParams>
79231
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
80232
__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads)
@@ -234,18 +386,27 @@ void run(Data const& data, void* stream)
234386
data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
235387
TLLM_CHECK_WITH_INFO(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", data.mPaddingLog2);
236388

389+
bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens;
237390
bool const useSingleCluster
238391
= data.mNumTokens <= (data.mPtrScores != nullptr ? MaxNumTokensSingleClusterScores : MaxNumTokensSingleCluster);
239392

240-
if (!useSingleCluster)
393+
if (!useSingleCluster && !useSingleBlock)
241394
{
242395
TLLM_CHECK_WITH_INFO(
243396
data.mPtrExpertIdx != nullptr, "When #tokens is large, `mPtrExpertIdx` is a required input.");
244397
TLLM_CHECK_WITH_INFO(
245398
data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input.");
246399
}
247400

248-
if (useSingleCluster)
401+
if (useSingleBlock)
402+
{
403+
//@TODO: For now we use the single block kernel for cases with token number no larger than 4.
404+
// We will future tune this threshold based on the performance.
405+
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, false, routingIndicesBlockKernel, 1, NumThreadsSingleBlock,
406+
/*smemSize=*/0, // No dynamic smem
407+
stream, data.mDoSoftmaxBeforeTopK, /*forceFloatInput=*/false);
408+
}
409+
else if (useSingleCluster)
249410
{
250411
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads,
251412
/*smemSize=*/0, // No dynamic smem

cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,28 @@ class RoutingRenormalizeKernelTest : public RoutingKernelTest<T>
178178

179179
TYPED_TEST_SUITE(RoutingRenormalizeKernelTest, FloatAndBf16Types);
180180

181+
TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelization)
182+
{
183+
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/4,
184+
/*numExperts=*/128, /*topK=*/8,
185+
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
186+
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
187+
/*usePdl=*/true, /*getExpWeights=*/true,
188+
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
189+
this->runTest(param);
190+
};
191+
192+
TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationWithExpertParallelization)
193+
{
194+
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/14,
195+
/*numExperts=*/128, /*topK=*/8,
196+
/*expertParallelization=*/2, /*expertParallelizationId=*/1,
197+
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
198+
/*usePdl=*/true, /*getExpWeights=*/true,
199+
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
200+
this->runTest(param);
201+
};
202+
181203
TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelization)
182204
{
183205
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/10,

0 commit comments

Comments
 (0)