Skip to content

Commit 541a62e

Browse files
byshiuelancelly
authored andcommitted
[Fix][nvbug 5401163][nvbug 5404726][Qwen3] Fix bug of MoE on tp > 1 with trtllm moe backend (NVIDIA#6235)
Signed-off-by: bhsueh <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
1 parent cb7f7f3 commit 541a62e

File tree

5 files changed

+36
-8
lines changed

5 files changed

+36
-8
lines changed

tensorrt_llm/_torch/models/modeling_qwen3_moe.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,13 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig]):
309309
super().__init__(model_config)
310310
config = self.model_config
311311
self.aux_stream = torch.cuda.Stream()
312+
self.preload_weight_modules = []
313+
if config.moe_backend == "TRTLLM":
314+
self.preload_weight_modules = [
315+
"experts",
316+
"routing_method",
317+
"all_reduce",
318+
]
312319

313320
if model_config.mapping.enable_attention_dp:
314321
# When attention_dp is enabled, we cannot do all_reduce since
@@ -381,6 +388,7 @@ def __init__(
381388
Qwen3MoEModel(model_config),
382389
model_config,
383390
)
391+
self.preload_weight_modules = self.model.preload_weight_modules
384392

385393
def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper):
386394
super().load_weights(weights, weight_mapper)

tensorrt_llm/_torch/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,7 @@ def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM],
865865
skip_modules: List[str] = [],
866866
params_map: Optional[Dict[str, str]] = None,
867867
preload_weight_modules: Optional[List[str]] = None):
868-
# TODO: remove preload_weight_modules - it is a workaround for min-latency llama4 model loading where
868+
# TODO: remove preload_weight_modules - it is a workaround for min-latency llama4 and Qwen3 model loading where
869869
# we need some order in the module loading. Once this is resolved, we can remove this workaround.
870870
weight_mapper.add_skip_modules(skip_modules)
871871
if params_map is not None:

tests/integration/defs/accuracy/references/gsm8k.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ Qwen3/Qwen3-30B-A3B:
7777
- quant_algo: NVFP4
7878
kv_cache_quant_algo: FP8
7979
accuracy: 83.43
80+
- spec_dec_algo: Eagle
81+
accuracy: 83.43
8082
Qwen3/Qwen3-235B-A22B:
8183
- quant_algo: FP8
8284
kv_cache_quant_algo: FP8

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,6 +1756,31 @@ def test_nvfp4(
17561756
task = GSM8K(self.MODEL_NAME)
17571757
task.evaluate(llm)
17581758

1759+
def test_eagle3(self):
1760+
pytorch_config = dict(
1761+
disable_overlap_scheduler=True,
1762+
cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 3, 4, 8]),
1763+
)
1764+
kv_cache_config = KvCacheConfig(enable_block_reuse=False)
1765+
1766+
eagle_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-30B-eagle3"
1767+
target_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-30B-A3B"
1768+
1769+
draft_len = 1
1770+
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
1771+
speculative_model_dir=eagle_model_dir,
1772+
eagle3_one_model=True)
1773+
1774+
llm = LLM(model=target_model_dir,
1775+
**pytorch_config,
1776+
kv_cache_config=kv_cache_config,
1777+
speculative_config=spec_config,
1778+
max_seq_len=8192)
1779+
1780+
with llm:
1781+
task = GSM8K(self.MODEL_NAME)
1782+
task.evaluate(llm)
1783+
17591784

17601785
class TestQwen3_32B(LlmapiAccuracyTestHarness):
17611786
MODEL_NAME = "Qwen3/Qwen3-32B"
@@ -1822,10 +1847,6 @@ def test_fp8(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
18221847
)
18231848
def test_nvfp4(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
18241849
overlap_scheduler, moe_backend):
1825-
if moe_backend == "TRTLLM":
1826-
pytest.skip(
1827-
"TRTLLM moe backend has accuracy issues: https://nvbugspro.nvidia.com/bug/5404726"
1828-
)
18291850

18301851
pytorch_config = dict(
18311852
disable_overlap_scheduler=not overlap_scheduler,

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,6 @@ examples/test_llama.py::test_llm_llama_v3_1_2nodes_8gpus[llama-3.1-8b-disable_fp
391391
test_e2e.py::test_openai_multinodes_chat_tp16pp1 SKIP (https://nvbugs/5112075)
392392
examples/test_qwen.py::test_llm_hf_qwen_quantization_1gpu[qwen2_vl_7b_instruct-fp8-bfloat16] SKIP (https://nvbugs/5322488)
393393
accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype SKIP (https://nvbugs/5234043)
394-
full:B200/accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm] SKIP (https://nvbugs/5401163)
395394
examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:8-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5360086)
396395
examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoder] SKIP (https://nvbugs/5355128)
397396
examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoderplus] SKIP (https://nvbugs/5355128)
@@ -422,8 +421,6 @@ triton_server/test_triton_llm.py::test_llava_onevision[test_video-False-1---Fals
422421
triton_server/test_triton.py::test_cpp_unit_tests[cpp-unit-tests] SKIP (https://nvbugs/5401088)
423422
accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype SKIP (https://nvbugs/5401114)
424423
test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True] SKIP (https://nvbugs/5401114)
425-
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm] SKIP (https://nvbugs/5401163)
426-
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm] SKIP (https://nvbugs/5401163)
427424
examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-recurrentgemma-2b-use_paged_cache-int4_awq-float16-enable_attn_plugin-enable_gemm_plugin] SKIP (https://nvbugs/5401233)
428425
examples/test_recurrentgemma.py::test_llm_recurrentgemma_2gpu[recurrentgemma-2b] SKIP (https://nvbugs/5401233)
429426
examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5401156)

0 commit comments

Comments
 (0)