Skip to content

Commit f238031

Browse files
authored
Merge branch 'main' into kv-cache-dce
2 parents 5b27f3c + 907c180 commit f238031

File tree

49 files changed

+2028
-559
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2028
-559
lines changed

.github/CODEOWNERS

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,103 @@
1414
/tensorrt_llm/_torch/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs
1515
/tensorrt_llm/examples/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs
1616

17+
## TensorRT-LLM Pytorch - Speculative Decoding
18+
/tensorrt_llm/_torch/speculative @NVIDIA/trt-llm-torch-spec-decoding
19+
20+
## TensorRT-LLM Pytorch - Graph Compiler
21+
/tensorrt_llm/_torch/compilation @NVIDIA/trt-llm-torch-graph-compiler
22+
/tensorrt_llm/_torch/custom_ops @NVIDIA/trt-llm-torch-graph-compiler
23+
/tensorrt_llm/_torch/autotuner.py @NVIDIA/trt-llm-torch-graph-compiler
24+
/tests/unittest/_torch/compilation @NVIDIA/trt-llm-torch-graph-compiler
25+
/tests/unittest/_torch/multi_gpu/test_ar_residual_norm.py @NVIDIA/trt-llm-torch-graph-compiler
26+
/tests/unittest/_torch/multi_gpu/test_user_buffers.py @NVIDIA/trt-llm-torch-graph-compiler
27+
/tests/unittest/_torch/test_custom_ops.py @NVIDIA/trt-llm-torch-graph-compiler
28+
/tests/unittest/_torch/test_autotuner.py @NVIDIA/trt-llm-torch-graph-compiler
29+
30+
## TensorRT-LLM Pytorch - Attention
31+
/tensorrt_llm/_torch/attention_backend @NVIDIA/trt-llm-torch-attention-devs
32+
/tensorrt_llm/_torch/modules/attention.py @NVIDIA/trt-llm-torch-attention-devs
33+
34+
## TensorRT-LLM Pytorch - Modules
35+
/tensorrt_llm/_torch/modules @NVIDIA/trt-llm-torch-modules
36+
37+
38+
## TensorRT-LLM Pytorch Models
39+
/tensorrt_llm/_torch/models @NVIDIA/trt-llm-torch-models-devs
40+
41+
### TensorRT-LLM Pytorch - Models - Gemma
42+
/tensorrt_llm/_torch/models/modeling_gemma3.py @NVIDIA/trt-llm-torch-models-gemma-devs @NVIDIA/trt-llm-torch-models-devs
43+
/tensorrt_llm/_torch/models/modeling_gemma3vl.py @NVIDIA/trt-llm-torch-models-gemma-devs @NVIDIA/trt-llm-torch-models-devs
44+
/tests/unittest/_torch/modeling/test_modeling_gemma3.py @NVIDIA/trt-llm-torch-models-gemma-devs @NVIDIA/trt-llm-torch-models-devs
45+
46+
### TensorRT-LLM Pytorch - Models - Mistral & Mixtral
47+
/tensorrt_llm/_torch/models/modeling_mistral.py @NVIDIA/trt-llm-torch-models-mistral-devs @NVIDIA/trt-llm-torch-models-devs
48+
/tensorrt_llm/_torch/models/modeling_pixtral.py @NVIDIA/trt-llm-torch-models-mistral-devs @NVIDIA/trt-llm-torch-models-devs
49+
/tests/unittest/_torch/modeling/test_modeling_mistral.py @NVIDIA/trt-llm-torch-models-mistral-devs @NVIDIA/trt-llm-torch-models-devs
50+
/tests/unittest/_torch/modeling/test_modeling_mixtral.py @NVIDIA/trt-llm-torch-models-mistral-devs @NVIDIA/trt-llm-torch-models-devs
51+
52+
### TensorRT-LLM Pytorch - Models - CLIP
53+
/tensorrt_llm/_torch/models/modeling_clip.py @NVIDIA/trt-llm-torch-models-clip-devs @NVIDIA/trt-llm-torch-models-devs
54+
/tests/unittest/_torch/modeling/test_modeling_clip.py @NVIDIA/trt-llm-torch-models-clip-devs @NVIDIA/trt-llm-torch-models-devs
55+
56+
### TensorRT-LLM Pytorch - Models - Phi
57+
/tensorrt_llm/_torch/models/modeling_phi3.py @NVIDIA/trt-llm-torch-models-phi-devs @NVIDIA/trt-llm-torch-models-devs
58+
/tensorrt_llm/_torch/models/modeling_phi4mm.py @NVIDIA/trt-llm-torch-models-phi-devs @NVIDIA/trt-llm-torch-models-devs
59+
/tests/unittest/_torch/modeling/test_modeling_phi3.py @NVIDIA/trt-llm-torch-models-phi-devs @NVIDIA/trt-llm-torch-models-devs
60+
/tests/integration/defs/examples/test_multimodal.py @NVIDIA/trt-llm-torch-models-phi-devs @NVIDIA/trt-llm-torch-models-devs
61+
62+
### TensorRT-LLM Pytorch - Models - Deepseek
63+
/tensorrt_llm/_torch/models/modeling_deepseekv3.py @NVIDIA/trt-llm-torch-models-deepseek-devs @NVIDIA/trt-llm-torch-models-devs
64+
/tests/unittest/_torch/modeling/test_modeling_deepseek.py @NVIDIA/trt-llm-torch-models-deepseek-devs @NVIDIA/trt-llm-torch-models-devs
65+
66+
### TensorRT-LLM Pytorch - Models - Llama
67+
/tensorrt_llm/_torch/models/modeling_mllama.py @NVIDIA/trt-llm-torch-models-llama-devs @NVIDIA/trt-llm-torch-models-devs
68+
/tensorrt_llm/_torch/models/modeling_llama.py @NVIDIA/trt-llm-torch-models-llama-devs @NVIDIA/trt-llm-torch-models-devs
69+
/tensorrt_llm/_torch/models/modeling_llama_min_latency.py @NVIDIA/trt-llm-torch-models-llama-devs @NVIDIA/trt-llm-torch-models-devs
70+
/tests/unittest/_torch/modeling/test_modeling_llama.py @NVIDIA/trt-llm-torch-models-llama-devs @NVIDIA/trt-llm-torch-models-devs
71+
/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py @NVIDIA/trt-llm-torch-models-llama-devs @NVIDIA/trt-llm-torch-models-devs
72+
73+
### TensorRT-LLM Pytorch - Models - Qwen
74+
/tensorrt_llm/_torch/models/modeling_qwen3_moe.py @NVIDIA/trt-llm-torch-models-qwen-devs @NVIDIA/trt-llm-torch-models-devs
75+
/tensorrt_llm/_torch/models/modeling_qwen3.py @NVIDIA/trt-llm-torch-models-qwen-devs @NVIDIA/trt-llm-torch-models-devs
76+
/tensorrt_llm/_torch/models/modeling_qwen2vl.py @NVIDIA/trt-llm-torch-models-qwen-devs @NVIDIA/trt-llm-torch-models-devs
77+
/tensorrt_llm/_torch/models/modeling_qwen.py @NVIDIA/trt-llm-torch-models-qwen-devs @NVIDIA/trt-llm-torch-models-devs
78+
/tensorrt_llm/_torch/models/modeling_qwen_moe.py @NVIDIA/trt-llm-torch-models-qwen-devs @NVIDIA/trt-llm-torch-models-devs
79+
80+
### TensorRT-LLM Pytorch - Models - VLMs
81+
/tensorrt_llm/_torch/models/modeling_vila.py @NVIDIA/trt-llm-torch-models-vlm-devs @NVIDIA/trt-llm-torch-models-devs
82+
/tests/unittest/_torch/modeling/test_modeling_vila.py @NVIDIA/trt-llm-torch-models-vlm-devs @NVIDIA/trt-llm-torch-models-devs
83+
/tensorrt_llm/_torch/models/modeling_pixtral.py @NVIDIA/trt-llm-torch-models-vlm-devs @NVIDIA/trt-llm-torch-models-devs
84+
/tests/unittest/_torch/modeling/test_modeling_pixtral.py @NVIDIA/trt-llm-torch-models-vlm-devs @NVIDIA/trt-llm-torch-models-devs
85+
86+
### TensorRT-LLM Pytorch - Models - Nemotron
87+
/tensorrt_llm/_torch/models/modeling_nemotron_nas.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
88+
/tensorrt_llm/_torch/models/modeling_nemotron_h.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
89+
/tensorrt_llm/_torch/models/modeling_nemotron_nas.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
90+
/tensorrt_llm/_torch/pyexecutor/resource_manager.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-runtime-devs @NVIDIA/trt-llm-torch-models-devs
91+
/tensorrt_llm/_torch/modules/mamba @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
92+
/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
93+
/tests/unittest/_torch/modeling/test_modeling_nemotron.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
94+
/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
95+
/tests/unittest/_torch/modeling/test_modeling_nemotron_nas.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
96+
97+
## TensorRT-LLM - PEFT
98+
/tensorrt_llm/_torch/peft @NVIDIA/trt-llm-torch-peft
99+
/tensorrt_llm/lora_manager.py @NVIDIA/trt-llm-torch-peft
100+
/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp @NVIDIA/trt-llm-torch-peft
101+
/cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h @NVIDIA/trt-llm-torch-peft
102+
/cpp/tensorrt_llm/runtime/loraCache.cpp @NVIDIA/trt-llm-torch-peft
103+
/cpp/include/tensorrt_llm/runtime/loraCache.h @NVIDIA/trt-llm-torch-peft
104+
/cpp/tensorrt_llm/runtime/loraModule.cpp @NVIDIA/trt-llm-torch-peft
105+
/cpp/include/tensorrt_llm/runtime/loraModule.h @NVIDIA/trt-llm-torch-peft
106+
/cpp/tensorrt_llm/runtime/loraManager.cpp @NVIDIA/trt-llm-torch-peft
107+
/cpp/tensorrt_llm/runtime/loraManager.h @NVIDIA/trt-llm-torch-peft
108+
/cpp/tensorrt_llm/runtime/loraUtils.cpp @NVIDIA/trt-llm-torch-peft
109+
/cpp/tensorrt_llm/runtime/loraUtils.h @NVIDIA/trt-llm-torch-peft
110+
111+
## TensorRT-LLM - Triton backend
112+
/triton_backend @NVIDIA/trt-llm-triton-backend-devs
113+
17114
## TensorRT-LLM trtllm-bench Reviewers
18115
/tensorrt_llm/bench @NVIDIA/trtllm-bench-reviewers
19116
/tensorrt_llm/commands/bench.py @NVIDIA/trtllm-bench-reviewers
@@ -23,6 +120,35 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers
23120
/tensorrt_llm/llmapi @NVIDIA/trt-llm-llmapi-devs
24121
/tensorrt_llm/executor @NVIDIA/trt-llm-llmapi-devs
25122

