Skip to content

Commit 0cf28c0

Browse files
authored
Merge branch 'main' into syr/attn_tp_config
2 parents 6488677 + 2e43753 commit 0cf28c0

36 files changed

+1973
-94
lines changed

.coderabbit.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@ reviews:
2929
suggested_labels: true
3030
suggested_reviewers: true
3131
poem: false
32+
review_status: false
3233
auto_review:
33-
drafts: true
34+
auto_incremental_review: false
35+
drafts: false
3436
base_branches: ["main", "release/.+"]
3537
knowledge_base:
3638
code_guidelines:
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include "tensorrt_llm/batch_manager/common.h"
20+
#include "tensorrt_llm/batch_manager/llmRequest.h"
21+
#include "tensorrt_llm/runtime/common.h"
22+
23+
#include <utility>
24+
#include <vector>
25+
26+
using SizeType32 = tensorrt_llm::runtime::SizeType32;
27+
using RequestIdType = tensorrt_llm::batch_manager::LlmRequest::RequestIdType;
28+
29+
/// See tensorrt_llm/_torch/pyexecutor/connector.py for details on the Connector API.
30+
31+
namespace tensorrt_llm::batch_manager::kv_connector
32+
{
33+
34+
/// @brief The KV connector manager. This is passed into the C++ KV Cache Manager when adding sequences.
35+
class KvCacheConnectorManager
36+
{
37+
public:
38+
KvCacheConnectorManager() = default;
39+
virtual ~KvCacheConnectorManager() = default;
40+
41+
/// @brief Handle the getNumNewMatchedTokens call inside the C++ KV Cache Manager.
42+
/// @return The number of tokens that can be loaded from remote KV cache.
43+
virtual SizeType32 getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) = 0;
44+
};
45+
46+
} // namespace tensorrt_llm::batch_manager::kv_connector

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#pragma once
1818

19+
#include "tensorrt_llm/batch_manager/kvCacheConnector.h"
1920
#include "tensorrt_llm/batch_manager/kvCacheEventManager.h"
2021
#include "tensorrt_llm/batch_manager/kvCacheType.h"
2122
#include "tensorrt_llm/batch_manager/llmRequest.h" // TODO forward declare
@@ -538,7 +539,8 @@ class WindowBlockManager
538539
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool,
539540
SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
540541
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
541-
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse);
542+
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
543+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager);
542544

543545
~WindowBlockManager();
544546

@@ -835,6 +837,8 @@ class WindowBlockManager
835837
bool mEnablePartialReuse;
836838
// Whether partially matched blocks that are already in use should be copied and reused.
837839
bool mCopyOnPartialReuse;
840+
// The kv cache connector manager
841+
std::shared_ptr<kv_connector::KvCacheConnectorManager> mKvCacheConnectorManager;
838842
};
839843

840844
class BlockManager
@@ -852,7 +856,8 @@ class BlockManager
852856
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF,
853857
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
854858
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
855-
bool copyOnPartialReuse = true);
859+
bool copyOnPartialReuse = true,
860+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
856861

857862
BlockManager(BlockManager const&) = delete;
858863
BlockManager& operator=(BlockManager const&) = delete;
@@ -1287,6 +1292,7 @@ class BaseKVCacheManager
12871292
LlmRequest::RequestIdType requestId, SizeType32 windowSize) const
12881293
= 0;
12891294

1295+
[[nodiscard]] virtual runtime::ITensor::SharedPtr getUniquePrimaryPool() const = 0;
12901296
[[nodiscard]] virtual runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const = 0;
12911297
[[nodiscard]] virtual SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const = 0;
12921298

@@ -1373,7 +1379,8 @@ class KVCacheManager : public BaseKVCacheManager
13731379
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
13741380
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
13751381
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
1376-
bool copyOnpartialReuse = true);
1382+
bool copyOnpartialReuse = true,
1383+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
13771384

13781385
KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
13791386
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
@@ -1383,7 +1390,8 @@ class KVCacheManager : public BaseKVCacheManager
13831390
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
13841391
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
13851392
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
1386-
bool copyOnpartialReuse = true);
1393+
bool copyOnpartialReuse = true,
1394+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
13871395

13881396
KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
13891397
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
@@ -1393,7 +1401,8 @@ class KVCacheManager : public BaseKVCacheManager
13931401
bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
13941402
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
13951403
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
1396-
bool copyOnpartialReuse = true);
1404+
bool copyOnpartialReuse = true,
1405+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
13971406

