Skip to content

Commit 90e1450

Browse files
committed
feat: initialize flashinfer planinfo when forward first step.
Signed-off-by: pengtao.156 <[email protected]>
1 parent 8a00ad7 commit 90e1450

22 files changed

+183
-99
lines changed

xllm/core/kernels/cuda/batch_decode.cpp

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ limitations under the License.
1818

1919
namespace xllm::kernel::cuda {
2020

21-
void batch_decode(torch::Tensor float_workspace_buffer,
21+
void batch_decode(const std::string& uri,
22+
torch::Tensor plan_info,
23+
torch::Tensor float_workspace_buffer,
2224
torch::Tensor int_workspace_buffer,
2325
torch::Tensor page_locked_int_workspace_buffer,
2426
torch::Tensor query,
@@ -32,41 +34,6 @@ void batch_decode(torch::Tensor float_workspace_buffer,
3234
torch::Tensor output,
3335
std::optional<torch::Tensor>& output_lse,
3436
bool enable_cuda_graph) {
35-
std::string uri = get_batch_decode_uri(query.scalar_type(),
36-
k_cache.scalar_type(),
37-
output.scalar_type(),
38-
paged_kv_indptr.scalar_type(),
39-
query.size(-1),
40-
v_cache.size(-1),
41-
/*pos_encoding_mode=*/0,
42-
/*use_sliding_window=*/false,
43-
/*use_logits_soft_cap=*/false);
44-
45-
torch::Tensor paged_kv_indptr_host = paged_kv_indptr.to(torch::kCPU);
46-
const int64_t batch_size = paged_kv_last_page_len.size(0);
47-
48-
torch::Tensor empty_q_data =
49-
torch::empty({0}, torch::TensorOptions().dtype(query.scalar_type()));
50-
torch::Tensor empty_kv_data =
51-
torch::empty({0}, torch::TensorOptions().dtype(k_cache.scalar_type()));
52-
53-
auto plan_info = FunctionFactory::get_instance().decode_plan_func(uri).call(
54-
float_workspace_buffer,
55-
int_workspace_buffer,
56-
page_locked_int_workspace_buffer,
57-
paged_kv_indptr_host,
58-
batch_size,
59-
query.size(1), // num_qo_heads
60-
k_cache.size(2), // num_kv_heads
61-
k_cache.size(1), // block_size
62-
enable_cuda_graph,
63-
window_left,
64-
/*logits_soft_cap=*/0.0,
65-
query.size(-1), // head_dim_qk
66-
v_cache.size(-1), // head_dim_vo
67-
empty_q_data,
68-
empty_kv_data);
69-
7037
FunctionFactory::get_instance().decode_run_func(uri).call(
7138
float_workspace_buffer,
7239
int_workspace_buffer,

xllm/core/kernels/cuda/batch_prefill.cpp

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ limitations under the License.
1818

1919
namespace xllm::kernel::cuda {
2020

21-
void batch_prefill(torch::Tensor float_workspace_buffer,
21+
void batch_prefill(const std::string& uri,
22+
torch::Tensor plan_info,
23+
torch::Tensor float_workspace_buffer,
2224
torch::Tensor int_workspace_buffer,
2325
torch::Tensor page_locked_int_workspace_buffer,
2426
torch::Tensor query,
@@ -31,42 +33,6 @@ void batch_prefill(torch::Tensor float_workspace_buffer,
3133
torch::Tensor output,
3234
std::optional<torch::Tensor>& output_lse,
3335
bool enable_cuda_graph) {
34-
std::string uri = get_batch_prefill_uri(/*backend=*/"fa2",
35-
query.scalar_type(),
36-
key.scalar_type(),
37-
output.scalar_type(),
38-
q_cu_seq_lens.scalar_type(),
39-
query.size(-1),
40-
value.size(-1),
41-
/*pos_encoding_mode=*/0,
42-
/*use_sliding_window=*/false,
43-
/*use_logits_soft_cap=*/false,
44-
/*use_fp16_qk_reduction=*/false);
45-
46-
torch::Tensor qo_indptr_host = q_cu_seq_lens.to(torch::kCPU);
47-
torch::Tensor kv_cu_seq_lens_host = kv_cu_seq_lens.to(torch::kCPU);
48-
torch::Tensor kv_len_arr_host =
49-
kv_cu_seq_lens_host.slice(0, 1) - kv_cu_seq_lens_host.slice(0, 0, -1);
50-
const int64_t total_num_rows = qo_indptr_host[-1].item<int64_t>();
51-
const int64_t batch_size = qo_indptr_host.size(0) - 1;
52-
53-
auto plan_info = FunctionFactory::get_instance().prefill_plan_func(uri).call(
54-
float_workspace_buffer,
55-
int_workspace_buffer,
56-
page_locked_int_workspace_buffer,
57-
qo_indptr_host,
58-
kv_cu_seq_lens_host,
59-
kv_len_arr_host,
60-
total_num_rows,
61-
batch_size,
62-
query.size(1), // num_qo_heads
63-
key.size(1), // num_kv_heads
64-
/*page_size=*/1,
65-
enable_cuda_graph,
66-
query.size(-1), // head_dim_qk
67-
value.size(-1), // head_dim_vo
68-
/*causal=*/true);
69-
7036
FunctionFactory::get_instance().prefill_ragged_run_func(uri).call(
7137
float_workspace_buffer,
7238
int_workspace_buffer,

xllm/core/kernels/cuda/cuda_ops_api.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ void reshape_paged_cache(
4545
torch::Tensor key_cache, // [n_blocks, block_size, n_heads, head_dim]
4646
torch::Tensor value_cache);
4747

48-
void batch_prefill(torch::Tensor float_workspace_buffer,
48+
void batch_prefill(const std::string& uri,
49+
torch::Tensor plan_info,
50+
torch::Tensor float_workspace_buffer,
4951
torch::Tensor int_workspace_buffer,
5052
torch::Tensor page_locked_int_workspace_buffer,
5153
torch::Tensor query,
@@ -59,7 +61,9 @@ void batch_prefill(torch::Tensor float_workspace_buffer,
5961
std::optional<torch::Tensor>& output_lse,
6062
bool enable_cuda_graph);
6163

62-
void batch_decode(torch::Tensor float_workspace_buffer,
64+
void batch_decode(const std::string& uri,
65+
torch::Tensor plan_info,
66+
torch::Tensor float_workspace_buffer,
6367
torch::Tensor int_workspace_buffer,
6468
torch::Tensor page_locked_int_workspace_buffer,
6569
torch::Tensor query,

xllm/core/kernels/cuda/utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#pragma once
1717

18+
#include <ATen/DynamicLibrary.h>
1819
#include <torch/torch.h>
1920

2021
#include <string>
@@ -56,4 +57,4 @@ std::string get_batch_decode_uri(torch::ScalarType dtype_q,
5657
bool use_sliding_window,
5758
bool use_logits_soft_cap);
5859

59-
} // namespace xllm::kernel::cuda
60+
} // namespace xllm::kernel::cuda

xllm/core/kernels/ops_api.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@ void batch_prefill(AttentionParams& params) {
153153
params.scale,
154154
params.output);
155155
#elif defined(USE_CUDA)
156-
cuda::batch_prefill(params.float_workspace_buffer,
156+
cuda::batch_prefill(params.uri,
157+
params.plan_info,
158+
params.float_workspace_buffer,
157159
params.int_workspace_buffer,
158160
params.page_locked_int_workspace_buffer,
159161
params.query,
@@ -225,7 +227,9 @@ void batch_decode(AttentionParams& params) {
225227
#elif defined(USE_CUDA)
226228
params.query = params.query.squeeze(1);
227229
params.output = params.output.squeeze(1);
228-
cuda::batch_decode(params.float_workspace_buffer,
230+
cuda::batch_decode(params.uri,
231+
params.plan_info,
232+
params.float_workspace_buffer,
229233
params.int_workspace_buffer,
230234
params.page_locked_int_workspace_buffer,
231235
params.query,

xllm/core/kernels/param.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ struct AttentionParams {
209209
torch::Tensor kv_cu_seq_lens;
210210
torch::Tensor q_cu_seq_lens;
211211
bool enable_cuda_graph = false;
212+
std::string uri;
213+
torch::Tensor plan_info;
212214

213215
// ========== Prefill-specific parameters ==========
214216
// Key tensor. Shape: [num_tokens, num_kv_heads, head_dim_qk] (packed) or

xllm/core/layers/common/qwen2_attention.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,10 @@ Qwen2AttentionImpl::Qwen2AttentionImpl(const ModelContext& context) {
110110
args.sliding_window()));
111111
}
112112

113-
torch::Tensor Qwen2AttentionImpl::forward(
114-
const torch::Tensor& positions,
115-
const torch::Tensor& hidden_states,
116-
const AttentionMetadata& attn_metadata,
117-
KVCache& kv_cache) {
113+
torch::Tensor Qwen2AttentionImpl::forward(const torch::Tensor& positions,
114+
const torch::Tensor& hidden_states,
115+
AttentionMetadata& attn_metadata,
116+
KVCache& kv_cache) {
118117
// 1. qkv projection
119118
auto qkv = qkv_proj_->forward(hidden_states);
120119

xllm/core/layers/common/qwen2_attention.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class Qwen2AttentionImpl : public torch::nn::Module {
4343

4444
torch::Tensor forward(const torch::Tensor& positions,
4545
const torch::Tensor& hidden_states,
46-
const AttentionMetadata& attn_metadata,
46+
AttentionMetadata& attn_metadata,
4747
KVCache& kv_cache);
4848

4949
void load_state_dict(const StateDict& state_dict);

xllm/core/layers/common/qwen2_decoder_layer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ torch::Tensor Qwen2DecoderLayerImpl::forward(
6767
torch::Tensor& x,
6868
std::optional<torch::Tensor>& residual,
6969
torch::Tensor& positions,
70-
const AttentionMetadata& attn_metadata,
70+
AttentionMetadata& attn_metadata,
7171
KVCache& kv_cache,
7272
const ModelInputParams& input_params) {
7373
// Pre-attention norm

xllm/core/layers/common/qwen2_decoder_layer.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class Qwen2DecoderLayerImpl : public torch::nn::Module {
5151
torch::Tensor forward(torch::Tensor& x,
5252
std::optional<torch::Tensor>& residual,
5353
torch::Tensor& positions,
54-
const AttentionMetadata& attn_metadata,
54+
AttentionMetadata& attn_metadata,
5555
KVCache& kv_cache,
5656
const ModelInputParams& input_params);
5757

@@ -64,7 +64,5 @@ class Qwen2DecoderLayerImpl : public torch::nn::Module {
6464
ParallelArgs parallel_args_;
6565
};
6666

67-
using Qwen3DecoderLayerImpl = Qwen2DecoderLayerImpl;
68-
6967
} // namespace layer
7068
} // namespace xllm

0 commit comments

Comments
 (0)