Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions examples/llm-api/llm_kv_cache_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from tensorrt_llm import LLM, SamplingParams, logger
from tensorrt_llm._torch.pyexecutor.kv_cache_connector import (
KvCacheConnectorScheduler, KvCacheConnectorWorker, SchedulerOutput)
from tensorrt_llm.bindings.executor import ExecutorConfig
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig

Expand All @@ -34,8 +33,8 @@ class PersistentKvCacheConnectorMetadata:

class PersistentKvCacheConnectorWorker(KvCacheConnectorWorker):

def __init__(self, executor_config: ExecutorConfig):
super().__init__(executor_config)
def __init__(self):
super().__init__()

self.kv_cache_tensor = None

Expand Down Expand Up @@ -81,10 +80,10 @@ def get_finished(

class PersistentKvCacheConnectorLeader(KvCacheConnectorScheduler):

def __init__(self, executor_config: ExecutorConfig):
super().__init__(executor_config)
def __init__(self, tokens_per_block):
super().__init__()

self.block_size = self._config.tokens_per_block
self.block_size = tokens_per_block
self.pending_loads = {}

self.cache_folder = os.environ.get(CONNECTOR_CACHE_FOLDER_KEY,
Expand Down
7 changes: 2 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@

from tensorrt_llm._utils import mpi_allgather, mpi_broadcast, mpi_rank
from tensorrt_llm.bindings import LlmRequestState
from tensorrt_llm.bindings.executor import ExecutorConfig
from tensorrt_llm.bindings.internal.batch_manager import \
KvCacheConnectorManager as KvCacheConnectorManagerCpp
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
Expand Down Expand Up @@ -81,8 +80,7 @@ class SchedulerOutput:

class KvCacheConnectorWorker(ABC):

def __init__(self, config: ExecutorConfig):
self._config = config
def __init__(self):
self._metadata = None
super().__init__()

Expand Down Expand Up @@ -162,8 +160,7 @@ def get_finished(

class KvCacheConnectorScheduler(ABC):

def __init__(self, executor_config: ExecutorConfig):
self._config = executor_config
def __init__(self):
super().__init__()

@abstractmethod
Expand Down
5 changes: 2 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,12 +409,11 @@ def create_py_executor(
# In this case, the worker may be dependent on the scheduler, or vice-versa.
# To deal with cases like this, we instantiate them both concurrently.
with ThreadPoolExecutor(max_workers=2) as executor:
connector_worker_task = executor.submit(worker_cls,
executor_config)
connector_worker_task = executor.submit(worker_cls)

if scheduler_cls is not None and rank == 0:
connector_scheduler_task = executor.submit(
scheduler_cls, executor_config)
scheduler_cls, executor_config.tokens_per_block)
connector_scheduler = connector_scheduler_task.result()
else:
connector_scheduler = None
Expand Down