Skip to content

Commit a640eeb

Browse files
committed
Support some ops for lightllm and lmdeploy.
1 parent af6dbbe commit a640eeb

File tree

2 files changed

+10
-78
lines changed

2 files changed

+10
-78
lines changed

csrc/extensions.cpp

Lines changed: 5 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -341,38 +341,6 @@ void extTokenSoftmaxReduceVInference(const at::Tensor& logics,
341341
b_start_loc, b_seq_len, max_input_len, other_kv_index);
342342
}
343343

344-
// void extTokenDecodeAttentionInference(const at::Tensor& q, const at::Tensor& k,
345-
// const at::Tensor& v, at::Tensor& out,
346-
// const at::Tensor& b_loc,
347-
// const at::Tensor& b_start_loc,
348-
// const at::Tensor& b_seq_len,
349-
// int max_input_len, int other_kv_index) {
350-
// callDiopi(diopiTokenDecodeAttentionInference, out, q, k, v, b_loc, b_start_loc,
351-
// b_seq_len, max_input_len, other_kv_index);
352-
// }
353-
354-
// void extTokenDecodeAttentionInferenceBatchOne(const at::Tensor& q, const at::Tensor& k,
355-
// const at::Tensor& v, at::Tensor& out,
356-
// const at::Tensor& b_loc,
357-
// const at::Tensor& b_start_loc,
358-
// const at::Tensor& b_seq_len,
359-
// int max_input_len, int other_kv_index) {
360-
// callDiopi(diopiTokenDecodeAttentionInferenceBatchOne, out, q, k, v, b_loc, b_start_loc,
361-
// b_seq_len, max_input_len, other_kv_index);
362-
// }
363-
364-
// void extIncreFlashAttention(const at::Tensor& q, const at::Tensor& k,
365-
// const at::Tensor& v, at::Tensor& out,
366-
// const int head, const char* layout,
367-
// const c10::optional<at::Tensor>& padding_mask = {},
368-
// const c10::optional<at::Tensor>& atten_mask = {},
369-
// const OptionalIntArray& actual_seq_lengths = {},
370-
// int64_t num_heads = 1, double scale_value = 1.0,
371-
// const std::string& input_layout = "BSH", int64_t num_key_value_heads = 0) {
372-
// callDiopi(diopiIncreFlashAttention, out, q, k, v, padding_mask, atten_mask,
373-
// actual_seq_lengths, num_heads, scale_value, input_layout.c_str(), num_key_value_heads);
374-
// }
375-
376344
void extPromptFlashAttention(at::Tensor& out, const at::Tensor& q,
377345
const at::Tensor& k, const at::Tensor& v,
378346
const at::Tensor& atten_mask,
@@ -412,11 +380,11 @@ void extApplyPenaltyV2(at::Tensor& logits, const at::Tensor& presence_penalty,
412380
}
413381

414382
void extPagedAttention(at::Tensor& out, const at::Tensor& q, const at::Tensor& k, const at::Tensor& v,
415-
const at::IntArrayRef& actual_seq_lengths,
416-
int64_t numHeads, int64_t numKeyValueHeads, int64_t dim,
417-
const at::Tensor& block_table,
418-
int64_t block_size) {
419-
callDiopi(diopiPagedAttention, out, q, k, v, actual_seq_lengths,
383+
const c10::optional<at::Tensor>& atten_mask = {},
384+
const at::IntArrayRef& actual_seq_lengths = {},
385+
int64_t numHeads = 1, int64_t numKeyValueHeads = 1, int64_t dim = 1,
386+
const c10::optional<at::Tensor>& block_table = {}, int64_t block_size = 1) {
387+
callDiopi(diopiPagedAttention, out, q, k, v, atten_mask, actual_seq_lengths,
420388
numHeads, numKeyValueHeads, dim,
421389
block_table, block_size);
422390
}
@@ -501,18 +469,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
501469
m.def("token_softmax_reducev_inference", &extTokenSoftmaxReduceVInference,
502470
"deeplink ext_token_softmax_reducev_inference");
503471
}
504-
// if (&diopiTokenDecodeAttentionInference != nullptr) {
505-
// m.def("token_decode_attention_inference", &extTokenDecodeAttentionInference,
506-
// "deeplink token_decode_attention_inference");
507-
// }
508-
// if (&diopiTokenDecodeAttentionInferenceBatchOne != nullptr) {
509-
// m.def("token_decode_attention_inference_batch_one", &extTokenDecodeAttentionInferenceBatchOne,
510-
// "deeplink token_decode_attention_inference");
511-
// }
512-
// if (&diopiIncreFlashAttention != nullptr) {
513-
// m.def("incre_flash_attention", &extIncreFlashAttention,
514-
// "deeplink incre_flash_attention");
515-
// }
516472
if (&diopiPromptFlashAttention != nullptr) {
517473
m.def("prompt_flash_attention", &extPromptFlashAttention,
518474
"deeplink ext_prompt_flash_attention");

deeplink_ext/patch_lightllm.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -71,39 +71,15 @@ def flash_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len)
7171
ext.prompt_flash_attention(single_out, single_q, single_k, single_v, None, mask, [], head, scale, 2147473647, 0, "BSH", numKeyValueHeads)
7272
return out
7373

74-
# def fused_context_attention(out, q, k, v, mask, b_seq_len, max_input_len, head, numKeyValueHeads, dim):
75-
# batch = b_start_loc.shape[0]
76-
# scale = 1 / math.sqrt(dim)
77-
# mask_key_str = str(batch) + ":" + str(max_input_len)
78-
# if mask_key_str not in mask_cache:
79-
# mask = torch.tril(torch.ones(max_input_len, max_input_len, dtype=torch.bool), diagonal=0).cuda()
80-
# mask = mask.repeat(batch, 1, 1)
81-
# mask = torch.logical_not(mask)
82-
# mask_cache[mask_key_str] = mask
83-
# print(f"cache mask in context attention, batch:seqLen={mask_key_str}")
84-
85-
# mask = mask_cache[mask_key_str]
86-
# ext.prompt_flash_attention(out, q, k, v,
87-
# mask, b_seq_len, max_input_len, head, numKeyValueHeads, dim)
88-
# return out
89-
90-
# context_attention_pack.context_attention_fwd = (
91-
# # flash_context_attention
92-
# fused_context_attention
93-
# )
9474
context_attention_pack.prompt_flash_attention = ext.prompt_flash_attention
9575

9676
def patch_paged_token_attention_inference():
97-
# def paged_token_attention(q, k_cache, v_cache, out, q_head_num, kv_head_num, head_dim, b_seq_len, block_table:torch.Tensor, block_size):
98-
# ext.paged_attention(out, q, k_cache, v_cache, None, None,
99-
# b_seq_len, block_table, q_head_num, kv_head_num,
100-
# 1.0 / math.sqrt(head_dim), "BSH", block_size, 0,
101-
# None, None, None, None, None, None, None, None
102-
# )
103-
# return out
77+
def paged_token_attention(out, q, k_cache, v_cache, b_seq_len, q_head_num,
78+
kv_head_num, head_dim, block_table, block_size):
79+
ext.paged_attention(out, q, k_cache, v_cache, None, b_seq_len, q_head_num,
80+
kv_head_num, head_dim, block_table, block_size)
10481

105-
token_attention_pack.paged_token_attention = ext.paged_attention
106-
82+
token_attention_pack.paged_token_attention = paged_token_attention
10783

10884
def patch_token_attention_inference():
10985
token_attention_pack.token_att_fwd = ext.token_attention_inference

0 commit comments

Comments
 (0)