Skip to content

Commit 20922b7

Browse files
[None][chore] Create PyExecutor from TorchLlmArgs Part 1 (#7105)
Signed-off-by: leslie-fang25 <[email protected]>
1 parent b845eb7 commit 20922b7

File tree

8 files changed

+197
-181
lines changed

8 files changed

+197
-181
lines changed

tensorrt_llm/executor/executor.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
from ..bindings import executor as tllm
2222
from ..builder import Engine
2323
from ..disaggregated_params import DisaggregatedParams
24+
from ..llmapi.llm_args import TorchLlmArgs
2425
from ..llmapi.llm_utils import KvCacheRetentionConfig
2526
from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available,
2627
need_spawn_mpi_workers)
28+
from ..llmapi.tokenizer import TokenizerBase
2729
from ..llmapi.utils import (AsyncQueue, enable_llm_debug,
2830
enable_worker_single_process_for_tp1, print_colored,
2931
print_colored_debug)
@@ -354,7 +356,9 @@ def create(
354356
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
355357
is_llm_executor: Optional[bool] = None,
356358
lora_config: Optional[LoraConfig] = None,
357-
garbage_collection_gen0_threshold: Optional[int] = None,
359+
hf_model_dir: Optional[Path] = None,
360+
tokenizer: Optional[TokenizerBase] = None,
361+
llm_args: Optional[TorchLlmArgs] = None,
358362
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
359363
# local imports to avoid cyclic importing
360364
from .proxy import GenerationExecutorProxy
@@ -381,6 +385,9 @@ def create(
381385
"engine": engine,
382386
"executor_config": executor_config,
383387
"batched_logits_processor": batched_logits_processor,
388+
"hf_model_dir": hf_model_dir,
389+
"tokenizer": tokenizer,
390+
"llm_args": llm_args,
384391
}
385392

386393
if lora_config:
@@ -398,9 +405,7 @@ def create(
398405
model_world_size=model_world_size,
399406
mpi_session=mpi_session,
400407
postproc_worker_config=postproc_worker_config,
401-
is_llm_executor=is_llm_executor,
402-
garbage_collection_gen0_threshold=
403-
garbage_collection_gen0_threshold)
408+
is_llm_executor=is_llm_executor)
404409

405410
# WAR: For the performance of gathering logits, we use single process worker
406411
# for TP1 to avoid the large overhead of IPC.
@@ -411,9 +416,7 @@ def create(
411416
"Using single process worker for TP1, this may hurt streaming generation performance."
412417
)
413418
return GenerationExecutorWorker(**worker_kwargs,
414-
is_llm_executor=is_llm_executor,
415-
garbage_collection_gen0_threshold=
416-
garbage_collection_gen0_threshold)
419+
is_llm_executor=is_llm_executor)
417420

418421
# For single-gpu case:
419422
# Partition the workload to multiple process for streaming performance.
@@ -425,9 +428,7 @@ def create(
425428
model_world_size=model_world_size,
426429
mpi_session=None, # use mpi4py
427430
postproc_worker_config=postproc_worker_config,
428-
is_llm_executor=is_llm_executor,
429-
garbage_collection_gen0_threshold=
430-
garbage_collection_gen0_threshold)
431+
is_llm_executor=is_llm_executor)
431432
else:
432433
ctx = multiprocessing.get_context("spawn")
433434
# The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot.
@@ -438,9 +439,7 @@ def create(
438439
model_world_size=model_world_size,
439440
mpi_session=mpi_session,
440441
postproc_worker_config=postproc_worker_config,
441-
is_llm_executor=is_llm_executor,
442-
garbage_collection_gen0_threshold=
443-
garbage_collection_gen0_threshold)
442+
is_llm_executor=is_llm_executor)
444443

