Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions tests/e2e/multicard/test_torchair_graph_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,65 @@ def test_e2e_pangu_with_torchair():
},
}
_pangu_torchair_test_fixture(additional_config)


def _qwen_torchair_test_fixture(
model,
tp,
enable_expert_parallel,
):
# The current access control does not support 16 cards,
# so the MC2 operator in Qwen's graph mode cannot run.
# Once 16-card support is available,
# this e2e can be switched to graph mode.
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

additional_config = {
"torchair_graph_config": {
"enabled": False,
},
"ascend_scheduler_config": {
"enabled": True,
},
"refresh": True,
}

with VllmRunner(
model,
dtype="half",
tensor_parallel_size=tp,
distributed_executor_backend="mp",
enforce_eager=True,
additional_config=additional_config,
enable_expert_parallel=enable_expert_parallel,
) as vllm_model:
# use greedy sampler to make sure the generated results are fix
vllm_output = vllm_model.generate_greedy(example_prompts, 5)

# NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE
# with 2 hidden layers, thus the golden results seems inaccurate.
# This will only change if accuracy changes with the official weights
# of PanguProMoE.
golden_results = [
'Hello, my name is Remempondeprecatedmiot忱',
'The president of the United States is Remem下的一个 rever ceremoni Segnali',
'The capital of France is Rememvoud administrativ Remem投',
'The future of AI isotope Segnali Zoeken精细化 supus',
]

assert len(golden_results) == len(vllm_output)
for i in range(len(vllm_output)):
print(f"Generated text: {vllm_output[i][1]!r}")


def test_e2e_qwen2_with_torchair():
_qwen_torchair_test_fixture("Qwen/Qwen2.5-0.5B-Instruct", 2, False)


def test_e2e_qwen3_moe_with_torchair():
_qwen_torchair_test_fixture("Qwen/Qwen3-30B-A3B", 2, True)
52 changes: 52 additions & 0 deletions tests/ut/models/test_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import math
import unittest

import pytest
import torch
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM

from vllm_ascend.models.qwen3_moe import CustomQwen3MoeForCausalLM
from vllm_ascend.torchair.models.qwen3_moe import CustomQwen3MoeAttention


class TestCustomQwen3MoeForCausalLM:
Expand Down Expand Up @@ -44,3 +48,51 @@ def test_packed_modules_mapping_structure(self):
]
}
assert CustomQwen3MoeForCausalLM.packed_modules_mapping == expected_mapping


class DummyRMSNorm:

def __init__(self, dim: int, eps: float = 1e-6):
self.dim = dim
self.eps = eps

def __call__(self, x):
mean_sq = x.pow(2).mean(dim=-1, keepdim=True)
denom = (mean_sq + self.eps).sqrt()
return x / denom


class TestCustomQwen3MoeAttention(unittest.TestCase):

def setUp(self):
self.batch = 2
self.seq_len = 3
self.q_size = 8
self.kv_size = 8
self.head_dim = 4
self.rms_eps = 1e-6

total_dim = self.q_size + 2 * self.kv_size

self.qkv = torch.arange(self.batch * self.seq_len * total_dim,
dtype=torch.float32).reshape(
self.batch, self.seq_len, total_dim)

def test_constant_input_normalization(self):
ones_qkv = torch.ones((1, 1, self.q_size + 2 * self.kv_size),
dtype=torch.float32)

q_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
k_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
q, k, v = CustomQwen3MoeAttention.normalize_qkv(
ones_qkv, self.q_size, self.kv_size, self.head_dim, q_norm, k_norm)

norm_val = 1.0 / math.sqrt(1.0 + self.rms_eps)

expected_q = torch.full((1, 1, self.q_size), norm_val)
expected_k = torch.full((1, 1, self.kv_size), norm_val)
expected_v = torch.ones((1, 1, self.kv_size), dtype=torch.float32)

self.assertTrue(torch.allclose(q, expected_q, atol=1e-6))
self.assertTrue(torch.allclose(k, expected_k, atol=1e-6))
self.assertTrue(torch.equal(v, expected_v))
2 changes: 1 addition & 1 deletion tests/ut/test_ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_check_ascend_config_wrong_case(self):

