Skip to content

Commit c0c5517

Browse files
committed
[main][bugfix] Fix bugs and refactor cached mask generation logic
Signed-off-by: rjg-lyh <[email protected]>
1 parent 3f867ee commit c0c5517

File tree

4 files changed

+77
-91
lines changed

4 files changed

+77
-91
lines changed

tests/ut/attention/test_attention_mask.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,30 @@ def test_init_attention_mask_builder(self):
2828
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
2929
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
3030
torch.float16)
31-
self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000)
3231
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
3332
(1024, 1024))
3433
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
3534
torch.tensor(float("-inf"), dtype=torch.float16))
3635

37-
# generate attention_mask_builder with int8
38-
attention_mask_builder = AttentionMaskBuilder(max_seq_len=512,
39-
dtype=torch.int8)
40-
self.assertEqual(attention_mask_builder._seq_len_cached, 512)
36+
# generate attention_mask_builder with bfloat16
37+
attention_mask_builder = AttentionMaskBuilder(max_seq_len=2048,
38+
dtype=torch.bfloat16)
39+
self.assertEqual(attention_mask_builder._seq_len_cached, 2048)
4140
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
42-
torch.int8)
43-
self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000)
41+
torch.bfloat16)
4442
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
45-
(512, 512))
43+
(2048, 2048))
4644
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
47-
torch.tensor(1, dtype=torch.int8))
45+
torch.tensor(1, dtype=torch.bfloat16))
46+
47+
def test_get_mask_scale_factor(self):
48+
# supported data types
49+
self.assertEqual(AttentionMaskBuilder.get_mask_scale_factor(torch.float16), 1)
50+
self.assertEqual(AttentionMaskBuilder.get_mask_scale_factor(torch.bfloat16), -10000)
51+
# mask_scale_factor now only supports data types: torch.float16 and torch.bfloat16
52+
# Otherwise raise ValueError
53+
with self.assertRaises(ValueError):
54+
AttentionMaskBuilder.get_mask_scale_factor(torch.int8)
4855

4956
def test_get_attn_mask(self):
5057
# if the len is less than max_seq_len, the attn_mask_cache will not be updated
@@ -77,34 +84,32 @@ def test_get_splitfuse_attn_mask(self):
7784
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
7885
dtype=torch.float16)
7986
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
80-
seq_lens=[512],
81-
query_lens=[512],
82-
position=torch.tensor([0]),
87+
seq_lens=torch.tensor([10, 20, 100]),
88+
position=torch.tensor([7, 8, 9, 18, 19, 99]),
8389
dtype=torch.float16,
8490
device=torch.device("cpu"),
8591
)
86-
self.assertEqual(attn_mask.shape, (1, 512))
92+
self.assertEqual(attn_mask.shape, (6, 100))
8793
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
8894

8995
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
90-
seq_lens=[2048],
91-
query_lens=[1024],
92-
position=torch.tensor([0]),
96+
seq_lens=torch.tensor([10, 3000, 2000]),
97+
position=torch.tensor([7, 8, 9, 2999, 1999]),
9398
dtype=torch.float16,
9499
device=torch.device("cpu"),
95100
)
96-
self.assertEqual(attn_mask.shape, (1024, 2048))
97-
98-
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
99-
dtype=torch.int8)
100-
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
101-
seq_lens=[512],
102-
query_lens=[512],
103-
position=torch.tensor([0]),
104-
dtype=torch.int8,
105-
device=torch.device("cpu"),
106-
)
107-
self.assertEqual(attn_mask.shape, (1, 512))
101+
self.assertEqual(attn_mask.shape, (5, 3000))
102+
self.assertEqual(attention_mask_builder._seq_len_cached, 3000)
103+
104+
# splitfuse_attn_mask now only supports data types: torch.float16 and torch.bfloat16
105+
# otherwise raise ValueError
106+
with self.assertRaises(ValueError):
107+
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
108+
seq_lens=torch.tensor([10, 20, 100]),
109+
position=torch.tensor([7, 8, 9, 18, 19, 99]),
110+
dtype=torch.int8,
111+
device=torch.device("cpu"),
112+
)
108113

