Skip to content
Merged
15 changes: 14 additions & 1 deletion tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from ...model_config import ModelConfig
from ...utils import Fp4QuantizedTensor
from .fused_moe_cutlass import CutlassFusedMoE
from .quantization import MoEWeightLoadingMode
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodCuteDsl,
MoEWeightLoadingMode, UnquantizedFusedMoEMethod)
from .routing import BaseMoeRoutingMethod


Expand Down Expand Up @@ -139,6 +140,18 @@ def __init__(
layer_idx=layer_idx,
)

def _get_quant_method(self):
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True):
if self.quant_config.layer_quant_mode.has_fp8_block_scales():
return DeepSeekFP8BlockScalesFusedMoEMethodCuteDsl()
else:
raise ValueError(
f"Unsupported quantization mode: {self.quant_config.quant_mode}"
)
else:
return UnquantizedFusedMoEMethod()

def forward_chunk(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
Expand Down
84 changes: 46 additions & 38 deletions tensorrt_llm/_torch/modules/fused_moe/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
module.fc31_input_dequant.data.copy_(max_fc31_input_scale)


class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
class DeepSeekFP8BlockScalesFusedMoEMethodCuteDsl(FusedMoEMethodBase):

def create_weights(self, module: torch.nn.Module):
weight_dtype = torch.float8_e4m3fn
Expand Down Expand Up @@ -468,45 +468,8 @@ def create_weights(self, module: torch.nn.Module):

def load_weights(self, module: torch.nn.Module, weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode):

if get_sm_version() == 100:
expert_ids = set(module.initial_local_expert_ids)
if self.need_load_shared_weights(module):
expert_ids.update(
module.layer_load_balancer.get_load_expert_ids())
for name in list(weights.keys()):
if name.endswith("weight_scale_inv"):
if int(name.split(".")[0]) not in expert_ids:
continue
weight_name = name.replace("weight_scale_inv", "weight")
logger.debug(f"Resmoothing {weight_name}")
weight = weights[weight_name][:]
scale = weights[name][:]
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
weight, scale)
super().load_weights(module, weights, weight_loading_mode)

if get_sm_version() == 100:
transfromed_w3_w1_scale = transform_sf_into_required_layout(
module.quant_scales[0],
mn=module.w3_w1_weight.shape[1],
k=module.w3_w1_weight.shape[2],
recipe=(1, 128, 128),
num_groups=module.w3_w1_weight.shape[0],
is_sfa=False)
module.w3_w1_weight_scaling_factor = nn.Parameter(
transfromed_w3_w1_scale, requires_grad=False)
transfromed_w2_scale = transform_sf_into_required_layout(
module.quant_scales[1],
mn=module.w2_weight.shape[1],
k=module.w2_weight.shape[2],
recipe=(1, 128, 128),
num_groups=module.w3_w1_weight.shape[0],
is_sfa=False)
module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale,
requires_grad=False)
self.setup_quant_scales(module)

def setup_quant_scales(self, module: torch.nn.Module):
module.quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales(
fc_weight_scales=module.w3_w1_weight_scaling_factor,
Expand Down Expand Up @@ -590,6 +553,51 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
})


class DeepSeekFP8BlockScalesFusedMoEMethod(
DeepSeekFP8BlockScalesFusedMoEMethodCuteDsl):

def load_weights(self, module: torch.nn.Module, weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode):
print(f"DeepSeekFP8BlockScalesFusedMoEMethod load_weights")
if get_sm_version() == 100:
expert_ids = set(module.initial_local_expert_ids)
if self.need_load_shared_weights(module):
expert_ids.update(
module.layer_load_balancer.get_load_expert_ids())
for name in list(weights.keys()):
if name.endswith("weight_scale_inv"):
if int(name.split(".")[0]) not in expert_ids:
continue
weight_name = name.replace("weight_scale_inv", "weight")
logger.debug(f"Resmoothing {weight_name}")
weight = weights[weight_name][:]
scale = weights[name][:]
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
weight, scale)
super().load_weights(module, weights, weight_loading_mode)

