-
Notifications
You must be signed in to change notification settings - Fork 386
qwen3_moe/qwen25 support torchair graph #2403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for DeepSeek models on the torchair backend, involving significant refactoring and the addition of new model implementation files. My review has identified two critical bugs in tensor padding logic which could lead to runtime errors in distributed environments. Additionally, there is a high-severity issue with an incorrect type hint in a method signature that could cause confusion and runtime errors. The remaining refactoring of attention backends and model registration appears to be in order.
if is_force_scatter and num_tokens % self.tp_size: | ||
output_parallel = nn.functional.pad( | ||
output_parallel, (0, 0, 0, -num_tokens % self.tp_size)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The padding logic here is incorrect. nn.functional.pad
does not support negative values for adding padding; negative values are for cropping. The expression -num_tokens % self.tp_size
will be negative if num_tokens
is not divisible by self.tp_size
, which will likely cause a runtime error. The intention is to pad the tensor to have a number of tokens divisible by self.tp_size
for the reduce_scatter
operation. The correct padding amount should be (self.tp_size - num_tokens % self.tp_size) % self.tp_size
.
if is_force_scatter and num_tokens % self.tp_size:
pad_size = (self.tp_size - num_tokens % self.tp_size) % self.tp_size
output_parallel = nn.functional.pad(
output_parallel, (0, 0, 0, pad_size))
if num_tokens % tp_size: | ||
residual = nn.functional.pad(residual, | ||
(0, 0, 0, -num_tokens % tp_size)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The padding logic is incorrect. nn.functional.pad
with a negative padding value will either raise an error or crop the tensor, which is not the intended behavior. The goal is to pad residual
so its number of tokens is divisible by tp_size
. The correct padding amount should be calculated as (tp_size - num_tokens % tp_size) % tp_size
.
if num_tokens % tp_size:
pad_size = (tp_size - num_tokens % tp_size) % tp_size
residual = nn.functional.pad(residual,
(0, 0, 0, pad_size))
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
kv_caches: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hint for kv_caches
is torch.Tensor
, but it is used as a list (kv_caches[current_step_idx]
). This contradicts the usage, the base class DeepSeekMultiTokenPredictor
which defines it as List[torch.Tensor]
, and the caller TorchairDeepSeekMTP
which passes an Optional[List[torch.Tensor]]
. To prevent potential runtime errors and improve code clarity, the type hint should be corrected to Optional[List[torch.Tensor]]
.
kv_caches: torch.Tensor, | |
kv_caches: Optional[List[torch.Tensor]], |
56ac826
to
3819a90
Compare
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
c25c9a3
to
dcf0383
Compare
Signed-off-by: taoyuxiang <[email protected]>
I'll merge this in quick once lint CI passed |
What this PR does / why we need it?
Added support for the TorchAir graph mode in qwen3_moe and qwen2.5
Does this PR introduce any user-facing change?
No
How was this patch tested?