Skip to content

[None][feat] Remove input_sf swizzle for module WideEPMoE #6231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
auto stream = streamPtr->get();
MoeMinLatencyParams min_latency_params;
#ifdef USING_OSS_CUTLASS_MOE_GEMM
mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr,
mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr, true,
mSelectedExperts + mSelectedExpertsSize * mBufferIndex,
mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr,
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
Expand All @@ -993,7 +993,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
/*enable_alltoall=*/false, mUseLora, mLoraParams[mBufferIndex],
/*use_fp8_block_scaling=*/false, /*min_latency_mode=*/false, min_latency_params, stream);
#else
mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr,
mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr, true,
mSelectedExperts + mSelectedExpertsSize * mBufferIndex,
mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr,
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
Expand Down
32 changes: 16 additions & 16 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -448,14 +448,14 @@ class CutlassMoeFCRunnerInterface
= 0;
virtual std::vector<cutlass_extensions::CutlassGemmConfig> getTactics() = 0;

virtual void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts,
float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases,
ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
virtual void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf,
int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights,
void const* fc1_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights,
void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size,
int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr,
void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config,
bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
= 0;

// Aliases for profiling the gemms
Expand Down Expand Up @@ -603,14 +603,14 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
return RunnerType::getConfigs(sm);
}

void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts,
float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases,
ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override;
void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf,
int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights,
void const* fc1_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights,
void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size,
int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr,
void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config,
bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override;

// We make these GEMM1 & GEMM2 static because they need to be stateless for the profiler to work
static void gemm1(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const hidden_size, int const k,
int const num_experts_per_node, QuantParams const& quant_params, bool use_per_expert_act_scale,
int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, cudaStream_t stream);
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf,
void const* prequant_scales, cudaStream_t stream);

template <class OutputType, class GemmOutputType, class ScaleBiasType>
void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_rows,
Expand Down
62 changes: 38 additions & 24 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ __device__ auto quantizePackedFPXValue(ComputeElem& post_act_val, float global_s
template <int VecSize, int ElementsPerThread>
__device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, int64_t source_token_id, int64_t token_id,
int64_t elem_idx, int64_t num_cols, TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf)
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf = true)
{
static constexpr int NumThreadsPerSF = VecSize / ElementsPerThread;

Expand All @@ -1061,12 +1061,24 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, int
{
if (input_sf)
{
auto const sf_in
= cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF, NumThreadsPerSF>(
std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */,
num_cols / VecSize, const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf),
QuantizationSFLayout::SWIZZLED);
*sf_out = *sf_in;
if (swizzled_input_sf)
{
auto const sf_in
= cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF, NumThreadsPerSF>(
std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */,
num_cols / VecSize, const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf),
QuantizationSFLayout::SWIZZLED);
*sf_out = *sf_in;
}
else
{
auto const sf_in
= cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF, NumThreadsPerSF>(
std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */,
num_cols / VecSize, const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf),
QuantizationSFLayout::LINEAR);
*sf_out = *sf_in;
}
}
else
{
Expand Down Expand Up @@ -1460,8 +1472,8 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
int const* permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size, int64_t const k,
float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t const* expert_first_token_offset,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, int64_t const num_experts_per_node,
InputActivationsType const* prequant_scales = nullptr)
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf,
int64_t const num_experts_per_node, InputActivationsType const* prequant_scales = nullptr)
{
static_assert(BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE || !PRE_QUANT_AWQ,
"AWQ and Block Scaling are mutually exclusive");
Expand Down Expand Up @@ -1563,7 +1575,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
{
assert(act_scale_idx == 0 && "Cannot use per-expert act scale for pre-quantized activations");
writeSF<VecSize, ELEM_PER_THREAD>(num_tokens_before_expert, expert, source_row, permuted_row,
elem_index, padded_hidden_size, fc1_act_sf_flat, input_sf);
elem_index, padded_hidden_size, fc1_act_sf_flat, input_sf, swizzled_input_sf);
dest_row_ptr[elem_index] = in_vec;
}
}
Expand Down Expand Up @@ -1664,7 +1676,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const hidden_size, int const k,
int const num_experts_per_node, QuantParams const& quant_params, bool use_per_expert_act_scale,
int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, cudaStream_t stream)
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf,
void const* prequant_scales, cudaStream_t stream)
{
#ifdef ENABLE_FP4
TLLM_CHECK_WITH_INFO(
Expand Down Expand Up @@ -1740,8 +1753,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
config.attrs = attrs;
cudaLaunchKernelEx(&config, func, unpermuted_input, permuted_output, unpermuted_scales, permuted_scales,
permuted_row_to_unpermuted_row, num_rows, hidden_size, k, quant_params.fp4.fc1.act_global_scale,
use_per_expert_act_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf, num_experts_per_node,
reinterpret_cast<InputActivationsType const*>(prequant_scales));
use_per_expert_act_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf, swizzled_input_sf,
num_experts_per_node, reinterpret_cast<InputActivationsType const*>(prequant_scales));
}

