Skip to content

Commit d913955

Browse files
authored
[TRTLLM-6898][feat] make fused_moe_cute_dsl work on blackwell (#6616)
Signed-off-by: Mindy Li <[email protected]>
1 parent 9687bb4 commit d913955

File tree

7 files changed

+262
-70
lines changed

7 files changed

+262
-70
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from ...model_config import ModelConfig
1414
from ...utils import Fp4QuantizedTensor
1515
from .fused_moe_cutlass import CutlassFusedMoE
16-
from .quantization import MoEWeightLoadingMode
16+
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
17+
MoEWeightLoadingMode, UnquantizedFusedMoEMethod)
1718
from .routing import BaseMoeRoutingMethod
1819

1920

@@ -340,6 +341,18 @@ def __init__(
340341
layer_idx=layer_idx,
341342
)
342343

344+
def _get_quant_method(self):
345+
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
346+
exclude_kv_cache=True):
347+
if self.quant_config.layer_quant_mode.has_fp8_block_scales():
348+
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm()
349+
else:
350+
raise ValueError(
351+
f"Unsupported quantization mode: {self.quant_config.quant_mode}"
352+
)
353+
else:
354+
return UnquantizedFusedMoEMethod()
355+
343356
@nvtx_range("[DG] forward")
344357
def forward_chunk(
345358
self,

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -629,45 +629,8 @@ def create_weights(self, module: torch.nn.Module):
629629

630630
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
631631
weight_loading_mode: MoEWeightLoadingMode):
632-
633-
if get_sm_version() == 100:
634-
expert_ids = set(module.initial_local_expert_ids)
635-
if self.need_load_shared_weights(module):
636-
expert_ids.update(
637-
module.layer_load_balancer.get_load_expert_ids())
638-
for name in list(weights.keys()):
639-
if name.endswith("weight_scale_inv"):
640-
if int(name.split(".")[0]) not in expert_ids:
641-
continue
642-
weight_name = name.replace("weight_scale_inv", "weight")
643-
logger.debug(f"Resmoothing {weight_name}")
644-
weight = weights[weight_name][:]
645-
scale = weights[name][:]
646-
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
647-
weight, scale)
648632
super().load_weights(module, weights, weight_loading_mode)
649633

650-
if get_sm_version() == 100:
651-
transfromed_w3_w1_scale = transform_sf_into_required_layout(
652-
module.quant_scales[0],
653-
mn=module.w3_w1_weight.shape[1],
654-
k=module.w3_w1_weight.shape[2],
655-
recipe=(1, 128, 128),
656-
num_groups=module.w3_w1_weight.shape[0],
657-
is_sfa=False)
658-
module.w3_w1_weight_scaling_factor = nn.Parameter(
659-
transfromed_w3_w1_scale, requires_grad=False)
660-
transfromed_w2_scale = transform_sf_into_required_layout(
661-
module.quant_scales[1],
662-
mn=module.w2_weight.shape[1],
663-
k=module.w2_weight.shape[2],
664-
recipe=(1, 128, 128),
665-
num_groups=module.w3_w1_weight.shape[0],
666-
is_sfa=False)
667-
module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale,
668-
requires_grad=False)
669-
self.setup_quant_scales(module)
670-
671634
def setup_quant_scales(self, module: torch.nn.Module):
672635
module.quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales(
673636
fc_weight_scales=module.w3_w1_weight_scaling_factor,
@@ -765,6 +728,50 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
765728
})
766729

767730

731+
class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
732+
DeepSeekFP8BlockScalesFusedMoEMethod):
733+
734+
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
735+
weight_loading_mode: MoEWeightLoadingMode):
736+
if get_sm_version() == 100:
737+
expert_ids = set(module.initial_local_expert_ids)
738+
if self.need_load_shared_weights(module):
739+
expert_ids.update(
740+
module.layer_load_balancer.get_load_expert_ids())
741+
for name in list(weights.keys()):
742+
if name.endswith("weight_scale_inv"):
743+
if int(name.split(".")[0]) not in expert_ids:
744+
continue
745+
weight_name = name.replace("weight_scale_inv", "weight")
746+
logger.debug(f"Resmoothing {weight_name}")
747+
weight = weights[weight_name][:]
748+
scale = weights[name][:]
749+
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
750+
weight, scale)
751+
super().load_weights(module, weights, weight_loading_mode)
752+
753+
if get_sm_version() == 100:
754+
transfromed_w3_w1_scale = transform_sf_into_required_layout(
755+
module.quant_scales[0],
756+
mn=module.w3_w1_weight.shape[1],
757+
k=module.w3_w1_weight.shape[2],
758+
recipe=(1, 128, 128),
759+
num_groups=module.w3_w1_weight.shape[0],
760+
is_sfa=False)
761+
module.w3_w1_weight_scaling_factor = nn.Parameter(
762+
transfromed_w3_w1_scale, requires_grad=False)
763+
transfromed_w2_scale = transform_sf_into_required_layout(
764+
module.quant_scales[1],
765+
mn=module.w2_weight.shape[1],
766+
k=module.w2_weight.shape[2],
767+
recipe=(1, 128, 128),
768+
num_groups=module.w3_w1_weight.shape[0],
769+
is_sfa=False)
770+
module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale,
771+
requires_grad=False)
772+
self.setup_quant_scales(module)
773+
774+
768775
class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
769776

