Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lmdeploy/pytorch/backends/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def build(
causal: bool = True,
use_flash_mla: bool = False,
learnable_sink: bool = False,
block_sparse_size: int = 1,
**kwargs,
) -> AttentionImpl[T]:
"""build."""
Expand Down
10 changes: 8 additions & 2 deletions lmdeploy/pytorch/backends/cuda/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
sliding_window: int = None,
logit_softcapping: float = None,
causal: bool = True,
block_sparse_size: int = 1,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(
world_size, rank = get_tp_world_rank()
self.alibi_head_offset = self.num_heads * rank
self.alibi_num_heads = self.num_heads * world_size
self.block_sparse_size = block_sparse_size

def forward(
self,
Expand All @@ -116,7 +118,7 @@ def forward(
kv_flatten_size = attn_metadata.kv_flatten_size
quant_policy = attn_metadata.quant_policy
if attn_metadata.is_decoding:
max_q_seqlen = 1
max_q_seqlen = self.block_sparse_size
else:
max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
fill_max_q_seqlen = max_q_seqlen
Expand Down Expand Up @@ -213,6 +215,7 @@ def forward(
logit_softcapping=self.logit_softcapping,
sinks=learnable_sink,
causal=self.causal,
block_sparse_size=self.block_sparse_size,
)

return attn_output
Expand Down Expand Up @@ -528,9 +531,11 @@ def build(
causal: bool = True,
use_flash_mla: bool = False,
learnable_sink: bool = False,
block_sparse_size: int = 1,
**kwargs,
) -> TritonAttentionImpl:
"""build."""
enable_fa3 = use_fa3 and not alibi and not learnable_sink and block_sparse_size == 1
if use_flash_mla is True:
return FlashMLAImpl(num_heads,
head_size,
Expand All @@ -542,7 +547,7 @@ def build(
logical_softcapping=logical_softcapping,
causal=causal,
**kwargs)
elif use_fa3 and not alibi and not learnable_sink:
elif enable_fa3:
return FA3Impl(num_heads,
head_size,
scale=scale,
Expand All @@ -563,4 +568,5 @@ def build(
sliding_window=sliding_window,
logical_softcapping=logical_softcapping,
causal=causal,
block_sparse_size=block_sparse_size,
**kwargs)
14 changes: 12 additions & 2 deletions lmdeploy/pytorch/kernels/cuda/flashattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def _load_kv(ptrs, boundary_check: tl.constexpr):
def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start, loop_end, sm_scale, history_mask,
kv_min_loc, causal_mask: tl.constexpr, window_size: tl.constexpr,
logit_softcapping: tl.constexpr, k_bound: tl.constexpr, v_bound: tl.constexpr,
shared_kv: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DK1: tl.constexpr):
shared_kv: tl.constexpr, block_sparse_size: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_DK1: tl.constexpr):
k_ptrs = tl.advance(k_ptrs, (0, loop_start))
v_ptrs = tl.advance(v_ptrs, (loop_start, 0))
if BLOCK_DK1:
Expand All @@ -77,7 +78,11 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start
qk *= sm_scale
qk = softcapping(qk, logit_softcapping)
qk = qk * tl_log2(math.e)
qk_mask = (history_mask[:, None]) >= (start_n + offs_n[None, :])
if block_sparse_size > 1:
offs_mask = (start_n + offs_n) // block_sparse_size * block_sparse_size
qk_mask = (history_mask[:, None]) >= offs_mask[None, :]
else:
qk_mask = (history_mask[:, None]) >= (start_n + offs_n[None, :])
if window_size > 0:
qk_mask = qk_mask and ((start_n + offs_n[None, :]) >= kv_min_loc[:, None])
qk = tl.where(
Expand Down Expand Up @@ -180,6 +185,7 @@ def _flash_prefill_fwd_kernel(
window_size: tl.constexpr,
logit_softcapping: tl.constexpr,
shared_kv: tl.constexpr,
block_sparse_size: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_DK: tl.constexpr,
Expand Down Expand Up @@ -295,6 +301,7 @@ def _flash_prefill_fwd_kernel(
k_bound=k_bound0,
v_bound=v_bound0,
shared_kv=shared_kv,
block_sparse_size=block_sparse_size,
BLOCK_N=BLOCK_N,
BLOCK_DK1=BLOCK_DK1)

Expand Down Expand Up @@ -322,6 +329,7 @@ def _flash_prefill_fwd_kernel(
k_bound=k_bound1,
v_bound=v_bound1,
shared_kv=shared_kv,
block_sparse_size=block_sparse_size,
BLOCK_N=BLOCK_N,
BLOCK_DK1=BLOCK_DK1)
# epilogue
Expand Down Expand Up @@ -440,6 +448,7 @@ def flash_attention_fwd(
logit_softcapping: float = None,
sinks: Tensor = None,
causal: bool = True,
block_sparse_size: int = 1,
kv_layout: str = 'hsd',
):
"""Varlen flash Attention forward.
Expand Down Expand Up @@ -534,6 +543,7 @@ def grid(args):
window_size=window_size,
logit_softcapping=logit_softcapping,
shared_kv=shared_kv,
block_sparse_size=block_sparse_size,
BLOCK_DK=BLOCK_DK,
BLOCK_DK1=BLOCK_DK1,
BLOCK_DV=BLOCK_DV,
Expand Down
52 changes: 30 additions & 22 deletions lmdeploy/pytorch/kernels/cuda/pagedattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _fwd_grouped_split_kernel(
stride_od: tl.constexpr,
stride_boffb,
kv_group_num: tl.constexpr,
seq_len: tl.constexpr,
window_size: tl.constexpr,
head_size: tl.constexpr,
head_size_v: tl.constexpr,
Expand All @@ -74,18 +75,20 @@ def _fwd_grouped_split_kernel(
):
"""First step kernel of split k attention."""
cur_batch = tl.program_id(2)
cur_kv_head = tl.program_id(0)
tile_id = tl.program_id(0)
split_k_id = tl.program_id(1)

if BLOCK_H < kv_group_num:
HEAD_PER_CTA: tl.constexpr = BLOCK_H
else:
HEAD_PER_CTA: tl.constexpr = kv_group_num
cur_head = cur_kv_head * HEAD_PER_CTA + tl.arange(0, BLOCK_H)
mask_h = cur_head < cur_kv_head * HEAD_PER_CTA + HEAD_PER_CTA
HEADS_PER_REQ: tl.constexpr = kv_group_num * seq_len
TILES_PER_GROUP: tl.constexpr = tl.cdiv(HEADS_PER_REQ, BLOCK_H)
subtile_id = tile_id % TILES_PER_GROUP
cur_kv_head = tile_id // TILES_PER_GROUP
offs_h = subtile_id * BLOCK_H + tl.arange(0, BLOCK_H)
cur_head = cur_kv_head * kv_group_num + offs_h % kv_group_num
cur_token = cur_batch * seq_len + offs_h // kv_group_num

mask_h = cur_head < cur_kv_head * kv_group_num + kv_group_num
mask_h = mask_h & (cur_token < cur_batch * seq_len + seq_len)
mask_h = mask_h & (cur_head < num_heads_q)
if BLOCK_H < kv_group_num:
cur_kv_head = (cur_kv_head * HEAD_PER_CTA) // kv_group_num

q_seqlen = 1
kv_seqlen = tl.load(KV_seqlens + cur_batch)
Expand All @@ -104,7 +107,7 @@ def _fwd_grouped_split_kernel(
off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + offs_n[None, :] * stride_kbs)
off_v = (cur_kv_head * stride_vh + offs_dv[None, :] * stride_vd + offs_n[:, None] * stride_vbs)

off_q = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] * stride_qd)
off_q = (cur_token[:, None] * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] * stride_qd)
q = tl.load(Q + off_q, mask=mask_h[:, None] & mask_d[None, :], other=0)

k_ptrs = K + off_k
Expand All @@ -114,7 +117,7 @@ def _fwd_grouped_split_kernel(
offs_d1 = BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL1)
mask_d1 = offs_d1 < head_size
offs_d1 = offs_d1 % head_size
off_q1 = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d1[None, :] * stride_qd)
off_q1 = (cur_token[:, None] * stride_qbs + cur_head[:, None] * stride_qh + offs_d1[None, :] * stride_qd)
q1 = tl.load(Q + off_q1, mask=mask_h[:, None] & mask_d1[None, :], other=0)
off_k1 = (cur_kv_head * stride_kh + offs_d1[:, None] * stride_kd + offs_n[None, :] * stride_kbs)
k1_ptrs = K + off_k1
Expand Down Expand Up @@ -196,11 +199,11 @@ def _fwd_grouped_split_kernel(

# initialize pointers to output
if loop_end > loop_start:
off_acc = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head[:, None] * stride_oh +
off_acc = (cur_token[:, None] * stride_obs + split_k_id * stride_ok + cur_head[:, None] * stride_oh +
offs_dv[None, :] * stride_od)
tl.store(Acc_out + off_acc, acc, mask=mask_h[:, None] & mask_dv[None, :])

off_meta = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head * stride_oh + head_size_v)
off_meta = (cur_token * stride_obs + split_k_id * stride_ok + cur_head * stride_oh + head_size_v)
tl.store(Acc_out + off_meta, m_i, mask=mask_h)
tl.store(Acc_out + off_meta + 1, l_i, mask=mask_h)

Expand Down Expand Up @@ -588,7 +591,9 @@ def _get_block_d(Lk):
if sm_scale is None:
sm_scale = 1.0 / (Lq**0.5)
batch, head = kv_seqlens.shape[0], q.shape[-2]
kv_group_num = q.shape[-2] // k.shape[h_dim]
num_tokens = q.shape[-3]
num_kv_heads = k.shape[h_dim]
kv_group_num = head // num_kv_heads

if sinks is not None:
assert sinks.is_contiguous()
Expand All @@ -601,20 +606,22 @@ def _get_block_d(Lk):
'might leads to bad performance. '
'Please reduce `block_size`.')

is_decoding = q.shape[-3] == kv_seqlens.size(0)
assert is_decoding, 'we only support decoding paged attention.'
valid = num_tokens % batch == 0
assert valid, 'we only support decoding paged attention.'
seq_len = num_tokens // batch

BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV = _get_block_d(Lq)
p2_kv_group_num = triton.next_power_of_2(kv_group_num)
BLOCK_H = max(16, min(BLOCK, p2_kv_group_num))
grid_1 = triton.cdiv(head, min(BLOCK_H, kv_group_num))
HEADS_PER_REQ = kv_group_num * seq_len
BLOCK_H = max(16, min(BLOCK, triton.next_power_of_2(HEADS_PER_REQ)))
TILES_PER_GROUP = triton.cdiv(HEADS_PER_REQ, BLOCK_H)
grid_1 = TILES_PER_GROUP * num_kv_heads

SPLIT_K = _get_split_k(q.device.index, grid_1, batch)

if quant_policy != 4:
acc = q.new_empty(batch, head, SPLIT_K, Lv + 2, dtype=torch.float32)
acc = q.new_empty(num_tokens, head, SPLIT_K, Lv + 2, dtype=torch.float32)
else:
acc = q.new_empty(batch, head, SPLIT_K, o.shape[-1] + 2, dtype=torch.float32)
acc = q.new_empty(num_tokens, head, SPLIT_K, o.shape[-1] + 2, dtype=torch.float32)

grid = (
grid_1,
Expand Down Expand Up @@ -704,6 +711,7 @@ def _get_block_d(Lk):
stride_od=acc.stride(-1),
stride_boffb=block_offsets.stride(0),
kv_group_num=kv_group_num,
seq_len=seq_len,
window_size=window_size,
head_size=Lk,
head_size_v=Lv,
Expand All @@ -720,7 +728,7 @@ def _get_block_d(Lk):
num_stages=num_stages)

num_warps = 4
grid = (batch, head)
grid = (num_tokens, head)
if quant_policy == 4:
Lv *= 2
BLOCK_DV *= 2
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
causal: bool = True,
use_flash_mla: bool = False,
learnable_sink: bool = False,
block_sparse_size: int = 1,
**kwargs,
):
super().__init__()
Expand All @@ -61,6 +62,7 @@ def __init__(
causal=causal,
use_flash_mla=use_flash_mla,
learnable_sink=learnable_sink,
block_sparse_size=block_sparse_size,
**kwargs,
)

Expand Down
74 changes: 68 additions & 6 deletions tests/pytorch/kernel/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@ def _conti_input(data, q_seqlens):


def _make_bias(q_seqlens, history_lens, neg_val, causal):
batch_size = q_seqlens.shape[0]
kv_seqlens = q_seqlens + history_lens
max_seq_len = q_seqlens.max().item()
max_kv_len = kv_seqlens.max().item()
if causal:
seq_ranges = [torch.arange(max_seq_len) for _ in q_seqlens]
for r, l in zip(seq_ranges, q_seqlens):
r[l:] = -max_kv_len
seq_ranges = torch.stack(seq_ranges, dim=0).cuda()
kv_ranges = [torch.arange(max_kv_len) for _ in kv_seqlens]
kv_ranges = torch.stack(kv_ranges, 0).cuda()
seq_ranges = torch.arange(max_seq_len).cuda()
seq_ranges = seq_ranges.repeat(batch_size, 1)
seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len)

kv_ranges = torch.arange(max_kv_len).cuda()
kv_ranges = kv_ranges.repeat(batch_size, 1)

mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None])
return mask.float() * neg_val
else:
Expand All @@ -31,6 +33,27 @@ def _make_bias(q_seqlens, history_lens, neg_val, causal):
return (~mask).float() * neg_val


def _make_block_sparse_bias(q_seqlens: torch.Tensor, history_lens: torch.Tensor, neg_val: float,
block_sparse_size: int):
"""Make block sparse bias."""
batch_size = q_seqlens.shape[0]
kv_seqlens = q_seqlens + history_lens
max_seq_len = q_seqlens.max().item()
max_kv_len = kv_seqlens.max().item()

seq_ranges = torch.arange(max_seq_len).cuda()
seq_ranges = seq_ranges // block_sparse_size * block_sparse_size
seq_ranges = seq_ranges.repeat(batch_size, 1)
seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len)

kv_ranges = torch.arange(max_kv_len).cuda()
kv_ranges = kv_ranges // block_sparse_size * block_sparse_size
kv_ranges = kv_ranges.repeat(batch_size, 1)

mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None])
return mask.float() * neg_val


def _naive_attention(batched_q, batched_kv, bias, sinks=None):
batched_k, batched_v = batched_kv

Expand Down Expand Up @@ -283,3 +306,42 @@ def test_sinks(self, conti_q, conti_kv, q_start_loc, q_seqlens, kv_start_loc, kv
sinks=sinks,
causal=causal)
torch.testing.assert_close(out, conti_sink_gt, atol=1e-3, rtol=1e-5)

# block sparse attention
@pytest.fixture
def block_sparse_size(self):
yield 4

@pytest.fixture
def block_sparse_mask(self, q_seqlens, history_lens, block_sparse_size):
neg_val = -1e30
yield _make_block_sparse_bias(q_seqlens, history_lens, neg_val, block_sparse_size)

@pytest.fixture
def block_sparse_gt(self, batched_q, batched_kv, block_sparse_mask):
yield _naive_attention(batched_q, batched_kv, block_sparse_mask)

@pytest.mark.parametrize('head_dim_k', [32], indirect=True)
@pytest.mark.parametrize('head_dim_v', [32], indirect=True)
@pytest.mark.parametrize('num_heads_q', [8], indirect=True)
@pytest.mark.parametrize('num_heads_k', [2], indirect=True)
@pytest.mark.parametrize(['q_seqlens', 'history_lens'], [([16, 32], [64, 8])], indirect=True)
def test_block_sparse_attention(self, conti_q, conti_kv, q_start_loc, q_seqlens, kv_start_loc, kv_seqlens,
head_dim_v, block_sparse_size, block_sparse_gt):
from lmdeploy.pytorch.kernels.cuda.flashattention import flash_attention_fwd
max_seq_len = q_seqlens.max().item()

conti_k, conti_v = conti_kv
out = conti_q.new_empty(*conti_q.shape[:-1], head_dim_v)
flash_attention_fwd(conti_q,
conti_k,
conti_v,
out,
q_start_loc=q_start_loc,
q_seqlens=q_seqlens,
kv_start_loc=kv_start_loc,
kv_seqlens=kv_seqlens,
max_seqlen=max_seq_len,
block_sparse_size=block_sparse_size)
gt = _conti_input(block_sparse_gt, q_seqlens)
torch.testing.assert_close(out, gt, atol=1e-3, rtol=1e-5)
Loading