def test_check_torchair_supported(self):
test_cases = [('deepseek_v3', True), ('PanguProMoE', True),
('qwen', False), ('llama', False)]
('qwen', True), ('llama', False)]
for model_type, expected_output in test_cases:
self.assertEqual(_check_torchair_supported(model_type),
expected_output)
Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from vllm.logger import logger

TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2"]
TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2", "qwen"]


def _check_torchair_supported(model_type: str):
Expand Down Expand Up @@ -162,7 +162,7 @@ def check_ascend_config(vllm_config, enforce_eager):
else:
# torchair_graph case
if ascend_config.torchair_graph_config.enabled:
# torchair_graph is supported for deepseek/pangu model only.
# torchair_graph is supported for deepseek/pangu/qwen model only.
if vllm_config.model_config:
model_type = vllm_config.model_config.hf_config.model_type
if not _check_torchair_supported(model_type):
Expand Down
98 changes: 94 additions & 4 deletions vllm_ascend/ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
import torch_npu
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)

Expand All @@ -37,17 +39,18 @@ def rope_forward_oot(
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None
is_neox_style_override: Optional[bool] = None,
is_qwen_torchair: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if get_ascend_config().torchair_graph_config.enabled:
if get_ascend_config(
).torchair_graph_config.enabled and not is_qwen_torchair:
return self.forward_native(
positions,
query,
key,
offsets,
)

import torch_npu
query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
Expand Down Expand Up @@ -246,6 +249,92 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("sin_cached", sin_cached, persistent=False)


def __set_cos_sin_cache(self, seq_len, device, dtype):
inv_freq = 1.0 / (self.base**(torch.arange(
0, self.rotary_dim, 2, device=device, dtype=torch.float32) *
(1 / self.rotary_dim)))
self.register_buffer("inv_freq", inv_freq)

t = torch.arange(self.max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)

emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False)
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False)
self.embed = F.embedding


_original_re_init = RotaryEmbedding.__init__


def qwen_rope_init_func(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
_original_re_init(self, head_size, rotary_dim, max_position_embeddings,
base, is_neox_style, dtype)
if get_ascend_config().torchair_graph_config.enabled:
__set_cos_sin_cache(self,
seq_len=max_position_embeddings,
device="npu",
dtype=dtype)


def rope_forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
max_seq_len: Optional[int] = None,
is_prefill: Optional[bool] = True,
is_qwen_torchair: Optional[bool] = False,
):
if get_ascend_config().torchair_graph_config.enabled \
and is_qwen_torchair and not is_prefill:
if max_seq_len is not None and torch.gt(max_seq_len,
self.max_position_embeddings):
__set_cos_sin_cache(self,
seq_len=max_seq_len,
device=query.device,
dtype=torch.float32)

# bsnd/bnsd
if positions is not None:
cos = self.embed(positions, self.cos)
sin = self.embed(positions, self.sin)
self.cos_embed = cos
self.sin_embed = sin
else:
cos = self.cos_embed
sin = self.sin_embed

query = query.view(*query.shape[:-1], -1, self.head_size).contiguous()
key = key.view(*key.shape[:-1], -1, self.head_size).contiguous()

cos = cos.unsqueeze(-2).unsqueeze(-2)
sin = sin.unsqueeze(-2).unsqueeze(-2)

query = query.unsqueeze(1)
key = key.unsqueeze(1)

q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(
query, key, cos, sin)
return q_embed.flatten(-2), k_embed.flatten(-2)
else:
return rope_forward_oot(self, positions, query, key, offsets,
is_neox_style_override,
is_qwen_torchair) # type: ignore


def deepseek_rope_init_func(
self,
head_size: int,
Expand Down Expand Up @@ -283,7 +372,8 @@ def deepseek_rope_init_func(
device="npu")


RotaryEmbedding.forward_oot = rope_forward_oot
RotaryEmbedding.__init__ = qwen_rope_init_func
RotaryEmbedding.forward_oot = rope_forward

# Note: we adopt the native huggingface deepseek rope initialization code from
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for
Expand Down
Loading
Loading