Skip to content

Commit 4f01f49

Browse files
committed
[main][bugfix] Fix bugs and refactor cached mask generation logic
Signed-off-by: rjg-lyh <[email protected]>
1 parent 0f81e03 commit 4f01f49

File tree

4 files changed

+97
-136
lines changed

4 files changed

+97
-136
lines changed

tests/ut/attention/test_attention_mask.py

Lines changed: 49 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,32 @@ def test_init_attention_mask_builder(self):
2828
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
2929
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
3030
torch.float16)
31-
self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000)
3231
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
3332
(1024, 1024))
3433
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
3534
torch.tensor(float("-inf"), dtype=torch.float16))
3635

37-
# generate attention_mask_builder with int8
38-
attention_mask_builder = AttentionMaskBuilder(max_seq_len=512,
39-
dtype=torch.int8)
40-
self.assertEqual(attention_mask_builder._seq_len_cached, 512)
36+
# generate attention_mask_builder with bfloat16
37+
attention_mask_builder = AttentionMaskBuilder(max_seq_len=2048,
38+
dtype=torch.bfloat16)
39+
self.assertEqual(attention_mask_builder._seq_len_cached, 2048)
4140
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
42-
torch.int8)
43-
self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000)
41+
torch.bfloat16)
4442
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
45-
(512, 512))
43+
(2048, 2048))
4644
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
47-
torch.tensor(1, dtype=torch.int8))
45+
torch.tensor(1, dtype=torch.bfloat16))
46+
47+
def test_get_mask_scale_factor(self):
48+
# supported data types
49+
self.assertEqual(
50+
AttentionMaskBuilder.get_mask_scale_factor(torch.float16), 1)
51+
self.assertEqual(
52+
AttentionMaskBuilder.get_mask_scale_factor(torch.bfloat16), -10000)
53+
# mask_scale_factor now only supports data types: torch.float16 and torch.bfloat16
54+
# Otherwise raise ValueError
55+
with self.assertRaises(ValueError):
56+
AttentionMaskBuilder.get_mask_scale_factor(torch.int8)
4857

4958
def test_get_attn_mask(self):
5059
# if the len is less than max_seq_len, the attn_mask_cache will not be updated
@@ -77,80 +86,48 @@ def test_get_splitfuse_attn_mask(self):
7786
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
7887
dtype=torch.float16)
7988
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
80-
seq_lens=[512],
81-
query_lens=[512],
82-
position=torch.tensor([0]),
89+
seq_lens=torch.tensor([10, 20, 100]),
90+
position=torch.tensor([7, 8, 9, 18, 19, 99]),
8391
dtype=torch.float16,
8492
device=torch.device("cpu"),
8593
)
86-
self.assertEqual(attn_mask.shape, (1, 512))
94+
self.assertEqual(attn_mask.shape, (6, 100))
8795
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
8896

8997
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
90-
seq_lens=[2048],
91-
query_lens=[1024],
92-
position=torch.tensor([0]),
98+
seq_lens=torch.tensor([10, 3000, 2000]),
99+
position=torch.tensor([7, 8, 9, 2999, 1999]),
93100
dtype=torch.float16,
94101
device=torch.device("cpu"),
95102
)
96-
self.assertEqual(attn_mask.shape, (1024, 2048))
103+
self.assertEqual(attn_mask.shape, (5, 3000))
104+
self.assertEqual(attention_mask_builder._seq_len_cached, 3000)
105+
106+
# splitfuse_attn_mask now only supports data types: torch.float16 and torch.bfloat16
107+
# otherwise raise ValueError
108+
with self.assertRaises(ValueError):
109+
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
110+
seq_lens=torch.tensor([10, 20, 100]),
111+
position=torch.tensor([7, 8, 9, 18, 19, 99]),
112+
dtype=torch.int8,
113+
device=torch.device("cpu"),
114+
)
115+
116+
def test_mask_value_cleanliness(self):
117+
attention_mask_builder = AttentionMaskBuilder(max_seq_len=6,
118+
dtype=torch.bfloat16)
119+
self.assertEqual(attention_mask_builder.attn_mask_cache[-2][-1],
120+
torch.tensor(1, dtype=torch.bfloat16))
97121