123+
## TensorRT-LLM LLM Disaggregated
124+
/examples/disaggregated @NVIDIA/trt-llm-disagg-devs
125+
/tensorrt_llm/disaggregated_params.py @NVIDIA/trt-llm-disagg-devs
126+
/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @NVIDIA/trt-llm-disagg-devs
127+
/tensorrt_llm/_torch/pyexecutor/py_executor.py @NVIDIA/trt-llm-disagg-devs
128+
/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @NVIDIA/trt-llm-disagg-devs
129+
/cpp/tensorrt_llm/batch_manager/cacheFormatter.h @NVIDIA/trt-llm-disagg-devs
130+
/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp @NVIDIA/trt-llm-disagg-devs
131+
/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h @NVIDIA/trt-llm-disagg-devs
132+
/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp @NVIDIA/trt-llm-disagg-devs
133+
/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @NVIDIA/trt-llm-disagg-devs
134+
/cpp/tensorrt_llm/batch_manager/dataTransceiver.h @NVIDIA/trt-llm-disagg-devs
135+
/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp @NVIDIA/trt-llm-disagg-devs
136+
/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h @NVIDIA/trt-llm-disagg-devs
137+
138+
## TensorRT-LLM Infra
139+
140+
### CI
141+
/jenkins @NVIDIA/trt-llm-ci-infra-devs @NVIDIA/trt-llm-infra-devs
142+
### Setup
143+
/docker @NVIDIA/trt-llm-setup-infra-devs @NVIDIA/trt-llm-infra-devs
144+
### Github workflows
145+
/tensorrt_llm/.github @NVIDIA/trt-llm-gh-workflows-infra-devs @NVIDIA/trt-llm-infra-devs
146+
/tensorrt_llm/.coderabbit.yaml @NVIDIA/trt-llm-gh-workflows-infra-devs @NVIDIA/trt-llm-infra-devs
147+
148+
## TensorRT-LLM - Docs
149+
/docs @NVIDIA/trt-llm-doc-owners
150+
/examples @NVIDIA/trt-llm-doc-owners
151+
26152
# The rule below requires that any PR modifying public APIs must be approved by at least one member
27153
# of the NVIDIA/trt-llm-committed-api-review-committee or NVIDIA/trt-llm-noncommitted-api-review-committee team.
28154
# This approval is mandatory regardless of other approvals the PR may have received. Without approval

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@ TensorRT-LLM
1818
<div align="left">
1919

