From 2effecce8d34329d987d47e7b616cb6e15d40047 Mon Sep 17 00:00:00 2001 From: rjg-lyh <1318825571@qq.com> Date: Tue, 19 Aug 2025 18:28:37 +0800 Subject: [PATCH] [main][bugfix] Fix bugs and refactor cached mask generation logic Signed-off-by: rjg-lyh <1318825571@qq.com> --- tests/ut/attention/test_attention_mask.py | 121 +++++++++------------- vllm_ascend/attention/attention_mask.py | 81 +++++++-------- vllm_ascend/worker/eagle_proposer_v1.py | 4 +- vllm_ascend/worker/model_runner_v1.py | 27 +++-- 4 files changed, 97 insertions(+), 136 deletions(-) diff --git a/tests/ut/attention/test_attention_mask.py b/tests/ut/attention/test_attention_mask.py index 312604f7ff..a87d21bd74 100644 --- a/tests/ut/attention/test_attention_mask.py +++ b/tests/ut/attention/test_attention_mask.py @@ -28,23 +28,32 @@ def test_init_attention_mask_builder(self): self.assertEqual(attention_mask_builder._seq_len_cached, 1024) self.assertEqual(attention_mask_builder.attn_mask_cache.dtype, torch.float16) - self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000) self.assertEqual(attention_mask_builder.attn_mask_cache.shape, (1024, 1024)) self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], torch.tensor(float("-inf"), dtype=torch.float16)) - # generate attention_mask_builder with int8 - attention_mask_builder = AttentionMaskBuilder(max_seq_len=512, - dtype=torch.int8) - self.assertEqual(attention_mask_builder._seq_len_cached, 512) + # generate attention_mask_builder with bfloat16 + attention_mask_builder = AttentionMaskBuilder(max_seq_len=2048, + dtype=torch.bfloat16) + self.assertEqual(attention_mask_builder._seq_len_cached, 2048) self.assertEqual(attention_mask_builder.attn_mask_cache.dtype, - torch.int8) - self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000) + torch.bfloat16) self.assertEqual(attention_mask_builder.attn_mask_cache.shape, - (512, 512)) + (2048, 2048)) self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], - torch.tensor(1, dtype=torch.int8)) + torch.tensor(1, dtype=torch.bfloat16)) + + def test_get_mask_scale_factor(self): + # supported data types + self.assertEqual( + AttentionMaskBuilder.get_mask_scale_factor(torch.float16), 1) + self.assertEqual( + AttentionMaskBuilder.get_mask_scale_factor(torch.bfloat16), -10000) + # mask_scale_factor now only supports data types: torch.float16 and torch.bfloat16 + # Otherwise raise ValueError + with self.assertRaises(ValueError): + AttentionMaskBuilder.get_mask_scale_factor(torch.int8) def test_get_attn_mask(self): # if the len is less than max_seq_len, the attn_mask_cache will not be updated @@ -77,80 +86,48 @@ def test_get_splitfuse_attn_mask(self): attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024, dtype=torch.float16) attn_mask = attention_mask_builder.get_splitfuse_attn_mask( - seq_lens=[512], - query_lens=[512], - position=torch.tensor([0]), + seq_lens=torch.tensor([10, 20, 100]), + position=torch.tensor([7, 8, 9, 18, 19, 99]), dtype=torch.float16, device=torch.device("cpu"), ) - self.assertEqual(attn_mask.shape, (1, 512)) + self.assertEqual(attn_mask.shape, (6, 100)) self.assertEqual(attention_mask_builder._seq_len_cached, 1024) attn_mask = attention_mask_builder.get_splitfuse_attn_mask( - seq_lens=[2048], - query_lens=[1024], - position=torch.tensor([0]), + seq_lens=torch.tensor([10, 3000, 2000]), + position=torch.tensor([7, 8, 9, 2999, 1999]), dtype=torch.float16, device=torch.device("cpu"), ) - self.assertEqual(attn_mask.shape, (1024, 2048)) + self.assertEqual(attn_mask.shape, (5, 3000)) + self.assertEqual(attention_mask_builder._seq_len_cached, 3000) + + # splitfuse_attn_mask now only supports data types: torch.float16 and torch.bfloat16 + # otherwise raise ValueError + with self.assertRaises(ValueError): + attn_mask = attention_mask_builder.get_splitfuse_attn_mask( + seq_lens=torch.tensor([10, 20, 100]), + position=torch.tensor([7, 8, 9, 18, 19, 99]), + dtype=torch.int8, + device=torch.device("cpu"), + ) + + def test_mask_value_cleanliness(self): + attention_mask_builder = AttentionMaskBuilder(max_seq_len=6, + dtype=torch.bfloat16) + self.assertEqual(attention_mask_builder.attn_mask_cache[-2][-1], + torch.tensor(1, dtype=torch.bfloat16)) - attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024, - dtype=torch.int8) attn_mask = attention_mask_builder.get_splitfuse_attn_mask( - seq_lens=[512], - query_lens=[512], - position=torch.tensor([0]), - dtype=torch.int8, + seq_lens=torch.tensor([6]), + position=torch.tensor([3, 4, 5]), + dtype=torch.bfloat16, device=torch.device("cpu"), ) - self.assertEqual(attn_mask.shape, (1, 512)) - - def test_use_multiple_masks(self): - max_seq_lens = [128, 512, 1024] - dtypes = [torch.float16, torch.bfloat16, torch.int8] - for max_seq_len, dtype in zip(max_seq_lens, dtypes): - with self.subTest(max_seq_len=max_seq_len, dtype=dtype): - self._test_use_multiple_masks(max_seq_len, dtype) - - def _test_use_multiple_masks(self, max_seq_len, dtype): - expected_mask_value = torch.finfo( - torch.float32).min if dtype == torch.float16 else 1 - if dtype == torch.float16: - expected_splitfuse_mask_value = expected_mask_value - elif dtype == torch.bfloat16: - expected_splitfuse_mask_value = -10000 - else: - assert dtype == torch.int8, "Unsupported dtype for attention mask" - expected_splitfuse_mask_value = -16 - - attention_mask_builder = AttentionMaskBuilder(max_seq_len=max_seq_len, - dtype=dtype) - - splitfuse_attn_mask = attention_mask_builder.get_splitfuse_attn_mask( - seq_lens=[max_seq_len], - query_lens=[max_seq_len], - position=torch.tensor([0]), - dtype=dtype, - device=torch.device("cpu"), - ) - self.assertEqual(splitfuse_attn_mask.shape, (1, max_seq_len)) self.assertEqual( - splitfuse_attn_mask[0][-1], - torch.tensor(expected_splitfuse_mask_value, dtype=dtype)) - self.assertEqual(attention_mask_builder._seq_len_cached, max_seq_len) - self.assertEqual(attention_mask_builder.attn_mask_cache.shape, - (max_seq_len, max_seq_len)) - self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], - torch.tensor(expected_mask_value, dtype=dtype)) - - attn_mask = attention_mask_builder.get_attn_mask( - max_seq_len=max_seq_len, dtype=dtype, device=torch.device("cpu")) - self.assertEqual(attn_mask.shape, (max_seq_len, max_seq_len)) - self.assertEqual(attn_mask[0][-1], - torch.tensor(expected_mask_value, dtype=dtype)) - self.assertEqual(attention_mask_builder._seq_len_cached, max_seq_len) - self.assertEqual(attention_mask_builder.attn_mask_cache.shape, - (max_seq_len, max_seq_len)) - self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], - torch.tensor(expected_mask_value, dtype=dtype)) + attn_mask[-2][-1], + torch.tensor(-10000, dtype=torch.bfloat16, + device=attn_mask.device)) + self.assertEqual(attention_mask_builder.attn_mask_cache[-2][-1], + torch.tensor(1, dtype=torch.bfloat16)) diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index 11f1115825..a0e63349b1 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -44,61 +44,50 @@ def __init__( self._seq_len_cached = attn_mask.shape[0] self.attn_mask_cache = attn_mask - self.splitfuse_mask_value = -10000 + + @staticmethod + def get_mask_scale_factor(dtype: torch.dtype = torch.float16): + if dtype == torch.float16: + mask_scale_factor = 1 + elif dtype == torch.bfloat16: + mask_scale_factor = -10000 + else: + raise ValueError( + "The current operation now only supports data types: torch.float16 and " + "torch.bfloat16. Please ensure the input is of one of these types." + ) + return mask_scale_factor def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype, device: torch.device): - self._update_attn_cache(max_seq_len, dtype, device) - return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous() + self._update_attn_cache(max_seq_len, dtype) + return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous( + ).to(device) def get_splitfuse_attn_mask( self, - seq_lens, - query_lens, - position, - dtype, - device, + seq_lens: torch.Tensor, + position: torch.Tensor, + dtype: torch.dtype, + device: torch.device, ) -> torch.Tensor: + if dtype not in [torch.float16, torch.bfloat16]: + raise ValueError( + "splitfuse_attn_mask now only supports bf16 and fp16") max_seq_len = max(seq_lens, default=0) - if max_seq_len <= self._seq_len_cached: - self._update_attn_cache(max_seq_len, dtype, device) - # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation - # is not the same. Fix this in the future when kernel is ready. - if self.attn_mask_cache.numel( - ) > 1 and self.attn_mask_cache[0][1] > 0: - attn_mask = self.get_attn_mask( # type: ignore - max_seq_len, dtype, device) - # Do not use in-place multiplication to avoid modifying `self.attn_mask_cache`! - attn_mask = attn_mask * -10000 - else: - attn_mask = self.attn_mask_cache - return torch.index_select(attn_mask, dim=0, - index=position)[:, :max_seq_len] - total_q_len = sum(query_lens) - attn_mask = torch.zeros((total_q_len, max_seq_len), - dtype=dtype, - device="cpu") - current_row = 0 - for i in range(len(query_lens)): - seq_len = seq_lens[i] - q_len = query_lens[i] - context_len = seq_len - q_len - - assert context_len >= 0 - attn_mask[current_row:current_row + q_len, - context_len:] = self.splitfuse_mask_value - right_tensor = attn_mask[current_row:current_row + q_len, - context_len:seq_len] - right_tensor.masked_fill_( - right_tensor.tril() == self.splitfuse_mask_value, 0) - current_row += q_len - - return attn_mask.to(device, non_blocking=True) + self._update_attn_cache(max_seq_len, dtype) + # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation + # is not the same. Fix this in the future when kernel is ready. + mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype) + attn_mask = torch.index_select(self.attn_mask_cache, + dim=0, + index=position)[:, :max_seq_len] + attn_mask *= mask_scale_factor + return attn_mask.contiguous().to(device, non_blocking=True) - def _update_attn_cache(self, seqlen: int, dtype: torch.dtype, - device: torch.device): + def _update_attn_cache(self, seqlen: int, dtype: torch.dtype): if seqlen > self._seq_len_cached: self._seq_len_cached = seqlen self.attn_mask_cache = _generate_attn_mask(seqlen, dtype) - if self.attn_mask_cache.device != device: - self.attn_mask_cache = self.attn_mask_cache.to(device) + if self.attn_mask_cache.dtype != dtype: + self.attn_mask_cache = self.attn_mask_cache.to(dtype) diff --git a/vllm_ascend/worker/eagle_proposer_v1.py b/vllm_ascend/worker/eagle_proposer_v1.py index 895649327c..479ef1ddf2 100644 --- a/vllm_ascend/worker/eagle_proposer_v1.py +++ b/vllm_ascend/worker/eagle_proposer_v1.py @@ -79,11 +79,10 @@ def __init__(self, def _make_attention_mask( self, seq_lens, - query_lens, position, ) -> torch.Tensor: return self.attn_mask_builder.get_splitfuse_attn_mask( - seq_lens, query_lens, position, self.dtype, self.device) + seq_lens, position, self.dtype, self.device) def propose( self, @@ -247,7 +246,6 @@ def propose( positions = positions_cpu.to(device) attn_mask = self._make_attention_mask( seq_lens=attn_metadata.seq_lens, - query_lens=attn_metadata.max_query_len, position=positions, ) attn_metadata.attn_mask = attn_mask diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 03486b0555..33fcb93620 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -20,7 +20,6 @@ import copy import gc import math -import os import time from contextlib import contextmanager, nullcontext from dataclasses import dataclass @@ -233,8 +232,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.attn_metadata_builder = self.attn_backend.get_builder_cls()( vllm_config, device) self.attn_mask_builder = AttentionMaskBuilder( - min(self.model_config.max_model_len, - int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype) + self.model_config.max_model_len, self.dtype) # Set up speculative decoding. self.use_aux_hidden_state_outputs = False @@ -817,12 +815,12 @@ def get_supported_tasks(self) -> "tuple[SupportedTask, ...]": return tuple(tasks) - def _make_attention_mask(self, seq_lens, query_lens, position, + def _make_attention_mask(self, seq_lens, position, attn_state) -> torch.Tensor: # Chunk Prefill situation. if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla: return self.attn_mask_builder.get_splitfuse_attn_mask( - seq_lens, query_lens, position, self.dtype, self.device) + seq_lens, position, self.dtype, self.device) # Prefill without cache situation. elif attn_state == AscendAttentionState.PrefillNoCache: max_seq_len = max(seq_lens, default=0) @@ -1123,16 +1121,17 @@ def _prepare_inputs( self.mrope_positions_cpu[:, :total_num_scheduled_tokens], non_blocking=True) - self.positions[total_num_scheduled_tokens:num_input_tokens].zero_() - self.positions[:total_num_scheduled_tokens].copy_( - self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) - + self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_() + self.positions[:num_input_tokens].copy_( + self.positions_cpu[:num_input_tokens], non_blocking=True) + positions_cpu = self.positions_cpu[:num_input_tokens] + positions = self.positions[:num_input_tokens] self.query_lens = torch.from_numpy(num_scheduled_tokens) self.seq_lens_np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) - seq_lens = self.seq_lens_cpu[:num_reqs] + seq_lens_cpu = self.seq_lens_cpu[:num_reqs] block_table_indices = (req_indices * self.max_num_blocks_per_req + positions_np // self.block_size) @@ -1147,11 +1146,9 @@ def _prepare_inputs( attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) - self.attn_mask = self._make_attention_mask( - seq_lens=seq_lens, - query_lens=num_scheduled_tokens, - position=self.positions[:num_input_tokens], - attn_state=attn_state) + self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, + position=positions_cpu, + attn_state=attn_state) self.attn_state = attn_state # type: ignore self.query_start_loc_np[0] = 0