98-
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
99-
dtype=torch.int8)
100122
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
101-
seq_lens=[512],
102-
query_lens=[512],
103-
position=torch.tensor([0]),
104-
dtype=torch.int8,
105-
device=torch.device("cpu"),
106-
)
107-
self.assertEqual(attn_mask.shape, (1, 512))
108-
109-
def test_use_multiple_masks(self):
110-
max_seq_lens = [128, 512, 1024]
111-
dtypes = [torch.float16, torch.bfloat16, torch.int8]
112-
for max_seq_len, dtype in zip(max_seq_lens, dtypes):
113-
with self.subTest(max_seq_len=max_seq_len, dtype=dtype):
114-
self._test_use_multiple_masks(max_seq_len, dtype)
115-
116-
def _test_use_multiple_masks(self, max_seq_len, dtype):
117-
expected_mask_value = torch.finfo(
118-
torch.float32).min if dtype == torch.float16 else 1
119-
if dtype == torch.float16:
120-
expected_splitfuse_mask_value = expected_mask_value
121-
elif dtype == torch.bfloat16:
122-
expected_splitfuse_mask_value = -10000
123-
else:
124-
assert dtype == torch.int8, "Unsupported dtype for attention mask"
125-
expected_splitfuse_mask_value = -16
126-
127-
attention_mask_builder = AttentionMaskBuilder(max_seq_len=max_seq_len,
128-
dtype=dtype)
129-
130-
splitfuse_attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
131-
seq_lens=[max_seq_len],
132-
query_lens=[max_seq_len],
133-
position=torch.tensor([0]),
134-
dtype=dtype,
123+
seq_lens=torch.tensor([6]),
124+
position=torch.tensor([3, 4, 5]),
125+
dtype=torch.float16,
135126
device=torch.device("cpu"),
136127
)
137-
self.assertEqual(splitfuse_attn_mask.shape, (1, max_seq_len))
138128
self.assertEqual(
139-
splitfuse_attn_mask[0][-1],
140-
torch.tensor(expected_splitfuse_mask_value, dtype=dtype))
141-
self.assertEqual(attention_mask_builder._seq_len_cached, max_seq_len)
142-
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
143-
(max_seq_len, max_seq_len))
144-
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
145-
torch.tensor(expected_mask_value, dtype=dtype))
146-
147-
attn_mask = attention_mask_builder.get_attn_mask(
148-
max_seq_len=max_seq_len, dtype=dtype, device=torch.device("cpu"))
149-
self.assertEqual(attn_mask.shape, (max_seq_len, max_seq_len))
150-
self.assertEqual(attn_mask[0][-1],
151-
torch.tensor(expected_mask_value, dtype=dtype))
152-
self.assertEqual(attention_mask_builder._seq_len_cached, max_seq_len)
153-
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
154-
(max_seq_len, max_seq_len))
155-
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
156-
torch.tensor(expected_mask_value, dtype=dtype))
129+
attn_mask[-2][-1],
130+
torch.tensor(-10000, dtype=torch.bfloat16,
131+
device=attn_mask.device))
132+
self.assertEqual(attention_mask_builder.attn_mask_cache[-2][-1],
133+
torch.tensor(1, dtype=torch.bfloat16))

vllm_ascend/attention/attention_mask.py

Lines changed: 35 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -44,61 +44,50 @@ def __init__(
4444

4545
self._seq_len_cached = attn_mask.shape[0]
4646
self.attn_mask_cache = attn_mask
47-
self.splitfuse_mask_value = -10000
47+
48+
@staticmethod
49+
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
50+
if dtype == torch.float16:
51+
mask_scale_factor = 1
52+
elif dtype == torch.bfloat16:
53+
mask_scale_factor = -10000
54+
else:
55+
raise ValueError(
56+
"The current operation now only supports data types: torch.float16 and "
57+
"torch.bfloat16. Please ensure the input is of one of these types."
58+
)
59+
return mask_scale_factor
4860

4961
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
5062
device: torch.device):
51-
self._update_attn_cache(max_seq_len, dtype, device)
52-
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous()
63+
self._update_attn_cache(max_seq_len, dtype)
64+
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
65+
).to(device)
5366

5467
def get_splitfuse_attn_mask(
5568
self,
56-
seq_lens,
57-
query_lens,
58-
position,
59-
dtype,
60-
device,
69+
seq_lens: torch.Tensor,
70+
position: torch.Tensor,
71+
dtype: torch.dtype,
72+
device: torch.device,
6173
) -> torch.Tensor:
74+
if dtype not in [torch.float16, torch.bfloat16]:
75+
raise ValueError(
76+
"splitfuse_attn_mask now only supports bf16 and fp16")
6277
max_seq_len = max(seq_lens, default=0)
63-
if max_seq_len <= self._seq_len_cached:
64-
self._update_attn_cache(max_seq_len, dtype, device)
65-
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
66-
# is not the same. Fix this in the future when kernel is ready.
67-
if self.attn_mask_cache.numel(
68-
) > 1 and self.attn_mask_cache[0][1] > 0:
69-
attn_mask = self.get_attn_mask( # type: ignore
70-
max_seq_len, dtype, device)
71-
# Do not use in-place multiplication to avoid modifying `self.attn_mask_cache`!
72-
attn_mask = attn_mask * -10000
73-
else:
74-
attn_mask = self.attn_mask_cache
75-
return torch.index_select(attn_mask, dim=0,
76-
index=position)[:, :max_seq_len]
77-
total_q_len = sum(query_lens)
78-
attn_mask = torch.zeros((total_q_len, max_seq_len),
79-
dtype=dtype,
80-
device="cpu")
81-
current_row = 0
82-
for i in range(len(query_lens)):
83-
seq_len = seq_lens[i]
84-
q_len = query_lens[i]
85-
context_len = seq_len - q_len
86-
87-
assert context_len >= 0
88-
attn_mask[current_row:current_row + q_len,
89-
context_len:] = self.splitfuse_mask_value
90-
right_tensor = attn_mask[current_row:current_row + q_len,
91-
context_len:seq_len]
92-
right_tensor.masked_fill_(
93-
right_tensor.tril() == self.splitfuse_mask_value, 0)
94-
current_row += q_len
95-
96-
return attn_mask.to(device, non_blocking=True)
78+
self._update_attn_cache(max_seq_len, dtype)
79+
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
80+
# is not the same. Fix this in the future when kernel is ready.
81+
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype)
82+
attn_mask = torch.index_select(self.attn_mask_cache,
83+
dim=0,
84+
index=position)[:, :max_seq_len]
85+
attn_mask *= mask_scale_factor
86+
return attn_mask.contiguous().to(device, non_blocking=True)
9787