if get_sm_version() == 100:
transfromed_w3_w1_scale = transform_sf_into_required_layout(
module.quant_scales[0],
mn=module.w3_w1_weight.shape[1],
k=module.w3_w1_weight.shape[2],
recipe=(1, 128, 128),
num_groups=module.w3_w1_weight.shape[0],
is_sfa=False)
module.w3_w1_weight_scaling_factor = nn.Parameter(
transfromed_w3_w1_scale, requires_grad=False)
transfromed_w2_scale = transform_sf_into_required_layout(
module.quant_scales[1],
mn=module.w2_weight.shape[1],
k=module.w2_weight.shape[2],
recipe=(1, 128, 128),
num_groups=module.w3_w1_weight.shape[0],
is_sfa=False)
module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale,
requires_grad=False)
self.setup_quant_scales(module)


class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):

def create_weights(self, module: torch.nn.Module):
Expand Down
9 changes: 6 additions & 3 deletions tensorrt_llm/_torch/modules/gated_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __init__(self,
config: Optional[ModelConfig] = None,
overridden_tp_size: Optional[int] = None,
reduce_output: bool = True,
layer_idx: Optional[int] = None):
layer_idx: Optional[int] = None,
use_trtllmgen_mm: bool = False):
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = hidden_size
Expand Down Expand Up @@ -74,7 +75,8 @@ def __init__(self,
reduce_output=False,
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
use_trtllmgen_mm=use_trtllmgen_mm)

self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
[self.hidden_size])
Expand All @@ -91,7 +93,8 @@ def __init__(self,
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.down_lora,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
use_trtllmgen_mm=use_trtllmgen_mm)

# These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used,
# but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ def apply(self, module: Linear, input: torch.Tensor,
input = input.to(torch.bfloat16) * module.input_scale
assert input.dtype == torch.bfloat16

if get_sm_version() == 100:
if get_sm_version() == 100 and not module.use_trtllmgen_mm:
import deep_gemm
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
output = torch.empty((input.shape[0], module.weight.shape[0]),
Expand Down Expand Up @@ -1461,6 +1461,7 @@ def __init__(
lora: Optional[LoraLayer] = None,
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO,
force_dynamic_quantization: bool = False,
use_trtllmgen_mm: bool = False,
):
from ..distributed import AllReduce

Expand All @@ -1477,6 +1478,7 @@ def __init__(
self.tp_mode = tensor_parallel_mode
self.gather_output = gather_output
self.force_dynamic_quantization = force_dynamic_quantization
self.use_trtllmgen_mm = use_trtllmgen_mm

local_in_features = in_features
local_out_features = out_features
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ def test_fp8_block_scales(self, mtp, fp8kv, attention_dp, cuda_graph,
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@skip_no_hopper
@skip_pre_blackwell
@parametrize_with_ids("torch_compile", [False])
@parametrize_with_ids(
"fp8kv,attention_dp,cuda_graph,overlap_scheduler",
Expand All @@ -1012,7 +1012,7 @@ def test_cute_dsl_fp8_block_scales(
max_num_streams=3) if torch_compile else None)
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
use_cuda_graph=cuda_graph,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
torch_compile_config=torch_compile_config,
moe_config=MoeConfig(backend="CUTEDSL"),
)
Expand Down Expand Up @@ -1139,7 +1139,7 @@ def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
task.evaluate(llm)

@pytest.mark.skip_less_device(4)
@skip_no_hopper
@skip_pre_blackwell
@parametrize_with_ids("torch_compile", [False])
@parametrize_with_ids(
"fp8kv,attention_dp,cuda_graph,overlap_scheduler",
Expand Down Expand Up @@ -1176,7 +1176,7 @@ def test_cute_dsl_fp8_block_scales_4gpus(
max_num_streams=3) if torch_compile else None)
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
use_cuda_graph=cuda_graph,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
torch_compile_config=torch_compile_config,
moe_config=MoeConfig(backend="CUTEDSL"),
)
Expand Down
44 changes: 16 additions & 28 deletions tests/unittest/_torch/modules/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from mpi4py import MPI
from mpi4py.futures import MPIPoolExecutor
from utils.util import (skip_neither_ada_nor_hopper_unittest,
skip_non_hopper_unittest, skip_pre_blackwell,
skip_pre_hopper)
skip_pre_blackwell, skip_pre_hopper)