2020
## Tech Blogs
21+
* [08/06] Running a High Performance GPT-OSS-120B Inference Server with TensorRT-LLM
22+
[➡️ link](./docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md)
23+
2124

2225
* [08/01] Scaling Expert Parallelism in TensorRT-LLM (Part 2: Performance Status and Optimization)
2326
[➡️ link](./docs/source/blogs/tech_blog/blog8_Scaling_Expert_Parallelism_in_TensorRT-LLM_part2.md)
2427

2528
* [07/26] N-Gram Speculative Decoding in TensorRT‑LLM
26-
[➡️ link](./docs/source/blogs/tech_blog/blog_7_NGram_performance_Analysis_And_Auto_Enablement.md)
29+
[➡️ link](./docs/source/blogs/tech_blog/blog7_NGram_performance_Analysis_And_Auto_Enablement.md)
2730

2831
* [06/19] Disaggregated Serving in TensorRT-LLM
2932
[➡️ link](./docs/source/blogs/tech_blog/blog5_Disaggregated_Serving_in_TensorRT-LLM.md)

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2334,6 +2334,11 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
23342334
void createSerializedResult(
23352335
std::vector<char>& serializedResult, bool& isFinal, bool useFastLogits = false, int32_t mpiWorldRank = 0);
23362336