13981407
KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
13991408
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
@@ -1624,6 +1633,7 @@ class KVCacheManager : public BaseKVCacheManager
16241633
std::vector<SizeType32> getNewlyAllocatedBlockIds(
16251634
LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override;
16261635

1636+
runtime::ITensor::SharedPtr getUniquePrimaryPool() const override;
16271637
runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const override;
16281638

16291639
SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const override

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,8 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
504504
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
505505
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType,
506506
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
507-
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
507+
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
508+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
508509
: mNumLayers{static_cast<SizeType32>(numKvHeadsPerLayer.size())}
509510
, mTokensPerBlock{tokensPerBlock}
510511
, mEventManager{std::move(eventManager)}
@@ -513,6 +514,10 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
513514
{
514515
auto const uniqueWindowSizeToLayers
515516
= BaseKVCacheManager::groupLayersByWindowSize(maxAttentionWindowVec, mNumLayers);
517+
518+
TLLM_CHECK_WITH_INFO(kvCacheConnectorManager == nullptr || uniqueWindowSizeToLayers.size() == 1,
519+
"KV Cache Connector is not supported with multiple window sizes");
520+
516521
auto const numUniqueWindowSizes = static_cast<SizeType32>(uniqueWindowSizeToLayers.size());
517522

518523
mIsVariableWindow = numUniqueWindowSizes > 1;
@@ -530,7 +535,7 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
530535
mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer,
531536
sizePerHead, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream,
532537
onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse,
533-
copyOnPartialReuse);
538+
copyOnPartialReuse, kvCacheConnectorManager);
534539
}
535540

536541
auto const numAllPools = getNumPools();
@@ -572,7 +577,8 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
572577
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
573578
SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks, CacheType cacheType,
574579
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
575-
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
580+
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
581+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
576582
: mDataType{dtype}
577583
, mWindowSize{windowSize}
578584
, mNumPrimaryBlocks{blocksInPrimaryPool}
@@ -596,6 +602,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
596602
, mTotalInputTokens{0.0}
597603
, mEnablePartialReuse{enablePartialReuse}
598604
, mCopyOnPartialReuse{copyOnPartialReuse}
605+
, mKvCacheConnectorManager{std::move(kvCacheConnectorManager)}
599606
{
600607
std::map<SizeType32, SizeType32> numLayersPerPool;
601608

@@ -1188,9 +1195,18 @@ void WindowBlockManager::addSequence(
11881195
auto const prepopulatedPromptLen = loadOrAllocateBlocks(blockKeys, numContextBlocks, sequence, perBlockRetentions);
11891196
mReusedTokens += static_cast<double>(prepopulatedPromptLen);
11901197
mTotalInputTokens += static_cast<double>(uniqueTokens.size());
1191-
llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen, getTokensPerBlock());
1192-
TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d", llmRequest.mRequestId,
1193-
inputLength, prepopulatedPromptLen);
1198+
1199+
SizeType32 numConnectorMatchedTokens = 0;
1200+
1201+
// If we're using a KV cache connector, check if any additional blocks can be loaded.
1202+
if (mKvCacheConnectorManager && !llmRequest.isDummyRequest())
1203+
{
1204+
numConnectorMatchedTokens = mKvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen);
1205+
}
1206+
1207+
llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numConnectorMatchedTokens, getTokensPerBlock());
1208+
TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d",
1209+
llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numConnectorMatchedTokens);
11941210
}
11951211

11961212
// There are two versions of BlockManager::addSequence function.
@@ -1206,6 +1222,13 @@ void BlockManager::addSequence(
12061222
void WindowBlockManager::addSequence(
12071223
GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock)
12081224
{
1225+
if (mKvCacheConnectorManager)
1226+
{
1227+
TLLM_LOG_WARNING(
1228+
"KV Cache Connector specified when block reuse is disabled. The KV Cache Connector will be "
1229+
"ignored.");
1230+
}
1231+
12091232
auto const requestId = sequence.getRequestId();
12101233
auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector<BlockPtr>{});
12111234
TLLM_CHECK(emplaceDone);
@@ -1618,12 +1641,13 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
16181641
SizeType32 sinkTokenLength, int64_t stream, std::optional<runtime::SizeType32> maxSequenceLength,
16191642
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
16201643
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
1621-
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
1644+
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
1645+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
16221646
: KVCacheManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth,
16231647
maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
16241648
std::make_shared<runtime::CudaStream>(reinterpret_cast<cudaStream_t>(stream)), maxSequenceLength,
16251649
enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, eventManager, enablePartialReuse,
1626-
copyOnPartialReuse)
1650+
copyOnPartialReuse, kvCacheConnectorManager)
16271651
{
16281652
}
16291653

