From eb122aa1539d11a29a850d399f8f3ba637638cb1 Mon Sep 17 00:00:00 2001 From: taoyuxiang Date: Mon, 18 Aug 2025 08:38:48 +0800 Subject: [PATCH] qwen3_moe/qwen25 support torchair graph Signed-off-by: taoyuxiang --- .../e2e/multicard/test_torchair_graph_mode.py | 62 ++ tests/ut/models/test_qwen3_moe.py | 52 ++ tests/ut/test_ascend_config.py | 2 +- vllm_ascend/ascend_config.py | 4 +- vllm_ascend/ops/rotary_embedding.py | 98 +++- vllm_ascend/torchair/models/qwen2.py | 364 ++++++++++++ vllm_ascend/torchair/models/qwen3_moe.py | 537 ++++++++++++++++++ vllm_ascend/torchair/torchair_attention.py | 5 +- vllm_ascend/torchair/utils.py | 8 + 9 files changed, 1123 insertions(+), 9 deletions(-) create mode 100644 vllm_ascend/torchair/models/qwen2.py create mode 100644 vllm_ascend/torchair/models/qwen3_moe.py diff --git a/tests/e2e/multicard/test_torchair_graph_mode.py b/tests/e2e/multicard/test_torchair_graph_mode.py index 71d33f0c82..a889f4ff7c 100644 --- a/tests/e2e/multicard/test_torchair_graph_mode.py +++ b/tests/e2e/multicard/test_torchair_graph_mode.py @@ -162,3 +162,65 @@ def test_e2e_pangu_with_torchair(): }, } _pangu_torchair_test_fixture(additional_config) + + +def _qwen_torchair_test_fixture( + model, + tp, + enable_expert_parallel, +): + # The current access control does not support 16 cards, + # so the MC2 operator in Qwen's graph mode cannot run. + # Once 16-card support is available, + # this e2e can be switched to graph mode. + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + additional_config = { + "torchair_graph_config": { + "enabled": False, + }, + "ascend_scheduler_config": { + "enabled": True, + }, + "refresh": True, + } + + with VllmRunner( + model, + dtype="half", + tensor_parallel_size=tp, + distributed_executor_backend="mp", + enforce_eager=True, + additional_config=additional_config, + enable_expert_parallel=enable_expert_parallel, + ) as vllm_model: + # use greedy sampler to make sure the generated results are fix + vllm_output = vllm_model.generate_greedy(example_prompts, 5) + + # NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE + # with 2 hidden layers, thus the golden results seems inaccurate. + # This will only change if accuracy changes with the official weights + # of PanguProMoE. + golden_results = [ + 'Hello, my name is Remempondeprecatedmiot忱', + 'The president of the United States is Remem下的一个 rever ceremoni Segnali', + 'The capital of France is Rememvoud administrativ Remem投', + 'The future of AI isotope Segnali Zoeken精细化 supus', + ] + + assert len(golden_results) == len(vllm_output) + for i in range(len(vllm_output)): + print(f"Generated text: {vllm_output[i][1]!r}") + + +def test_e2e_qwen2_with_torchair(): + _qwen_torchair_test_fixture("Qwen/Qwen2.5-0.5B-Instruct", 2, False) + + +def test_e2e_qwen3_moe_with_torchair(): + _qwen_torchair_test_fixture("Qwen/Qwen3-30B-A3B", 2, True) diff --git a/tests/ut/models/test_qwen3_moe.py b/tests/ut/models/test_qwen3_moe.py index 71be045a64..e882fe21bd 100644 --- a/tests/ut/models/test_qwen3_moe.py +++ b/tests/ut/models/test_qwen3_moe.py @@ -12,11 +12,15 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # +import math +import unittest import pytest +import torch from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM from vllm_ascend.models.qwen3_moe import CustomQwen3MoeForCausalLM +from vllm_ascend.torchair.models.qwen3_moe import CustomQwen3MoeAttention class TestCustomQwen3MoeForCausalLM: @@ -44,3 +48,51 @@ def test_packed_modules_mapping_structure(self): ] } assert CustomQwen3MoeForCausalLM.packed_modules_mapping == expected_mapping + + +class DummyRMSNorm: + + def __init__(self, dim: int, eps: float = 1e-6): + self.dim = dim + self.eps = eps + + def __call__(self, x): + mean_sq = x.pow(2).mean(dim=-1, keepdim=True) + denom = (mean_sq + self.eps).sqrt() + return x / denom + + +class TestCustomQwen3MoeAttention(unittest.TestCase): + + def setUp(self): + self.batch = 2 + self.seq_len = 3 + self.q_size = 8 + self.kv_size = 8 + self.head_dim = 4 + self.rms_eps = 1e-6 + + total_dim = self.q_size + 2 * self.kv_size + + self.qkv = torch.arange(self.batch * self.seq_len * total_dim, + dtype=torch.float32).reshape( + self.batch, self.seq_len, total_dim) + + def test_constant_input_normalization(self): + ones_qkv = torch.ones((1, 1, self.q_size + 2 * self.kv_size), + dtype=torch.float32) + + q_norm = DummyRMSNorm(self.head_dim, self.rms_eps) + k_norm = DummyRMSNorm(self.head_dim, self.rms_eps) + q, k, v = CustomQwen3MoeAttention.normalize_qkv( + ones_qkv, self.q_size, self.kv_size, self.head_dim, q_norm, k_norm) + + norm_val = 1.0 / math.sqrt(1.0 + self.rms_eps) + + expected_q = torch.full((1, 1, self.q_size), norm_val) + expected_k = torch.full((1, 1, self.kv_size), norm_val) + expected_v = torch.ones((1, 1, self.kv_size), dtype=torch.float32) + + self.assertTrue(torch.allclose(q, expected_q, atol=1e-6)) + self.assertTrue(torch.allclose(k, expected_k, atol=1e-6)) + self.assertTrue(torch.equal(v, expected_v)) diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index ec00c0d965..49b9abeabc 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -232,7 +232,7 @@ def test_check_ascend_config_wrong_case(self): def test_check_torchair_supported(self): test_cases = [('deepseek_v3', True), ('PanguProMoE', True), - ('qwen', False), ('llama', False)] + ('qwen', True), ('llama', False)] for model_type, expected_output in test_cases: self.assertEqual(_check_torchair_supported(model_type), expected_output) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 777ff9ffac..9b355783b0 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -17,7 +17,7 @@ from vllm.logger import logger -TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2"] +TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2", "qwen"] def _check_torchair_supported(model_type: str): @@ -162,7 +162,7 @@ def check_ascend_config(vllm_config, enforce_eager): else: # torchair_graph case if ascend_config.torchair_graph_config.enabled: - # torchair_graph is supported for deepseek/pangu model only. + # torchair_graph is supported for deepseek/pangu/qwen model only. if vllm_config.model_config: model_type = vllm_config.model_config.hf_config.model_type if not _check_torchair_supported(model_type): diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 3dd91ea63f..806a210744 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -19,6 +19,8 @@ from typing import Optional, Tuple import torch +import torch.nn.functional as F +import torch_npu from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) @@ -37,9 +39,11 @@ def rope_forward_oot( query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, - is_neox_style_override: Optional[bool] = None + is_neox_style_override: Optional[bool] = None, + is_qwen_torchair: Optional[bool] = False, ) -> Tuple[torch.Tensor, torch.Tensor]: - if get_ascend_config().torchair_graph_config.enabled: + if get_ascend_config( + ).torchair_graph_config.enabled and not is_qwen_torchair: return self.forward_native( positions, query, @@ -47,7 +51,6 @@ def rope_forward_oot( offsets, ) - import torch_npu query_shape, key_shape = query.shape, key.shape if self.cos_sin_cache.device != query.device: self.cos_sin_cache = self.cos_sin_cache.to(query.device) @@ -246,6 +249,92 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", sin_cached, persistent=False) +def __set_cos_sin_cache(self, seq_len, device, dtype): + inv_freq = 1.0 / (self.base**(torch.arange( + 0, self.rotary_dim, 2, device=device, dtype=torch.float32) * + (1 / self.rotary_dim))) + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange(self.max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.float32) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False) + self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False) + self.embed = F.embedding + + +_original_re_init = RotaryEmbedding.__init__ + + +def qwen_rope_init_func( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, +) -> None: + _original_re_init(self, head_size, rotary_dim, max_position_embeddings, + base, is_neox_style, dtype) + if get_ascend_config().torchair_graph_config.enabled: + __set_cos_sin_cache(self, + seq_len=max_position_embeddings, + device="npu", + dtype=dtype) + + +def rope_forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + is_neox_style_override: Optional[bool] = None, + max_seq_len: Optional[int] = None, + is_prefill: Optional[bool] = True, + is_qwen_torchair: Optional[bool] = False, +): + if get_ascend_config().torchair_graph_config.enabled \ + and is_qwen_torchair and not is_prefill: + if max_seq_len is not None and torch.gt(max_seq_len, + self.max_position_embeddings): + __set_cos_sin_cache(self, + seq_len=max_seq_len, + device=query.device, + dtype=torch.float32) + + # bsnd/bnsd + if positions is not None: + cos = self.embed(positions, self.cos) + sin = self.embed(positions, self.sin) + self.cos_embed = cos + self.sin_embed = sin + else: + cos = self.cos_embed + sin = self.sin_embed + + query = query.view(*query.shape[:-1], -1, self.head_size).contiguous() + key = key.view(*key.shape[:-1], -1, self.head_size).contiguous() + + cos = cos.unsqueeze(-2).unsqueeze(-2) + sin = sin.unsqueeze(-2).unsqueeze(-2) + + query = query.unsqueeze(1) + key = key.unsqueeze(1) + + q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb( + query, key, cos, sin) + return q_embed.flatten(-2), k_embed.flatten(-2) + else: + return rope_forward_oot(self, positions, query, key, offsets, + is_neox_style_override, + is_qwen_torchair) # type: ignore + + def deepseek_rope_init_func( self, head_size: int, @@ -283,7 +372,8 @@ def deepseek_rope_init_func( device="npu") -RotaryEmbedding.forward_oot = rope_forward_oot +RotaryEmbedding.__init__ = qwen_rope_init_func +RotaryEmbedding.forward_oot = rope_forward # Note: we adopt the native huggingface deepseek rope initialization code from # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for diff --git a/vllm_ascend/torchair/models/qwen2.py b/vllm_ascend/torchair/models/qwen2.py new file mode 100644 index 0000000000..3537aa84e8 --- /dev/null +++ b/vllm_ascend/torchair/models/qwen2.py @@ -0,0 +1,364 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +from collections.abc import Iterable +from typing import Any, List, Optional, Union + +import torch +import torch.nn.functional as F +import vllm +import vllm.envs as envs +from torch import nn +from transformers import Qwen2Config +from vllm.attention import AttentionMetadata, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, tensor_model_parallel_all_gather, + tensor_model_parallel_reduce_scatter) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.qwen2 import Qwen2Attention # noqa: F401 +from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM # noqa: F401 +from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Model +from vllm.model_executor.models.utils import (AutoWeightsLoader, + PPMissingLayer, maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.attention.attention_v1 import AscendAttentionState + + +def all_gather_and_maybe_unpad( + hidden_states: torch.Tensor, + pad_size: int, +) -> torch.Tensor: + hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) + if pad_size > 0: + return hidden_states[:-pad_size, :] + return hidden_states + + +def maybe_pad_and_reduce_scatter( + hidden_states: torch.Tensor, + pad_size: int, +) -> torch.Tensor: + if pad_size > 0: + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size)) + hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0) + return hidden_states + + +class CustomQwen2Attention(Qwen2Attention): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + dual_chunk_attention_config: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_position=max_position, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=prefix, + attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.torchair_graph_enabled and attn_metadata is not None and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + q, k = self.rotary_emb(positions, + q, + k, + is_prefill=False, + is_qwen_torchair=True) + forward_kwargs = {} + if envs.VLLM_USE_V1: + output_shape = q.shape + output = torch.empty(output_shape, + dtype=q.dtype, + device=q.device) + forward_kwargs['output'] = output + + attn_output = self.attn.impl.forward(self.attn, + q, + k, + v, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + trace_flag=False, + **forward_kwargs) + output, _ = self.o_proj(attn_output) + return output + else: + if type(self.rotary_emb) is RotaryEmbedding: + q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True) + else: + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class CustomQwen2DecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + dual_chunk_attention_config = getattr(config, + "dual_chunk_attention_config", + None) + + # By default, Qwen2 uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = CustomQwen2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config, + ) + self.mlp = Qwen2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class CustomQwen2Model(Qwen2Model): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + decoder_layer_type=decoder_layer_type) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + kv_cache = kv_caches[i - self.start_layer] \ + if kv_caches is not None else None + hidden_states, residual = layer(positions, + hidden_states, + residual, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + # add `CustomQwen2Model` to init self.model + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = CustomQwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + +vllm.model_executor.models.qwen2.Qwen2ForCausalLM = CustomQwen2ForCausalLM diff --git a/vllm_ascend/torchair/models/qwen3_moe.py b/vllm_ascend/torchair/models/qwen3_moe.py new file mode 100644 index 0000000000..dd4a592d65 --- /dev/null +++ b/vllm_ascend/torchair/models/qwen3_moe.py @@ -0,0 +1,537 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from vllm/model_executor/models/qwen3_moe.py +# This file is a part of the vllm-ascend project. +from typing import Any, List, Optional, Union + +import torch +import vllm.envs as envs +from torch import nn +from transformers import PretrainedConfig +from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, CompilationLevel, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, + get_tp_group) +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.interfaces import (MixtureOfExperts, + SupportsLoRA, SupportsPP) +from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention, + Qwen3MoeDecoderLayer, + Qwen3MoeForCausalLM, + Qwen3MoeMLP, Qwen3MoeModel, + Qwen3MoeSparseMoeBlock) +from vllm.model_executor.models.utils import ( + PPMissingLayer, extract_layer_index, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.sequence import IntermediateTensors + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.ops.sequence_parallel import (MetadataForPadding, + init_metadata_for_sp) + + +class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + nn.Module.__init__(self) + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + self.experts = AscendFusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) + + self.top_k = config.num_experts_per_tok + + self.dp_size = get_dp_group().world_size + + self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group + self.ep_group = get_ep_group() + + self.params_dtype = torch.get_default_dtype() + + def forward( + self, + hidden_states, + attn_metadata=None, + _metadata_for_padding: Optional[MetadataForPadding] = None, + ): + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + enable_force_load_balance = get_forward_context().in_profile_run + is_prefill = get_forward_context().with_prefill + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=self.top_k, + enable_force_load_balance=enable_force_load_balance, + shared_experts=None, + _metadata_for_padding=_metadata_for_padding, + ) + + return hidden_states + + +class CustomQwen3MoeAttention(Qwen3MoeAttention): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear(hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + @staticmethod + def normalize_qkv(qkv: torch.Tensor, q_size: int, kv_size: int, + head_dim: int, q_norm, k_norm): + q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) + + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim) + q_by_head = q_norm(q_by_head) + q = q_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim) + k_by_head = k_norm(k_by_head) + k = k_by_head.view(k.shape) + + return q, k, v + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = self.normalize_qkv(qkv, self.q_size, self.kv_size, + self.head_dim, self.q_norm, self.k_norm) + + if (self.torchair_graph_enabled and attn_metadata is not None and + attn_metadata.attn_state == AscendAttentionState.DecodeOnly): + q, k = self.rotary_emb(positions, + q, + k, + is_prefill=False, + is_qwen_torchair=True) + forward_kwargs = {} + if envs.VLLM_USE_V1: + output_shape = q.shape + output = torch.empty(output_shape, + dtype=q.dtype, + device=q.device) + forward_kwargs['output'] = output + + attn_output = self.attn.impl.forward(self.attn, + q, + k, + v, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + trace_flag=False, + **forward_kwargs) + output, _ = self.o_proj(attn_output) + return output + else: + q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + vllm_config: Optional[VllmConfig] = None, + prefix: str = "", + ) -> None: + + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = CustomQwen3MoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'attention_bias', False), + head_dim=getattr(config, 'head_dim', None), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + # `mlp_only_layers` in the config. + layer_idx = extract_layer_index(prefix) + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + self.use_aclgraph = (vllm_config is not None + and vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE + and not vllm_config.model_config.enforce_eager) + if (layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and + (layer_idx + 1) % config.decoder_sparse_step == 0): + if not self.use_aclgraph: + # FIXME: custom sparse moe block doesn't work with aclgraph. + self.mlp = CustomSparseMoeBlock(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen3MoeSparseMoeBlock(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + self.enable_sequence_parallelism = ( + vllm_config.compilation_config.pass_config. + enable_sequence_parallelism if vllm_config is not None else False) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + _metadata_for_padding: Optional[MetadataForPadding] = None, + ) -> torch.Tensor: + + # To prevent precision issues during the decoder phase when only prefilling enables SP + if not self.enable_sequence_parallelism: + self.self_attn.o_proj.reduce_results = True + else: + self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True + + # Self Attention + if residual is None: + residual = hidden_states + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + residual = _metadata_for_padding.padding_slice(residual) + + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + hidden_states = _metadata_for_padding.allgather_unpadding_aligned( + hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter( + hidden_states) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + if not self.use_aclgraph: + hidden_states = self.mlp( + hidden_states, _metadata_for_padding=_metadata_for_padding) + else: + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class CustomQwen3MoeModel(Qwen3MoeModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + parallel_config = vllm_config.parallel_config + self.num_redundant_experts = parallel_config.num_redundant_experts + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens") + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: CustomQwen3MoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + vllm_config=vllm_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + _metadata_for_padding: Optional[MetadataForPadding] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + residual, + kv_caches[i - + self.start_layer] if kv_caches is not None else None, + attn_metadata, + _metadata_for_padding=_metadata_for_padding) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + + if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: + hidden_states = _metadata_for_padding.allgather_unpadding_aligned( + hidden_states) + + return hidden_states + + +class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + SupportsPP.__init__(self) + SupportsLoRA.__init__(self) + MixtureOfExperts.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = CustomQwen3MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head")) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism + # Set MoE hyperparameters + self.expert_weights: list[torch.Tensor] = [] + + self.moe_layers: list[FusedMoE] = [] + example_layer = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, Qwen3MoeDecoderLayer) + if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): + example_layer = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_layer is None: + raise RuntimeError("No Qwen3MoE layer found in the model.layers.") + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + _metadata_for_padding = init_metadata_for_sp( + input_ids, self.enable_sequence_parallelism) + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds, _metadata_for_padding) + return hidden_states diff --git a/vllm_ascend/torchair/torchair_attention.py b/vllm_ascend/torchair/torchair_attention.py index a3fda61036..46ee794dcd 100644 --- a/vllm_ascend/torchair/torchair_attention.py +++ b/vllm_ascend/torchair/torchair_attention.py @@ -325,8 +325,9 @@ def forward( shape = [batch_size * seq_len, num_heads, head_size] """ num_tokens = query.shape[0] - use_kv_cache_quant = kv_cache is not None and kv_cache[0].numel( - ) > 0 and kv_cache[0].dtype == torch.int8 + use_kv_cache_quant = (kv_cache is not None and len(kv_cache) > 0 + and kv_cache[0].numel() > 0 + and kv_cache[0].dtype == torch.int8) if output is None: output = torch.empty(num_tokens, self.num_heads, diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py index 0a94494cb2..5b8b393322 100644 --- a/vllm_ascend/torchair/utils.py +++ b/vllm_ascend/torchair/utils.py @@ -115,3 +115,11 @@ def register_torchair_model(): "DeepseekV3ForCausalLM", "vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM" ) + + ModelRegistry.register_model( + "Qwen2ForCausalLM", + "vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM") + + ModelRegistry.register_model( + "Qwen3ForCausalLM", + "vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM")