@@ -18,7 +18,9 @@ limitations under the License.
1818
1919namespace xllm ::kernel::cuda {
2020
21- void batch_decode (torch::Tensor float_workspace_buffer,
21+ void batch_decode (const std::string& uri,
22+ torch::Tensor plan_info,
23+ torch::Tensor float_workspace_buffer,
2224 torch::Tensor int_workspace_buffer,
2325 torch::Tensor page_locked_int_workspace_buffer,
2426 torch::Tensor query,
@@ -32,41 +34,6 @@ void batch_decode(torch::Tensor float_workspace_buffer,
3234 torch::Tensor output,
3335 std::optional<torch::Tensor>& output_lse,
3436 bool enable_cuda_graph) {
35- std::string uri = get_batch_decode_uri (query.scalar_type (),
36- k_cache.scalar_type (),
37- output.scalar_type (),
38- paged_kv_indptr.scalar_type (),
39- query.size (-1 ),
40- v_cache.size (-1 ),
41- /* pos_encoding_mode=*/ 0 ,
42- /* use_sliding_window=*/ false ,
43- /* use_logits_soft_cap=*/ false );
44-
45- torch::Tensor paged_kv_indptr_host = paged_kv_indptr.to (torch::kCPU );
46- const int64_t batch_size = paged_kv_last_page_len.size (0 );
47-
48- torch::Tensor empty_q_data =
49- torch::empty ({0 }, torch::TensorOptions ().dtype (query.scalar_type ()));
50- torch::Tensor empty_kv_data =
51- torch::empty ({0 }, torch::TensorOptions ().dtype (k_cache.scalar_type ()));
52-
53- auto plan_info = FunctionFactory::get_instance ().decode_plan_func (uri).call (
54- float_workspace_buffer,
55- int_workspace_buffer,
56- page_locked_int_workspace_buffer,
57- paged_kv_indptr_host,
58- batch_size,
59- query.size (1 ), // num_qo_heads
60- k_cache.size (2 ), // num_kv_heads
61- k_cache.size (1 ), // block_size
62- enable_cuda_graph,
63- window_left,
64- /* logits_soft_cap=*/ 0.0 ,
65- query.size (-1 ), // head_dim_qk
66- v_cache.size (-1 ), // head_dim_vo
67- empty_q_data,
68- empty_kv_data);
69-
7037 FunctionFactory::get_instance ().decode_run_func (uri).call (
7138 float_workspace_buffer,
7239 int_workspace_buffer,
0 commit comments