2337+
/// @brief Check if the (user-provided) tokens fall within the vocabulary range.
2338+
/// @details Currently only supports invocation before context phase is completed.
2339+
/// @return True if tokens are within range.
2340+
bool checkTokenIdRange(SizeType32 vocabSize);
2341+
23372342
void validate(SizeType32 maxInputLen, SizeType32 maxSequenceLen, SizeType32 maxDraftLen, SizeType32 vocabSizePadded,
23382343
std::optional<SizeType32> maxEncoderInputLen = std::nullopt, bool enableKVCacheReuse = false);
23392344

cpp/kernels/xqa/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ set(CMAKE_CUDA_ARCHITECTURES 89-real 90a-real)
2323
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
2424

2525
option(BUILD_XQA_TESTS "Build XQA tests" OFF)
26+
set(PAGED_KV_CACHE_LAYOUT
27+
"0"
28+
CACHE STRING "Paged KV cache format (0 for XQA Original, 1 for VLLM)")
29+
add_definitions(-DPAGED_KV_CACHE_LAYOUT=${PAGED_KV_CACHE_LAYOUT})
2630

2731
# todo: remove include_directories link_directories and link libs like
2832
# CUDA::cuda_driver CUDA::cudart CUDA::nvrtc
@@ -37,7 +41,7 @@ set(CMAKE_CXX_FLAGS
3741
"${CMAKE_CXX_FLAGS} -march=haswell -Wfatal-errors -Wreturn-type -Wall -Wextra -Wno-unknown-pragmas"
3842
)
3943
set(CMAKE_CUDA_FLAGS
40-
"${CMAKE_CUDA_FLAGS} -allow-unsupported-compiler --expt-relaxed-constexpr -t 0 -res-usage"
44+
"${CMAKE_CUDA_FLAGS} -allow-unsupported-compiler --expt-relaxed-constexpr -t 0 -res-usage -DPAGED_KV_CACHE_LAYOUT=${PAGED_KV_CACHE_LAYOUT}"
4145
)
4246
set(CUDA_PTXAS_FLAGS "-warn-lmem-usage -warn-double-usage -warn-spills"
4347
)# -Werror -v

cpp/kernels/xqa/README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ You need to install libgtest-dev and libeigen3-dev before building. To build, us
1616

1717
- ```mkdir build```
1818
- ```cd build```
19-
- ```cmake .. -DCMAKE_BUILD_TYPE=Release```
19+
- ```cmake .. -DCMAKE_BUILD_TYPE=Release -DBUILD_XQA_TESTS=ON```
2020
- ```cmake --build . -j```
2121

