From 8f16dc76a5031988263ea798bfa81b53a38f805f Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 22 Oct 2025 09:32:16 +0800 Subject: [PATCH 1/2] Limit the scope of pybind based CacheTransceiverConfig Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/_util.py | 5 +++-- .../_torch/pyexecutor/kv_cache_transceiver.py | 16 ++++++++-------- .../_torch/pyexecutor/py_executor_creator.py | 7 ++----- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 4354495ef19..df27539b709 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -11,7 +11,8 @@ MODEL_CLASS_VISION_ENCODER_MAPPING from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str from tensorrt_llm.bindings.executor import DecodingMode -from tensorrt_llm.llmapi.llm_args import (EagleDecodingConfig, KvCacheConfig, +from tensorrt_llm.llmapi.llm_args import (CacheTransceiverConfig, + EagleDecodingConfig, KvCacheConfig, MTPDecodingConfig, PeftCacheConfig, SamplerType, SchedulerConfig, SparseAttentionConfig, @@ -666,7 +667,7 @@ def create_py_executor_instance( max_num_tokens: Optional[int] = None, peft_cache_config: Optional[PeftCacheConfig] = None, scheduler_config: Optional[SchedulerConfig] = None, - cache_transceiver_config: Optional[trtllm.CacheTransceiverConfig] = None, + cache_transceiver_config: Optional[CacheTransceiverConfig] = None, ) -> PyExecutor: kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None) diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py index 852d0352cfd..171ecab891b 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @@ -5,7 +5,7 @@ from tensorrt_llm import logger from tensorrt_llm._torch.distributed.communicator import Distributed from tensorrt_llm.bindings import WorldConfig -from tensorrt_llm.bindings.executor import CacheTransceiverConfig +from tensorrt_llm.llmapi.llm_args import CacheTransceiverConfig from tensorrt_llm.mapping import Mapping from .llm_request import LlmRequest @@ -36,13 +36,13 @@ def create_kv_cache_transceiver( logger.info("cache_transceiver is disabled") return None - if cache_transceiver_config.backend == BackendTypeCpp.DEFAULT: + if cache_transceiver_config.backend == "DEFAULT": # When cache_transceiver_config.backend is not set, fallback to env_vars settings # NIXL is the default backend - cache_transceiver_config.backend = BackendTypeCpp.NIXL + cache_transceiver_config.backend = "NIXL" # Ordered by priority - env_vars = [("TRTLLM_USE_UCX_KVCACHE", BackendTypeCpp.UCX), - ("TRTLLM_USE_MPI_KVCACHE", BackendTypeCpp.MPI)] + env_vars = [("TRTLLM_USE_UCX_KVCACHE", "UCX"), + ("TRTLLM_USE_MPI_KVCACHE", "MPI")] for env_var, be_type in env_vars: if getenv(env_var) == "1": logger.warning( @@ -51,10 +51,10 @@ def create_kv_cache_transceiver( cache_transceiver_config.backend = be_type break - if cache_transceiver_config.backend == BackendTypeCpp.MPI: + if cache_transceiver_config.backend == "MPI": logger.warning( "MPI CacheTransceiver is deprecated, UCX or NIXL is recommended") - elif cache_transceiver_config.backend == BackendTypeCpp.UCX: + elif cache_transceiver_config.backend == "UCX": logger.info( f"Using UCX kv-cache transceiver. If your devices are not in the same domain, please consider setting " f"UCX_CUDA_IPC_ENABLE_MNNVL=n, UCX_RNDV_SCHEME=put_zcopy and/or unset UCX_NET_DEVICES upon server " @@ -114,7 +114,7 @@ def __init__(self, mapping: Mapping, dist: Distributed, tokens_per_block, world_config, pp_layer_num_per_pp_rank, dtype, attention_type, - cache_transceiver_config) + cache_transceiver_config._to_pybind()) def respond_and_send_async(self, req: LlmRequest): return self.impl.respond_and_send_async(req) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index bab9f2354e4..f1d385bdf65 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -17,7 +17,7 @@ from tensorrt_llm.bindings.executor import GuidedDecodingConfig from tensorrt_llm.llmapi.llm_args import (CapacitySchedulerPolicy, ContextChunkingPolicy, LoadFormat, - PybindMirror, TorchLlmArgs) + TorchLlmArgs) from tensorrt_llm.llmapi.tokenizer import (TokenizerBase, _llguidance_tokenizer_info, _xgrammar_tokenizer_info) @@ -294,10 +294,7 @@ def create_py_executor( else: dist = MPIDist(mapping=mapping) - cache_transceiver_config = None - if llm_args.cache_transceiver_config is not None: - cache_transceiver_config = PybindMirror.maybe_to_pybind( - llm_args.cache_transceiver_config) + cache_transceiver_config = llm_args.cache_transceiver_config has_draft_model_engine = False has_spec_drafter = False From ac9fff037e478e9338e3f1e4972d7aef40a0e13b Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 22 Oct 2025 11:07:55 +0800 Subject: [PATCH 2/2] fix test Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- .../unittest/others/test_kv_cache_transceiver.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/unittest/others/test_kv_cache_transceiver.py b/tests/unittest/others/test_kv_cache_transceiver.py index 2a395679952..a8afc3f2cfb 100644 --- a/tests/unittest/others/test_kv_cache_transceiver.py +++ b/tests/unittest/others/test_kv_cache_transceiver.py @@ -12,6 +12,7 @@ from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest, LlmRequestState) from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager +from tensorrt_llm.llmapi.llm_args import CacheTransceiverConfig from tensorrt_llm.mapping import Mapping from tensorrt_llm.sampling_params import SamplingParams @@ -67,11 +68,7 @@ def ctx_gen_kv_cache_dtype(request): @pytest.mark.parametrize("attention_type", [AttentionTypeCpp.DEFAULT, AttentionTypeCpp.MLA], ids=["mha", "mla"]) -@pytest.mark.parametrize("backend", [ - trtllm.CacheTransceiverBackendType.NIXL, - trtllm.CacheTransceiverBackendType.UCX -], - ids=["NIXL", "UCX"]) +@pytest.mark.parametrize("backend", ["NIXL", "UCX"], ids=["NIXL", "UCX"]) def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype, attention_type, backend): # Init kv_cache manager and cache transceiver @@ -80,8 +77,8 @@ def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype, kv_cache_manager_ctx = create_kv_cache_manager(mapping, ctx_kv_cache_dtype) kv_cache_manager_gen = create_kv_cache_manager(mapping, gen_kv_cache_dtype) - cache_transceiver_config = trtllm.CacheTransceiverConfig( - backend=backend, max_tokens_in_buffer=512) + cache_transceiver_config = CacheTransceiverConfig(backend=backend, + max_tokens_in_buffer=512) dist = MPIDist(mapping=mapping) kv_cache_transceiver_ctx = create_kv_cache_transceiver( mapping, dist, kv_cache_manager_ctx, attention_type, @@ -147,9 +144,8 @@ def test_cancel_request_in_transmission(attention_type): kv_cache_manager_ctx = create_kv_cache_manager(mapping, ctx_kv_cache_dtype) kv_cache_manager_gen = create_kv_cache_manager(mapping, gen_kv_cache_dtype) - cache_transceiver_config = trtllm.CacheTransceiverConfig( - backend=trtllm.CacheTransceiverBackendType.DEFAULT, - max_tokens_in_buffer=512) + cache_transceiver_config = CacheTransceiverConfig(backend="DEFAULT", + max_tokens_in_buffer=512) kv_cache_transceiver_ctx = create_kv_cache_transceiver( mapping, dist, kv_cache_manager_ctx, attention_type,