from tensorrt_llm._torch.autotuner import AutoTuner, autotune
from tensorrt_llm._torch.model_config import ModelConfig
Expand Down Expand Up @@ -559,7 +558,7 @@ def grouped_gemm(a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor,
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)


@skip_non_hopper_unittest
@skip_pre_blackwell
@pytest.mark.parametrize(
"dtype, num_experts, seq_len, hidden_size, RoutingMethodCls",
product(
Expand All @@ -570,12 +569,12 @@ def grouped_gemm(a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor,
[DefaultMoeRoutingMethod],
),
)
def test_fused_moe_fp8_blockwise(dtype,
num_experts,
seq_len,
hidden_size,
RoutingMethodCls,
mapping=None):
def test_fused_moe_fp8_blockwise_cute_dsl(dtype,
num_experts,
seq_len,
hidden_size,
RoutingMethodCls,
mapping=None):
SEQ_LEN = seq_len
HIDDEN_SIZE = hidden_size
INTERMEDIATE_SIZE = 1536
Expand Down Expand Up @@ -647,52 +646,39 @@ def test_fused_moe_fp8_blockwise(dtype,
fused_moe.cuda()
fused_moe.load_weights([weights])

fused_moe_origin = CutlassFusedMoE(
num_experts=NUM_EXPERTS,
routing_method=routing_method,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
reduce_results=True,
model_config=ModelConfig(quant_config=quant_config, mapping=mapping),
)
fused_moe_origin.cuda()
fused_moe_origin.load_weights([weights])

ref_fused_moe = RefGatedMLPFusedMoE(
num_experts=NUM_EXPERTS,
routing_method=routing_method,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
model_config=ModelConfig(quant_config=quant_config),
# Note: use deepgemm mm will cause accuracy error, so we use trtllmgen mm here
use_trtllmgen_mm=True,
)
ref_fused_moe.load_weights([weights])
ref_fused_moe.cuda()

with torch.inference_mode():
output = fused_moe.forward(x, router_logits)
output_origin = fused_moe_origin.forward(x, router_logits)
ref_output = ref_fused_moe.forward(x, router_logits)

# compare
torch.cuda.synchronize()
torch.testing.assert_close(output_origin, output, rtol=1e-2, atol=0.1)
torch.testing.assert_close(output_origin, ref_output, rtol=1e-2, atol=0.1)
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
return True


@skip_non_hopper_unittest
@skip_pre_blackwell
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="needs 4 GPUs to run this test")
@pytest.mark.parametrize("ep_size", [1, 2, 4])
@pytest.mark.parametrize("routing_method", [DefaultMoeRoutingMethod])
def test_fused_moe_fp8_blockwise_multi_gpu(ep_size, routing_method):
def test_fused_moe_fp8_blockwise_cute_dsl_multi_gpu(ep_size, routing_method):
world_size = 4
with MPIPoolExecutor(max_workers=world_size) as executor:
results = executor.map(
test_fused_moe_fp8_blockwise,
test_fused_moe_fp8_blockwise_cute_dsl,
*zip(*[(
torch.bfloat16,
72,
Expand Down Expand Up @@ -966,7 +952,8 @@ def __init__(self,
hidden_size: int,
intermediate_size: int,
dtype: Optional[torch.dtype] = None,
model_config: ModelConfig = ModelConfig()):
model_config: ModelConfig = ModelConfig(),
use_trtllmgen_mm: bool = False):
super().__init__()
self.num_experts = num_experts
self.routing_method = routing_method
Expand All @@ -983,6 +970,7 @@ def __init__(self,
bias=False,
dtype=self.dtype,
config=model_config,
use_trtllmgen_mm=use_trtllmgen_mm,
) for _ in range(self.num_experts)
])

Expand Down
Loading