Skip to content

Commit 0f81e03

Browse files
authored
[1/N][refactor] torchair fused_moe refactor (#2438)
### What this PR does / why we need it? Move torchair related fused_moe section into torchair_fused_moe to make the code clear. Next step we'll remove all torchair related code outside of torchair_fused_moe . ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? vLLM version: v0.10.0 vLLM main: vllm-project/vllm@08d5f71 - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@170e8ea Signed-off-by: hust17yixuan <[email protected]>
1 parent 334c446 commit 0f81e03

File tree

5 files changed

+1974
-6
lines changed

5 files changed

+1974
-6
lines changed

tests/ut/torchair/models/test_torchair_deepseek_v2.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def mock_distributed():
112112
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group", return_value=pp_group), \
113113
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group",
114114
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
115-
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
115+
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
116116
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
117117
_PP=pp_group), \
118118
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group):
@@ -227,8 +227,9 @@ def test_torchair_deepseek_v2_moe(mock_distributed, base_config,
227227

228228
x = torch.randn(2, 4, 128)
229229
attn_metadata = Mock(num_prefills=1)
230-
with patch("vllm_ascend.ops.fused_moe.AscendFusedMoE.__call__",
231-
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
230+
with patch(
231+
"vllm_ascend.torchair.ops.torchair_fused_moe.TorchairAscendFusedMoE.__call__",
232+
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
232233
output = moe(x, attn_metadata)
233234
assert output.shape == (2, 4, 128)
234235

0 commit comments

Comments
 (0)