109114
def test_use_multiple_masks(self):
110115
max_seq_lens = [128, 512, 1024]

vllm_ascend/attention/attention_mask.py

Lines changed: 32 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -44,61 +44,47 @@ def __init__(
4444

4545
self._seq_len_cached = attn_mask.shape[0]
4646
self.attn_mask_cache = attn_mask
47-
self.splitfuse_mask_value = -10000
47+
48+
@staticmethod
49+
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
50+
if dtype == torch.float16:
51+
mask_scale_factor = 1
52+
elif dtype == torch.bfloat16:
53+
mask_scale_factor = -10000
54+
else:
55+
raise ValueError("The current operation now only supports data types: torch.float16 and "
56+
"torch.bfloat16. Please ensure the input is of one of these types.")
57+
return mask_scale_factor
4858

4959
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
5060
device: torch.device):
51-
self._update_attn_cache(max_seq_len, dtype, device)
52-
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous()
61+
self._update_attn_cache(max_seq_len, dtype)
62+
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
63+
).to(device)
5364

5465
def get_splitfuse_attn_mask(
5566
self,
56-
seq_lens,
57-
query_lens,
58-
position,
59-
dtype,
60-
device,
67+
seq_lens: torch.Tensor,
68+
position: torch.Tensor,
69+
dtype: torch.dtype,
70+
device: torch.device,
6171
) -> torch.Tensor:
72+
if dtype not in [torch.float16, torch.bfloat16]:
73+
raise ValueError("splitfuse_attn_mask now only supports bf16 and fp16")
6274
max_seq_len = max(seq_lens, default=0)
63-
if max_seq_len <= self._seq_len_cached:
64-
self._update_attn_cache(max_seq_len, dtype, device)
65-
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
66-
# is not the same. Fix this in the future when kernel is ready.
67-
if self.attn_mask_cache.numel(
68-
) > 1 and self.attn_mask_cache[0][1] > 0:
69-
attn_mask = self.get_attn_mask( # type: ignore
70-
max_seq_len, dtype, device)
71-
# Do not use in-place multiplication to avoid modifying `self.attn_mask_cache`!
72-
attn_mask = attn_mask * -10000
73-
else:
74-
attn_mask = self.attn_mask_cache
75-
return torch.index_select(attn_mask, dim=0,
76-
index=position)[:, :max_seq_len]
77-
total_q_len = sum(query_lens)
78-
attn_mask = torch.zeros((total_q_len, max_seq_len),
79-
dtype=dtype,
80-
device="cpu")
81-
current_row = 0
82-
for i in range(len(query_lens)):
83-
seq_len = seq_lens[i]
84-
q_len = query_lens[i]
85-
context_len = seq_len - q_len
86-
87-
assert context_len >= 0
88-
attn_mask[current_row:current_row + q_len,
89-
context_len:] = self.splitfuse_mask_value
90-
right_tensor = attn_mask[current_row:current_row + q_len,
91-
context_len:seq_len]
92-
right_tensor.masked_fill_(
93-
right_tensor.tril() == self.splitfuse_mask_value, 0)
94-
current_row += q_len
95-
96-
return attn_mask.to(device, non_blocking=True)
75+
self._update_attn_cache(max_seq_len, dtype)
76+
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
77+
# is not the same. Fix this in the future when kernel is ready.
78+
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype)
79+
attn_mask = torch.index_select(self.attn_mask_cache,
80+
dim=0,
81+
index=position)[:, :max_seq_len]
82+
attn_mask *= mask_scale_factor
83+
return attn_mask.contiguous().to(device, non_blocking=True)
9784