770777
def create_weights(self, module: torch.nn.Module):

tensorrt_llm/_torch/modules/gated_mlp.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ def __init__(self,
2727
config: Optional[ModelConfig] = None,
2828
overridden_tp_size: Optional[int] = None,
2929
reduce_output: bool = True,
30-
layer_idx: Optional[int] = None):
30+
layer_idx: Optional[int] = None,
31+
use_cute_dsl_blockscaling_mm: bool = False):
3132
super().__init__()
3233
self.layer_idx = layer_idx
3334
self.hidden_size = hidden_size
@@ -64,7 +65,8 @@ def __init__(self,
6465
reduce_output=False,
6566
skip_create_weights_in_init=config.skip_create_weights_in_init,
6667
allreduce_strategy=config.allreduce_strategy,
67-
force_dynamic_quantization=config.force_dynamic_quantization)
68+
force_dynamic_quantization=config.force_dynamic_quantization,
69+
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm)
6870

6971
self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
7072
[self.hidden_size])
@@ -81,7 +83,8 @@ def __init__(self,
8183
skip_create_weights_in_init=config.skip_create_weights_in_init,
8284
lora=self.down_lora,
8385
allreduce_strategy=config.allreduce_strategy,
84-
force_dynamic_quantization=config.force_dynamic_quantization)
86+
force_dynamic_quantization=config.force_dynamic_quantization,
87+
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm)
8588

8689
# These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used,
8790
# but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora

tensorrt_llm/_torch/modules/linear.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -583,21 +583,29 @@ def apply(self, module: Linear, input: torch.Tensor,
583583
assert input.dtype == torch.bfloat16
584584

585585
if get_sm_version() == 100:
586-
from tensorrt_llm import deep_gemm
587-
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
588-
output = torch.empty((input.shape[0], module.weight.shape[0]),
589-
device=input.device,
590-
dtype=torch.bfloat16)
591-
deep_gemm.fp8_gemm_nt((a, a_sf),
592-
(module.weight, module.weight_scale),
593-
output,
594-
disable_ue8m0_cast=True)
586+
if module.use_cute_dsl_blockscaling_mm:
587+
# TODO (@lmin): replace with cute_dsl gemm
588+
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
589+
input)
590+
output = torch.ops.trtllm.fp8_block_scaling_gemm(
591+
act_input_fp8, module.weight, act_input_sf,
592+
module.weight_scale)
593+
else:
594+
from tensorrt_llm import deep_gemm
595+
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
596+
output = torch.empty((input.shape[0], module.weight.shape[0]),
597+
device=input.device,
598+
dtype=torch.bfloat16)
599+
deep_gemm.fp8_gemm_nt((a, a_sf),
600+
(module.weight, module.weight_scale),
601+
output,
602+
disable_ue8m0_cast=True)
595603
else:
596604
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
597605
input)
598-
599606
output = torch.ops.trtllm.fp8_block_scaling_gemm(
600607
act_input_fp8, module.weight, act_input_sf, module.weight_scale)
608+
601609
if bias is not None:
602610
output = output + bias
603611
return output
@@ -1488,6 +1496,7 @@ def __init__(
14881496
lora: Optional[LoraLayer] = None,
14891497
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO,
14901498
force_dynamic_quantization: bool = False,
1499+
use_cute_dsl_blockscaling_mm: bool = False,
14911500
):
14921501
from ..distributed import AllReduce
14931502

@@ -1504,6 +1513,7 @@ def __init__(
15041513
self.tp_mode = tensor_parallel_mode
15051514
self.gather_output = gather_output
15061515
self.force_dynamic_quantization = force_dynamic_quantization
1516+
self.use_cute_dsl_blockscaling_mm = use_cute_dsl_blockscaling_mm
15071517

15081518
local_in_features = in_features
15091519
local_out_features = out_features

tensorrt_llm/evaluate/lm_eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
try:
2727
from lm_eval.api.model import TemplateLM
28+
from lm_eval.tasks import TaskManager
2829
except ImportError:
2930
TemplateLM = object
3031

@@ -147,7 +148,7 @@ def __init__(self,
147148
self.dataset_path = dataset_path
148149
self.num_samples = num_samples
149150

150-
task_manager = lm_eval.tasks.TaskManager(
151+
task_manager = TaskManager(
151152
include_path=f"{os.path.dirname(__file__)}/lm_eval_tasks")
152153
with self._patch_lm_eval():
153154
self.task_dict = lm_eval.tasks.get_task_dict(

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,7 @@ def test_fp8_block_scales(self, mtp, fp8kv, attention_dp, cuda_graph,
10201020
task = GSM8K(self.MODEL_NAME)
10211021
task.evaluate(llm)
10221022

1023-
@skip_no_hopper
1023+
@skip_pre_blackwell
10241024
@parametrize_with_ids("torch_compile", [False])
10251025
@parametrize_with_ids(
10261026
"fp8kv,attention_dp,cuda_graph,overlap_scheduler",
@@ -1170,7 +1170,7 @@ def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
11701170
task.evaluate(llm)
11711171

11721172
@pytest.mark.skip_less_device(4)
1173-
@skip_no_hopper
1173+
@skip_pre_blackwell
11741174
@parametrize_with_ids("torch_compile", [False])
11751175
@parametrize_with_ids(
11761176
"fp8kv,attention_dp,cuda_graph,overlap_scheduler",

0 commit comments

Comments
 (0)