Skip to content

[TRTLLM-6898][feat] make fused_moe_cute_dsl work on blackwell #6616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 8, 2025
Merged
15 changes: 14 additions & 1 deletion tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from ...model_config import ModelConfig
from ...utils import Fp4QuantizedTensor
from .fused_moe_cutlass import CutlassFusedMoE
from .quantization import MoEWeightLoadingMode
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
MoEWeightLoadingMode, UnquantizedFusedMoEMethod)
from .routing import BaseMoeRoutingMethod


Expand Down Expand Up @@ -340,6 +341,18 @@ def __init__(
layer_idx=layer_idx,
)

def _get_quant_method(self):
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True):
if self.quant_config.layer_quant_mode.has_fp8_block_scales():
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm()
else:
raise ValueError(
f"Unsupported quantization mode: {self.quant_config.quant_mode}"
)
else:
return UnquantizedFusedMoEMethod()

@nvtx_range("[DG] forward")
def forward_chunk(
self,
Expand Down
81 changes: 44 additions & 37 deletions tensorrt_llm/_torch/modules/fused_moe/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,45 +629,8 @@ def create_weights(self, module: torch.nn.Module):

def load_weights(self, module: torch.nn.Module, weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode):

if get_sm_version() == 100:
expert_ids = set(module.initial_local_expert_ids)
if self.need_load_shared_weights(module):
expert_ids.update(
module.layer_load_balancer.get_load_expert_ids())
for name in list(weights.keys()):
if name.endswith("weight_scale_inv"):
if int(name.split(".")[0]) not in expert_ids:
continue
weight_name = name.replace("weight_scale_inv", "weight")
logger.debug(f"Resmoothing {weight_name}")
weight = weights[weight_name][:]
scale = weights[name][:]
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
weight, scale)
super().load_weights(module, weights, weight_loading_mode)

if get_sm_version() == 100:
transfromed_w3_w1_scale = transform_sf_into_required_layout(
module.quant_scales[0],
mn=module.w3_w1_weight.shape[1],
k=module.w3_w1_weight.shape[2],
recipe=(1, 128, 128),
num_groups=module.w3_w1_weight.shape[0],
is_sfa=False)
module.w3_w1_weight_scaling_factor = nn.Parameter(
transfromed_w3_w1_scale, requires_grad=False)
transfromed_w2_scale = transform_sf_into_required_layout(
module.quant_scales[1],
mn=module.w2_weight.shape[1],
k=module.w2_weight.shape[2],
recipe=(1, 128, 128),
num_groups=module.w3_w1_weight.shape[0],
is_sfa=False)
module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale,
requires_grad=False)
self.setup_quant_scales(module)

def setup_quant_scales(self, module: torch.nn.Module):
module.quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales(
fc_weight_scales=module.w3_w1_weight_scaling_factor,
Expand Down Expand Up @@ -765,6 +728,50 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
})


class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
DeepSeekFP8BlockScalesFusedMoEMethod):

def load_weights(self, module: torch.nn.Module, weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode):
if get_sm_version() == 100:
expert_ids = set(module.initial_local_expert_ids)
if self.need_load_shared_weights(module):
expert_ids.update(
module.layer_load_balancer.get_load_expert_ids())
for name in list(weights.keys()):
if name.endswith("weight_scale_inv"):
if int(name.split(".")[0]) not in expert_ids:
continue
weight_name = name.replace("weight_scale_inv", "weight")
logger.debug(f"Resmoothing {weight_name}")
weight = weights[weight_name][:]
scale = weights[name][:]
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
weight, scale)
super().load_weights(module, weights, weight_loading_mode)

if get_sm_version() == 100:
transfromed_w3_w1_scale = transform_sf_into_required_layout(
module.quant_scales[0],
mn=module.w3_w1_weight.shape[1],
k=module.w3_w1_weight.shape[2],
recipe=(1, 128, 128),
num_groups=module.w3_w1_weight.shape[0],
is_sfa=False)
module.w3_w1_weight_scaling_factor = nn.Parameter(
transfromed_w3_w1_scale, requires_grad=False)
transfromed_w2_scale = transform_sf_into_required_layout(
module.quant_scales[1],
mn=module.w2_weight.shape[1],
k=module.w2_weight.shape[2],
recipe=(1, 128, 128),
num_groups=module.w3_w1_weight.shape[0],
is_sfa=False)
module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale,
requires_grad=False)
self.setup_quant_scales(module)


