|  | 
| 15 | 15 | # limitations under the License. | 
| 16 | 16 | # This file is a part of the vllm-ascend project. | 
| 17 | 17 | 
 | 
|  | 18 | +import unittest | 
|  | 19 | +from unittest import mock | 
|  | 20 | + | 
| 18 | 21 | import pytest | 
|  | 22 | +import torch | 
| 19 | 23 | from pytest_mock import MockerFixture | 
| 20 | 24 | 
 | 
| 21 | 25 | from tests.ut.base import PytestBase | 
| 22 | 26 | from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( | 
| 23 |  | -    MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) | 
|  | 27 | +    AscendSocVersion, MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig, | 
|  | 28 | +    TokenDispatcherWithAllGather, TokenDispatcherWithMC2) | 
| 24 | 29 | from vllm_ascend.utils import adapt_patch  # noqa E402 | 
| 25 | 30 | 
 | 
| 26 | 31 | 
 | 
| @@ -63,3 +68,289 @@ def test_initialization(self, dispatcher, config): | 
| 63 | 68 |         assert dispatcher.ep_rank == 0 | 
| 64 | 69 |         assert dispatcher.ep_size == 2 | 
| 65 | 70 |         assert dispatcher.overlap_stream is not None | 
|  | 71 | + | 
|  | 72 | + | 
|  | 73 | +class TestTokenDispatcherWithMC2(unittest.TestCase): | 
|  | 74 | + | 
|  | 75 | +    def setUp(self): | 
|  | 76 | +        self.mc2_group = mock.MagicMock() | 
|  | 77 | +        self.mc2_group.device_group.return_value._get_backend.return_value.get_hccl_comm_name.return_value = "hccl_123" | 
|  | 78 | +        self.mc2_group.rank_in_group = 0 | 
|  | 79 | +        self.mc2_group.world_size = 8 | 
|  | 80 | +        self.mc2_group_patch = mock.patch( | 
|  | 81 | +            "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group", | 
|  | 82 | +            return_value=self.mc2_group) | 
|  | 83 | +        self.mc2_group_patch.start() | 
|  | 84 | + | 
|  | 85 | +        self.rank_group_patch = mock.patch("torch.distributed.get_rank", | 
|  | 86 | +                                           return_value=0) | 
|  | 87 | +        self.rank_group_patch.start() | 
|  | 88 | + | 
|  | 89 | +        # Mock get_forward_context().mc2_mask | 
|  | 90 | +        self.forward_context = mock.MagicMock() | 
|  | 91 | +        self.forward_context.mc2_mask = torch.tensor([1, 0, 1]) | 
|  | 92 | +        self.forward_context_patch = mock.patch( | 
|  | 93 | +            "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_forward_context", | 
|  | 94 | +            return_value=self.forward_context) | 
|  | 95 | +        self.forward_context_patch.start() | 
|  | 96 | + | 
|  | 97 | +        # Mock get_ascend_soc_version() | 
|  | 98 | +        self.ascend_soc_version_patch = mock.patch( | 
|  | 99 | +            "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version", | 
|  | 100 | +            return_value=AscendSocVersion.A3) | 
|  | 101 | +        self.ascend_soc_version_patch.start() | 
|  | 102 | + | 
|  | 103 | +        # Mock get_ascend_config() | 
|  | 104 | +        self.ascend_config = mock.MagicMock() | 
|  | 105 | +        self.ascend_config.torchair_graph_config.enabled = False | 
|  | 106 | +        self.ascend_config_patch = mock.patch( | 
|  | 107 | +            "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_config", | 
|  | 108 | +            return_value=self.ascend_config) | 
|  | 109 | +        self.ascend_config_patch.start() | 
|  | 110 | + | 
|  | 111 | +        kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128} | 
|  | 112 | +        self.dispatcher = TokenDispatcherWithMC2(**kwargs) | 
|  | 113 | + | 
|  | 114 | +    def tearDown(self): | 
|  | 115 | +        self.mc2_group_patch.stop() | 
|  | 116 | +        self.forward_context_patch.stop() | 
|  | 117 | +        self.ascend_soc_version_patch.stop() | 
|  | 118 | +        self.ascend_config_patch.stop() | 
|  | 119 | + | 
|  | 120 | +    def test_init(self): | 
|  | 121 | +        # self.assertEqual(self.dispatcher.moe_all_to_all_group_name, "hccl_123") | 
|  | 122 | +        self.assertEqual(self.dispatcher.ep_rank_id, 0) | 
|  | 123 | +        self.assertEqual(self.dispatcher.ep_world_size, 8) | 
|  | 124 | +        self.assertFalse(self.dispatcher.torchair_graph_enabled) | 
|  | 125 | +        self.assertFalse(self.dispatcher.with_quant) | 
|  | 126 | +        self.assertTrue(self.dispatcher.enable_dispatch_v2) | 
|  | 127 | +        self.assertTrue(self.dispatcher.need_extra_args) | 
|  | 128 | +        self.assertTrue(self.dispatcher.a3_need_extra_args) | 
|  | 129 | + | 
|  | 130 | +    def test_get_permute_mc2_kwargs_without_quant(self): | 
|  | 131 | +        hidden_states = torch.randn(10, 128) | 
|  | 132 | +        topk_ids = torch.randint(0, 8, (10, 1)) | 
|  | 133 | +        topk_weights = torch.randn(10, 1) | 
|  | 134 | +        expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) | 
|  | 135 | + | 
|  | 136 | +        kwargs = self.dispatcher.get_permute_mc2_kwargs( | 
|  | 137 | +            hidden_states, topk_weights, topk_ids, expert_map) | 
|  | 138 | +        self.assertIn("x", kwargs) | 
|  | 139 | +        self.assertIn("expert_ids", kwargs) | 
|  | 140 | +        self.assertEqual(kwargs["moe_expert_num"], 8) | 
|  | 141 | + | 
|  | 142 | +    def test_token_permutation_dispatch(self): | 
|  | 143 | +        hidden_states = torch.randn(10, 128) | 
|  | 144 | +        topk_weights = torch.randn(10, 1) | 
|  | 145 | +        topk_ids = torch.randint(0, 8, (10, 1)) | 
|  | 146 | +        expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) | 
|  | 147 | + | 
|  | 148 | +        with mock.patch("torch_npu.npu_moe_distribute_dispatch_v2", | 
|  | 149 | +                        return_value=(torch.randn(10, 128), ) * | 
|  | 150 | +                        5) as mock_dispatch: | 
|  | 151 | +            output = self.dispatcher.token_permutation(hidden_states, | 
|  | 152 | +                                                       topk_weights, topk_ids, | 
|  | 153 | +                                                       expert_map) | 
|  | 154 | +            mock_dispatch.assert_called_once() | 
|  | 155 | +            self.assertEqual(output[0], 1)  # group_list_type == 1 | 
|  | 156 | + | 
|  | 157 | +    def test_token_permutation_with_shared_experts_and_quant(self): | 
|  | 158 | +        self.shared_experts = mock.MagicMock() | 
|  | 159 | +        self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128), | 
|  | 160 | +                                                         torch.tensor(1.0)) | 
|  | 161 | +        self.shared_experts.act_fn.return_value = torch.randn(10, 128) | 
|  | 162 | +        self.dispatcher.with_quant = False | 
|  | 163 | +        self.dispatcher.shared_act = torch.randn(10, 128) | 
|  | 164 | +        self.dispatcher.swiglu_out_scale = torch.tensor(1.0) | 
|  | 165 | +        self.hidden_states = torch.randn(10, 128) | 
|  | 166 | +        self.topk_weights = torch.randn(10, 1) | 
|  | 167 | + | 
|  | 168 | +        with mock.patch("torch_npu.npu_moe_distribute_dispatch_v2", | 
|  | 169 | +                        return_value=(torch.randn(10, 128), ) * 5): | 
|  | 170 | +            with mock.patch( | 
|  | 171 | +                    "vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch", | 
|  | 172 | +                    autospec=True): | 
|  | 173 | +                with mock.patch( | 
|  | 174 | +                        "vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor", | 
|  | 175 | +                        autospec=True) as mock_wait: | 
|  | 176 | +                    self.dispatcher.token_permutation( | 
|  | 177 | +                        self.hidden_states, | 
|  | 178 | +                        self.topk_weights, | 
|  | 179 | +                        torch.randint(0, 8, (10, 1)), | 
|  | 180 | +                        torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), | 
|  | 181 | +                        shared_experts=self.shared_experts) | 
|  | 182 | +                    mock_wait.assert_any_call(self.hidden_states, | 
|  | 183 | +                                              self.topk_weights) | 
|  | 184 | + | 
|  | 185 | +    def test_get_unpermute_mc_kwargs_with_quant(self): | 
|  | 186 | +        self.dispatcher.with_quant = True | 
|  | 187 | +        hidden_states = torch.randn(10, 128) | 
|  | 188 | +        self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1)) | 
|  | 189 | +        self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1)) | 
|  | 190 | +        self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) | 
|  | 191 | +        self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) | 
|  | 192 | +        self.dispatcher.need_extra_args = True | 
|  | 193 | +        self.dispatcher.enable_dispatch_v2 = True | 
|  | 194 | +        self.dispatcher.output = torch.randint(0, 8, (10, 1)) | 
|  | 195 | + | 
|  | 196 | +        kwargs = self.dispatcher.get_unpermute_mc_kwargs(hidden_states) | 
|  | 197 | +        self.assertIn("tp_send_counts", kwargs) | 
|  | 198 | + | 
|  | 199 | +    def test_token_unpermutation_with_shared_experts(self): | 
|  | 200 | +        self.dispatcher.shared_experts = mock.MagicMock() | 
|  | 201 | +        self.dispatcher.shared_experts.down_proj.return_value = (torch.randn( | 
|  | 202 | +            10, 128), torch.tensor(1.0)) | 
|  | 203 | +        self.dispatcher.shared_act = torch.randn(10, 128) | 
|  | 204 | +        self.dispatcher.with_quant = True | 
|  | 205 | +        self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1)) | 
|  | 206 | +        self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1)) | 
|  | 207 | +        self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) | 
|  | 208 | +        self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) | 
|  | 209 | +        self.dispatcher.need_extra_args = True | 
|  | 210 | +        self.dispatcher.enable_dispatch_v2 = True | 
|  | 211 | +        self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1)) | 
|  | 212 | +        self.dispatcher.output = torch.randint(0, 8, (10, 1)) | 
|  | 213 | +        self.hidden_states = torch.randn(10, 128) | 
|  | 214 | + | 
|  | 215 | +        with mock.patch("torch_npu.npu_moe_distribute_combine_v2", | 
|  | 216 | +                        return_value=torch.randn(10, 128)): | 
|  | 217 | +            with mock.patch( | 
|  | 218 | +                    "vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch", | 
|  | 219 | +                    autospec=True): | 
|  | 220 | +                with mock.patch( | 
|  | 221 | +                        "vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor", | 
|  | 222 | +                        autospec=True): | 
|  | 223 | +                    self.dispatcher.token_unpermutation(self.hidden_states) | 
|  | 224 | + | 
|  | 225 | + | 
|  | 226 | +class TestTokenDispatcherWithAllGather(unittest.TestCase): | 
|  | 227 | + | 
|  | 228 | +    def setUp(self): | 
|  | 229 | +        # Mock dependencies | 
|  | 230 | +        kwargs = { | 
|  | 231 | +            "apply_router_weight_on_input": False, | 
|  | 232 | +            "top_k": 2, | 
|  | 233 | +            "max_num_tokens": 100, | 
|  | 234 | +            "ep_size": 2, | 
|  | 235 | +            "num_experts": 128, | 
|  | 236 | +            "with_quant": False, | 
|  | 237 | +        } | 
|  | 238 | +        self.dispatcher = TokenDispatcherWithAllGather(**kwargs) | 
|  | 239 | + | 
|  | 240 | +        # Mock NPU functions | 
|  | 241 | +        self.patcher_moe_init_routing = mock.patch( | 
|  | 242 | +            'torch_npu.npu_moe_init_routing') | 
|  | 243 | +        self.mock_moe_init_routing = self.patcher_moe_init_routing.start() | 
|  | 244 | +        self.mock_moe_init_routing.return_value = ( | 
|  | 245 | +            torch.randn(6, 128),  # sorted_hidden_states | 
|  | 246 | +            torch.tensor([0, 1, 2, 3, 4, 5]),  # expanded_row_idx | 
|  | 247 | +            torch.tensor([0, 1, 0, 1, 0, 1])  # expanded_expert_idx | 
|  | 248 | +        ) | 
|  | 249 | + | 
|  | 250 | +        self.patcher_moe_compute_expert_tokens = mock.patch( | 
|  | 251 | +            'torch_npu.npu_moe_compute_expert_tokens') | 
|  | 252 | +        self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start( | 
|  | 253 | +        ) | 
|  | 254 | +        self.mock_moe_compute_expert_tokens.return_value = torch.tensor( | 
|  | 255 | +            [3, 3])  # expert_tokens | 
|  | 256 | + | 
|  | 257 | +        self.patcher_moe_finalize_routing = mock.patch( | 
|  | 258 | +            'torch_npu.npu_moe_finalize_routing') | 
|  | 259 | +        self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start( | 
|  | 260 | +        ) | 
|  | 261 | +        self.mock_moe_finalize_routing.return_value = torch.randn(3, 128) | 
|  | 262 | + | 
|  | 263 | +    def tearDown(self): | 
|  | 264 | +        self.patcher_moe_init_routing.stop() | 
|  | 265 | +        self.patcher_moe_compute_expert_tokens.stop() | 
|  | 266 | +        self.patcher_moe_finalize_routing.stop() | 
|  | 267 | + | 
|  | 268 | +    def test_token_permutation_without_expert_map(self): | 
|  | 269 | +        hidden_states = torch.randn(3, 128) | 
|  | 270 | +        topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) | 
|  | 271 | +        topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) | 
|  | 272 | + | 
|  | 273 | +        group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_permutation( | 
|  | 274 | +            hidden_states, topk_weights, topk_ids, None) | 
|  | 275 | + | 
|  | 276 | +        # Verify npu_moe_init_routing is called | 
|  | 277 | +        self.mock_moe_init_routing.assert_called_once() | 
|  | 278 | +        args, kwargs = self.mock_moe_init_routing.call_args | 
|  | 279 | + | 
|  | 280 | +        self.assertEqual(group_list_type, 0) | 
|  | 281 | + | 
|  | 282 | +    def test_token_permutation_with_quant(self): | 
|  | 283 | +        kwargs = { | 
|  | 284 | +            "apply_router_weight_on_input": False, | 
|  | 285 | +            "top_k": 2, | 
|  | 286 | +            "max_num_tokens": 100, | 
|  | 287 | +            "ep_size": 2, | 
|  | 288 | +            "num_experts": 128, | 
|  | 289 | +            "with_quant": True, | 
|  | 290 | +        } | 
|  | 291 | +        self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs) | 
|  | 292 | + | 
|  | 293 | +        hidden_states = torch.randn(3, 128) | 
|  | 294 | +        topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) | 
|  | 295 | +        topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) | 
|  | 296 | + | 
|  | 297 | +        group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher_quant.token_permutation( | 
|  | 298 | +            hidden_states, topk_weights, topk_ids, None) | 
|  | 299 | + | 
|  | 300 | +        # Verify quant mode returns group_list_type=1 | 
|  | 301 | +        self.assertEqual(group_list_type, 0) | 
|  | 302 | + | 
|  | 303 | +    def test_token_unpermutation_with_expert_map(self): | 
|  | 304 | +        self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3]) | 
|  | 305 | +        self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1]) | 
|  | 306 | +        self.dispatcher.sorted_weights = torch.tensor( | 
|  | 307 | +            [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]) | 
|  | 308 | +        self.dispatcher.original_shape = (3, 128) | 
|  | 309 | +        self.dispatcher.mask = torch.tensor([0, 1, 1, 0]) | 
|  | 310 | +        hidden_states = torch.randn(6, 128) | 
|  | 311 | + | 
|  | 312 | +        final_hidden_states = self.dispatcher.token_unpermutation( | 
|  | 313 | +            hidden_states) | 
|  | 314 | + | 
|  | 315 | +        # Verify index_add_ is applied correctly | 
|  | 316 | +        self.assertEqual(final_hidden_states.shape, (3, 128)) | 
|  | 317 | + | 
|  | 318 | +    def test_token_unpermutation_without_expert_map(self): | 
|  | 319 | +        self.dispatcher.with_quant = False | 
|  | 320 | +        self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1]) | 
|  | 321 | +        self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) | 
|  | 322 | +        self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1]) | 
|  | 323 | +        self.dispatcher.sorted_weights = torch.tensor( | 
|  | 324 | +            [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]) | 
|  | 325 | +        self.dispatcher.original_shape = (3, 128) | 
|  | 326 | +        self.dispatcher.mask = torch.tensor([0, 1, 1, 0]) | 
|  | 327 | +        hidden_states = torch.randn(6, 128) | 
|  | 328 | + | 
|  | 329 | +        final_hidden_states = self.dispatcher.token_unpermutation( | 
|  | 330 | +            hidden_states) | 
|  | 331 | + | 
|  | 332 | +        # Verify npu_moe_finalize_routing is called | 
|  | 333 | +        self.mock_moe_finalize_routing.assert_called_once() | 
|  | 334 | +        args, kwargs = self.mock_moe_finalize_routing.call_args | 
|  | 335 | + | 
|  | 336 | +        self.assertEqual(final_hidden_states.shape, (3, 128)) | 
|  | 337 | + | 
|  | 338 | +    def test_token_permutation_with_router_weight(self): | 
|  | 339 | +        self.dispatcher.apply_router_weight_on_input = True | 
|  | 340 | +        hidden_states = torch.randn(3, 128) | 
|  | 341 | +        topk_weights = torch.tensor([[0.7], [0.6], [0.5]])  # topk=1 | 
|  | 342 | +        topk_ids = torch.tensor([[0], [1], [2]]) | 
|  | 343 | + | 
|  | 344 | +        group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_permutation( | 
|  | 345 | +            hidden_states, topk_weights, topk_ids, None) | 
|  | 346 | +        self.assertEqual(sorted_hidden_states.shape, (6, 128)) | 
|  | 347 | + | 
|  | 348 | +    def test_token_permutation_invalid_topk_when_router_weight(self): | 
|  | 349 | +        self.dispatcher.apply_router_weight_on_input = True | 
|  | 350 | +        hidden_states = torch.randn(3, 128) | 
|  | 351 | +        topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) | 
|  | 352 | + | 
|  | 353 | +        with self.assertRaises(AssertionError): | 
|  | 354 | +            self.dispatcher.token_permutation( | 
|  | 355 | +                hidden_states, topk_weights, | 
|  | 356 | +                torch.tensor([[0, 1], [1, 2], [2, 3]]), None) | 
0 commit comments