Skip to content

Commit b27a7b5

Browse files
committed
[None][chore] rm executor config in kv cache connector
Signed-off-by: leslie-fang25 <[email protected]>
1 parent 31b0f0f commit b27a7b5

File tree

3 files changed

+9
-14
lines changed

3 files changed

+9
-14
lines changed

examples/llm-api/llm_kv_cache_connector.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from tensorrt_llm import LLM, SamplingParams, logger
1515
from tensorrt_llm._torch.pyexecutor.kv_cache_connector import (
1616
KvCacheConnectorScheduler, KvCacheConnectorWorker, SchedulerOutput)
17-
from tensorrt_llm.bindings.executor import ExecutorConfig
1817
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
1918
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig
2019

@@ -34,8 +33,8 @@ class PersistentKvCacheConnectorMetadata:
3433

3534
class PersistentKvCacheConnectorWorker(KvCacheConnectorWorker):
3635

37-
def __init__(self, executor_config: ExecutorConfig):
38-
super().__init__(executor_config)
36+
def __init__(self):
37+
super().__init__()
3938

4039
self.kv_cache_tensor = None
4140

@@ -81,10 +80,10 @@ def get_finished(
8180

8281
class PersistentKvCacheConnectorLeader(KvCacheConnectorScheduler):
8382

84-
def __init__(self, executor_config: ExecutorConfig):
85-
super().__init__(executor_config)
83+
def __init__(self, tokens_per_block):
84+
super().__init__()
8685

87-
self.block_size = self._config.tokens_per_block
86+
self.block_size = tokens_per_block
8887
self.pending_loads = {}
8988

9089
self.cache_folder = os.environ.get(CONNECTOR_CACHE_FOLDER_KEY,

tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444

4545
from tensorrt_llm._utils import mpi_allgather, mpi_broadcast, mpi_rank
4646
from tensorrt_llm.bindings import LlmRequestState
47-
from tensorrt_llm.bindings.executor import ExecutorConfig
4847
from tensorrt_llm.bindings.internal.batch_manager import \
4948
KvCacheConnectorManager as KvCacheConnectorManagerCpp
5049
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
@@ -81,8 +80,7 @@ class SchedulerOutput:
8180

8281
class KvCacheConnectorWorker(ABC):
8382

84-
def __init__(self, config: ExecutorConfig):
85-
self._config = config
83+
def __init__(self):
8684
self._metadata = None
8785
super().__init__()
8886

@@ -162,8 +160,7 @@ def get_finished(
162160

163161
class KvCacheConnectorScheduler(ABC):
164162

165-
def __init__(self, executor_config: ExecutorConfig):
166-
self._config = executor_config
163+
def __init__(self):
167164
super().__init__()
168165

169166
@abstractmethod

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,12 +409,11 @@ def create_py_executor(
409409
# In this case, the worker may be dependent on the scheduler, or vice-versa.
410410
# To deal with cases like this, we instantiate them both concurrently.
411411
with ThreadPoolExecutor(max_workers=2) as executor:
412-
connector_worker_task = executor.submit(worker_cls,
413-
executor_config)
412+
connector_worker_task = executor.submit(worker_cls)
414413

415414
if scheduler_cls is not None and rank == 0:
416415
connector_scheduler_task = executor.submit(
417-
scheduler_cls, executor_config)
416+
scheduler_cls, executor_config.tokens_per_block)
418417
connector_scheduler = connector_scheduler_task.result()
419418
else:
420419
connector_scheduler = None

0 commit comments

Comments
 (0)