445444
def wait_first_completed(
446445
self, futures: List[GenerationResult]

tensorrt_llm/executor/proxy.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def __init__(
4545
worker_cls: type = GenerationExecutorWorker,
4646
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
4747
is_llm_executor: Optional[bool] = None,
48-
garbage_collection_gen0_threshold: Optional[int] = None,
4948
) -> None:
5049
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig(
5150
)
@@ -87,14 +86,14 @@ def __init__(
8786

8887
self.model_world_size = model_world_size
8988

90-
self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold
89+
self.garbage_collection_gen0_threshold = worker_kwargs[
90+
"llm_args"].garbage_collection_gen0_threshold if worker_kwargs.get(
91+
"llm_args", None) is not None else None
9192

9293
worker_kwargs = dict(**worker_kwargs,
9394
worker_queues=self._setup_queues(),
9495
postproc_worker_config=postproc_worker_config,
95-
is_llm_executor=False,
96-
garbage_collection_gen0_threshold=self.
97-
garbage_collection_gen0_threshold)
96+
is_llm_executor=False)
9897

9998
if "log_level" not in worker_kwargs:
10099
worker_kwargs["log_level"] = logger.level

tensorrt_llm/executor/worker.py

Lines changed: 64 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
mpi_comm, mpi_rank, nvtx_range_debug)
1919
from ..bindings import executor as tllm
2020
from ..builder import ConfigEncoder, Engine, EngineConfig
21-
from ..llmapi.llm_args import PybindMirror
21+
from ..llmapi.llm_args import PybindMirror, TorchLlmArgs
2222
from ..llmapi.mpi_session import set_mpi_session_cpp
23+
from ..llmapi.tokenizer import TokenizerBase
2324
from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
2425
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
2526
clear_sched_affinity, print_colored_debug,
@@ -60,7 +61,9 @@ def __init__(
6061
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
6162
is_llm_executor: Optional[bool] = None,
6263
lora_config: Optional[LoraConfig] = None,
63-
garbage_collection_gen0_threshold: Optional[int] = None,
64+
hf_model_dir: Optional[Path] = None,
65+
tokenizer: Optional[TokenizerBase] = None,
66+
llm_args: Optional[TorchLlmArgs] = None,
6467
) -> None:
6568
postproc_config = postproc_worker_config or PostprocWorkerConfig()
6669
super().__init__(
@@ -81,54 +84,49 @@ def __init__(
8184
self._await_response_helper = AwaitResponseHelper(
8285
self) # TODO: make it weakref
8386
self._executor_config = executor_config
84-
self._is_pytorch_backend = getattr(self._executor_config, "backend",
85-
None) == "pytorch"
87+
self._is_pytorch_backend = llm_args is not None and llm_args.backend == "pytorch"
88+
self.llm_args = llm_args
8689

8790
if global_mpi_size() > 1:
8891
logger.set_rank(self.global_rank)
8992

9093
if isinstance(engine, list):
9194
engine = engine[self.rank]
9295

93-
if executor_config is None:
94-
executor_config = tllm.ExecutorConfig(1)
95-
96-
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
97-
processor_batched=batched_logits_processor, replicate=False)
98-
99-
def _create_engine():
96+
def _get_comm_ranks_device_id():
10097
device_id = self.global_rank % torch.cuda.device_count()
10198
torch.cuda.set_device(device_id)
102-
10399
# Make sure C++ executor would use same devices/ranks as py_executor
104100
global_rank = global_mpi_rank()
105101
comm_ranks = mpi_comm().allgather(global_rank)
106102
device_ids = mpi_comm().allgather(device_id)
103+
return comm_ranks, device_ids
104+
105+
def _create_py_executor(executor_config):
106+
assert executor_config is None, "expect an empty executor_config is _create_py_executor"
107+
executor_config = llm_args.get_executor_config(
108+
hf_model_dir, tokenizer)
109+
# Persist so downstream code (e.g., default max_tokens deduction) has access
110+
self._executor_config = executor_config
111+
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
112+
processor_batched=batched_logits_processor, replicate=False)
113+
comm_ranks, device_ids = _get_comm_ranks_device_id()
107114
executor_config.parallel_config = tllm.ParallelConfig(
108115
participant_ids=comm_ranks, device_ids=device_ids)
109-
110-
if isinstance(engine, Engine):
111-
return tllm.Executor(engine.engine,
112-
json.dumps(engine.config.to_dict(),
113-
cls=ConfigEncoder),
114-
tllm.ModelType.DECODER_ONLY,
115-
executor_config=executor_config,
116-
managed_weights=engine.managed_weights)
117-
118-
if not hasattr(executor_config, "backend"):
119-
return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY,
120-
executor_config)
121116
args = {
122117
"executor_config": executor_config,
123118
"checkpoint_dir": executor_config.hf_model_dir,
124119
}
120+
assert hasattr(
121+
executor_config, "backend"
122+
), "executor_config should be with backend in _create_py_executor"
125123
if executor_config.backend == "pytorch":
126124
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
127125
create_py_executor
128126
create_executor = create_py_executor
129127
args["lora_config"] = lora_config
130128
args[
131-
"garbage_collection_gen0_threshold"] = garbage_collection_gen0_threshold
129+
"garbage_collection_gen0_threshold"] = llm_args.garbage_collection_gen0_threshold
132130
elif executor_config.backend == "_autodeploy":
133131
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
134132
create_autodeploy_executor
@@ -138,7 +136,30 @@ def _create_engine():
138136
f"Unsupported backend config: {executor_config.backend}")
139137
return create_executor(**args)
140138