2222
To run unit tests, run `./unitTests`. There are a few runtime options that can be controlled with environment variables:
@@ -25,6 +25,16 @@ To run unit tests, run `./unitTests`. There are a few runtime options that can b
2525
- XQA_USE_QGMMA: On Hopper, we try to use TMA+QGMMA kernel (mha_sm90.cu) by default if possible. To force using mha.cu, set this to 0.
2626
- XQA_NB_SUB_SEQ: The number of CUDA thread blocks used to handle one K/V head. We have reasonable default but if you want to change it manually, use this variable.
2727

28+
## Support for VLLM Paged KV-Cache
29+
When `PAGED_KV_CACHE_LAYOUT=1` is enabled, XQA supports VLLM-style KV pool input with split-wise KV-pool and sequence-first memory layout.
30+
To build and test with this feature enabled, run the following commands:
31+
32+
- ```mkdir build```
33+
- ```cd build```
34+
- ```cmake .. -DCMAKE_BUILD_TYPE=Release -DBUILD_XQA_TESTS=ON -DPAGED_KV_CACHE_LAYOUT=1```
35+
- ```cmake --build . -j```
36+
- ```./unitTests```
37+
2838
## Generation cubins used in TensorRT-LLM
2939

3040
Run `gen_cubin.py` in the repo workspace.

cpp/kernels/xqa/defines.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,15 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena
9797
#define USE_PAGED_KV_CACHE (TOKENS_PER_PAGE > 0)
9898
#endif
9999

100+
// Paged KV Cache Format
101+
// 0 - XQA Original
102+
// 1 - separate K and V cache pools, each with layout (batch, seq_len, head, head_elem) for VLLM/SGLang
103+
#ifdef USE_PAGED_KV_CACHE
104+
#ifndef PAGED_KV_CACHE_LAYOUT
105+
#define PAGED_KV_CACHE_LAYOUT 0
106+
#endif
107+
#endif
108+
100109
// don't modify
101110
#define USE_BEAM_SEARCH (BEAM_WIDTH > 1)
102111

