Skip to content

Commit 56ac826

Browse files
committed
init qwen torchair graph mode
Signed-off-by: taoyuxiang <[email protected]>
1 parent cfe91da commit 56ac826

File tree

8 files changed

+1029
-5
lines changed

8 files changed

+1029
-5
lines changed

tests/e2e/multicard/test_torchair_graph_mode.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,65 @@ def test_e2e_pangu_with_torchair():
162162
},
163163
}
164164
_pangu_torchair_test_fixture(additional_config)
165+
166+
167+
def _qwen_torchair_test_fixture(
168+
model,
169+
tp,
170+
enable_expert_parallel,
171+
):
172+
# The current access control does not support 16 cards,
173+
# so the MC2 operator in Qwen's graph mode cannot run.
174+
# Once 16-card support is available,
175+
# this e2e can be switched to graph mode.
176+
example_prompts = [
177+
"Hello, my name is",
178+
"The president of the United States is",
179+
"The capital of France is",
180+
"The future of AI is",
181+
]
182+
183+
additional_config = {
184+
"torchair_graph_config": {
185+
"enabled": False,
186+
},
187+
"ascend_scheduler_config": {
188+
"enabled": True,
189+
},
190+
"refresh": True,
191+
}
192+
193+
with VllmRunner(
194+
model,
195+
dtype="half",
196+
tensor_parallel_size=tp,
197+
distributed_executor_backend="mp",
198+
enforce_eager=True,
199+
additional_config=additional_config,
200+
enable_expert_parallel=enable_expert_parallel,
201+
) as vllm_model:
202+
# use greedy sampler to make sure the generated results are fix
203+
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
204+
205+
# NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE
206+
# with 2 hidden layers, thus the golden results seems inaccurate.
207+
# This will only change if accuracy changes with the official weights
208+
# of PanguProMoE.
209+
golden_results = [
210+
'Hello, my name is Remempondeprecatedmiot忱',
211+
'The president of the United States is Remem下的一个 rever ceremoni Segnali',
212+
'The capital of France is Rememvoud administrativ Remem投',
213+
'The future of AI isotope Segnali Zoeken精细化 supus',
214+
]
215+
216+
assert len(golden_results) == len(vllm_output)
217+
for i in range(len(vllm_output)):
218+
print(f"Generated text: {vllm_output[i][1]!r}")
219+
220+
221+
def test_e2e_qwen2_with_torchair():
222+
_qwen_torchair_test_fixture("Qwen/Qwen2.5-0.5B-Instruct", 2, False)
223+
224+
225+
def test_e2e_qwen3_moe_with_torchair():
226+
_qwen_torchair_test_fixture("Qwen/Qwen3-30B-A3B", 2, True)

tests/ut/models/test_qwen3_moe.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212
# limitations under the License.
1313
# This file is a part of the vllm-ascend project.
1414
#
15+
import math
16+
import unittest
1517

1618
import pytest
19+
import torch
1720
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
1821

1922
from vllm_ascend.models.qwen3_moe import CustomQwen3MoeForCausalLM
23+
from vllm_ascend.torchair.models.qwen3_moe import CustomQwen3MoeAttention
2024

2125

2226
class TestCustomQwen3MoeForCausalLM:
@@ -44,3 +48,51 @@ def test_packed_modules_mapping_structure(self):
4448
]
4549
}
4650
assert CustomQwen3MoeForCausalLM.packed_modules_mapping == expected_mapping
51+
52+
53+
class DummyRMSNorm:
54+
55+
def __init__(self, dim: int, eps: float = 1e-6):
56+
self.dim = dim
57+
self.eps = eps
58+
59+
def __call__(self, x):
60+
mean_sq = x.pow(2).mean(dim=-1, keepdim=True)
61+
denom = (mean_sq + self.eps).sqrt()
62+
return x / denom
63+
64+
65+
class TestCustomQwen3MoeAttention(unittest.TestCase):
66+
67+
def setUp(self):
68+
self.batch = 2
69+
self.seq_len = 3
70+
self.q_size = 8
71+
self.kv_size = 8
72+
self.head_dim = 4
73+
self.rms_eps = 1e-6
74+
75+
total_dim = self.q_size + 2 * self.kv_size
76+
77+
self.qkv = torch.arange(self.batch * self.seq_len * total_dim,
78+
dtype=torch.float32).reshape(
79+
self.batch, self.seq_len, total_dim)
80+
81+
def test_constant_input_normalization(self):
82+
ones_qkv = torch.ones((1, 1, self.q_size + 2 * self.kv_size),
83+
dtype=torch.float32)
84+
85+
q_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
86+
k_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
87+
q, k, v = CustomQwen3MoeAttention.normalize_qkv(
88+
ones_qkv, self.q_size, self.kv_size, self.head_dim, q_norm, k_norm)
89+
90+
norm_val = 1.0 / math.sqrt(1.0 + self.rms_eps)
91+
92+
expected_q = torch.full((1, 1, self.q_size), norm_val)
93+
expected_k = torch.full((1, 1, self.kv_size), norm_val)
94+
expected_v = torch.ones((1, 1, self.kv_size), dtype=torch.float32)
95+
96+
self.assertTrue(torch.allclose(q, expected_q, atol=1e-6))
97+
self.assertTrue(torch.allclose(k, expected_k, atol=1e-6))
98+
self.assertTrue(torch.equal(v, expected_v))

tests/ut/test_ascend_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def test_check_ascend_config_wrong_case(self):
232232

233233
def test_check_torchair_supported(self):
234234
test_cases = [('deepseek_v3', True), ('PanguProMoE', True),
235-
('qwen', False), ('llama', False)]
235+
('qwen', True), ('llama', False)]
236236
for model_type, expected_output in test_cases:
237237
self.assertEqual(_check_torchair_supported(model_type),
238238
expected_output)

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from vllm.logger import logger
1919

20-
TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2"]
20+
TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2", "qwen"]
2121

2222

2323
def _check_torchair_supported(model_type: str):
@@ -162,7 +162,7 @@ def check_ascend_config(vllm_config, enforce_eager):
162162
else:
163163
# torchair_graph case
164164
if ascend_config.torchair_graph_config.enabled:
165-
# torchair_graph is supported for deepseek/pangu model only.
165+
# torchair_graph is supported for deepseek/pangu/qwen model only.
166166
if vllm_config.model_config:
167167
model_type = vllm_config.model_config.hf_config.model_type
168168
if not _check_torchair_supported(model_type):

0 commit comments

Comments
 (0)