@@ -112,7 +112,7 @@ def mock_distributed():
112
112
patch ("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group" , return_value = pp_group ), \
113
113
patch ("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group" ,
114
114
return_value = Mock (is_first_rank = False , is_last_rank = False )), \
115
- patch ("vllm_ascend.ops.fused_moe .get_current_vllm_config" , return_value = mock_vllm_config ), \
115
+ patch ("vllm_ascend.torchair. ops.torchair_fused_moe .get_current_vllm_config" , return_value = mock_vllm_config ), \
116
116
patch .dict ("vllm.distributed.parallel_state.__dict__" , _TP = tp_group , _EP = ep_group , _DP = dp_group ,
117
117
_PP = pp_group ), \
118
118
patch .dict ("vllm_ascend.distributed.parallel_state.__dict__" , _MC2 = ep_group ):
@@ -227,8 +227,9 @@ def test_torchair_deepseek_v2_moe(mock_distributed, base_config,
227
227
228
228
x = torch .randn (2 , 4 , 128 )
229
229
attn_metadata = Mock (num_prefills = 1 )
230
- with patch ("vllm_ascend.ops.fused_moe.AscendFusedMoE.__call__" ,
231
- return_value = (torch .randn (2 , 4 , 128 ), torch .randn (2 , 4 , 128 ))):
230
+ with patch (
231
+ "vllm_ascend.torchair.ops.torchair_fused_moe.TorchairAscendFusedMoE.__call__" ,
232
+ return_value = (torch .randn (2 , 4 , 128 ), torch .randn (2 , 4 , 128 ))):
232
233
output = moe (x , attn_metadata )
233
234
assert output .shape == (2 , 4 , 128 )
234
235
0 commit comments