From 3270b8d3a8ba07c1b76dc8f19c5b39dbeb64923b Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Mon, 4 Aug 2025 21:39:47 -0700 Subject: [PATCH 1/8] make fused_moe_cute_dsl work on blackwell. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_cute_dsl.py | 15 +++- .../_torch/modules/fused_moe/quantization.py | 84 ++++++++++--------- tensorrt_llm/_torch/modules/gated_mlp.py | 9 +- tensorrt_llm/_torch/modules/linear.py | 4 +- .../defs/accuracy/test_llm_api_pytorch.py | 8 +- .../unittest/_torch/modules/test_fused_moe.py | 44 ++++------ 6 files changed, 89 insertions(+), 75 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py index 815dae64766..94ca1c294ea 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py @@ -9,7 +9,8 @@ from ...model_config import ModelConfig from ...utils import Fp4QuantizedTensor from .fused_moe_cutlass import CutlassFusedMoE -from .quantization import MoEWeightLoadingMode +from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodCuteDsl, + MoEWeightLoadingMode, UnquantizedFusedMoEMethod) from .routing import BaseMoeRoutingMethod @@ -139,6 +140,18 @@ def __init__( layer_idx=layer_idx, ) + def _get_quant_method(self): + if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( + exclude_kv_cache=True): + if self.quant_config.layer_quant_mode.has_fp8_block_scales(): + return DeepSeekFP8BlockScalesFusedMoEMethodCuteDsl() + else: + raise ValueError( + f"Unsupported quantization mode: {self.quant_config.quant_mode}" + ) + else: + return UnquantizedFusedMoEMethod() + def forward_chunk( self, x: Union[torch.Tensor, Fp4QuantizedTensor], diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 18e9c7cc98a..c9702a8aabd 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -430,7 +430,7 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): module.fc31_input_dequant.data.copy_(max_fc31_input_scale) -class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase): +class DeepSeekFP8BlockScalesFusedMoEMethodCuteDsl(FusedMoEMethodBase): def create_weights(self, module: torch.nn.Module): weight_dtype = torch.float8_e4m3fn @@ -468,45 +468,8 @@ def create_weights(self, module: torch.nn.Module): def load_weights(self, module: torch.nn.Module, weights: List[Dict], weight_loading_mode: MoEWeightLoadingMode): - - if get_sm_version() == 100: - expert_ids = set(module.initial_local_expert_ids) - if self.need_load_shared_weights(module): - expert_ids.update( - module.layer_load_balancer.get_load_expert_ids()) - for name in list(weights.keys()): - if name.endswith("weight_scale_inv"): - if int(name.split(".")[0]) not in expert_ids: - continue - weight_name = name.replace("weight_scale_inv", "weight") - logger.debug(f"Resmoothing {weight_name}") - weight = weights[weight_name][:] - scale = weights[name][:] - weights[weight_name], weights[name] = resmooth_to_fp8_e8m0( - weight, scale) super().load_weights(module, weights, weight_loading_mode) - if get_sm_version() == 100: - transfromed_w3_w1_scale = transform_sf_into_required_layout( - module.quant_scales[0], - mn=module.w3_w1_weight.shape[1], - k=module.w3_w1_weight.shape[2], - recipe=(1, 128, 128), - num_groups=module.w3_w1_weight.shape[0], - is_sfa=False) - module.w3_w1_weight_scaling_factor = nn.Parameter( - transfromed_w3_w1_scale, requires_grad=False) - transfromed_w2_scale = transform_sf_into_required_layout( - module.quant_scales[1], - mn=module.w2_weight.shape[1], - k=module.w2_weight.shape[2], - recipe=(1, 128, 128), - num_groups=module.w3_w1_weight.shape[0], - is_sfa=False) - module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale, - requires_grad=False) - self.setup_quant_scales(module) - def setup_quant_scales(self, module: torch.nn.Module): module.quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales( fc_weight_scales=module.w3_w1_weight_scaling_factor, @@ -590,6 +553,51 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): }) +class DeepSeekFP8BlockScalesFusedMoEMethod( + DeepSeekFP8BlockScalesFusedMoEMethodCuteDsl): + + def load_weights(self, module: torch.nn.Module, weights: List[Dict], + weight_loading_mode: MoEWeightLoadingMode): + print(f"DeepSeekFP8BlockScalesFusedMoEMethod load_weights") + if get_sm_version() == 100: + expert_ids = set(module.initial_local_expert_ids) + if self.need_load_shared_weights(module): + expert_ids.update( + module.layer_load_balancer.get_load_expert_ids()) + for name in list(weights.keys()): + if name.endswith("weight_scale_inv"): + if int(name.split(".")[0]) not in expert_ids: + continue + weight_name = name.replace("weight_scale_inv", "weight") + logger.debug(f"Resmoothing {weight_name}") + weight = weights[weight_name][:] + scale = weights[name][:] + weights[weight_name], weights[name] = resmooth_to_fp8_e8m0( + weight, scale) + super().load_weights(module, weights, weight_loading_mode) + + if get_sm_version() == 100: + transfromed_w3_w1_scale = transform_sf_into_required_layout( + module.quant_scales[0], + mn=module.w3_w1_weight.shape[1], + k=module.w3_w1_weight.shape[2], + recipe=(1, 128, 128), + num_groups=module.w3_w1_weight.shape[0], + is_sfa=False) + module.w3_w1_weight_scaling_factor = nn.Parameter( + transfromed_w3_w1_scale, requires_grad=False) + transfromed_w2_scale = transform_sf_into_required_layout( + module.quant_scales[1], + mn=module.w2_weight.shape[1], + k=module.w2_weight.shape[2], + recipe=(1, 128, 128), + num_groups=module.w3_w1_weight.shape[0], + is_sfa=False) + module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale, + requires_grad=False) + self.setup_quant_scales(module) + + class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase): def create_weights(self, module: torch.nn.Module): diff --git a/tensorrt_llm/_torch/modules/gated_mlp.py b/tensorrt_llm/_torch/modules/gated_mlp.py index 3f45ae80651..aae8db42184 100644 --- a/tensorrt_llm/_torch/modules/gated_mlp.py +++ b/tensorrt_llm/_torch/modules/gated_mlp.py @@ -37,7 +37,8 @@ def __init__(self, config: Optional[ModelConfig] = None, overridden_tp_size: Optional[int] = None, reduce_output: bool = True, - layer_idx: Optional[int] = None): + layer_idx: Optional[int] = None, + use_trtllmgen_mm: bool = False): super().__init__() self.layer_idx = layer_idx self.hidden_size = hidden_size @@ -74,7 +75,8 @@ def __init__(self, reduce_output=False, skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + use_trtllmgen_mm=use_trtllmgen_mm) self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H], [self.hidden_size]) @@ -91,7 +93,8 @@ def __init__(self, skip_create_weights_in_init=config.skip_create_weights_in_init, lora=self.down_lora, allreduce_strategy=config.allreduce_strategy, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + use_trtllmgen_mm=use_trtllmgen_mm) # These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used, # but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 9653a3530e5..95554f2f27f 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -572,7 +572,7 @@ def apply(self, module: Linear, input: torch.Tensor, input = input.to(torch.bfloat16) * module.input_scale assert input.dtype == torch.bfloat16 - if get_sm_version() == 100: + if get_sm_version() == 100 and not module.use_trtllmgen_mm: import deep_gemm a, a_sf = fp8_utils.per_token_quant_and_transform(input) output = torch.empty((input.shape[0], module.weight.shape[0]), @@ -1461,6 +1461,7 @@ def __init__( lora: Optional[LoraLayer] = None, allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO, force_dynamic_quantization: bool = False, + use_trtllmgen_mm: bool = False, ): from ..distributed import AllReduce @@ -1477,6 +1478,7 @@ def __init__( self.tp_mode = tensor_parallel_mode self.gather_output = gather_output self.force_dynamic_quantization = force_dynamic_quantization + self.use_trtllmgen_mm = use_trtllmgen_mm local_in_features = in_features local_out_features = out_features diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 276ad131217..1816860d529 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -985,7 +985,7 @@ def test_fp8_block_scales(self, mtp, fp8kv, attention_dp, cuda_graph, task = GSM8K(self.MODEL_NAME) task.evaluate(llm) - @skip_no_hopper + @skip_pre_blackwell @parametrize_with_ids("torch_compile", [False]) @parametrize_with_ids( "fp8kv,attention_dp,cuda_graph,overlap_scheduler", @@ -1012,7 +1012,7 @@ def test_cute_dsl_fp8_block_scales( max_num_streams=3) if torch_compile else None) pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, - use_cuda_graph=cuda_graph, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, torch_compile_config=torch_compile_config, moe_config=MoeConfig(backend="CUTEDSL"), ) @@ -1139,7 +1139,7 @@ def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn, task.evaluate(llm) @pytest.mark.skip_less_device(4) - @skip_no_hopper + @skip_pre_blackwell @parametrize_with_ids("torch_compile", [False]) @parametrize_with_ids( "fp8kv,attention_dp,cuda_graph,overlap_scheduler", @@ -1176,7 +1176,7 @@ def test_cute_dsl_fp8_block_scales_4gpus( max_num_streams=3) if torch_compile else None) pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, - use_cuda_graph=cuda_graph, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, torch_compile_config=torch_compile_config, moe_config=MoeConfig(backend="CUTEDSL"), ) diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index 51a7758d281..5f83df6dd9e 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -14,8 +14,7 @@ from mpi4py import MPI from mpi4py.futures import MPIPoolExecutor from utils.util import (skip_neither_ada_nor_hopper_unittest, - skip_non_hopper_unittest, skip_pre_blackwell, - skip_pre_hopper) + skip_pre_blackwell, skip_pre_hopper) from tensorrt_llm._torch.autotuner import AutoTuner, autotune from tensorrt_llm._torch.model_config import ModelConfig @@ -559,7 +558,7 @@ def grouped_gemm(a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor, torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) -@skip_non_hopper_unittest +@skip_pre_blackwell @pytest.mark.parametrize( "dtype, num_experts, seq_len, hidden_size, RoutingMethodCls", product( @@ -570,12 +569,12 @@ def grouped_gemm(a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor, [DefaultMoeRoutingMethod], ), ) -def test_fused_moe_fp8_blockwise(dtype, - num_experts, - seq_len, - hidden_size, - RoutingMethodCls, - mapping=None): +def test_fused_moe_fp8_blockwise_cute_dsl(dtype, + num_experts, + seq_len, + hidden_size, + RoutingMethodCls, + mapping=None): SEQ_LEN = seq_len HIDDEN_SIZE = hidden_size INTERMEDIATE_SIZE = 1536 @@ -647,18 +646,6 @@ def test_fused_moe_fp8_blockwise(dtype, fused_moe.cuda() fused_moe.load_weights([weights]) - fused_moe_origin = CutlassFusedMoE( - num_experts=NUM_EXPERTS, - routing_method=routing_method, - hidden_size=HIDDEN_SIZE, - intermediate_size=INTERMEDIATE_SIZE, - dtype=dtype, - reduce_results=True, - model_config=ModelConfig(quant_config=quant_config, mapping=mapping), - ) - fused_moe_origin.cuda() - fused_moe_origin.load_weights([weights]) - ref_fused_moe = RefGatedMLPFusedMoE( num_experts=NUM_EXPERTS, routing_method=routing_method, @@ -666,33 +653,32 @@ def test_fused_moe_fp8_blockwise(dtype, intermediate_size=INTERMEDIATE_SIZE, dtype=dtype, model_config=ModelConfig(quant_config=quant_config), + # Note: use deepgemm mm will cause accuracy error, so we use trtllmgen mm here + use_trtllmgen_mm=True, ) ref_fused_moe.load_weights([weights]) ref_fused_moe.cuda() with torch.inference_mode(): output = fused_moe.forward(x, router_logits) - output_origin = fused_moe_origin.forward(x, router_logits) ref_output = ref_fused_moe.forward(x, router_logits) # compare torch.cuda.synchronize() - torch.testing.assert_close(output_origin, output, rtol=1e-2, atol=0.1) - torch.testing.assert_close(output_origin, ref_output, rtol=1e-2, atol=0.1) torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) return True -@skip_non_hopper_unittest +@skip_pre_blackwell @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="needs 4 GPUs to run this test") @pytest.mark.parametrize("ep_size", [1, 2, 4]) @pytest.mark.parametrize("routing_method", [DefaultMoeRoutingMethod]) -def test_fused_moe_fp8_blockwise_multi_gpu(ep_size, routing_method): +def test_fused_moe_fp8_blockwise_cute_dsl_multi_gpu(ep_size, routing_method): world_size = 4 with MPIPoolExecutor(max_workers=world_size) as executor: results = executor.map( - test_fused_moe_fp8_blockwise, + test_fused_moe_fp8_blockwise_cute_dsl, *zip(*[( torch.bfloat16, 72, @@ -966,7 +952,8 @@ def __init__(self, hidden_size: int, intermediate_size: int, dtype: Optional[torch.dtype] = None, - model_config: ModelConfig = ModelConfig()): + model_config: ModelConfig = ModelConfig(), + use_trtllmgen_mm: bool = False): super().__init__() self.num_experts = num_experts self.routing_method = routing_method @@ -983,6 +970,7 @@ def __init__(self, bias=False, dtype=self.dtype, config=model_config, + use_trtllmgen_mm=use_trtllmgen_mm, ) for _ in range(self.num_experts) ]) From 66e4a7d2c047a289f20c3e482ade194a251426bf Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Mon, 4 Aug 2025 22:13:17 -0700 Subject: [PATCH 2/8] rename. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> --- tensorrt_llm/_torch/modules/gated_mlp.py | 6 ++-- tensorrt_llm/_torch/modules/linear.py | 34 ++++++++++++------- tensorrt_llm/evaluate/lm_eval.py | 3 +- .../unittest/_torch/modules/test_fused_moe.py | 6 ++-- 4 files changed, 29 insertions(+), 20 deletions(-) diff --git a/tensorrt_llm/_torch/modules/gated_mlp.py b/tensorrt_llm/_torch/modules/gated_mlp.py index aae8db42184..1991db843a4 100644 --- a/tensorrt_llm/_torch/modules/gated_mlp.py +++ b/tensorrt_llm/_torch/modules/gated_mlp.py @@ -38,7 +38,7 @@ def __init__(self, overridden_tp_size: Optional[int] = None, reduce_output: bool = True, layer_idx: Optional[int] = None, - use_trtllmgen_mm: bool = False): + use_cute_dsl_blockscaling_mm: bool = False): super().__init__() self.layer_idx = layer_idx self.hidden_size = hidden_size @@ -76,7 +76,7 @@ def __init__(self, skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization, - use_trtllmgen_mm=use_trtllmgen_mm) + use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm) self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H], [self.hidden_size]) @@ -94,7 +94,7 @@ def __init__(self, lora=self.down_lora, allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization, - use_trtllmgen_mm=use_trtllmgen_mm) + use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm) # These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used, # but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 95554f2f27f..72afd6d9c84 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -572,22 +572,30 @@ def apply(self, module: Linear, input: torch.Tensor, input = input.to(torch.bfloat16) * module.input_scale assert input.dtype == torch.bfloat16 - if get_sm_version() == 100 and not module.use_trtllmgen_mm: - import deep_gemm - a, a_sf = fp8_utils.per_token_quant_and_transform(input) - output = torch.empty((input.shape[0], module.weight.shape[0]), - device=input.device, - dtype=torch.bfloat16) - deep_gemm.fp8_gemm_nt((a, a_sf), - (module.weight, module.weight_scale), - output, - disable_ue8m0_cast=True) + if get_sm_version() == 100: + if module.use_cute_dsl_blockscaling_mm: + # TODO (@lmin): replace with cute_dsl gemm + act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128( + input) + output = torch.ops.trtllm.fp8_block_scaling_gemm( + act_input_fp8, module.weight, act_input_sf, + module.weight_scale) + else: + import deep_gemm + a, a_sf = fp8_utils.per_token_quant_and_transform(input) + output = torch.empty((input.shape[0], module.weight.shape[0]), + device=input.device, + dtype=torch.bfloat16) + deep_gemm.fp8_gemm_nt((a, a_sf), + (module.weight, module.weight_scale), + output, + disable_ue8m0_cast=True) else: act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128( input) - output = torch.ops.trtllm.fp8_block_scaling_gemm( act_input_fp8, module.weight, act_input_sf, module.weight_scale) + if bias is not None: output = output + bias return output @@ -1461,7 +1469,7 @@ def __init__( lora: Optional[LoraLayer] = None, allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO, force_dynamic_quantization: bool = False, - use_trtllmgen_mm: bool = False, + use_cute_dsl_blockscaling_mm: bool = False, ): from ..distributed import AllReduce @@ -1478,7 +1486,7 @@ def __init__( self.tp_mode = tensor_parallel_mode self.gather_output = gather_output self.force_dynamic_quantization = force_dynamic_quantization - self.use_trtllmgen_mm = use_trtllmgen_mm + self.use_cute_dsl_blockscaling_mm = use_cute_dsl_blockscaling_mm local_in_features = in_features local_out_features = out_features diff --git a/tensorrt_llm/evaluate/lm_eval.py b/tensorrt_llm/evaluate/lm_eval.py index bdddbcbb736..6a24e07f79a 100644 --- a/tensorrt_llm/evaluate/lm_eval.py +++ b/tensorrt_llm/evaluate/lm_eval.py @@ -25,6 +25,7 @@ try: from lm_eval.api.model import TemplateLM + from lm_eval.tasks import TaskManager except ImportError: TemplateLM = object @@ -147,7 +148,7 @@ def __init__(self, self.dataset_path = dataset_path self.num_samples = num_samples - task_manager = lm_eval.tasks.TaskManager( + task_manager = TaskManager( include_path=f"{os.path.dirname(__file__)}/lm_eval_tasks") with self._patch_lm_eval(): self.task_dict = lm_eval.tasks.get_task_dict( diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index 5f83df6dd9e..dd3fbd978cd 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -654,7 +654,7 @@ def test_fused_moe_fp8_blockwise_cute_dsl(dtype, dtype=dtype, model_config=ModelConfig(quant_config=quant_config), # Note: use deepgemm mm will cause accuracy error, so we use trtllmgen mm here - use_trtllmgen_mm=True, + use_cute_dsl_blockscaling_mm=True, ) ref_fused_moe.load_weights([weights]) ref_fused_moe.cuda() @@ -953,7 +953,7 @@ def __init__(self, intermediate_size: int, dtype: Optional[torch.dtype] = None, model_config: ModelConfig = ModelConfig(), - use_trtllmgen_mm: bool = False): + use_cute_dsl_blockscaling_mm: bool = False): super().__init__() self.num_experts = num_experts self.routing_method = routing_method @@ -970,7 +970,7 @@ def __init__(self, bias=False, dtype=self.dtype, config=model_config, - use_trtllmgen_mm=use_trtllmgen_mm, + use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm, ) for _ in range(self.num_experts) ]) From 9cd1f27b15c4c61c805bf72d2682034bff703527 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Mon, 4 Aug 2025 22:58:24 -0700 Subject: [PATCH 3/8] remove print. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> --- tensorrt_llm/_torch/modules/fused_moe/quantization.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index c9702a8aabd..04b920c5a59 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -558,7 +558,6 @@ class DeepSeekFP8BlockScalesFusedMoEMethod( def load_weights(self, module: torch.nn.Module, weights: List[Dict], weight_loading_mode: MoEWeightLoadingMode): - print(f"DeepSeekFP8BlockScalesFusedMoEMethod load_weights") if get_sm_version() == 100: expert_ids = set(module.initial_local_expert_ids) if self.need_load_shared_weights(module): From d80cabec90697e70848d980a8369f6ac5a42deb6 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 5 Aug 2025 00:32:43 -0700 Subject: [PATCH 4/8] refactor moe loading logics. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_deepseekv3.py | 3 +++ .../modules/fused_moe/fused_moe_cute_dsl.py | 15 +-------------- .../modules/fused_moe/fused_moe_deepgemm.py | 15 ++++++++++++++- .../_torch/modules/fused_moe/quantization.py | 6 +++--- 4 files changed, 21 insertions(+), 18 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 66ea5e3a0eb..9b6fb1d3cbc 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -1344,6 +1344,9 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, params_map = {'gate_up_proj': ['gate_proj', 'up_proj']} all_named_modules = dict(self.named_modules()) + # moe_backend: cute_dsl_group_gemm + # use_cute_dsl_gemm, use_cute_dsl_bmm; use_cute_dsl + # attention/mla, gated_mlp, linear if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales( ) and get_sm_version() == 100: for name in list(weights.keys()): diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py index 94ca1c294ea..815dae64766 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py @@ -9,8 +9,7 @@ from ...model_config import ModelConfig from ...utils import Fp4QuantizedTensor from .fused_moe_cutlass import CutlassFusedMoE -from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodCuteDsl, - MoEWeightLoadingMode, UnquantizedFusedMoEMethod) +from .quantization import MoEWeightLoadingMode from .routing import BaseMoeRoutingMethod @@ -140,18 +139,6 @@ def __init__( layer_idx=layer_idx, ) - def _get_quant_method(self): - if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( - exclude_kv_cache=True): - if self.quant_config.layer_quant_mode.has_fp8_block_scales(): - return DeepSeekFP8BlockScalesFusedMoEMethodCuteDsl() - else: - raise ValueError( - f"Unsupported quantization mode: {self.quant_config.quant_mode}" - ) - else: - return UnquantizedFusedMoEMethod() - def forward_chunk( self, x: Union[torch.Tensor, Fp4QuantizedTensor], diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 3721a5d2afd..9f07975ab8f 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -13,7 +13,8 @@ from ...model_config import ModelConfig from ...utils import Fp4QuantizedTensor from .fused_moe_cutlass import CutlassFusedMoE -from .quantization import MoEWeightLoadingMode +from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm, + MoEWeightLoadingMode, UnquantizedFusedMoEMethod) from .routing import BaseMoeRoutingMethod @@ -340,6 +341,18 @@ def __init__( layer_idx=layer_idx, ) + def _get_quant_method(self): + if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( + exclude_kv_cache=True): + if self.quant_config.layer_quant_mode.has_fp8_block_scales(): + return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm() + else: + raise ValueError( + f"Unsupported quantization mode: {self.quant_config.quant_mode}" + ) + else: + return UnquantizedFusedMoEMethod() + @nvtx_range("[DG] forward") def forward_chunk( self, diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 04b920c5a59..0b653b32153 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -430,7 +430,7 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): module.fc31_input_dequant.data.copy_(max_fc31_input_scale) -class DeepSeekFP8BlockScalesFusedMoEMethodCuteDsl(FusedMoEMethodBase): +class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase): def create_weights(self, module: torch.nn.Module): weight_dtype = torch.float8_e4m3fn @@ -553,8 +553,8 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): }) -class DeepSeekFP8BlockScalesFusedMoEMethod( - DeepSeekFP8BlockScalesFusedMoEMethodCuteDsl): +class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm( + DeepSeekFP8BlockScalesFusedMoEMethod): def load_weights(self, module: torch.nn.Module, weights: List[Dict], weight_loading_mode: MoEWeightLoadingMode): From 0a74c9961558504608e4a16280139d4e15ca067b Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 5 Aug 2025 00:35:34 -0700 Subject: [PATCH 5/8] minor. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_deepseekv3.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 9b6fb1d3cbc..66ea5e3a0eb 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -1344,9 +1344,6 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, params_map = {'gate_up_proj': ['gate_proj', 'up_proj']} all_named_modules = dict(self.named_modules()) - # moe_backend: cute_dsl_group_gemm - # use_cute_dsl_gemm, use_cute_dsl_bmm; use_cute_dsl - # attention/mla, gated_mlp, linear if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales( ) and get_sm_version() == 100: for name in list(weights.keys()): From 99b56a7a0a82cef96bafadef3893e812c12467b8 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 5 Aug 2025 19:57:30 -0700 Subject: [PATCH 6/8] do not delete fused_moe_cutlass hopper test Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> --- .../unittest/_torch/modules/test_fused_moe.py | 137 +++++++++++++++++- 1 file changed, 136 insertions(+), 1 deletion(-) diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index 0298247a3ca..b305b68a3c6 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -14,7 +14,8 @@ from mpi4py import MPI from mpi4py.futures import MPIPoolExecutor from utils.util import (skip_neither_ada_nor_hopper_unittest, - skip_pre_blackwell, skip_pre_hopper) + skip_non_hopper_unittest, skip_pre_blackwell, + skip_pre_hopper) from tensorrt_llm._torch.autotuner import AutoTuner, autotune from tensorrt_llm._torch.model_config import ModelConfig @@ -693,6 +694,140 @@ def test_fused_moe_fp8_blockwise_cute_dsl(dtype, return True +@skip_non_hopper_unittest +@pytest.mark.parametrize( + "dtype, num_experts, seq_len, hidden_size, RoutingMethodCls, WeightLoadingMode", + product( + [torch.bfloat16], + [72], + [128, 256, 384, 512, 1024, 2048, 4096, 8192], + [2560], + [DefaultMoeRoutingMethod], + [MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ], + ), +) +def test_fused_moe_fp8_blockwise_cutlass(dtype, + num_experts, + seq_len, + hidden_size, + RoutingMethodCls, + WeightLoadingMode, + mapping=None): + SEQ_LEN = seq_len + HIDDEN_SIZE = hidden_size + INTERMEDIATE_SIZE = 1536 + NUM_EXPERTS = num_experts + TOP_K = 6 + + routing_method = RoutingMethodCls(top_k=TOP_K) + + mapping = mapping or Mapping() + mapping.rank = mpi_rank() + torch.cuda.set_device(mapping.rank) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype, device="cuda") + # Note: we use some special values init x and weight, otherwise the test will false positive failed. + set_tensor_value_2(x, SEQ_LEN, HIDDEN_SIZE) + + x = x.cuda() + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), + dtype=dtype, + device="cuda") + + weights = {} + + if WeightLoadingMode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + weights['gate_up_proj'] = {} + weights['down_proj'] = {} + weights['gate_up_proj_weight_scale'] = {} + weights['down_proj_weight_scale'] = {} + + for expert_id in range(NUM_EXPERTS): + w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype, + device="cuda") + w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE), + dtype=dtype, + device="cuda") + w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE), + dtype=dtype, + device="cuda") + set_tensor_value_3(w1_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE) + set_tensor_value_4(w2_weight, HIDDEN_SIZE, INTERMEDIATE_SIZE) + set_tensor_value_3(w3_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE) + + w1_weight_fp8, w1_weight_scale = per_block_cast_to_fp8(w1_weight) + w1_weight_fp8 = w1_weight_fp8.view(torch.float8_e4m3fn).cuda() + + w2_weight_fp8, w2_weight_scale = per_block_cast_to_fp8(w2_weight) + w2_weight_fp8 = w2_weight_fp8.view(torch.float8_e4m3fn).cuda() + + w3_weight_fp8, w3_weight_scale = per_block_cast_to_fp8(w3_weight) + w3_weight_fp8 = w3_weight_fp8.view(torch.float8_e4m3fn).cuda() + + weights[f"{expert_id}.w1.weight"] = w1_weight_fp8 + weights[f"{expert_id}.w2.weight"] = w2_weight_fp8 + weights[f"{expert_id}.w3.weight"] = w3_weight_fp8 + weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale + weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale + weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale + + if WeightLoadingMode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + weights['gate_up_proj'][expert_id] = torch.cat( + [w3_weight_fp8, w1_weight_fp8], + dim=-2).transpose(0, 1).contiguous() + weights['down_proj'][expert_id] = w2_weight_fp8.transpose( + 0, 1).contiguous() + weights['gate_up_proj_weight_scale'][expert_id] = torch.cat( + [w3_weight_scale, w1_weight_scale], + dim=-2).transpose(0, 1).contiguous() + weights['down_proj_weight_scale'][ + expert_id] = w2_weight_scale.transpose(0, 1).contiguous() + elif WeightLoadingMode == MoEWeightLoadingMode.VANILLA: + weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale + weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale + weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale + + quant_config = QuantConfig(quant_algo=QuantAlgo.FP8_BLOCK_SCALES) + + fused_moe = CutlassFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=True, + model_config=ModelConfig(quant_config=quant_config, mapping=mapping), + weight_loading_mode=WeightLoadingMode, + ) + fused_moe.cuda() + fused_moe.load_weights([weights]) + + ref_fused_moe = RefGatedMLPFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + model_config=ModelConfig(quant_config=quant_config), + # Note: use deepgemm mm will cause accuracy error, so we use trtllmgen mm here + use_cute_dsl_blockscaling_mm=True, + ) + ref_fused_moe.load_weights([weights]) + ref_fused_moe.cuda() + + with torch.inference_mode(): + output = fused_moe.forward(x, router_logits) + ref_output = ref_fused_moe.forward(x, router_logits) + + # compare + torch.cuda.synchronize() + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) + return True + + @skip_pre_blackwell @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="needs 4 GPUs to run this test") From b59a9adfadf1288801917cd8eff42de4631b66fa Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 5 Aug 2025 20:08:11 -0700 Subject: [PATCH 7/8] recover fused_moe_cutlass test on hopper. Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> --- .../unittest/_torch/modules/test_fused_moe.py | 41 ++++++++++++++++++- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index b305b68a3c6..d6cc9853d89 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -828,12 +828,49 @@ def test_fused_moe_fp8_blockwise_cutlass(dtype, return True +@skip_non_hopper_unittest +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="needs 4 GPUs to run this test") +@pytest.mark.parametrize("ep_size", [1, 2, 4]) +@pytest.mark.parametrize("routing_method", [DefaultMoeRoutingMethod]) +@pytest.mark.parametrize( + "weight_loading_mode", + [MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ]) +def test_fused_moe_fp8_blockwise_cutlass_multi_gpu(ep_size, routing_method, + weight_loading_mode): + world_size = 4 + with MPIPoolExecutor(max_workers=world_size) as executor: + results = executor.map( + test_fused_moe_fp8_blockwise_cutlass, + *zip(*[( + torch.bfloat16, + 72, + 384, + 384, + routing_method, + weight_loading_mode, + Mapping( + world_size=world_size, + tp_size=world_size, + moe_ep_size=ep_size, + moe_tp_size=world_size // ep_size, + ), + )] * world_size), + ) + for r in results: + assert r is True + + @skip_pre_blackwell @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="needs 4 GPUs to run this test") @pytest.mark.parametrize("ep_size", [1, 2, 4]) @pytest.mark.parametrize("routing_method", [DefaultMoeRoutingMethod]) -def test_fused_moe_fp8_blockwise_cute_dsl_multi_gpu(ep_size, routing_method): +@pytest.mark.parametrize( + "weight_loading_mode", + [MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ]) +def test_fused_moe_fp8_blockwise_cute_dsl_multi_gpu(ep_size, routing_method, + weight_loading_mode): world_size = 4 with MPIPoolExecutor(max_workers=world_size) as executor: results = executor.map( @@ -844,7 +881,7 @@ def test_fused_moe_fp8_blockwise_cute_dsl_multi_gpu(ep_size, routing_method): 384, 384, routing_method, - MoEWeightLoadingMode.VANILLA, + weight_loading_mode, Mapping( world_size=world_size, tp_size=world_size, From eddf7eac7aa2d7767359bf69e734686bc6afbab2 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 5 Aug 2025 20:20:49 -0700 Subject: [PATCH 8/8] minor Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> --- tests/unittest/_torch/modules/test_fused_moe.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index d6cc9853d89..6c7c408d616 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -812,8 +812,6 @@ def test_fused_moe_fp8_blockwise_cutlass(dtype, intermediate_size=INTERMEDIATE_SIZE, dtype=dtype, model_config=ModelConfig(quant_config=quant_config), - # Note: use deepgemm mm will cause accuracy error, so we use trtllmgen mm here - use_cute_dsl_blockscaling_mm=True, ) ref_fused_moe.load_weights([weights]) ref_fused_moe.cuda()