18
18
mpi_comm , mpi_rank , nvtx_range_debug )
19
19
from ..bindings import executor as tllm
20
20
from ..builder import ConfigEncoder , Engine , EngineConfig
21
- from ..llmapi .llm_args import PybindMirror
21
+ from ..llmapi .llm_args import PybindMirror , TorchLlmArgs
22
22
from ..llmapi .mpi_session import set_mpi_session_cpp
23
+ from ..llmapi .tokenizer import TokenizerBase
23
24
from ..llmapi .tracer import VizTracer , global_tracer , set_global_tracer
24
25
from ..llmapi .utils import (AsyncQueue , ManagedThread , _SyncQueue ,
25
26
clear_sched_affinity , print_colored_debug ,
@@ -60,7 +61,9 @@ def __init__(
60
61
postproc_worker_config : Optional [PostprocWorkerConfig ] = None ,
61
62
is_llm_executor : Optional [bool ] = None ,
62
63
lora_config : Optional [LoraConfig ] = None ,
63
- garbage_collection_gen0_threshold : Optional [int ] = None ,
64
+ hf_model_dir : Optional [Path ] = None ,
65
+ tokenizer : Optional [TokenizerBase ] = None ,
66
+ llm_args : Optional [TorchLlmArgs ] = None ,
64
67
) -> None :
65
68
postproc_config = postproc_worker_config or PostprocWorkerConfig ()
66
69
super ().__init__ (
@@ -81,54 +84,49 @@ def __init__(
81
84
self ._await_response_helper = AwaitResponseHelper (
82
85
self ) # TODO: make it weakref
83
86
self ._executor_config = executor_config
84
- self ._is_pytorch_backend = getattr ( self . _executor_config , "backend" ,
85
- None ) == "pytorch"
87
+ self ._is_pytorch_backend = llm_args is not None and llm_args . backend == "pytorch"
88
+ self . llm_args = llm_args
86
89
87
90
if global_mpi_size () > 1 :
88
91
logger .set_rank (self .global_rank )
89
92
90
93
if isinstance (engine , list ):
91
94
engine = engine [self .rank ]
92
95
93
- if executor_config is None :
94
- executor_config = tllm .ExecutorConfig (1 )
95
-
96
- executor_config .logits_post_processor_config = tllm .LogitsPostProcessorConfig (
97
- processor_batched = batched_logits_processor , replicate = False )
98
-
99
- def _create_engine ():
96
+ def _get_comm_ranks_device_id ():
100
97
device_id = self .global_rank % torch .cuda .device_count ()
101
98
torch .cuda .set_device (device_id )
102
-
103
99
# Make sure C++ executor would use same devices/ranks as py_executor
104
100
global_rank = global_mpi_rank ()
105
101
comm_ranks = mpi_comm ().allgather (global_rank )
106
102
device_ids = mpi_comm ().allgather (device_id )
103
+ return comm_ranks , device_ids
104
+
105
+ def _create_py_executor (executor_config ):
106
+ assert executor_config is None , "expect an empty executor_config is _create_py_executor"
107
+ executor_config = llm_args .get_executor_config (
108
+ hf_model_dir , tokenizer )
109
+ # Persist so downstream code (e.g., default max_tokens deduction) has access
110
+ self ._executor_config = executor_config
111
+ executor_config .logits_post_processor_config = tllm .LogitsPostProcessorConfig (
112
+ processor_batched = batched_logits_processor , replicate = False )
113
+ comm_ranks , device_ids = _get_comm_ranks_device_id ()
107
114
executor_config .parallel_config = tllm .ParallelConfig (
108
115
participant_ids = comm_ranks , device_ids = device_ids )
109
-
110
- if isinstance (engine , Engine ):
111
- return tllm .Executor (engine .engine ,
112
- json .dumps (engine .config .to_dict (),
113
- cls = ConfigEncoder ),
114
- tllm .ModelType .DECODER_ONLY ,
115
- executor_config = executor_config ,
116
- managed_weights = engine .managed_weights )
117
-
118
- if not hasattr (executor_config , "backend" ):
119
- return tllm .Executor (engine , tllm .ModelType .DECODER_ONLY ,
120
- executor_config )
121
116
args = {
122
117
"executor_config" : executor_config ,
123
118
"checkpoint_dir" : executor_config .hf_model_dir ,
124
119
}
120
+ assert hasattr (
121
+ executor_config , "backend"
122
+ ), "executor_config should be with backend in _create_py_executor"
125
123
if executor_config .backend == "pytorch" :
126
124
from tensorrt_llm ._torch .pyexecutor .py_executor_creator import \
127
125
create_py_executor
128
126
create_executor = create_py_executor
129
127
args ["lora_config" ] = lora_config
130
128
args [
131
- "garbage_collection_gen0_threshold" ] = garbage_collection_gen0_threshold
129
+ "garbage_collection_gen0_threshold" ] = llm_args . garbage_collection_gen0_threshold
132
130
elif executor_config .backend == "_autodeploy" :
133
131
from tensorrt_llm ._torch .auto_deploy .shim .ad_executor import \
134
132
create_autodeploy_executor
@@ -138,7 +136,30 @@ def _create_engine():
138
136
f"Unsupported backend config: { executor_config .backend } " )
139
137
return create_executor (** args )
140
138
141
- self .engine = _create_engine ()
139
+ def _create_engine (executor_config ):
140
+ if executor_config is None :
141
+ executor_config = tllm .ExecutorConfig (1 )
142
+ executor_config .logits_post_processor_config = tllm .LogitsPostProcessorConfig (
143
+ processor_batched = batched_logits_processor , replicate = False )
144
+ comm_ranks , device_ids = _get_comm_ranks_device_id ()
145
+ executor_config .parallel_config = tllm .ParallelConfig (
146
+ participant_ids = comm_ranks , device_ids = device_ids )
147
+
148
+ if isinstance (engine , Engine ):
149
+ return tllm .Executor (engine .engine ,
150
+ json .dumps (engine .config .to_dict (),
151
+ cls = ConfigEncoder ),
152
+ tllm .ModelType .DECODER_ONLY ,
153
+ executor_config = executor_config ,
154
+ managed_weights = engine .managed_weights )
155
+
156
+ assert not hasattr (executor_config , "backend" )
157
+ return tllm .Executor (engine , tllm .ModelType .DECODER_ONLY ,
158
+ executor_config )
159
+
160
+ self .engine = _create_py_executor (
161
+ executor_config ) if llm_args is not None else _create_engine (
162
+ executor_config )
142
163
143
164
self ._lora_manager : Optional [LoraManager ] = None
144
165
self ._prompt_adapter_manager : Optional [PromptAdapterManager ] = None
@@ -161,7 +182,7 @@ def _create_engine():
161
182
if engine_config .build_config .max_prompt_embedding_table_size > 0 :
162
183
self ._prompt_adapter_manager = PromptAdapterManager ()
163
184
164
- if getattr (executor_config , "backend" ,
185
+ if getattr (self . _executor_config , "backend" ,
165
186
"" ) == "pytorch" and lora_config is not None :
166
187
from tensorrt_llm ._torch .pyexecutor .resource_manager import \
167
188
ResourceManagerType
@@ -430,14 +451,16 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
430
451
context_phase_params = request .disaggregated_params .get_context_phase_params (
431
452
)
432
453
433
- is_overlap_enabled = self ._is_pytorch_backend and not self ._executor_config .pytorch_backend_config .disable_overlap_scheduler
434
- if is_overlap_enabled :
435
- is_disaggregated = self .engine .kv_cache_transceiver is not None
436
- if is_disaggregated and (
437
- request_type == tllm .RequestType .REQUEST_TYPE_CONTEXT_ONLY ):
438
- raise ValueError (
439
- "Context only requests are not supported in pytorch backend when overlap is enabled."
440
- )
454
+ if self ._is_pytorch_backend :
455
+ assert isinstance (self .llm_args , TorchLlmArgs )
456
+ if not self .llm_args .disable_overlap_scheduler :
457
+ is_disaggregated = self .engine .kv_cache_transceiver is not None
458
+ if is_disaggregated and (
459
+ request_type
460
+ == tllm .RequestType .REQUEST_TYPE_CONTEXT_ONLY ):
461
+ raise ValueError (
462
+ "Context only requests are not supported in pytorch backend when overlap is enabled."
463
+ )
441
464
442
465
assert request .id is not None
443
466
@@ -641,7 +664,9 @@ def worker_main(
641
664
is_llm_executor : Optional [
642
665
bool ] = True , # whether it's the main executor instance
643
666
lora_config : Optional [LoraConfig ] = None ,
644
- garbage_collection_gen0_threshold : Optional [int ] = None ,
667
+ hf_model_dir : Optional [Path ] = None ,
668
+ tokenizer : Optional [TokenizerBase ] = None ,
669
+ llm_args : Optional [TorchLlmArgs ] = None ,
645
670
) -> None :
646
671
mpi_comm ().barrier ()
647
672
print_colored_debug (f"Worker { mpi_rank ()} entering worker_main...\n " ,
@@ -768,7 +793,9 @@ def notify_proxy_threads_to_quit():
768
793
postproc_worker_config = postproc_worker_config ,
769
794
is_llm_executor = is_llm_executor ,
770
795
lora_config = lora_config ,
771
- garbage_collection_gen0_threshold = garbage_collection_gen0_threshold )
796
+ hf_model_dir = hf_model_dir ,
797
+ tokenizer = tokenizer ,
798
+ llm_args = llm_args )
772
799
except Exception as e :
773
800
logger .error (f"Failed to initialize executor on rank { mpi_rank ()} : { e } " )
774
801
logger .error (traceback .format_exc ())
0 commit comments