@@ -1634,7 +1658,8 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
16341658
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
16351659
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
16361660
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
1637-
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
1661+
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
1662+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
16381663
: mMaxBeamWidth(maxBeamWidth)
16391664
, mDataType(dtype)
16401665
, mMaxAttentionWindow(*std::max_element(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end()))
@@ -1644,7 +1669,7 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
16441669
, mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
16451670
std::move(stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype,
16461671
mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager),
1647-
enablePartialReuse, copyOnPartialReuse)
1672+
enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager))
16481673
// disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case
16491674
, mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse}
16501675
{
@@ -1668,11 +1693,12 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size
16681693
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
16691694
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
16701695
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
1671-
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
1696+
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
1697+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
16721698
: KVCacheManager(std::vector<SizeType32>(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
16731699
maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
16741700
std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority,
1675-
std::move(eventManager), enablePartialReuse, copyOnPartialReuse)
1701+
std::move(eventManager), enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager))
16761702
{
16771703
}
16781704

@@ -2383,6 +2409,13 @@ std::vector<SizeType32> KVCacheManager::getNewlyAllocatedBlockIds(
23832409
return mBlockManager.getNewlyAllocatedBlockIds(getSequence(requestId), windowSize);
23842410
}
23852411

2412+
runtime::ITensor::SharedPtr KVCacheManager::getUniquePrimaryPool() const
2413+
{
2414+
TLLM_CHECK_WITH_INFO(mBlockManager.getWindowSizesMetadata().size() == 1,
2415+
"getUniquePrimaryPool is only supported for a single window size");
2416+
return mBlockManager.getPrimaryPool(0);
2417+
}
2418+
23862419
runtime::ITensor::SharedPtr KVCacheManager::getPrimaryPool(SizeType32 layer_idx) const
23872420
{
23882421
return mBlockManager.getPrimaryPool(mBlockManager.getLayerPoolIdx(layer_idx));
@@ -2462,4 +2495,5 @@ SizeType32 KVCacheManager::calculateMaxBlockRequirements(SizeType32 inputLength,
24622495
auto const leftoverBlockCapacity = blockCapacity - outputBlockRequirements;
24632496
return std::min(outputLength + leftoverBlockCapacity * tokensPerBlock, inputLength + outputLength);
24642497
}
2498+
24652499
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

cpp/tensorrt_llm/nanobind/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ set(SRCS
77
batch_manager/algorithms.cpp
88
batch_manager/bindings.cpp
99
batch_manager/cacheTransceiver.cpp
10+
batch_manager/kvCacheConnector.cpp
1011
batch_manager/kvCacheManager.cpp
1112
batch_manager/llmRequest.cpp
1213
executor/bindings.cpp
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#include "tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h"
19+
20+
#include <nanobind/trampoline.h>
21+
#include <torch/extension.h>
22+
23+
namespace
24+
{
25+
using KvCacheConnectorManager = tensorrt_llm::batch_manager::kv_connector::KvCacheConnectorManager;
26+
27+
namespace tb = tensorrt_llm::batch_manager;
28+
29+
class PyKvCacheConnectorManager : KvCacheConnectorManager
30+
{
31+
public:
32+
NB_TRAMPOLINE(KvCacheConnectorManager, 1);
33+
34+
SizeType32 getNumNewMatchedTokens(tb::LlmRequest const& request, SizeType32 numComputedTokens) override
35+
{
36+
NB_OVERRIDE_PURE_NAME("get_num_new_matched_tokens", getNumNewMatchedTokens, request, numComputedTokens);
37+
}
38+
};
39+
40+
} // namespace
41+
42+
void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(nb::module_& m)
43+
{
44+
nb::class_<tb::kv_connector::KvCacheConnectorManager, PyKvCacheConnectorManager>(m, "KvCacheConnectorManager")
45+
.def(nb::init<>())
46+
.def("get_num_new_matched_tokens", &tb::kv_connector::KvCacheConnectorManager::getNumNewMatchedTokens,
47+
nb::arg("request"), nb::arg("num_computed_tokens"));
48+
}

0 commit comments

Comments
 (0)