Skip to content

Commit 0b0946c

Browse files
committed
fix AddRMSNormW8A8Quant init error 'TypeError:wrapper_rmsnorm_init..init() takes 2 positional arguments but 6 were given'
Signed-off-by: socrahow <[email protected]>
1 parent 527a3af commit 0b0946c

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

vllm_ascend/models/gemma3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(
7979
AscendW8A8LinearMethod):
8080
self.pre_feedforward_layernorm = AddRMSNormW8A8Quant(
8181
config.hidden_size,
82-
layer=self.self_attn.qkv_proj,
82+
layer=self.mlp.gate_up_proj,
8383
eps=config.rms_norm_eps
8484
)
8585

vllm_ascend/ops/layernorm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ def __init__(
3333
has_weight: bool = True,
3434
dtype: Optional[torch.dtype] = None,
3535
) -> None:
36-
super().__init__(hidden_size=hidden_size, eps=eps, var_hidden_size=var_hidden_size, has_weight=has_weight, dtype=dtype)
36+
super().__init__(hidden_size=hidden_size,
37+
eps=eps,
38+
var_hidden_size=var_hidden_size,
39+
has_weight=has_weight,
40+
dtype=dtype)
3741
self.layer = layer
3842

3943
def forward(
@@ -59,14 +63,12 @@ def forward(
5963

6064

6165
class AscendRMSNorm(RMSNorm):
62-
6366
def forward_oot(
6467
self,
6568
x: torch.Tensor,
6669
residual: Optional[torch.Tensor] = None,
6770
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
6871
import torch_npu
69-
7072
from vllm_ascend.utils import is_310p
7173
if residual is not None:
7274
if is_310p():

0 commit comments

Comments
 (0)