-
Notifications
You must be signed in to change notification settings - Fork 465
[main] Fix AddRMSNormW8A8Quant init bug #2440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for the Gemma3 model on Ascend NPUs, including performance optimizations for the gemmarmsnorm
operator and a fix for an initialization bug in AddRMSNormW8A8Quant
. The changes are well-structured, adding the necessary model definitions, registration, and a test case.
However, I've found a critical issue in the new AscendGemma3DecoderLayer
implementation where an incorrect layer is used for quantization fusion in the MLP block, likely due to a copy-paste error. This would lead to incorrect model execution. A fix is suggested in the detailed comments.
vllm_ascend/models/gemma3.py
Outdated
AscendW8A8LinearMethod): | ||
self.pre_feedforward_layernorm = AddRMSNormW8A8Quant( | ||
config.hidden_size, | ||
layer=self.self_attn.qkv_proj, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a copy-paste error. The pre_feedforward_layernorm
is applied just before the MLP block. For the AddRMSNormW8A8Quant
fusion to work correctly, it needs to be linked with the subsequent linear layer, which is self.mlp.gate_up_proj
, not self.self_attn.qkv_proj
.
layer=self.self_attn.qkv_proj, | |
layer=self.mlp.gate_up_proj, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
Please briefly describe the optimization principles, solutions, and the optimization results. |
Before optimizing,the rmsnorm time in one decoding is 531.5us. After optimizing,the rmsnorm time in one decoding is 105us. |
vllm_ascend/ops/layernorm.py
Outdated
dtype: Optional[torch.dtype] = None, | ||
) -> None: | ||
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) | ||
super().__init__(hidden_size=hidden_size, eps=eps, var_hidden_size=var_hidden_size, has_weight=has_weight, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest splitting this PR into two PRs: one focused on the bugfix, and the other focused on the new model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I delete new model part
Maybe we should modify the def wrapper_rmsnorm_init(func):
- def init(self, hidden_size: int, **extra_args) -> None:
- func(self, hidden_size, **extra_args)
+ def init(self, hidden_size: int, *args, **kwargs) -> None:
+ func(self, hidden_size, *args, **kwargs)
self.ignore_anti = True
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
return init |
Codecov Report❌ Patch coverage is
❌ Your patch check has failed because the patch coverage (50.00%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #2440 +/- ##
==========================================
+ Coverage 76.18% 77.71% +1.52%
==========================================
Files 120 132 +12
Lines 13532 17520 +3988
==========================================
+ Hits 10310 13615 +3305
- Misses 3222 3905 +683
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
…gemmarmsnorm operator of the gemma3 model on NPU
What this PR does / why we need it?
Fix AddRMSNormW8A8Quant init bug and optimize the performance of the gemmarmsnorm operator of the gemma3 model on NPU
Before fixing bug,it will raise error "TypeError: wrapper_rmsnorm_init..init() takes 2 positional arguments but 6 were given". After fixing, it can run smoothly.
Does this PR introduce any user-facing change?
No
How was this patch tested?
Test by running the gemma3 model