141-
self.engine = _create_engine()
139+
def _create_engine(executor_config):
140+
if executor_config is None:
141+
executor_config = tllm.ExecutorConfig(1)
142+
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
143+
processor_batched=batched_logits_processor, replicate=False)
144+
comm_ranks, device_ids = _get_comm_ranks_device_id()
145+
executor_config.parallel_config = tllm.ParallelConfig(
146+
participant_ids=comm_ranks, device_ids=device_ids)
147+
148+
if isinstance(engine, Engine):
149+
return tllm.Executor(engine.engine,
150+
json.dumps(engine.config.to_dict(),
151+
cls=ConfigEncoder),
152+
tllm.ModelType.DECODER_ONLY,
153+
executor_config=executor_config,
154+
managed_weights=engine.managed_weights)
155+
156+
assert not hasattr(executor_config, "backend")
157+
return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY,
158+
executor_config)
159+
160+
self.engine = _create_py_executor(
161+
executor_config) if llm_args is not None else _create_engine(
162+
executor_config)
142163

143164
self._lora_manager: Optional[LoraManager] = None
144165
self._prompt_adapter_manager: Optional[PromptAdapterManager] = None
@@ -161,7 +182,7 @@ def _create_engine():
161182
if engine_config.build_config.max_prompt_embedding_table_size > 0:
162183
self._prompt_adapter_manager = PromptAdapterManager()
163184

164-
if getattr(executor_config, "backend",
185+
if getattr(self._executor_config, "backend",
165186
"") == "pytorch" and lora_config is not None:
166187
from tensorrt_llm._torch.pyexecutor.resource_manager import \
167188
ResourceManagerType
@@ -430,14 +451,16 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
430451
context_phase_params = request.disaggregated_params.get_context_phase_params(
431452
)
432453

433-
is_overlap_enabled = self._is_pytorch_backend and not self._executor_config.pytorch_backend_config.disable_overlap_scheduler
434-
if is_overlap_enabled:
435-
is_disaggregated = self.engine.kv_cache_transceiver is not None
436-
if is_disaggregated and (
437-
request_type == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY):
438-
raise ValueError(
439-
"Context only requests are not supported in pytorch backend when overlap is enabled."
440-
)
454+
if self._is_pytorch_backend:
455+
assert isinstance(self.llm_args, TorchLlmArgs)
456+
if not self.llm_args.disable_overlap_scheduler:
457+
is_disaggregated = self.engine.kv_cache_transceiver is not None
458+
if is_disaggregated and (
459+
request_type
460+
== tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY):
461+
raise ValueError(
462+
"Context only requests are not supported in pytorch backend when overlap is enabled."
463+
)
441464

442465
assert request.id is not None
443466

@@ -641,7 +664,9 @@ def worker_main(
641664
is_llm_executor: Optional[
642665
bool] = True, # whether it's the main executor instance
643666
lora_config: Optional[LoraConfig] = None,
644-
garbage_collection_gen0_threshold: Optional[int] = None,
667+
hf_model_dir: Optional[Path] = None,
668+
tokenizer: Optional[TokenizerBase] = None,
669+
llm_args: Optional[TorchLlmArgs] = None,
645670
) -> None:
646671
mpi_comm().barrier()
647672
print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n",
@@ -768,7 +793,9 @@ def notify_proxy_threads_to_quit():
768793
postproc_worker_config=postproc_worker_config,
769794
is_llm_executor=is_llm_executor,
770795
lora_config=lora_config,
771-
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
796+
hf_model_dir=hf_model_dir,
797+
tokenizer=tokenizer,
798+
llm_args=llm_args)
772799
except Exception as e:
773800
logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}")
774801
logger.error(traceback.format_exc())

0 commit comments

Comments
 (0)