cpp/kernels/xqa/mha.cu

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,17 +1671,33 @@ CUBIN_EXPORT __global__
16711671
uint32_t const dstHeadOffset = 0;
16721672
uint32_t const seqOffset = ctaTile.x * seqIter + warpTile.x * warpIdx.x;
16731673
#if USE_PAGED_KV_CACHE
1674+
#if PAGED_KV_CACHE_LAYOUT == 1
1675+
uint32_t const idxHeadBeg = (seqOffset % tokensPerPage) * nbKHeads + idxHeadGrp;
1676+
1677+
#else
16741678
uint32_t const idxHeadBeg = tokensPerPage * idxHeadGrp + seqOffset % tokensPerPage;
1679+
#endif
16751680
#if BEAM_WIDTH == 1
1681+
#if PAGED_KV_CACHE_LAYOUT == 1
1682+
HeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerWarpTile> const src{
1683+
cacheList.kCacheVLLM, pageIdx, nbKHeads, idxHeadBeg};
1684+
#else
16761685
HeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerWarpTile> const src{
16771686
cacheList.pool, pageIdx, nbKHeads, idxHeadBeg};
1687+
#endif
16781688
#else
1679-
IndexedHeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerWarpTile> const src{
1689+
IndexedHeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerWarpTile> const src
1690+
{
16801691
/*indices=*/smem.gemm0CacheIndir[warpIdx.x].data,
1681-
/*pool=*/cacheList.pool,
1682-
/*pageIndices=*/smem.kCachePages[warpIdx.x].data,
1683-
/*nbKHeads=*/nbKHeads,
1684-
/*offset=*/idxHeadBeg};
1692+
#if PAGED_KV_CACHE_LAYOUT == 1
1693+
/*pool=*/cacheList.kCacheVLLM,
1694+
#else
1695+
/*pool=*/cacheList.pool,
1696+
#endif
1697+
/*pageIndices=*/smem.kCachePages[warpIdx.x].data,
1698+
/*nbKHeads=*/nbKHeads,
1699+
/*offset=*/idxHeadBeg
1700+
};
16851701
#endif
16861702
#else
16871703
uint32_t const idxHeadBeg = cacheKSeqBaseOffset + seqOffset;
@@ -1990,17 +2006,33 @@ CUBIN_EXPORT __global__
19902006
uint32_t const seqOffset = ctaTile.x * seqIter + warpTile.x * nbXTilesPerXIter * xIter
19912007
+ cacheVTileSeqStride * vIter + cacheVTileSeqLen * warpGrpIdx;
19922008
#if USE_PAGED_KV_CACHE
2009+
#if PAGED_KV_CACHE_LAYOUT == 1
2010+
uint32_t const idxHeadBeg = (seqOffset % tokensPerPage) * nbKHeads + idxHeadGrp;
2011+
2012+
#else
19932013
uint32_t const idxHeadBeg = tokensPerPage * idxHeadGrp + seqOffset % tokensPerPage;
2014+
#endif
19942015
#if BEAM_WIDTH == 1
2016+
#if PAGED_KV_CACHE_LAYOUT == 1
2017+
HeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerVTile> const src{
2018+
cacheList.vCacheVLLM, pageIdx, nbKHeads, idxHeadBeg};
2019+
#else
19952020
HeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerVTile> const src{
19962021
cacheList.pool, pageIdx, nbKHeads, idxHeadBeg};
2022+
#endif
19972023
#else
1998-
IndexedHeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerVTile> const src{
2024+
IndexedHeadPtr<GMemCacheHead const, tokensPerPage, nbPagesPerVTile> const src
2025+
{
19992026
/*indices=*/smem.gemm1CacheIndir[grpLoadV ? warpGrpIdx : warpIdx.x].data,
2000-
/*pool=*/cacheList.pool,
2001-
/*pageIndices=*/smem.vCachePages[grpLoadV ? warpGrpIdx : warpIdx.x].data,
2002-
/*nbKHeads=*/nbKHeads,
2003-
/*offset=*/idxHeadBeg};
2027+
#if PAGED_KV_CACHE_LAYOUT == 1
2028+
/*pool=*/cacheList.vCacheVLLM,
2029+
#else
2030+
/*pool=*/cacheList.pool,
2031+
#endif
2032+
/*pageIndices=*/smem.vCachePages[grpLoadV ? warpGrpIdx : warpIdx.x].data,
2033+
/*nbKHeads=*/nbKHeads,
2034+
/*offset=*/idxHeadBeg
2035+
};
20042036
#endif
20052037
#else
20062038
uint32_t const idxHeadBeg = cacheVSeqBaseOffset + seqOffset;
@@ -2636,7 +2668,11 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
26362668
InputHead const* q,
26372669
#endif
26382670
#if USE_PAGED_KV_CACHE
2671+
#if PAGED_KV_CACHE_LAYOUT == 1
2672+
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
2673+
#else
26392674
GMemCacheHead* pool, // global pool of pages
2675+
#endif
26402676
KVCachePageIndex const*
26412677
kvCachePageList, // device pointer. shape: KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq].
26422678
#else
@@ -2702,7 +2738,11 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
27022738
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
27032739
#if USE_PAGED_KV_CACHE
27042740
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
2741+
#if PAGED_KV_CACHE_LAYOUT == 1
2742+
KVCacheList<true> const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, maxNbPagesPerSeq};
2743+
#else
27052744
KVCacheList<true> const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq};
2745+
#endif
27062746
cudaLaunchKernelEx(&launchCfg, kernel_mha,
27072747
#if SPEC_DEC
27082748
qSeqLen, nbKHeads, headGrpSize, qCuSeqLens,

cpp/kernels/xqa/mha.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
102102
InputHead const* q,
103103
#endif
104104
#if USE_PAGED_KV_CACHE
105+
#if PAGED_KV_CACHE_LAYOUT == 1
106+
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
107+
#else
105108
GMemCacheHead* pool, // global pool of pages
109+
#endif
106110
KVCachePageIndex const*
107111
kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
108112
#else
@@ -137,7 +141,11 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
137141
InputHead const* q,
138142
#endif
139143
#if USE_PAGED_KV_CACHE
144+
#if PAGED_KV_CACHE_LAYOUT == 1
145+
GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM,
146+
#else
140147
GMemCacheHead* pool, // global pool of pages
148+
#endif
141149
KVCachePageIndex const*
142150
kvCachePageList, // device pointer. shape: KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq].
143151
#else

0 commit comments

Comments
 (0)