98-
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype,
99-
device: torch.device):
85+
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
10086
if seqlen > self._seq_len_cached:
10187
self._seq_len_cached = seqlen
10288
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
103-
if self.attn_mask_cache.device != device:
104-
self.attn_mask_cache = self.attn_mask_cache.to(device)
89+
if self.attn_mask_cache.dtype != dtype:
90+
self.attn_mask_cache = self.attn_mask_cache.to(dtype)

vllm_ascend/worker/eagle_proposer_v1.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,10 @@ def __init__(self,
7979
def _make_attention_mask(
8080
self,
8181
seq_lens,
82-
query_lens,
8382
position,
8483
) -> torch.Tensor:
8584
return self.attn_mask_builder.get_splitfuse_attn_mask(
86-
seq_lens, query_lens, position, self.dtype, self.device)
85+
seq_lens, position, self.dtype, self.device)
8786

8887
def propose(
8988
self,
@@ -247,7 +246,6 @@ def propose(
247246
positions = positions_cpu.to(device)
248247
attn_mask = self._make_attention_mask(
249248
seq_lens=attn_metadata.seq_lens,
250-
query_lens=attn_metadata.max_query_len,
251249
position=positions,
252250
)
253251
attn_metadata.attn_mask = attn_mask

vllm_ascend/worker/model_runner_v1.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import copy
2121
import gc
2222
import math
23-
import os
2423
import time
2524
import types
2625
from contextlib import contextmanager, nullcontext
@@ -228,8 +227,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
228227
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
229228
vllm_config, device)
230229
self.attn_mask_builder = AttentionMaskBuilder(
231-
min(self.model_config.max_model_len,
232-
int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype)
230+
self.model_config.max_model_len, self.dtype)
233231

234232
# Set up speculative decoding.
235233
self.use_aux_hidden_state_outputs = False
@@ -847,12 +845,12 @@ def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
847845

848846
return tuple(tasks)
849847

850-
def _make_attention_mask(self, seq_lens, query_lens, position,
848+
def _make_attention_mask(self, seq_lens, position,
851849
attn_state) -> torch.Tensor:
852850
# Chunk Prefill situation.
853851
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
854852
return self.attn_mask_builder.get_splitfuse_attn_mask(
855-
seq_lens, query_lens, position, self.dtype, self.device)
853+
seq_lens, position, self.dtype, self.device)
856854
# Prefill without cache situation.
857855
elif attn_state == AscendAttentionState.PrefillNoCache:
858856
max_seq_len = max(seq_lens, default=0)
@@ -1124,16 +1122,17 @@ def _process_reqs(
11241122
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
11251123
non_blocking=True)
11261124

1127-
self.positions[total_num_scheduled_tokens:num_input_tokens].zero_()
1128-
self.positions[:total_num_scheduled_tokens].copy_(
1129-
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
1125+
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
1126+
self.positions[:num_input_tokens].copy_(
1127+
self.positions_cpu[:num_input_tokens], non_blocking=True)
1128+
positions_cpu = self.positions_cpu[:num_input_tokens]
11301129
positions = self.positions[:num_input_tokens]
11311130
self.query_lens = torch.from_numpy(num_scheduled_tokens)
11321131

11331132
self.seq_lens_np[:num_reqs] = (
11341133
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
11351134
num_scheduled_tokens)
1136-
seq_lens = self.seq_lens_cpu[:num_reqs]
1135+
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
11371136

11381137
block_table_indices = (req_indices * self.max_num_blocks_per_req +
11391138
positions_np // self.block_size)
@@ -1169,11 +1168,9 @@ def _process_reqs(
11691168
else:
11701169
attn_state = AscendAttentionState.PrefillCacheHit
11711170

1172-
self.attn_mask = self._make_attention_mask(
1173-
seq_lens=seq_lens,
1174-
query_lens=num_scheduled_tokens,
1175-
position=positions,
1176-
attn_state=attn_state)
1171+
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
1172+
position=positions_cpu,
1173+
attn_state=attn_state)
11771174
self.attn_state = attn_state # type: ignore
11781175

11791176
self.query_start_loc_np[0] = 0

0 commit comments

Comments
 (0)