#define INSTANTIATE_EXPAND_INPUT_ROWS(InputActivationsType, ExpandedActivationsType) \
Expand All @@ -1751,8 +1764,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
int64_t const num_rows, int64_t const hidden_size, int const k, int const num_experts_per_node, \
QuantParams const& quant_params, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, \
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, \
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, \
cudaStream_t stream)
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, \
void const* prequant_scales, cudaStream_t stream)

// Instantiate the data types that are used by the external pytorch op
INSTANTIATE_EXPAND_INPUT_ROWS(float, float);
Expand Down Expand Up @@ -3509,14 +3522,14 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab

template <class T, class WeightType, class OutputType, class InputType, class BackBoneType, class Enable>
void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::runMoe(
void const* input_activations_void, void const* input_sf_void, int const* token_selected_experts,
float const* token_final_scales, void const* fc1_expert_weights_void, void const* fc1_expert_biases_void,
ActivationParams fc1_activation_type, void const* fc2_expert_weights_void, void const* fc2_expert_biases_void,
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
int const full_num_experts, int const experts_per_token, char* workspace_ptr, void* final_output_void,
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
void const* input_activations_void, void const* input_sf_void, bool const swizzled_input_sf,
int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights_void,
void const* fc1_expert_biases_void, ActivationParams fc1_activation_type, void const* fc2_expert_weights_void,
void const* fc2_expert_biases_void, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size,
int64_t const inter_size, int const full_num_experts, int const experts_per_token, char* workspace_ptr,
void* final_output_void, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config,
bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
{
static constexpr bool int_scales_required
= std::is_same<WeightType, uint8_t>::value || std::is_same<WeightType, cutlass::uint4b_t>::value || use_wfp4a16;
Expand Down Expand Up @@ -3728,7 +3741,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
expandInputRowsKernelLauncher(input_activations, gemm1_input_expand, token_topk_unpermuted_scales,
permuted_token_final_scales_, permuted_row_to_unpermuted_row_, num_rows, hidden_size, experts_per_token,
num_experts_per_node, quant_params, use_per_expert_act_scale, expert_first_token_offset_,
fc1_fp4_act_scale_, input_sf, use_w4afp8 ? quant_params.groupwise.fc1.act_scales : nullptr, stream);
fc1_fp4_act_scale_, input_sf, swizzled_input_sf,
use_w4afp8 ? quant_params.groupwise.fc1.act_scales : nullptr, stream);
auto const* gemm1_input = gemm1_input_expand;

sync_check_cuda_error(stream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
MoeMinLatencyParams min_latency_params{};
mMOERunner->setTactic(gemm1, gemm2);
#ifdef USING_OSS_CUTLASS_MOE_GEMM
mMOERunner->runMoe(inputs[getInputTensorIndex()], nullptr,
mMOERunner->runMoe(inputs[getInputTensorIndex()], nullptr, true,
static_cast<int const*>(inputs[getTokenSelectedExpertsIndex()]),
hasFinalScales() ? static_cast<float const*>(inputs[getTokenFinalScalesIndex()]) : nullptr,
inputs[getExpertWeights1Index()], hasBias() ? inputs[getExpertBias1Index()] : nullptr,
Expand All @@ -969,7 +969,7 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
/*enable_alltoall=*/false, hasLora(), lora_params, /*use_deepseek_fp8_block_scale=*/false,
/*min_latency_mode=*/false, min_latency_params, stream);
#else
mMOERunner->runMoe(inputs[getInputTensorIndex()], nullptr,
mMOERunner->runMoe(inputs[getInputTensorIndex()], nullptr, true,
static_cast<int const*>(inputs[getTokenSelectedExpertsIndex()]),
hasFinalScales() ? static_cast<float const*>(inputs[getTokenFinalScalesIndex()]) : nullptr,
inputs[getExpertWeights1Index()], hasBias() ? inputs[getExpertBias1Index()] : nullptr,
Expand Down
Loading