Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tensorrt_llm/_torch/models/modeling_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,13 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig]):
super().__init__(model_config)
config = self.model_config
self.aux_stream = torch.cuda.Stream()
self.preload_weight_modules = []
if config.moe_backend == "TRTLLM":
self.preload_weight_modules = [
"experts",
"routing_method",
"all_reduce",
]

if model_config.mapping.enable_attention_dp:
# When attention_dp is enabled, we cannot do all_reduce since
Expand Down Expand Up @@ -381,6 +388,7 @@ def __init__(
Qwen3MoEModel(model_config),
model_config,
)
self.preload_weight_modules = self.model.preload_weight_modules

def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper):
super().load_weights(weights, weight_mapper)
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,7 @@ def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM],
skip_modules: List[str] = [],
params_map: Optional[Dict[str, str]] = None,
preload_weight_modules: Optional[List[str]] = None):
# TODO: remove preload_weight_modules - it is a workaround for min-latency llama4 model loading where
# TODO: remove preload_weight_modules - it is a workaround for min-latency llama4 and Qwen3 model loading where
# we need some order in the module loading. Once this is resolved, we can remove this workaround.
weight_mapper.add_skip_modules(skip_modules)
if params_map is not None:
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/defs/accuracy/references/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ Qwen3/Qwen3-30B-A3B:
- quant_algo: NVFP4
kv_cache_quant_algo: FP8
accuracy: 83.43
- spec_dec_algo: Eagle
accuracy: 83.43
Qwen3/Qwen3-235B-A22B:
- quant_algo: FP8
kv_cache_quant_algo: FP8
Expand Down
29 changes: 25 additions & 4 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1756,6 +1756,31 @@ def test_nvfp4(
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

def test_eagle3(self):
pytorch_config = dict(
disable_overlap_scheduler=True,
cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 3, 4, 8]),
)
kv_cache_config = KvCacheConfig(enable_block_reuse=False)

eagle_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-30B-eagle3"
target_model_dir = f"{llm_models_root()}/Qwen3/Qwen3-30B-A3B"

draft_len = 1
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=True)

llm = LLM(model=target_model_dir,
**pytorch_config,
kv_cache_config=kv_cache_config,
speculative_config=spec_config,
max_seq_len=8192)

with llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)


class TestQwen3_32B(LlmapiAccuracyTestHarness):
MODEL_NAME = "Qwen3/Qwen3-32B"
Expand Down Expand Up @@ -1822,10 +1847,6 @@ def test_fp8(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
)
def test_nvfp4(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
overlap_scheduler, moe_backend):
if moe_backend == "TRTLLM":
pytest.skip(
"TRTLLM moe backend has accuracy issues: https://nvbugspro.nvidia.com/bug/5404726"
)

pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
Expand Down
3 changes: 0 additions & 3 deletions tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,6 @@ examples/test_llama.py::test_llm_llama_v3_1_2nodes_8gpus[llama-3.1-8b-disable_fp
test_e2e.py::test_openai_multinodes_chat_tp16pp1 SKIP (https://nvbugs/5112075)
examples/test_qwen.py::test_llm_hf_qwen_quantization_1gpu[qwen2_vl_7b_instruct-fp8-bfloat16] SKIP (https://nvbugs/5322488)
accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype SKIP (https://nvbugs/5234043)
full:B200/accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm] SKIP (https://nvbugs/5401163)
examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:8-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5360086)
examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoder] SKIP (https://nvbugs/5355128)
examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoderplus] SKIP (https://nvbugs/5355128)
Expand Down Expand Up @@ -424,8 +423,6 @@ triton_server/test_triton_llm.py::test_llava_onevision[test_video-False-1---Fals
triton_server/test_triton.py::test_cpp_unit_tests[cpp-unit-tests] SKIP (https://nvbugs/5401088)
accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype SKIP (https://nvbugs/5401114)
test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True] SKIP (https://nvbugs/5401114)
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm] SKIP (https://nvbugs/5401163)
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm] SKIP (https://nvbugs/5401163)
examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-recurrentgemma-2b-use_paged_cache-int4_awq-float16-enable_attn_plugin-enable_gemm_plugin] SKIP (https://nvbugs/5401233)
triton_server/test_triton_llm.py::test_gpt_disaggregated_serving_bls[test_basic-False-1-top_k_top_p--False-True-True-0-128-enableDecoupleMode-inflight_fused_batching-disableTrtOverlap-0.2-max_utilization---1-1-1-True-tensorrt_llm_bls] SKIP (https://nvbugs/5401261)
triton_server/test_triton.py::test_gpt_disaggregated_serving_bls[gpt-disaggregated-serving-bls] SKIP (https://nvbugs/5401261)
Expand Down