Skip to content

Commit ec3ed20

Browse files
committed
[None][chore] Use llm args in create_py_executor
Signed-off-by: leslie-fang25 <[email protected]>
1 parent 31b0f0f commit ec3ed20

File tree

6 files changed

+134
-77
lines changed

6 files changed

+134
-77
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from tensorrt_llm._utils import nvtx_range
1010

1111
from ...._utils import mpi_rank, mpi_world_size
12-
from ....bindings.executor import ExecutorConfig
1312
from ....bindings.internal.batch_manager import CacheType
1413
from ....mapping import Mapping
1514
from ...distributed import MPIDist
@@ -259,7 +258,7 @@ def forward(
259258
return {"logits": logits_flat}
260259

261260

262-
def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir: str = None):
261+
def create_autodeploy_executor(ad_config: LlmArgs):
263262
"""Create an AutoDeploy executor from the given configuration and checkpoint directory.
264263
265264
This is the entrypoint API to the _autodeploy backend.
@@ -276,8 +275,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
276275

277276
# some config
278277
msg = "pytorch_backend_config must be an AD LlmArgs object"
279-
assert isinstance(executor_config.pytorch_backend_config, LlmArgs), msg
280-
ad_config: LlmArgs = executor_config.pytorch_backend_config
278+
assert isinstance(ad_config, LlmArgs), msg
281279
assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported"
282280

283281
max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,8 @@ def create_py_executor_instance(
510510
guided_decoder: Optional[GuidedDecoder] = None,
511511
lora_config: Optional[LoraConfig] = None,
512512
garbage_collection_gen0_threshold: Optional[int] = None,
513-
kv_connector_manager: Optional[KvCacheConnectorManager] = None
513+
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
514+
max_seq_len: Optional[int] = None,
514515
) -> PyExecutor:
515516
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
516517

@@ -659,7 +660,8 @@ def create_py_executor_instance(
659660
guided_decoder=guided_decoder,
660661
start_worker=start_worker,
661662
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
662-
kv_connector_manager=kv_connector_manager)
663+
kv_connector_manager=kv_connector_manager,
664+
max_seq_len=max_seq_len)
663665

664666

665667
def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -139,25 +139,25 @@ class BatchStatePP(BatchState):
139139

140140
class PyExecutor:
141141

142-
def __init__(
143-
self,
144-
resource_manager,
145-
scheduler: RequestScheduler,
146-
model_engine: ModelEngine,
147-
sampler: Sampler,
148-
dist: Distributed,
149-
max_num_sequences: int,
150-
drafter: Optional[Drafter] = None,
151-
disable_overlap_scheduler: bool = False,
152-
max_input_len: int = 2048,
153-
max_batch_size: int = 8,
154-
max_beam_width: int = 1,
155-
max_draft_len: int = 0,
156-
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
157-
guided_decoder: Optional[GuidedDecoder] = None,
158-
garbage_collection_gen0_threshold: Optional[int] = None,
159-
start_worker: bool = True,
160-
kv_connector_manager: Optional[KvCacheConnectorManager] = None):
142+
def __init__(self,
143+
resource_manager,
144+
scheduler: RequestScheduler,
145+
model_engine: ModelEngine,
146+
sampler: Sampler,
147+
dist: Distributed,
148+
max_num_sequences: int,
149+
drafter: Optional[Drafter] = None,
150+
disable_overlap_scheduler: bool = False,
151+
max_input_len: int = 2048,
152+
max_batch_size: int = 8,
153+
max_beam_width: int = 1,
154+
max_draft_len: int = 0,
155+
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
156+
guided_decoder: Optional[GuidedDecoder] = None,
157+
garbage_collection_gen0_threshold: Optional[int] = None,
158+
start_worker: bool = True,
159+
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
160+
max_seq_len: Optional[int] = None):
161161
super(PyExecutor, self).__init__()
162162
self.device_id = torch.cuda.current_device()
163163
self.global_rank = global_mpi_rank()
@@ -271,6 +271,7 @@ def __init__(
271271
)
272272
self.draft_seq_slot_manager = SeqSlotManager(max_num_sequences)
273273
self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold
274+
self.max_seq_len = max_seq_len
274275

275276
self.worker_started = False
276277
self.worker_lock = threading.Lock()

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414
from tensorrt_llm._utils import get_sm_version
1515
from tensorrt_llm.bindings.executor import (CapacitySchedulerPolicy,
1616
ContextChunkingPolicy,
17-
ExecutorConfig)
17+
ExecutorConfig,
18+
LogitsPostProcessorConfig,
19+
ParallelConfig)
1820
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
19-
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig
21+
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs
22+
from tensorrt_llm.llmapi.tokenizer import TokenizerBase
2023
from tensorrt_llm.logger import logger
2124
from tensorrt_llm.lora_helper import LoraConfig
2225
from tensorrt_llm.mapping import Mapping
@@ -209,12 +212,21 @@ def _get_mapping(executor_config: ExecutorConfig) -> Mapping:
209212

210213

211214
def create_py_executor(
212-
executor_config: ExecutorConfig,
213-
checkpoint_dir: str = None,
214-
lora_config: Optional[LoraConfig] = None,
215-
garbage_collection_gen0_threshold: Optional[int] = None,
216-
kv_connector_config: Optional[KvCacheConnectorConfig] = None
215+
llm_args: TorchLlmArgs,
216+
checkpoint_dir: str = None,
217+
tokenizer: Optional[TokenizerBase] = None,
218+
lora_config: Optional[LoraConfig] = None,
219+
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
220+
logits_post_processor_config: Optional[LogitsPostProcessorConfig] = None,
221+
parallel_config: Optional[ParallelConfig] = None,
217222
) -> PyExecutor:
223+
224+
executor_config = llm_args.get_executor_config(checkpoint_dir, tokenizer)
225+
executor_config.logits_post_processor_config = logits_post_processor_config
226+
executor_config.parallel_config = parallel_config
227+
228+
garbage_collection_gen0_threshold = llm_args.garbage_collection_gen0_threshold
229+
218230
_mangle_executor_config(executor_config)
219231
pytorch_backend_config = executor_config.pytorch_backend_config
220232

@@ -484,6 +496,7 @@ def create_py_executor(
484496
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
485497
kv_connector_manager=kv_connector_manager
486498
if not estimating_kv_cache else None,
499+
max_seq_len=executor_config.max_seq_len,
487500
)
488501

489502
if estimating_kv_cache:
@@ -528,6 +541,7 @@ def create_py_executor(
528541
garbage_collection_gen0_threshold=
529542
garbage_collection_gen0_threshold,
530543
kv_connector_manager=kv_connector_manager,
544+
max_seq_len=executor_config.max_seq_len,
531545
)
532546

533547
_adjust_torch_mem_fraction(executor_config.pytorch_backend_config)

tensorrt_llm/executor/executor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..bindings import executor as tllm
2222
from ..builder import Engine
2323
from ..disaggregated_params import DisaggregatedParams
24-
from ..llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs
24+
from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig
2525
from ..llmapi.llm_utils import KvCacheRetentionConfig
2626
from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available,
2727
need_spawn_mpi_workers)
@@ -359,7 +359,7 @@ def create(
359359
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
360360
hf_model_dir: Optional[Path] = None,
361361
tokenizer: Optional[TokenizerBase] = None,
362-
llm_args: Optional[TorchLlmArgs] = None,
362+
llm_args: Optional[BaseLlmArgs] = None,
363363
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
364364
# local imports to avoid cyclic importing
365365
from .proxy import GenerationExecutorProxy

tensorrt_llm/executor/worker.py

Lines changed: 85 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
mpi_comm, mpi_rank, nvtx_range_debug)
1919
from ..bindings import executor as tllm
2020
from ..builder import ConfigEncoder, Engine, EngineConfig
21-
from ..llmapi.llm_args import KvCacheConnectorConfig, PybindMirror, TorchLlmArgs
21+
from ..llmapi.llm_args import (BaseLlmArgs, KvCacheConnectorConfig,
22+
PybindMirror, TorchLlmArgs)
2223
from ..llmapi.mpi_session import set_mpi_session_cpp
2324
from ..llmapi.tokenizer import TokenizerBase
2425
from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
@@ -64,7 +65,7 @@ def __init__(
6465
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
6566
hf_model_dir: Optional[Path] = None,
6667
tokenizer: Optional[TokenizerBase] = None,
67-
llm_args: Optional[TorchLlmArgs] = None,
68+
llm_args: Optional[BaseLlmArgs] = None,
6869
) -> None:
6970
postproc_config = postproc_worker_config or PostprocWorkerConfig()
7071
super().__init__(
@@ -107,40 +108,55 @@ def _get_comm_ranks_device_id():
107108
device_ids = mpi_comm().allgather(device_id)
108109
return comm_ranks, device_ids
109110

110-
def _create_py_executor(executor_config):
111-
assert executor_config is None, "expect an empty executor_config is _create_py_executor"
112-
executor_config = llm_args.get_executor_config(
113-
hf_model_dir, tokenizer)
114-
# Persist so downstream code (e.g., default max_tokens deduction) has access
115-
self._executor_config = executor_config
116-
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
117-
processor_batched=batched_logits_processor, replicate=False)
118-
comm_ranks, device_ids = _get_comm_ranks_device_id()
119-
executor_config.parallel_config = tllm.ParallelConfig(
120-
participant_ids=comm_ranks, device_ids=device_ids)
121-
args = {
122-
"executor_config": executor_config,
123-
"checkpoint_dir": executor_config.hf_model_dir,
124-
}
111+
def _create_py_executor():
112+
args = {}
125113
assert hasattr(
126-
executor_config, "backend"
127-
), "executor_config should be with backend in _create_py_executor"
128-
if executor_config.backend == "pytorch":
114+
self.llm_args, "backend"
115+
), "llm_args should be with backend in _create_py_executor"
116+
if self.llm_args.backend == "pytorch":
129117
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
130118
create_py_executor
131119
create_executor = create_py_executor
120+
args["llm_args"] = self.llm_args
121+
args["checkpoint_dir"] = hf_model_dir
122+
args["tokenizer"] = tokenizer
132123
args["lora_config"] = lora_config
133-
args[
134-
"garbage_collection_gen0_threshold"] = llm_args.garbage_collection_gen0_threshold
135124
args["kv_connector_config"] = kv_connector_config
136-
elif executor_config.backend == "_autodeploy":
125+
args[
126+
"logits_post_processor_config"] = tllm.LogitsPostProcessorConfig(
127+
processor_batched=batched_logits_processor,
128+
replicate=False)
129+
comm_ranks, device_ids = _get_comm_ranks_device_id()
130+
args["parallel_config"] = tllm.ParallelConfig(
131+
participant_ids=comm_ranks, device_ids=device_ids)
132+
elif self.llm_args.backend == "_autodeploy":
133+
from tensorrt_llm._torch.auto_deploy.llm_args import \
134+
LlmArgs as ADLlmArgs
137135
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
138136
create_autodeploy_executor
139137
create_executor = create_autodeploy_executor
138+
assert isinstance(self.llm_args, ADLlmArgs)
139+
args["ad_config"] = self.llm_args.get_pytorch_backend_config()
140140
else:
141141
raise ValueError(
142-
f"Unsupported backend config: {executor_config.backend}")
143-
return create_executor(**args)
142+
f"Unsupported backend config: {self.llm_args.backend}")
143+
144+
# Define additional attributes that can be used later, such as in _deduce_max_tokens
145+
self.mapping = self.llm_args.parallel_config.to_mapping()
146+
self.checkpoint_loader = None
147+
if self.llm_args.backend == "pytorch":
148+
from tensorrt_llm._torch.pyexecutor.config import \
149+
_construct_checkpoint_loader
150+
self.checkpoint_loader = _construct_checkpoint_loader(
151+
self.llm_args.backend, self.llm_args.checkpoint_loader,
152+
self.llm_args.checkpoint_format)
153+
154+
_executor = create_executor(**args)
155+
self.max_seq_len = self.llm_args.max_seq_len
156+
if _executor.max_seq_len is not None:
157+
# max_seq_len might be updated by model engine as in create_py_executor
158+
self.max_seq_len = _executor.max_seq_len
159+
return _executor
144160

145161
def _create_engine(executor_config):
146162
if executor_config is None:
@@ -164,8 +180,7 @@ def _create_engine(executor_config):
164180
executor_config)
165181

166182
self.engine = _create_py_executor(
167-
executor_config) if llm_args is not None else _create_engine(
168-
executor_config)
183+
) if self.llm_args is not None else _create_engine(executor_config)
169184

170185
self._lora_manager: Optional[LoraManager] = None
171186
self._prompt_adapter_manager: Optional[PromptAdapterManager] = None
@@ -188,8 +203,9 @@ def _create_engine(executor_config):
188203
if engine_config.build_config.max_prompt_embedding_table_size > 0:
189204
self._prompt_adapter_manager = PromptAdapterManager()
190205

191-
if getattr(self._executor_config, "backend",
192-
"") == "pytorch" and lora_config is not None:
206+
if self.llm_args and getattr(
207+
self.llm_args, "backend",
208+
"") == "pytorch" and lora_config is not None:
193209
from tensorrt_llm._torch.pyexecutor.resource_manager import \
194210
ResourceManagerType
195211
peft_cache_manager = self.engine.resource_manager.resource_managers.get(
@@ -471,26 +487,43 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
471487
assert request.id is not None
472488

473489
def _deduce_max_tokens(request: GenerationRequest,
474-
executor_config: tllm.ExecutorConfig) -> int:
490+
executor_config: tllm.ExecutorConfig,
491+
llm_args: Optional[BaseLlmArgs] = None) -> int:
475492
# deduce max_tokens when it's not set by user
476493
max_tokens = request.sampling_params.max_tokens
477494
query_token_len = len(
478495
request.query_token_ids) if request.query_token_ids else 0
479-
cp_size = 1 if (not hasattr(executor_config, "mapping")
480-
or executor_config.mapping.cp_size
481-
is None) else executor_config.mapping.cp_size
482-
if not hasattr(executor_config, "max_seq_len"):
496+
497+
cp_size = 1
498+
max_seq_len = None
499+
if llm_args is not None:
500+
# deduce max_tokens by llm args
501+
assert executor_config is None, "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined."
502+
if hasattr(self,
503+
"mapping") and self.mapping.cp_size is not None:
504+
cp_size = self.mapping.cp_size
505+
max_seq_len = getattr(self, "max_seq_len", None)
506+
else:
507+
# deduce max_tokens by executor config
508+
if hasattr(executor_config, "mapping"
509+
) and executor_config.mapping.cp_size is not None:
510+
cp_size = executor_config.mapping.cp_size
511+
max_seq_len = getattr(executor_config, "max_seq_len", None)
512+
if max_seq_len is None:
483513
logger.warning("`default_max_tokens` cannot be deduced")
484514
if max_tokens is None:
485515
raise ValueError(
486516
"`max_tokens` must be set when `default_max_tokens` cannot be deduced"
487517
)
518+
else:
519+
# use max_tokens if can't deduce default_max_tokens
520+
return max_tokens
488521
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
489-
default_max_tokens = executor_config.max_seq_len - splited_prompt_len - query_token_len
522+
default_max_tokens = max_seq_len - splited_prompt_len - query_token_len
490523
if default_max_tokens <= 0:
491524
logger.warning(
492525
f"`default_max_tokens` ({default_max_tokens}) should be greater than 0, "
493-
f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({executor_config.max_seq_len})"
526+
f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({max_seq_len})"
494527
f" - `splited_prompt_len` ({splited_prompt_len}) - `query_token_len` ({query_token_len})"
495528
)
496529
if max_tokens is None:
@@ -512,7 +545,8 @@ def _deduce_max_tokens(request: GenerationRequest,
512545
executor_request = tllm.Request(
513546
client_id=request.id,
514547
input_token_ids=prompt_token_ids,
515-
max_tokens=_deduce_max_tokens(request, self._executor_config),
548+
max_tokens=_deduce_max_tokens(request, self._executor_config,
549+
self.llm_args),
516550
streaming=request.streaming,
517551
sampling_config=request.sampling_params._get_sampling_config(),
518552
end_id=-1 if request.sampling_params.ignore_eos else
@@ -638,11 +672,19 @@ def shutdown(self):
638672
self.engine.shutdown()
639673
self.engine = None
640674

641-
if hasattr(
642-
self._executor_config, "checkpoint_loader"
643-
) and self._executor_config.checkpoint_loader is not None:
644-
self._executor_config.checkpoint_loader.cleanup()
645-
self._executor_config.checkpoint_loader = None
675+
if self.llm_args is not None:
676+
assert self._executor_config is None, "An empty executor_config is expected in shutdown when LLM arguments are defined."
677+
if (self.llm_args.backend == "pytorch"
678+
and hasattr(self, "checkpoint_loader")
679+
and self.checkpoint_loader is not None):
680+
self.checkpoint_loader.cleanup()
681+
self.checkpoint_loader = None
682+
else:
683+
if hasattr(
684+
self._executor_config, "checkpoint_loader"
685+
) and self._executor_config.checkpoint_loader is not None:
686+
self._executor_config.checkpoint_loader.cleanup()
687+
self._executor_config.checkpoint_loader = None
646688

647689
# Check if there are any errors from the threads before shutdown.
648690
self._handle_background_error()
@@ -689,7 +731,7 @@ def worker_main(
689731
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
690732
hf_model_dir: Optional[Path] = None,
691733
tokenizer: Optional[TokenizerBase] = None,
692-
llm_args: Optional[TorchLlmArgs] = None,
734+
llm_args: Optional[BaseLlmArgs] = None,
693735
) -> None:
694736
mpi_comm().barrier()
695737
print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n",

0 commit comments

Comments
 (0)