class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):

def create_weights(self, module: torch.nn.Module):
Expand Down
9 changes: 6 additions & 3 deletions tensorrt_llm/_torch/modules/gated_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def __init__(self,
config: Optional[ModelConfig] = None,
overridden_tp_size: Optional[int] = None,
reduce_output: bool = True,
layer_idx: Optional[int] = None):
layer_idx: Optional[int] = None,
use_cute_dsl_blockscaling_mm: bool = False):
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = hidden_size
Expand Down Expand Up @@ -64,7 +65,8 @@ def __init__(self,
reduce_output=False,
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm)

self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
[self.hidden_size])
Expand All @@ -81,7 +83,8 @@ def __init__(self,
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.down_lora,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm)

# These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used,
# but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora
Expand Down
30 changes: 20 additions & 10 deletions tensorrt_llm/_torch/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,21 +583,29 @@ def apply(self, module: Linear, input: torch.Tensor,
assert input.dtype == torch.bfloat16

if get_sm_version() == 100:
from tensorrt_llm import deep_gemm
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
output = torch.empty((input.shape[0], module.weight.shape[0]),
device=input.device,
dtype=torch.bfloat16)
deep_gemm.fp8_gemm_nt((a, a_sf),
(module.weight, module.weight_scale),
output,
disable_ue8m0_cast=True)
if module.use_cute_dsl_blockscaling_mm:
# TODO (@lmin): replace with cute_dsl gemm
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
input)
output = torch.ops.trtllm.fp8_block_scaling_gemm(
act_input_fp8, module.weight, act_input_sf,
module.weight_scale)
else:
from tensorrt_llm import deep_gemm
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
output = torch.empty((input.shape[0], module.weight.shape[0]),
device=input.device,
dtype=torch.bfloat16)
deep_gemm.fp8_gemm_nt((a, a_sf),
(module.weight, module.weight_scale),
output,
disable_ue8m0_cast=True)
else:
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
input)

output = torch.ops.trtllm.fp8_block_scaling_gemm(
act_input_fp8, module.weight, act_input_sf, module.weight_scale)

if bias is not None:
output = output + bias
return output
Expand Down Expand Up @@ -1488,6 +1496,7 @@ def __init__(
lora: Optional[LoraLayer] = None,
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO,
force_dynamic_quantization: bool = False,
use_cute_dsl_blockscaling_mm: bool = False,
):
from ..distributed import AllReduce

Expand All @@ -1504,6 +1513,7 @@ def __init__(
self.tp_mode = tensor_parallel_mode
self.gather_output = gather_output
self.force_dynamic_quantization = force_dynamic_quantization
self.use_cute_dsl_blockscaling_mm = use_cute_dsl_blockscaling_mm

local_in_features = in_features
local_out_features = out_features
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/evaluate/lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

try:
from lm_eval.api.model import TemplateLM
from lm_eval.tasks import TaskManager
except ImportError:
TemplateLM = object

Expand Down Expand Up @@ -147,7 +148,7 @@ def __init__(self,
self.dataset_path = dataset_path
self.num_samples = num_samples

task_manager = lm_eval.tasks.TaskManager(
task_manager = TaskManager(
include_path=f"{os.path.dirname(__file__)}/lm_eval_tasks")
with self._patch_lm_eval():
self.task_dict = lm_eval.tasks.get_task_dict(
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ def test_fp8_block_scales(self, mtp, fp8kv, attention_dp, cuda_graph,
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@skip_no_hopper
@skip_pre_blackwell
@parametrize_with_ids("torch_compile", [False])
@parametrize_with_ids(
"fp8kv,attention_dp,cuda_graph,overlap_scheduler",
Expand Down Expand Up @@ -1171,7 +1171,7 @@ def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
task.evaluate(llm)

@pytest.mark.skip_less_device(4)
@skip_no_hopper
@skip_pre_blackwell
@parametrize_with_ids("torch_compile", [False])
@parametrize_with_ids(
"fp8kv,attention_dp,cuda_graph,overlap_scheduler",
Expand Down
Loading