Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
MODEL_CLASS_VISION_ENCODER_MAPPING
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
from tensorrt_llm.bindings.executor import DecodingMode
from tensorrt_llm.llmapi.llm_args import (EagleDecodingConfig, KvCacheConfig,
from tensorrt_llm.llmapi.llm_args import (CacheTransceiverConfig,
EagleDecodingConfig, KvCacheConfig,
MTPDecodingConfig, PeftCacheConfig,
SamplerType, SchedulerConfig,
SparseAttentionConfig,
Expand Down Expand Up @@ -666,7 +667,7 @@ def create_py_executor_instance(
max_num_tokens: Optional[int] = None,
peft_cache_config: Optional[PeftCacheConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
cache_transceiver_config: Optional[trtllm.CacheTransceiverConfig] = None,
cache_transceiver_config: Optional[CacheTransceiverConfig] = None,
) -> PyExecutor:
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)

Expand Down
16 changes: 8 additions & 8 deletions tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tensorrt_llm import logger
from tensorrt_llm._torch.distributed.communicator import Distributed
from tensorrt_llm.bindings import WorldConfig
from tensorrt_llm.bindings.executor import CacheTransceiverConfig
from tensorrt_llm.llmapi.llm_args import CacheTransceiverConfig
from tensorrt_llm.mapping import Mapping

from .llm_request import LlmRequest
Expand Down Expand Up @@ -36,13 +36,13 @@ def create_kv_cache_transceiver(
logger.info("cache_transceiver is disabled")
return None

if cache_transceiver_config.backend == BackendTypeCpp.DEFAULT:
if cache_transceiver_config.backend == "DEFAULT":
# When cache_transceiver_config.backend is not set, fallback to env_vars settings
# NIXL is the default backend
cache_transceiver_config.backend = BackendTypeCpp.NIXL
cache_transceiver_config.backend = "NIXL"
# Ordered by priority
env_vars = [("TRTLLM_USE_UCX_KVCACHE", BackendTypeCpp.UCX),
("TRTLLM_USE_MPI_KVCACHE", BackendTypeCpp.MPI)]
env_vars = [("TRTLLM_USE_UCX_KVCACHE", "UCX"),
("TRTLLM_USE_MPI_KVCACHE", "MPI")]
for env_var, be_type in env_vars:
if getenv(env_var) == "1":
logger.warning(
Expand All @@ -51,10 +51,10 @@ def create_kv_cache_transceiver(
cache_transceiver_config.backend = be_type
break

if cache_transceiver_config.backend == BackendTypeCpp.MPI:
if cache_transceiver_config.backend == "MPI":
logger.warning(
"MPI CacheTransceiver is deprecated, UCX or NIXL is recommended")
elif cache_transceiver_config.backend == BackendTypeCpp.UCX:
elif cache_transceiver_config.backend == "UCX":
logger.info(
f"Using UCX kv-cache transceiver. If your devices are not in the same domain, please consider setting "
f"UCX_CUDA_IPC_ENABLE_MNNVL=n, UCX_RNDV_SCHEME=put_zcopy and/or unset UCX_NET_DEVICES upon server "
Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(self, mapping: Mapping, dist: Distributed,
tokens_per_block, world_config,
pp_layer_num_per_pp_rank, dtype,
attention_type,
cache_transceiver_config)
cache_transceiver_config._to_pybind())

def respond_and_send_async(self, req: LlmRequest):
return self.impl.respond_and_send_async(req)
Expand Down
7 changes: 2 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from tensorrt_llm.bindings.executor import GuidedDecodingConfig
from tensorrt_llm.llmapi.llm_args import (CapacitySchedulerPolicy,
ContextChunkingPolicy, LoadFormat,
PybindMirror, TorchLlmArgs)
TorchLlmArgs)
from tensorrt_llm.llmapi.tokenizer import (TokenizerBase,
_llguidance_tokenizer_info,
_xgrammar_tokenizer_info)
Expand Down Expand Up @@ -289,10 +289,7 @@ def create_py_executor(
else:
dist = MPIDist(mapping=mapping)

cache_transceiver_config = None
if llm_args.cache_transceiver_config is not None:
cache_transceiver_config = PybindMirror.maybe_to_pybind(
llm_args.cache_transceiver_config)
cache_transceiver_config = llm_args.cache_transceiver_config

has_draft_model_engine = False
has_spec_drafter = False
Expand Down
16 changes: 6 additions & 10 deletions tests/unittest/others/test_kv_cache_transceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest,
LlmRequestState)
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.llmapi.llm_args import CacheTransceiverConfig
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.sampling_params import SamplingParams

Expand Down Expand Up @@ -67,11 +68,7 @@ def ctx_gen_kv_cache_dtype(request):
@pytest.mark.parametrize("attention_type",
[AttentionTypeCpp.DEFAULT, AttentionTypeCpp.MLA],
ids=["mha", "mla"])
@pytest.mark.parametrize("backend", [
trtllm.CacheTransceiverBackendType.NIXL,
trtllm.CacheTransceiverBackendType.UCX
],
ids=["NIXL", "UCX"])
@pytest.mark.parametrize("backend", ["NIXL", "UCX"], ids=["NIXL", "UCX"])
def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype,
attention_type, backend):
# Init kv_cache manager and cache transceiver
Expand All @@ -80,8 +77,8 @@ def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype,
kv_cache_manager_ctx = create_kv_cache_manager(mapping, ctx_kv_cache_dtype)
kv_cache_manager_gen = create_kv_cache_manager(mapping, gen_kv_cache_dtype)

cache_transceiver_config = trtllm.CacheTransceiverConfig(
backend=backend, max_tokens_in_buffer=512)
cache_transceiver_config = CacheTransceiverConfig(backend=backend,
max_tokens_in_buffer=512)
dist = MPIDist(mapping=mapping)
kv_cache_transceiver_ctx = create_kv_cache_transceiver(
mapping, dist, kv_cache_manager_ctx, attention_type,
Expand Down Expand Up @@ -147,9 +144,8 @@ def test_cancel_request_in_transmission(attention_type):
kv_cache_manager_ctx = create_kv_cache_manager(mapping, ctx_kv_cache_dtype)
kv_cache_manager_gen = create_kv_cache_manager(mapping, gen_kv_cache_dtype)

cache_transceiver_config = trtllm.CacheTransceiverConfig(
backend=trtllm.CacheTransceiverBackendType.DEFAULT,
max_tokens_in_buffer=512)
cache_transceiver_config = CacheTransceiverConfig(backend="DEFAULT",
max_tokens_in_buffer=512)

kv_cache_transceiver_ctx = create_kv_cache_transceiver(
mapping, dist, kv_cache_manager_ctx, attention_type,
Expand Down