From 3c496cfa43f34066ba5bd9621ae9736527fdb203 Mon Sep 17 00:00:00 2001 From: bhsueh <11360707+byshiue@users.noreply.github.com> Date: Mon, 21 Jul 2025 15:00:53 -0700 Subject: [PATCH 1/3] add some modules into serailized loading list for Qwen3 to prevent bug on tp > 1 with trtllm moe backend Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_qwen3_moe.py | 8 ++++++++ tensorrt_llm/_torch/models/modeling_utils.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py index 4d1210fc93f..2d447dd527b 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py @@ -309,6 +309,13 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig]): super().__init__(model_config) config = self.model_config self.aux_stream = torch.cuda.Stream() + self.preload_weight_modules = [] + if config.moe_backend == "TRTLLM": + self.preload_weight_modules = [ + "experts", + "routing_method", + "all_reduce", + ] if model_config.mapping.enable_attention_dp: # When attention_dp is enabled, we cannot do all_reduce since @@ -381,6 +388,7 @@ def __init__( Qwen3MoEModel(model_config), model_config, ) + self.preload_weight_modules = self.model.preload_weight_modules def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper): super().load_weights(weights, weight_mapper) diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index c751bdcbb01..fe0005398d7 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -863,7 +863,7 @@ def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM], skip_modules: List[str] = [], params_map: Optional[Dict[str, str]] = None, preload_weight_modules: Optional[List[str]] = None): - # TODO: remove preload_weight_modules - it is a workaround for min-latency llama4 model loading where + # TODO: remove preload_weight_modules - it is a workaround for min-latency llama4 and Qwen3 model loading where # we need some order in the module loading. Once this is resolved, we can remove this workaround. weight_mapper.add_skip_modules(skip_modules) if params_map is not None: From 2ed4e243b461de0567860c31702a4d5ed1b6012e Mon Sep 17 00:00:00 2001 From: bhsueh <11360707+byshiue@users.noreply.github.com> Date: Mon, 21 Jul 2025 15:02:45 -0700 Subject: [PATCH 2/3] unwaive related CI tests Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> --- tests/integration/test_lists/waives.txt | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 346aab5adf5..08b9c808254 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -393,7 +393,6 @@ examples/test_llama.py::test_llm_llama_v3_1_2nodes_8gpus[llama-3.1-8b-disable_fp test_e2e.py::test_openai_multinodes_chat_tp16pp1 SKIP (https://nvbugs/5112075) examples/test_qwen.py::test_llm_hf_qwen_quantization_1gpu[qwen2_vl_7b_instruct-fp8-bfloat16] SKIP (https://nvbugs/5322488) accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype SKIP (https://nvbugs/5234043) -full:B200/accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm] SKIP (https://nvbugs/5401163) examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:8-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5360086) examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoder] SKIP (https://nvbugs/5355128) examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoderplus] SKIP (https://nvbugs/5355128) @@ -424,8 +423,6 @@ triton_server/test_triton_llm.py::test_llava_onevision[test_video-False-1---Fals triton_server/test_triton.py::test_cpp_unit_tests[cpp-unit-tests] SKIP (https://nvbugs/5401088) accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype SKIP (https://nvbugs/5401114) test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True] SKIP (https://nvbugs/5401114) -accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm] SKIP (https://nvbugs/5401163) -accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm] SKIP (https://nvbugs/5401163) examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-recurrentgemma-2b-use_paged_cache-int4_awq-float16-enable_attn_plugin-enable_gemm_plugin] SKIP (https://nvbugs/5401233) triton_server/test_triton_llm.py::test_gpt_disaggregated_serving_bls[test_basic-False-1-top_k_top_p--False-True-True-0-128-enableDecoupleMode-inflight_fused_batching-disableTrtOverlap-0.2-max_utilization---1-1-1-True-tensorrt_llm_bls] SKIP (https://nvbugs/5401261) triton_server/test_triton.py::test_gpt_disaggregated_serving_bls[gpt-disaggregated-serving-bls] SKIP (https://nvbugs/5401261) From 1966e9941b3c89a9c4c6b9cb1ed61b1fc348a85b Mon Sep 17 00:00:00 2001 From: bhsueh <11360707+byshiue@users.noreply.github.com> Date: Tue, 22 Jul 2025 20:19:22 +0000 Subject: [PATCH 3/3] add qwen3 30b eagle3 test Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> --- .../defs/accuracy/references/gsm8k.yaml | 2 ++ .../defs/accuracy/test_llm_api_pytorch.py | 29 ++++++++++++++++--- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index 41dce7f1837..850f27389b8 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -77,6 +77,8 @@ Qwen3/Qwen3-30B-A3B: - quant_algo: NVFP4 kv_cache_quant_algo: FP8 accuracy: 83.43 + - spec_dec_algo: Eagle + accuracy: 83.43 Qwen3/Qwen3-235B-A22B: - quant_algo: FP8 kv_cache_quant_algo: FP8 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index fb46cd337e8..20409478704 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1756,6 +1756,31 @@ def test_nvfp4( task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + def test_eagle3(self): + pytorch_config = dict( + disable_overlap_scheduler=True, + cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 3, 4, 8]), + ) + kv_cache_config = KvCacheConfig(enable_block_reuse=False) + + eagle_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-30B-eagle3" + target_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-30B-A3B" + + draft_len = 1 + spec_config = EagleDecodingConfig(max_draft_len=draft_len, + speculative_model_dir=eagle_model_dir, + eagle3_one_model=True) + + llm = LLM(model=target_model_dir, + **pytorch_config, + kv_cache_config=kv_cache_config, + speculative_config=spec_config, + max_seq_len=8192) + + with llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + class TestQwen3_32B(LlmapiAccuracyTestHarness): MODEL_NAME = "Qwen3/Qwen3-32B" @@ -1822,10 +1847,6 @@ def test_fp8(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph, ) def test_nvfp4(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph, overlap_scheduler, moe_backend): - if moe_backend == "TRTLLM": - pytest.skip( - "TRTLLM moe backend has accuracy issues: https://nvbugspro.nvidia.com/bug/5404726" - ) pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler,