Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 49 additions & 72 deletions tests/ut/attention/test_attention_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,32 @@ def test_init_attention_mask_builder(self):
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
torch.float16)
self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(1024, 1024))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(float("-inf"), dtype=torch.float16))

# generate attention_mask_builder with int8
attention_mask_builder = AttentionMaskBuilder(max_seq_len=512,
dtype=torch.int8)
self.assertEqual(attention_mask_builder._seq_len_cached, 512)
# generate attention_mask_builder with bfloat16
attention_mask_builder = AttentionMaskBuilder(max_seq_len=2048,
dtype=torch.bfloat16)
self.assertEqual(attention_mask_builder._seq_len_cached, 2048)
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
torch.int8)
self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000)
torch.bfloat16)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(512, 512))
(2048, 2048))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(1, dtype=torch.int8))
torch.tensor(1, dtype=torch.bfloat16))

def test_get_mask_scale_factor(self):
# supported data types
self.assertEqual(
AttentionMaskBuilder.get_mask_scale_factor(torch.float16), 1)
self.assertEqual(
AttentionMaskBuilder.get_mask_scale_factor(torch.bfloat16), -10000)
# mask_scale_factor now only supports data types: torch.float16 and torch.bfloat16
# Otherwise raise ValueError
with self.assertRaises(ValueError):
AttentionMaskBuilder.get_mask_scale_factor(torch.int8)

def test_get_attn_mask(self):
# if the len is less than max_seq_len, the attn_mask_cache will not be updated
Expand Down Expand Up @@ -77,80 +86,48 @@ def test_get_splitfuse_attn_mask(self):
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
dtype=torch.float16)
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=[512],
query_lens=[512],
position=torch.tensor([0]),
seq_lens=torch.tensor([10, 20, 100]),
position=torch.tensor([7, 8, 9, 18, 19, 99]),
dtype=torch.float16,
device=torch.device("cpu"),
)
self.assertEqual(attn_mask.shape, (1, 512))
self.assertEqual(attn_mask.shape, (6, 100))
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)

attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=[2048],
query_lens=[1024],
position=torch.tensor([0]),
seq_lens=torch.tensor([10, 3000, 2000]),
position=torch.tensor([7, 8, 9, 2999, 1999]),
dtype=torch.float16,
device=torch.device("cpu"),
)
self.assertEqual(attn_mask.shape, (1024, 2048))
self.assertEqual(attn_mask.shape, (5, 3000))
self.assertEqual(attention_mask_builder._seq_len_cached, 3000)

# splitfuse_attn_mask now only supports data types: torch.float16 and torch.bfloat16
# otherwise raise ValueError
with self.assertRaises(ValueError):
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=torch.tensor([10, 20, 100]),
position=torch.tensor([7, 8, 9, 18, 19, 99]),
dtype=torch.int8,
device=torch.device("cpu"),
)

def test_mask_value_cleanliness(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have add this test here. @ApsarasX

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have add this test here. @ApsarasX

OK, I see.

attention_mask_builder = AttentionMaskBuilder(max_seq_len=6,
dtype=torch.bfloat16)
self.assertEqual(attention_mask_builder.attn_mask_cache[-2][-1],
torch.tensor(1, dtype=torch.bfloat16))

attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
dtype=torch.int8)
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=[512],
query_lens=[512],
position=torch.tensor([0]),
dtype=torch.int8,
seq_lens=torch.tensor([6]),
position=torch.tensor([3, 4, 5]),
dtype=torch.bfloat16,
device=torch.device("cpu"),
)
self.assertEqual(attn_mask.shape, (1, 512))

def test_use_multiple_masks(self):
max_seq_lens = [128, 512, 1024]
dtypes = [torch.float16, torch.bfloat16, torch.int8]
for max_seq_len, dtype in zip(max_seq_lens, dtypes):
with self.subTest(max_seq_len=max_seq_len, dtype=dtype):
self._test_use_multiple_masks(max_seq_len, dtype)

