Skip to content

Commit c25c9a3

Browse files
NicholasTaonicholastao
authored andcommitted
qwen3_moe/qwen25 support torchair graph
Signed-off-by: taoyuxiang <[email protected]>
1 parent 9554116 commit c25c9a3

File tree

7 files changed

+1015
-9
lines changed

7 files changed

+1015
-9
lines changed

tests/ut/test_ascend_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def test_check_ascend_config_wrong_case(self):
232232

233233
def test_check_torchair_supported(self):
234234
test_cases = [('deepseek_v3', True), ('PanguProMoE', True),
235-
('qwen', False), ('llama', False)]
235+
('qwen', True), ('llama', False)]
236236
for model_type, expected_output in test_cases:
237237
self.assertEqual(_check_torchair_supported(model_type),
238238
expected_output)

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from vllm.logger import logger
1919

20-
TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2"]
20+
TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2", "qwen"]
2121

2222

2323
def _check_torchair_supported(model_type: str):
@@ -162,7 +162,7 @@ def check_ascend_config(vllm_config, enforce_eager):
162162
else:
163163
# torchair_graph case
164164
if ascend_config.torchair_graph_config.enabled:
165-
# torchair_graph is supported for deepseek/pangu model only.
165+
# torchair_graph is supported for deepseek/pangu/qwen model only.
166166
if vllm_config.model_config:
167167
model_type = vllm_config.model_config.hf_config.model_type
168168
if not _check_torchair_supported(model_type):

vllm_ascend/ops/rotary_embedding.py

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from typing import Optional, Tuple
2020

2121
import torch
22+
import torch.nn.functional as F
23+
import torch_npu
2224
from vllm.model_executor.layers.rotary_embedding import (
2325
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
2426

@@ -37,17 +39,18 @@ def rope_forward_oot(
3739
query: torch.Tensor,
3840
key: torch.Tensor,
3941
offsets: Optional[torch.Tensor] = None,
40-
is_neox_style_override: Optional[bool] = None
42+
is_neox_style_override: Optional[bool] = None,
43+
is_qwen_torchair: Optional[bool] = False,
4144
) -> Tuple[torch.Tensor, torch.Tensor]:
42-
if get_ascend_config().torchair_graph_config.enabled:
45+
if get_ascend_config(
46+
).torchair_graph_config.enabled and not is_qwen_torchair:
4347
return self.forward_native(
4448
positions,
4549
query,
4650
key,
4751
offsets,
4852
)
4953

50-
import torch_npu
5154
query_shape, key_shape = query.shape, key.shape
5255
if self.cos_sin_cache.device != query.device:
5356
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
@@ -246,6 +249,98 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
246249
self.register_buffer("sin_cached", sin_cached, persistent=False)
247250

248251

252+
def __set_cos_sin_cache(self, seq_len, device, dtype):
253+
inv_freq = 1.0 / (self.base**(torch.arange(
254+
0, self.rotary_dim, 2, device=device, dtype=torch.float32) *
255+
(1 / self.rotary_dim)))
256+
self.register_buffer("inv_freq", inv_freq)
257+
258+
t = torch.arange(self.max_position_embeddings,
259+
device=self.inv_freq.device,
260+
dtype=torch.float32)
261+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
262+
263+
emb = torch.cat((freqs, freqs), dim=-1)
264+
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False)
265+
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False)
266+
self.embed = F.embedding
267+
268+
269+
def qwen_rope_init_func(
270+
self,
271+
head_size: int,
272+
rotary_dim: int,
273+
max_position_embeddings: int,
274+
base: float,
275+
is_neox_style: bool,
276+
dtype: torch.dtype,
277+
) -> None:
278+
super(RotaryEmbedding, self).__init__()
279+
self.head_size = head_size
280+
self.rotary_dim = rotary_dim
281+
self.max_position_embeddings = max_position_embeddings
282+
self.base = base
283+
self.is_neox_style = is_neox_style
284+
self.dtype = dtype
285+
286+
cache = self._compute_cos_sin_cache()
287+
cache = cache.to(dtype)
288+
self.cos_sin_cache: torch.Tensor # type: ignore[misc]
289+
self.register_buffer("cos_sin_cache", cache, persistent=False)
290+
if get_ascend_config().torchair_graph_config.enabled:
291+
__set_cos_sin_cache(self,
292+
seq_len=max_position_embeddings,
293+
device="npu",
294+
dtype=dtype)
295+
296+
297+
def rope_forward(
298+
self,
299+
positions: torch.Tensor,
300+
query: torch.Tensor,
301+
key: torch.Tensor,
302+
offsets: Optional[torch.Tensor] = None,
303+
is_neox_style_override: Optional[bool] = None,
304+
max_seq_len: Optional[int] = None,
305+
is_prefill: Optional[bool] = True,
306+
is_qwen_torchair: Optional[bool] = False,
307+
):
308+
if (not get_ascend_config().torchair_graph_config.enabled
309+
or not is_qwen_torchair or is_prefill):
310+
return rope_forward_oot(self, positions, query, key, offsets,
311+
is_neox_style_override,
312+
is_qwen_torchair) # type: ignore
313+
314+
if max_seq_len is not None and torch.gt(max_seq_len,
315+
self.max_position_embeddings):
316+
__set_cos_sin_cache(self,
317+
seq_len=max_seq_len,
318+
device=query.device,
319+
dtype=torch.float32)
320+
321+
# bsnd/bnsd
322+
if positions is not None:
323+
cos = self.embed(positions, self.cos)
324+
sin = self.embed(positions, self.sin)
325+
self.cos_embed = cos
326+
self.sin_embed = sin
327+
else:
328+
cos = self.cos_embed
329+
sin = self.sin_embed
330+
331+
query = query.view(*query.shape[:-1], -1, self.head_size).contiguous()
332+
key = key.view(*key.shape[:-1], -1, self.head_size).contiguous()
333+
334+
cos = cos.unsqueeze(-2).unsqueeze(-2)
335+
sin = sin.unsqueeze(-2).unsqueeze(-2)
336+
337+
query = query.unsqueeze(1)
338+
key = key.unsqueeze(1)
339+
340+
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin)
341+
return q_embed.flatten(-2), k_embed.flatten(-2)
342+
343+
249344
def deepseek_rope_init_func(
250345
self,
251346
head_size: int,
@@ -283,7 +378,8 @@ def deepseek_rope_init_func(
283378
device="npu")
284379

285380

286-
RotaryEmbedding.forward_oot = rope_forward_oot
381+
RotaryEmbedding.__init__ = qwen_rope_init_func
382+
RotaryEmbedding.forward_oot = rope_forward
287383

288384
# Note: we adopt the native huggingface deepseek rope initialization code from
289385
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for

0 commit comments

Comments
 (0)