98-
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype,
99-
device: torch.device):
88+
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
10089
if seqlen > self._seq_len_cached:
10190
self._seq_len_cached = seqlen
10291
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
103-
if self.attn_mask_cache.device != device:
104-
self.attn_mask_cache = self.attn_mask_cache.to(device)
92+
if self.attn_mask_cache.dtype != dtype:
93+
self.attn_mask_cache = self.attn_mask_cache.to(dtype)

vllm_ascend/worker/eagle_proposer_v1.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,10 @@ def __init__(self,
7979
def _make_attention_mask(
8080
self,
8181
seq_lens,
82-
query_lens,
8382
position,
8483
) -> torch.Tensor:
8584
return self.attn_mask_builder.get_splitfuse_attn_mask(
86-
seq_lens, query_lens, position, self.dtype, self.device)
85+
seq_lens, position, self.dtype, self.device)
8786

8887
def propose(
8988
self,
@@ -247,7 +246,6 @@ def propose(
247246
positions = positions_cpu.to(device)
248247
attn_mask = self._make_attention_mask(
249248
seq_lens=attn_metadata.seq_lens,
250-
query_lens=attn_metadata.max_query_len,
251249
position=positions,
252250
)
253251
attn_metadata.attn_mask = attn_mask

vllm_ascend/worker/model_runner_v1.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import copy
2121
import gc
2222
import math
23-
import os
2423
import time
2524
from contextlib import contextmanager, nullcontext
2625
from dataclasses import dataclass
@@ -232,8 +231,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
232231
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
233232
vllm_config, device)
234233
self.attn_mask_builder = AttentionMaskBuilder(
235-
min(self.model_config.max_model_len,
236-
int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype)
234+
self.model_config.max_model_len, self.dtype)
237235

238236
# Set up speculative decoding.
239237
self.use_aux_hidden_state_outputs = False
@@ -808,12 +806,12 @@ def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
808806

809807
return tuple(tasks)
810808

811-
def _make_attention_mask(self, seq_lens, query_lens, position,
809+
def _make_attention_mask(self, seq_lens, position,
812810
attn_state) -> torch.Tensor:
813811
# Chunk Prefill situation.
814812
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
815813
return self.attn_mask_builder.get_splitfuse_attn_mask(
816-
seq_lens, query_lens, position, self.dtype, self.device)
814+
seq_lens, position, self.dtype, self.device)
817815
# Prefill without cache situation.
818816
elif attn_state == AscendAttentionState.PrefillNoCache:
819817
max_seq_len = max(seq_lens, default=0)
@@ -1082,16 +1080,17 @@ def _prepare_inputs(
10821080
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
10831081
non_blocking=True)
10841082

1085-
self.positions[total_num_scheduled_tokens:num_input_tokens].zero_()
1086-
self.positions[:total_num_scheduled_tokens].copy_(
1087-
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
1088-
1083+
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
1084+
self.positions[:num_input_tokens].copy_(
1085+
self.positions_cpu[:num_input_tokens], non_blocking=True)
1086+
positions_cpu = self.positions_cpu[:num_input_tokens]
1087+
positions = self.positions[:num_input_tokens]
10891088
self.query_lens = torch.from_numpy(num_scheduled_tokens)
10901089

10911090
self.seq_lens_np[:num_reqs] = (
10921091
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
10931092
num_scheduled_tokens)
1094-
seq_lens = self.seq_lens_cpu[:num_reqs]
1093+
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
10951094

10961095
block_table_indices = (req_indices * self.max_num_blocks_per_req +
10971096
positions_np // self.block_size)
@@ -1106,11 +1105,9 @@ def _prepare_inputs(
11061105
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
11071106
num_valid_tokens)
11081107

1109-
self.attn_mask = self._make_attention_mask(
1110-
seq_lens=seq_lens,
1111-
query_lens=num_scheduled_tokens,
1112-
position=self.positions[:num_input_tokens],
1113-
attn_state=attn_state)
1108+
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
1109+
position=positions_cpu,
1110+
attn_state=attn_state)
11141111
self.attn_state = attn_state # type: ignore
11151112

11161113
self.query_start_loc_np[0] = 0

0 commit comments

Comments
 (0)