def _test_use_multiple_masks(self, max_seq_len, dtype):
expected_mask_value = torch.finfo(
torch.float32).min if dtype == torch.float16 else 1
if dtype == torch.float16:
expected_splitfuse_mask_value = expected_mask_value
elif dtype == torch.bfloat16:
expected_splitfuse_mask_value = -10000
else:
assert dtype == torch.int8, "Unsupported dtype for attention mask"
expected_splitfuse_mask_value = -16

attention_mask_builder = AttentionMaskBuilder(max_seq_len=max_seq_len,
dtype=dtype)

splitfuse_attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=[max_seq_len],
query_lens=[max_seq_len],
position=torch.tensor([0]),
dtype=dtype,
device=torch.device("cpu"),
)
self.assertEqual(splitfuse_attn_mask.shape, (1, max_seq_len))
self.assertEqual(
splitfuse_attn_mask[0][-1],
torch.tensor(expected_splitfuse_mask_value, dtype=dtype))
self.assertEqual(attention_mask_builder._seq_len_cached, max_seq_len)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(max_seq_len, max_seq_len))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(expected_mask_value, dtype=dtype))

attn_mask = attention_mask_builder.get_attn_mask(
max_seq_len=max_seq_len, dtype=dtype, device=torch.device("cpu"))
self.assertEqual(attn_mask.shape, (max_seq_len, max_seq_len))
self.assertEqual(attn_mask[0][-1],
torch.tensor(expected_mask_value, dtype=dtype))
self.assertEqual(attention_mask_builder._seq_len_cached, max_seq_len)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(max_seq_len, max_seq_len))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(expected_mask_value, dtype=dtype))
attn_mask[-2][-1],
torch.tensor(-10000, dtype=torch.bfloat16,
device=attn_mask.device))
self.assertEqual(attention_mask_builder.attn_mask_cache[-2][-1],
torch.tensor(1, dtype=torch.bfloat16))
81 changes: 35 additions & 46 deletions vllm_ascend/attention/attention_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,61 +44,50 @@ def __init__(

self._seq_len_cached = attn_mask.shape[0]
self.attn_mask_cache = attn_mask
self.splitfuse_mask_value = -10000

@staticmethod
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
if dtype == torch.float16:
mask_scale_factor = 1
elif dtype == torch.bfloat16:
mask_scale_factor = -10000
else:
raise ValueError(
"The current operation now only supports data types: torch.float16 and "
"torch.bfloat16. Please ensure the input is of one of these types."
)
return mask_scale_factor

def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
device: torch.device):
self._update_attn_cache(max_seq_len, dtype, device)
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous()
self._update_attn_cache(max_seq_len, dtype)
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
).to(device)

def get_splitfuse_attn_mask(
self,
seq_lens,
query_lens,
position,
dtype,
device,
seq_lens: torch.Tensor,
position: torch.Tensor,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
if dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(
"splitfuse_attn_mask now only supports bf16 and fp16")
max_seq_len = max(seq_lens, default=0)
if max_seq_len <= self._seq_len_cached:
self._update_attn_cache(max_seq_len, dtype, device)
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
# is not the same. Fix this in the future when kernel is ready.
if self.attn_mask_cache.numel(
) > 1 and self.attn_mask_cache[0][1] > 0:
attn_mask = self.get_attn_mask( # type: ignore
max_seq_len, dtype, device)
# Do not use in-place multiplication to avoid modifying `self.attn_mask_cache`!
attn_mask = attn_mask * -10000
else:
attn_mask = self.attn_mask_cache
return torch.index_select(attn_mask, dim=0,
index=position)[:, :max_seq_len]
total_q_len = sum(query_lens)
attn_mask = torch.zeros((total_q_len, max_seq_len),
dtype=dtype,
device="cpu")
current_row = 0
for i in range(len(query_lens)):
seq_len = seq_lens[i]
q_len = query_lens[i]
context_len = seq_len - q_len

assert context_len >= 0
attn_mask[current_row:current_row + q_len,
context_len:] = self.splitfuse_mask_value
right_tensor = attn_mask[current_row:current_row + q_len,
context_len:seq_len]
right_tensor.masked_fill_(
right_tensor.tril() == self.splitfuse_mask_value, 0)
current_row += q_len

return attn_mask.to(device, non_blocking=True)
self._update_attn_cache(max_seq_len, dtype)
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
# is not the same. Fix this in the future when kernel is ready.
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype)
attn_mask = torch.index_select(self.attn_mask_cache,
dim=0,
index=position)[:, :max_seq_len]
attn_mask *= mask_scale_factor
return attn_mask.contiguous().to(device, non_blocking=True)

