|
1 | 1 | from unittest.mock import MagicMock, patch
|
2 | 2 |
|
3 | 3 | import torch
|
| 4 | +from torch import nn |
4 | 5 | from vllm.distributed.parallel_state import GroupCoordinator
|
5 | 6 | from vllm.model_executor.layers.linear import LinearBase
|
6 | 7 |
|
7 | 8 | from tests.ut.base import TestBase
|
8 | 9 | from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
| 10 | +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata |
9 | 11 | from vllm_ascend.torchair.torchair_mla import (
|
10 | 12 | AscendMLATorchairBackend, AscendMLATorchairDecodeMetadata,
|
11 | 13 | AscendMLATorchairImpl, AscendMLATorchairMetadata,
|
@@ -398,6 +400,68 @@ def test_build_dummy(self, mock_ascend_config):
|
398 | 400 | assert torch.equal(sin_golden, metadata.decode.sin)
|
399 | 401 | assert torch.equal(cos_golden, metadata.decode.cos)
|
400 | 402 |
|
| 403 | + @patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") |
| 404 | + def test_build_decode(self, mock_ascend_config): |
| 405 | + ascend_config = MagicMock() |
| 406 | + mock_ascend_config.return_value = ascend_config |
| 407 | + ascend_config.torchair_graph_config.enabled = False |
| 408 | + |
| 409 | + mock_vllm_config = MagicMock() |
| 410 | + mock_vllm_config.model_config.max_model_len = 1024 |
| 411 | + mock_vllm_config.cache_config.block_size = 16 |
| 412 | + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False |
| 413 | + mock_vllm_config.get_head_size.return_value = 64 |
| 414 | + mock_vllm_config.model_config.dtype = torch.float16 |
| 415 | + mock_device = 'cpu' |
| 416 | + model = MagicMock(spec=nn.Module) |
| 417 | + model.model = MagicMock(spec=nn.Module) |
| 418 | + |
| 419 | + builder = AscendMLATorchairMetadataBuilder( |
| 420 | + mock_vllm_config, |
| 421 | + mock_device, |
| 422 | + metadata_cls=AscendMLATorchairMetadata) |
| 423 | + builder.rope_dim = 64 |
| 424 | + |
| 425 | + builder.sin_cache = torch.tensor([10, 10]) |
| 426 | + builder.cos_cache = torch.tensor([10, 10]) |
| 427 | + |
| 428 | + with patch.object(builder, |
| 429 | + "_get_graph_runner_block_tables", |
| 430 | + side_effect=lambda x, y: y): |
| 431 | + common_attn_metadata = AscendCommonAttentionMetadata( |
| 432 | + query_start_loc=torch.tensor([0, 1, 2, 3]), |
| 433 | + query_start_loc_cpu=torch.tensor([0, 1, 2, 3]), |
| 434 | + seq_lens_cpu=torch.tensor([1, 1, 1]), |
| 435 | + num_reqs=3, |
| 436 | + num_actual_tokens=3, |
| 437 | + max_query_len=1, |
| 438 | + decode_token_per_req=torch.tensor([1, 1, 1]), |
| 439 | + block_table_tensor=torch.zeros((10, 10)), |
| 440 | + slot_mapping_cpu=torch.tensor(range(20)), |
| 441 | + actual_seq_lengths_q=torch.tensor([0, 1, 2]), |
| 442 | + positions=torch.tensor([1, 1]), |
| 443 | + attn_mask=torch.ones((15, 15)), |
| 444 | + spec_attn_mask=None, |
| 445 | + attn_state=AscendAttentionState.ChunkedPrefill) |
| 446 | + |
| 447 | + metadata = builder.build(common_attn_metadata, model) |
| 448 | + |
| 449 | + self.assertIsInstance(metadata, AscendMLATorchairMetadata) |
| 450 | + self.assertEqual(metadata.num_input_tokens, 0) |
| 451 | + self.assertEqual(metadata.num_actual_tokens, 3) |
| 452 | + self.assertEqual(metadata.num_decodes, 3) |
| 453 | + self.assertEqual(metadata.num_decode_tokens, 3) |
| 454 | + self.assertEqual(metadata.num_prefills, 0) |
| 455 | + self.assertEqual(metadata.attn_state, |
| 456 | + AscendAttentionState.ChunkedPrefill) |
| 457 | + self.assertIsNone(metadata.prefill) |
| 458 | + self.assertIsInstance(metadata.decode, AscendMLATorchairDecodeMetadata) |
| 459 | + self.assertEqual(metadata.block_tables.shape[0], 3) |
| 460 | + self.assertEqual(metadata.block_tables.shape[1], 10) |
| 461 | + self.assertEqual(metadata.seq_lens.shape[0], 3) |
| 462 | + self.assertEqual(metadata.slot_mapping.shape[0], 3) |
| 463 | + self.assertEqual(metadata.query_start_loc.shape[0], 4) |
| 464 | + |
401 | 465 |
|
402 | 466 | class TestAscendMLATorchairImpl(TestBase):
|
403 | 467 |
|
|
0 commit comments