diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 4354495ef19..df27539b709 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -11,7 +11,8 @@ MODEL_CLASS_VISION_ENCODER_MAPPING from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str from tensorrt_llm.bindings.executor import DecodingMode -from tensorrt_llm.llmapi.llm_args import (EagleDecodingConfig, KvCacheConfig, +from tensorrt_llm.llmapi.llm_args import (CacheTransceiverConfig, + EagleDecodingConfig, KvCacheConfig, MTPDecodingConfig, PeftCacheConfig, SamplerType, SchedulerConfig, SparseAttentionConfig, @@ -666,7 +667,7 @@ def create_py_executor_instance( max_num_tokens: Optional[int] = None, peft_cache_config: Optional[PeftCacheConfig] = None, scheduler_config: Optional[SchedulerConfig] = None, - cache_transceiver_config: Optional[trtllm.CacheTransceiverConfig] = None, + cache_transceiver_config: Optional[CacheTransceiverConfig] = None, ) -> PyExecutor: kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None) diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py index bd13b2284fd..73a9df01301 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @@ -5,7 +5,7 @@ from tensorrt_llm import logger from tensorrt_llm._torch.distributed.communicator import Distributed from tensorrt_llm.bindings import WorldConfig -from tensorrt_llm.bindings.executor import CacheTransceiverConfig +from tensorrt_llm.llmapi.llm_args import CacheTransceiverConfig from tensorrt_llm.mapping import Mapping from .llm_request import LlmRequest @@ -36,13 +36,13 @@ def create_kv_cache_transceiver( logger.info("cache_transceiver is disabled") return None - if cache_transceiver_config.backend == BackendTypeCpp.DEFAULT: + if cache_transceiver_config.backend == "DEFAULT": # When cache_transceiver_config.backend is not set, fallback to env_vars settings # NIXL is the default backend - cache_transceiver_config.backend = BackendTypeCpp.NIXL + cache_transceiver_config.backend = "NIXL" # Ordered by priority - env_vars = [("TRTLLM_USE_UCX_KVCACHE", BackendTypeCpp.UCX), - ("TRTLLM_USE_MPI_KVCACHE", BackendTypeCpp.MPI)] + env_vars = [("TRTLLM_USE_UCX_KVCACHE", "UCX"), + ("TRTLLM_USE_MPI_KVCACHE", "MPI")] for env_var, be_type in env_vars: if getenv(env_var) == "1": logger.warning( @@ -51,10 +51,10 @@ def create_kv_cache_transceiver( cache_transceiver_config.backend = be_type break - if cache_transceiver_config.backend == BackendTypeCpp.MPI: + if cache_transceiver_config.backend == "MPI": logger.warning( "MPI CacheTransceiver is deprecated, UCX or NIXL is recommended") - elif cache_transceiver_config.backend == BackendTypeCpp.UCX: + elif cache_transceiver_config.backend == "UCX": logger.info( f"Using UCX kv-cache transceiver. If your devices are not in the same domain, please consider setting " f"UCX_CUDA_IPC_ENABLE_MNNVL=n, UCX_RNDV_SCHEME=put_zcopy and/or unset UCX_NET_DEVICES upon server " @@ -116,7 +116,7 @@ def __init__(self, mapping: Mapping, dist: Distributed, tokens_per_block, world_config, pp_layer_num_per_pp_rank, dtype, attention_type, - cache_transceiver_config) + cache_transceiver_config._to_pybind()) def respond_and_send_async(self, req: LlmRequest): return self.impl.respond_and_send_async(req) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 01c94bbceed..e7c51b52fa6 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -17,7 +17,7 @@ from tensorrt_llm.bindings.executor import GuidedDecodingConfig from tensorrt_llm.llmapi.llm_args import (CapacitySchedulerPolicy, ContextChunkingPolicy, LoadFormat, - PybindMirror, TorchLlmArgs) + TorchLlmArgs) from tensorrt_llm.llmapi.tokenizer import (TokenizerBase, _llguidance_tokenizer_info, _xgrammar_tokenizer_info) @@ -289,10 +289,7 @@ def create_py_executor( else: dist = MPIDist(mapping=mapping) - cache_transceiver_config = None - if llm_args.cache_transceiver_config is not None: - cache_transceiver_config = PybindMirror.maybe_to_pybind( - llm_args.cache_transceiver_config) + cache_transceiver_config = llm_args.cache_transceiver_config has_draft_model_engine = False has_spec_drafter = False diff --git a/tests/unittest/others/test_kv_cache_transceiver.py b/tests/unittest/others/test_kv_cache_transceiver.py index 2a395679952..a8afc3f2cfb 100644 --- a/tests/unittest/others/test_kv_cache_transceiver.py +++ b/tests/unittest/others/test_kv_cache_transceiver.py @@ -12,6 +12,7 @@ from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest, LlmRequestState) from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager +from tensorrt_llm.llmapi.llm_args import CacheTransceiverConfig from tensorrt_llm.mapping import Mapping from tensorrt_llm.sampling_params import SamplingParams @@ -67,11 +68,7 @@ def ctx_gen_kv_cache_dtype(request): @pytest.mark.parametrize("attention_type", [AttentionTypeCpp.DEFAULT, AttentionTypeCpp.MLA], ids=["mha", "mla"]) -@pytest.mark.parametrize("backend", [ - trtllm.CacheTransceiverBackendType.NIXL, - trtllm.CacheTransceiverBackendType.UCX -], - ids=["NIXL", "UCX"]) +@pytest.mark.parametrize("backend", ["NIXL", "UCX"], ids=["NIXL", "UCX"]) def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype, attention_type, backend): # Init kv_cache manager and cache transceiver @@ -80,8 +77,8 @@ def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype, kv_cache_manager_ctx = create_kv_cache_manager(mapping, ctx_kv_cache_dtype) kv_cache_manager_gen = create_kv_cache_manager(mapping, gen_kv_cache_dtype) - cache_transceiver_config = trtllm.CacheTransceiverConfig( - backend=backend, max_tokens_in_buffer=512) + cache_transceiver_config = CacheTransceiverConfig(backend=backend, + max_tokens_in_buffer=512) dist = MPIDist(mapping=mapping) kv_cache_transceiver_ctx = create_kv_cache_transceiver( mapping, dist, kv_cache_manager_ctx, attention_type, @@ -147,9 +144,8 @@ def test_cancel_request_in_transmission(attention_type): kv_cache_manager_ctx = create_kv_cache_manager(mapping, ctx_kv_cache_dtype) kv_cache_manager_gen = create_kv_cache_manager(mapping, gen_kv_cache_dtype) - cache_transceiver_config = trtllm.CacheTransceiverConfig( - backend=trtllm.CacheTransceiverBackendType.DEFAULT, - max_tokens_in_buffer=512) + cache_transceiver_config = CacheTransceiverConfig(backend="DEFAULT", + max_tokens_in_buffer=512) kv_cache_transceiver_ctx = create_kv_cache_transceiver( mapping, dist, kv_cache_manager_ctx, attention_type,