Skip to content

Conversation

NicholasTao
Copy link
Contributor

@NicholasTao NicholasTao commented Aug 17, 2025

What this PR does / why we need it?

Added support for the TorchAir graph mode in qwen3_moe and qwen2.5

Does this PR introduce any user-facing change?

No

How was this patch tested?

llm = LLM(
    model=model,
    tensor_parallel_size=GPUs_per_dp_rank,
    enforce_eager=False,
    enable_expert_parallel=True,
    max_model_len=4096,
    max_num_seqs=16,
    trust_remote_code=trust_remote_code,
    gpu_memory_utilization=0.4,
    additional_config={
             "torchair_graph_config": {
                 "enabled": True,
                 "use_cached_graph": False,
                 "graph_batch_sizes_init": False,
                 "graph_batch_sizes": [16]
             },
             "ascend_scheduler_config": {
                 "enabled": True,
                 "chunked_prefill_enabled":True,
             },
             "refresh": True,
    },
)

Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for DeepSeek models on the torchair backend, involving significant refactoring and the addition of new model implementation files. My review has identified two critical bugs in tensor padding logic which could lead to runtime errors in distributed environments. Additionally, there is a high-severity issue with an incorrect type hint in a method signature that could cause confusion and runtime errors. The remaining refactoring of attention backends and model registration appears to be in order.

Comment on lines 165 to 167
if is_force_scatter and num_tokens % self.tp_size:
output_parallel = nn.functional.pad(
output_parallel, (0, 0, 0, -num_tokens % self.tp_size))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The padding logic here is incorrect. nn.functional.pad does not support negative values for adding padding; negative values are for cropping. The expression -num_tokens % self.tp_size will be negative if num_tokens is not divisible by self.tp_size, which will likely cause a runtime error. The intention is to pad the tensor to have a number of tokens divisible by self.tp_size for the reduce_scatter operation. The correct padding amount should be (self.tp_size - num_tokens % self.tp_size) % self.tp_size.

            if is_force_scatter and num_tokens % self.tp_size:
                pad_size = (self.tp_size - num_tokens % self.tp_size) % self.tp_size
                output_parallel = nn.functional.pad(
                    output_parallel, (0, 0, 0, pad_size))

Comment on lines 769 to 771
if num_tokens % tp_size:
residual = nn.functional.pad(residual,
(0, 0, 0, -num_tokens % tp_size))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The padding logic is incorrect. nn.functional.pad with a negative padding value will either raise an error or crop the tensor, which is not the intended behavior. The goal is to pad residual so its number of tokens is divisible by tp_size. The correct padding amount should be calculated as (tp_size - num_tokens % tp_size) % tp_size.

            if num_tokens % tp_size:
                pad_size = (tp_size - num_tokens % tp_size) % tp_size
                residual = nn.functional.pad(residual,
                                             (0, 0, 0, pad_size))

self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The type hint for kv_caches is torch.Tensor, but it is used as a list (kv_caches[current_step_idx]). This contradicts the usage, the base class DeepSeekMultiTokenPredictor which defines it as List[torch.Tensor], and the caller TorchairDeepSeekMTP which passes an Optional[List[torch.Tensor]]. To prevent potential runtime errors and improve code clarity, the type hint should be corrected to Optional[List[torch.Tensor]].

Suggested change
kv_caches: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]],

@NicholasTao NicholasTao force-pushed the tg816raw branch 6 times, most recently from 56ac826 to 3819a90 Compare August 18, 2025 23:10
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@NicholasTao NicholasTao changed the title Tg816raw qwen3_moe/qwen25 support torchair graph Aug 19, 2025
@NicholasTao NicholasTao force-pushed the tg816raw branch 2 times, most recently from c25c9a3 to dcf0383 Compare August 20, 2025 02:09
@wangxiyuan
Copy link
Collaborator

I'll merge this in quick once lint CI passed

@wangxiyuan wangxiyuan merged commit 7bec1a9 into vllm-project:main Aug 20, 2025
18 of 19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants