Skip to content

Commit e97ad15

Browse files
committed
refatcor torchair fused_moe 1/N
Signed-off-by: hust17yixuan <[email protected]> Signed-off-by: hust17yixuan <[email protected]> Signed-off-by: hust17yixuan <[email protected]> Signed-off-by: hust17yixuan <[email protected]> Signed-off-by: hust17yixuan <[email protected]> Signed-off-by: hust17yixuan <[email protected]> Signed-off-by: hust17yixuan <[email protected]> Signed-off-by: hust17yixuan <[email protected]> Signed-off-by: hust17yixuan <[email protected]> Signed-off-by: hust17yixuan <[email protected]> Signed-off-by: hust17yixuan <[email protected]>
1 parent 3629bc4 commit e97ad15

File tree

5 files changed

+1974
-6
lines changed

5 files changed

+1974
-6
lines changed

tests/ut/torchair/models/test_torchair_deepseek_v2.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def mock_distributed():
112112
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group", return_value=pp_group), \
113113
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group",
114114
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
115-
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
115+
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
116116
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
117117
_PP=pp_group), \
118118
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group):
@@ -227,8 +227,9 @@ def test_torchair_deepseek_v2_moe(mock_distributed, base_config,
227227

228228
x = torch.randn(2, 4, 128)
229229
attn_metadata = Mock(num_prefills=1)
230-
with patch("vllm_ascend.ops.fused_moe.AscendFusedMoE.__call__",
231-
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
230+
with patch(
231+
"vllm_ascend.torchair.ops.torchair_fused_moe.TorchairAscendFusedMoE.__call__",
232+
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
232233
output = moe(x, attn_metadata)
233234
assert output.shape == (2, 4, 128)
234235

0 commit comments

Comments
 (0)