Skip to content

Commit 9f590c7

Browse files
[Bugfix] Fix the bug that qwen3 moe doesn't work with aclgraph (#2478)
### What this PR does / why we need it? What's the PR does: 1. Move AscendSparseMoeBlock to qwen3 model, since it's only used by qwen3 model. 2. Disable AscendSparseMoeBlock if aclgraph is enabled, AscendSparseMoeBlock doesn't work with aclgraph currently. --------- Signed-off-by: shen-shanshan <[email protected]>
1 parent f64208b commit 9f590c7

File tree

2 files changed

+70
-5
lines changed

2 files changed

+70
-5
lines changed

tests/multicard/test_qwen3_moe.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
18+
#
19+
"""Compare the short outputs of HF and vLLM when using greedy sampling.
20+
21+
Run `pytest tests/test_offline_inference.py`.
22+
"""
23+
24+
from tests.conftest import VllmRunner
25+
26+
27+
def test_models_distributed_Qwen3_MOE_TP2():
28+
example_prompts = [
29+
"Hello, my name is",
30+
]
31+
dtype = "half"
32+
max_tokens = 5
33+
with VllmRunner(
34+
"Qwen/Qwen3-30B-A3B",
35+
dtype=dtype,
36+
tensor_parallel_size=4,
37+
distributed_executor_backend="mp",
38+
) as vllm_model:
39+
vllm_model.generate_greedy(example_prompts, max_tokens)
40+
41+
42+
def test_models_distributed_Qwen3_MOE_TP2_WITH_EP():
43+
example_prompts = [
44+
"Hello, my name is",
45+
]
46+
dtype = "half"
47+
max_tokens = 5
48+
with VllmRunner(
49+
"Qwen/Qwen3-30B-A3B",
50+
dtype=dtype,
51+
tensor_parallel_size=4,
52+
enable_expert_parallel=True,
53+
distributed_executor_backend="mp",
54+
) as vllm_model:
55+
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/models/qwen3_moe.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch import nn
2323
from transformers import PretrainedConfig
2424
from vllm.compilation.decorators import support_torch_compile
25-
from vllm.config import CacheConfig, VllmConfig
25+
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
2626
from vllm.distributed import get_pp_group
2727
from vllm.model_executor.layers.layernorm import RMSNorm
2828
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -32,7 +32,8 @@
3232
from vllm.model_executor.models.interfaces import SupportsPP
3333
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
3434
Qwen3MoeForCausalLM,
35-
Qwen3MoeMLP, Qwen3MoeModel)
35+
Qwen3MoeMLP, Qwen3MoeModel,
36+
Qwen3MoeSparseMoeBlock)
3637
from vllm.model_executor.models.utils import (
3738
extract_layer_index, make_empty_intermediate_tensors_factory, make_layers,
3839
maybe_prefix)
@@ -78,12 +79,21 @@ def __init__(
7879
layer_idx = extract_layer_index(prefix)
7980
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
8081
config.mlp_only_layers)
82+
use_aclgraph = (vllm_config is not None
83+
and vllm_config.compilation_config.level
84+
== CompilationLevel.PIECEWISE
85+
and not vllm_config.model_config.enforce_eager)
8186
if (layer_idx not in mlp_only_layers) and (
8287
config.num_experts > 0 and
8388
(layer_idx + 1) % config.decoder_sparse_step == 0):
84-
self.mlp = AscendSparseMoeBlock(config=config,
85-
quant_config=quant_config,
86-
prefix=f"{prefix}.mlp")
89+
if not use_aclgraph:
90+
self.mlp = AscendSparseMoeBlock(config=config,
91+
quant_config=quant_config,
92+
prefix=f"{prefix}.mlp")
93+
else:
94+
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
95+
quant_config=quant_config,
96+
prefix=f"{prefix}.mlp")
8797
else:
8898
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
8999
intermediate_size=config.intermediate_size,

0 commit comments

Comments
 (0)