def _update_attn_cache(self, seqlen: int, dtype: torch.dtype,
device: torch.device):
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
if seqlen > self._seq_len_cached:
self._seq_len_cached = seqlen
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
if self.attn_mask_cache.device != device:
self.attn_mask_cache = self.attn_mask_cache.to(device)
if self.attn_mask_cache.dtype != dtype:
self.attn_mask_cache = self.attn_mask_cache.to(dtype)
4 changes: 1 addition & 3 deletions vllm_ascend/worker/eagle_proposer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,10 @@ def __init__(self,
def _make_attention_mask(
self,
seq_lens,
query_lens,
position,
) -> torch.Tensor:
return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, query_lens, position, self.dtype, self.device)
seq_lens, position, self.dtype, self.device)

def propose(
self,
Expand Down Expand Up @@ -247,7 +246,6 @@ def propose(
positions = positions_cpu.to(device)
attn_mask = self._make_attention_mask(
seq_lens=attn_metadata.seq_lens,
query_lens=attn_metadata.max_query_len,
position=positions,
)
attn_metadata.attn_mask = attn_mask
Expand Down
27 changes: 12 additions & 15 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import copy
import gc
import math
import os
import time
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
Expand Down Expand Up @@ -233,8 +232,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
vllm_config, device)
self.attn_mask_builder = AttentionMaskBuilder(
min(self.model_config.max_model_len,
int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype)
self.model_config.max_model_len, self.dtype)

# Set up speculative decoding.
self.use_aux_hidden_state_outputs = False
Expand Down Expand Up @@ -817,12 +815,12 @@ def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":

return tuple(tasks)

def _make_attention_mask(self, seq_lens, query_lens, position,
def _make_attention_mask(self, seq_lens, position,
attn_state) -> torch.Tensor:
# Chunk Prefill situation.
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, query_lens, position, self.dtype, self.device)
seq_lens, position, self.dtype, self.device)
# Prefill without cache situation.
elif attn_state == AscendAttentionState.PrefillNoCache:
max_seq_len = max(seq_lens, default=0)
Expand Down Expand Up @@ -1123,16 +1121,17 @@ def _prepare_inputs(
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
non_blocking=True)

self.positions[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)

self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions[:num_input_tokens].copy_(
self.positions_cpu[:num_input_tokens], non_blocking=True)
positions_cpu = self.positions_cpu[:num_input_tokens]
positions = self.positions[:num_input_tokens]
self.query_lens = torch.from_numpy(num_scheduled_tokens)

self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
seq_lens = self.seq_lens_cpu[:num_reqs]
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]

block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size)
Expand All @@ -1147,11 +1146,9 @@ def _prepare_inputs(
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
num_valid_tokens)

self.attn_mask = self._make_attention_mask(
seq_lens=seq_lens,
query_lens=num_scheduled_tokens,
position=self.positions[:num_input_tokens],
attn_state=attn_state)
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
position=positions_cpu,
attn_state=attn_state)
self.attn_state = attn_state # type: ignore

self.query_start_loc_np[0] = 0
Expand Down
Loading