Skip to content

Commit 37a6357

Browse files
committed
[None][chore] Use llm args in create_py_executor
Signed-off-by: leslie-fang25 <[email protected]>
1 parent 4f84a45 commit 37a6357

File tree

6 files changed

+119
-58
lines changed

6 files changed

+119
-58
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from tensorrt_llm._utils import nvtx_range
1010

1111
from ...._utils import mpi_rank, mpi_world_size
12-
from ....bindings.executor import ExecutorConfig
1312
from ....bindings.internal.batch_manager import CacheType
1413
from ....mapping import Mapping
1514
from ...distributed import MPIDist
@@ -259,7 +258,7 @@ def forward(
259258
return {"logits": logits_flat}
260259

261260

262-
def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir: str = None):
261+
def create_autodeploy_executor(ad_config: LlmArgs):
263262
"""Create an AutoDeploy executor from the given configuration and checkpoint directory.
264263
265264
This is the entrypoint API to the _autodeploy backend.
@@ -276,8 +275,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
276275

277276
# some config
278277
msg = "pytorch_backend_config must be an AD LlmArgs object"
279-
assert isinstance(executor_config.pytorch_backend_config, LlmArgs), msg
280-
ad_config: LlmArgs = executor_config.pytorch_backend_config
278+
assert isinstance(ad_config, LlmArgs), msg
281279
assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported"
282280

283281
max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,8 @@ def create_py_executor_instance(
423423
drafter,
424424
guided_decoder: Optional[GuidedDecoder] = None,
425425
lora_config: Optional[LoraConfig] = None,
426-
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
426+
garbage_collection_gen0_threshold: Optional[int] = None,
427+
max_seq_len: Optional[int] = None) -> PyExecutor:
427428
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
428429

429430
spec_config = model_engine.spec_config
@@ -570,7 +571,8 @@ def create_py_executor_instance(
570571
kv_cache_transceiver=kv_cache_transceiver,
571572
guided_decoder=guided_decoder,
572573
start_worker=start_worker,
573-
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
574+
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
575+
max_seq_len=max_seq_len)
574576

575577

576578
def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ def __init__(self,
153153
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
154154
guided_decoder: Optional[GuidedDecoder] = None,
155155
garbage_collection_gen0_threshold: Optional[int] = None,
156-
start_worker: bool = True):
156+
start_worker: bool = True,
157+
max_seq_len: Optional[int] = None):
157158
super(PyExecutor, self).__init__()
158159
self.device_id = torch.cuda.current_device()
159160
self.global_rank = global_mpi_rank()
@@ -267,6 +268,7 @@ def __init__(self,
267268
)
268269
self.draft_seq_slot_manager = SeqSlotManager(max_num_sequences)
269270
self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold
271+
self.max_seq_len = max_seq_len
270272

271273
self.worker_started = False
272274
self.worker_lock = threading.Lock()

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,13 @@
1010
import tensorrt_llm
1111
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
1212
from tensorrt_llm._utils import get_sm_version
13-
from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig
13+
from tensorrt_llm.bindings.executor import (ContextChunkingPolicy,
14+
ExecutorConfig,
15+
LogitsPostProcessorConfig,
16+
ParallelConfig)
1417
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
18+
from tensorrt_llm.llmapi.llm_args import TorchLlmArgs
19+
from tensorrt_llm.llmapi.tokenizer import TokenizerBase
1520
from tensorrt_llm.logger import logger
1621
from tensorrt_llm.lora_helper import LoraConfig
1722
from tensorrt_llm.mapping import Mapping
@@ -203,10 +208,20 @@ def _get_mapping(executor_config: ExecutorConfig) -> Mapping:
203208

204209

205210
def create_py_executor(
206-
executor_config: ExecutorConfig,
207-
checkpoint_dir: str = None,
208-
lora_config: Optional[LoraConfig] = None,
209-
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
211+
llm_args: TorchLlmArgs,
212+
checkpoint_dir: str = None,
213+
tokenizer: Optional[TokenizerBase] = None,
214+
lora_config: Optional[LoraConfig] = None,
215+
logits_post_processor_config: Optional[LogitsPostProcessorConfig] = None,
216+
parallel_config: Optional[ParallelConfig] = None,
217+
) -> PyExecutor:
218+
219+
executor_config = llm_args.get_executor_config(checkpoint_dir, tokenizer)
220+
executor_config.logits_post_processor_config = logits_post_processor_config
221+
executor_config.parallel_config = parallel_config
222+
223+
garbage_collection_gen0_threshold = llm_args.garbage_collection_gen0_threshold
224+
210225
_mangle_executor_config(executor_config)
211226
pytorch_backend_config = executor_config.pytorch_backend_config
212227

@@ -425,6 +440,7 @@ def create_py_executor(
425440
guided_decoder=guided_decoder,
426441
lora_config=lora_config,
427442
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
443+
max_seq_len=max_seq_len,
428444
)
429445

430446
if estimating_kv_cache:
@@ -468,6 +484,7 @@ def create_py_executor(
468484
lora_config=lora_config,
469485
garbage_collection_gen0_threshold=
470486
garbage_collection_gen0_threshold,
487+
max_seq_len=max_seq_len,
471488
)
472489

473490
_adjust_torch_mem_fraction(executor_config.pytorch_backend_config)

tensorrt_llm/executor/executor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..bindings import executor as tllm
2222
from ..builder import Engine
2323
from ..disaggregated_params import DisaggregatedParams
24-
from ..llmapi.llm_args import TorchLlmArgs
24+
from ..llmapi.llm_args import BaseLlmArgs
2525
from ..llmapi.llm_utils import KvCacheRetentionConfig
2626
from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available,
2727
need_spawn_mpi_workers)
@@ -358,7 +358,7 @@ def create(
358358
lora_config: Optional[LoraConfig] = None,
359359
hf_model_dir: Optional[Path] = None,
360360
tokenizer: Optional[TokenizerBase] = None,
361-
llm_args: Optional[TorchLlmArgs] = None,
361+
llm_args: Optional[BaseLlmArgs] = None,
362362
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
363363
# local imports to avoid cyclic importing
364364
from .proxy import GenerationExecutorProxy

tensorrt_llm/executor/worker.py

Lines changed: 86 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
mpi_comm, mpi_rank, nvtx_range_debug)
1919
from ..bindings import executor as tllm
2020
from ..builder import ConfigEncoder, Engine, EngineConfig
21-
from ..llmapi.llm_args import PybindMirror, TorchLlmArgs
21+
from ..llmapi.llm_args import BaseLlmArgs, PybindMirror, TorchLlmArgs
2222
from ..llmapi.mpi_session import set_mpi_session_cpp
2323
from ..llmapi.tokenizer import TokenizerBase
2424
from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
@@ -63,7 +63,7 @@ def __init__(
6363
lora_config: Optional[LoraConfig] = None,
6464
hf_model_dir: Optional[Path] = None,
6565
tokenizer: Optional[TokenizerBase] = None,
66-
llm_args: Optional[TorchLlmArgs] = None,
66+
llm_args: Optional[BaseLlmArgs] = None,
6767
) -> None:
6868
postproc_config = postproc_worker_config or PostprocWorkerConfig()
6969
super().__init__(
@@ -102,39 +102,54 @@ def _get_comm_ranks_device_id():
102102
device_ids = mpi_comm().allgather(device_id)
103103
return comm_ranks, device_ids
104104

105-
def _create_py_executor(executor_config):
106-
assert executor_config is None, "expect an empty executor_config is _create_py_executor"
107-
executor_config = llm_args.get_executor_config(
108-
hf_model_dir, tokenizer)
109-
# Persist so downstream code (e.g., default max_tokens deduction) has access
110-
self._executor_config = executor_config
111-
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
112-
processor_batched=batched_logits_processor, replicate=False)
113-
comm_ranks, device_ids = _get_comm_ranks_device_id()
114-
executor_config.parallel_config = tllm.ParallelConfig(
115-
participant_ids=comm_ranks, device_ids=device_ids)
116-
args = {
117-
"executor_config": executor_config,
118-
"checkpoint_dir": executor_config.hf_model_dir,
119-
}
105+
def _create_py_executor():
106+
args = {}
120107
assert hasattr(
121-
executor_config, "backend"
122-
), "executor_config should be with backend in _create_py_executor"
123-
if executor_config.backend == "pytorch":
108+
self.llm_args, "backend"
109+
), "llm_args should be with backend in _create_py_executor"
110+
if self.llm_args.backend == "pytorch":
124111
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
125112
create_py_executor
126113
create_executor = create_py_executor
114+
args["llm_args"] = self.llm_args
115+
args["checkpoint_dir"] = hf_model_dir
116+
args["tokenizer"] = tokenizer
127117
args["lora_config"] = lora_config
128118
args[
129-
"garbage_collection_gen0_threshold"] = llm_args.garbage_collection_gen0_threshold
130-
elif executor_config.backend == "_autodeploy":
119+
"logits_post_processor_config"] = tllm.LogitsPostProcessorConfig(
120+
processor_batched=batched_logits_processor,
121+
replicate=False)
122+
comm_ranks, device_ids = _get_comm_ranks_device_id()
123+
args["parallel_config"] = tllm.ParallelConfig(
124+
participant_ids=comm_ranks, device_ids=device_ids)
125+
elif self.llm_args.backend == "_autodeploy":
126+
from tensorrt_llm._torch.auto_deploy.llm_args import \
127+
LlmArgs as ADLlmArgs
131128
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
132129
create_autodeploy_executor
133130
create_executor = create_autodeploy_executor
131+
assert isinstance(self.llm_args, ADLlmArgs)
132+
args["ad_config"] = self.llm_args.get_pytorch_backend_config()
134133
else:
135134
raise ValueError(
136-
f"Unsupported backend config: {executor_config.backend}")
137-
return create_executor(**args)
135+
f"Unsupported backend config: {self.llm_args.backend}")
136+
137+
# Define additional attributes that can be used later, such as in _deduce_max_tokens
138+
self.mapping = self.llm_args.parallel_config.to_mapping()
139+
self.checkpoint_loader = None
140+
if self.llm_args.backend == "pytorch":
141+
from tensorrt_llm._torch.pyexecutor.config import \
142+
_construct_checkpoint_loader
143+
self.checkpoint_loader = _construct_checkpoint_loader(
144+
self.llm_args.backend, self.llm_args.checkpoint_loader,
145+
self.llm_args.checkpoint_format)
146+
147+
_executor = create_executor(**args)
148+
self.max_seq_len = self.llm_args.max_seq_len
149+
if _executor.max_seq_len is not None:
150+
# max_seq_len might be updated by model engine as in create_py_executor
151+
self.max_seq_len = _executor.max_seq_len
152+
return _executor
138153

139154
def _create_engine(executor_config):
140155
if executor_config is None:
@@ -158,8 +173,7 @@ def _create_engine(executor_config):
158173
executor_config)
159174

160175
self.engine = _create_py_executor(
161-
executor_config) if llm_args is not None else _create_engine(
162-
executor_config)
176+
) if self.llm_args is not None else _create_engine(executor_config)
163177

164178
self._lora_manager: Optional[LoraManager] = None
165179
self._prompt_adapter_manager: Optional[PromptAdapterManager] = None
@@ -182,8 +196,9 @@ def _create_engine(executor_config):
182196
if engine_config.build_config.max_prompt_embedding_table_size > 0:
183197
self._prompt_adapter_manager = PromptAdapterManager()
184198

185-
if getattr(self._executor_config, "backend",
186-
"") == "pytorch" and lora_config is not None:
199+
if self.llm_args and getattr(
200+
self.llm_args, "backend",
201+
"") == "pytorch" and lora_config is not None:
187202
from tensorrt_llm._torch.pyexecutor.resource_manager import \
188203
ResourceManagerType
189204
peft_cache_manager = self.engine.resource_manager.resource_managers.get(
@@ -465,32 +480,51 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
465480
assert request.id is not None
466481

467482
def _deduce_max_tokens(request: GenerationRequest,
468-
executor_config: tllm.ExecutorConfig) -> int:
483+
executor_config: Optional[tllm.ExecutorConfig],
484+
llm_args: Optional[BaseLlmArgs] = None) -> int:
469485
if request.sampling_params.max_tokens:
470486
return request.sampling_params.max_tokens
471487
# deduce max_tokens when it's not set by user
472488
query_token_len = len(
473489
request.query_token_ids) if request.query_token_ids else 0
474-
cp_size = 1 if (not hasattr(executor_config, "mapping")
475-
or executor_config.mapping.cp_size
476-
is None) else executor_config.mapping.cp_size
477-
if not hasattr(executor_config, "max_seq_len"):
478-
raise RuntimeError(
479-
"max_tokens for sampling is not set and cannot be deduced")
490+
cp_size = 1
491+
max_seq_len = None
492+
if llm_args is not None:
493+
# deduce max_tokens by llm args
494+
assert executor_config is None, "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined."
495+
if hasattr(self,
496+
"mapping") and self.mapping.cp_size is not None:
497+
cp_size = self.mapping.cp_size
498+
if not hasattr(self, "max_seq_len"):
499+
raise RuntimeError(
500+
"max_tokens for sampling is not set and cannot be deduced by llm args"
501+
)
502+
max_seq_len = self.max_seq_len
503+
else:
504+
# deduce max_tokens by executor config
505+
if hasattr(executor_config, "mapping"
506+
) and executor_config.mapping.cp_size is not None:
507+
cp_size = executor_config.mapping.cp_size
508+
if not hasattr(executor_config, "max_seq_len"):
509+
raise RuntimeError(
510+
"max_tokens for sampling is not set and cannot be deduced"
511+
)
512+
max_seq_len = executor_config.max_seq_len
480513
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
481-
default_max_tokens = executor_config.max_seq_len - splited_prompt_len - query_token_len
514+
default_max_tokens = max_seq_len - splited_prompt_len - query_token_len
482515
if default_max_tokens < 0:
483516
raise ValueError(
484517
f"Deduced max_tokens {default_max_tokens} is less than 0, because"
485518
f"prompt length {splited_prompt_len} plus query length {query_token_len} "
486-
f"is larger than max_seq_len {executor_config.max_seq_len}")
519+
f"is larger than max_seq_len {max_seq_len}")
487520
return default_max_tokens
488521

489522
try:
490523
executor_request = tllm.Request(
491524
client_id=request.id,
492525
input_token_ids=prompt_token_ids,
493-
max_tokens=_deduce_max_tokens(request, self._executor_config),
526+
max_tokens=_deduce_max_tokens(request, self._executor_config,
527+
self.llm_args),
494528
streaming=request.streaming,
495529
sampling_config=request.sampling_params._get_sampling_config(),
496530
end_id=-1 if request.sampling_params.ignore_eos else
@@ -616,11 +650,19 @@ def shutdown(self):
616650
self.engine.shutdown()
617651
self.engine = None
618652

619-
if hasattr(
620-
self._executor_config, "checkpoint_loader"
621-
) and self._executor_config.checkpoint_loader is not None:
622-
self._executor_config.checkpoint_loader.cleanup()
623-
self._executor_config.checkpoint_loader = None
653+
if self.llm_args is not None:
654+
assert self._executor_config is None, "An empty executor_config is expected in shutdown when LLM arguments are defined."
655+
if (self.llm_args.backend == "pytorch"
656+
and hasattr(self, "checkpoint_loader")
657+
and self.checkpoint_loader is not None):
658+
self.checkpoint_loader.cleanup()
659+
self.checkpoint_loader = None
660+
else:
661+
if hasattr(
662+
self._executor_config, "checkpoint_loader"
663+
) and self._executor_config.checkpoint_loader is not None:
664+
self._executor_config.checkpoint_loader.cleanup()
665+
self._executor_config.checkpoint_loader = None
624666

625667
# Check if there are any errors from the threads before shutdown.
626668
self._handle_background_error()
@@ -666,7 +708,7 @@ def worker_main(
666708
lora_config: Optional[LoraConfig] = None,
667709
hf_model_dir: Optional[Path] = None,
668710
tokenizer: Optional[TokenizerBase] = None,
669-
llm_args: Optional[TorchLlmArgs] = None,
711+
llm_args: Optional[BaseLlmArgs] = None,
670712
) -> None:
671713
mpi_comm().barrier()
672714
print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n",

0 commit comments

Comments
 (0)