18
18
mpi_comm , mpi_rank , nvtx_range_debug )
19
19
from ..bindings import executor as tllm
20
20
from ..builder import ConfigEncoder , Engine , EngineConfig
21
- from ..llmapi .llm_args import KvCacheConnectorConfig , PybindMirror , TorchLlmArgs
21
+ from ..llmapi .llm_args import (BaseLlmArgs , KvCacheConnectorConfig ,
22
+ PybindMirror , TorchLlmArgs )
22
23
from ..llmapi .mpi_session import set_mpi_session_cpp
23
24
from ..llmapi .tokenizer import TokenizerBase
24
25
from ..llmapi .tracer import VizTracer , global_tracer , set_global_tracer
@@ -64,7 +65,7 @@ def __init__(
64
65
kv_connector_config : Optional [KvCacheConnectorConfig ] = None ,
65
66
hf_model_dir : Optional [Path ] = None ,
66
67
tokenizer : Optional [TokenizerBase ] = None ,
67
- llm_args : Optional [TorchLlmArgs ] = None ,
68
+ llm_args : Optional [BaseLlmArgs ] = None ,
68
69
) -> None :
69
70
postproc_config = postproc_worker_config or PostprocWorkerConfig ()
70
71
super ().__init__ (
@@ -107,40 +108,55 @@ def _get_comm_ranks_device_id():
107
108
device_ids = mpi_comm ().allgather (device_id )
108
109
return comm_ranks , device_ids
109
110
110
- def _create_py_executor (executor_config ):
111
- assert executor_config is None , "expect an empty executor_config is _create_py_executor"
112
- executor_config = llm_args .get_executor_config (
113
- hf_model_dir , tokenizer )
114
- # Persist so downstream code (e.g., default max_tokens deduction) has access
115
- self ._executor_config = executor_config
116
- executor_config .logits_post_processor_config = tllm .LogitsPostProcessorConfig (
117
- processor_batched = batched_logits_processor , replicate = False )
118
- comm_ranks , device_ids = _get_comm_ranks_device_id ()
119
- executor_config .parallel_config = tllm .ParallelConfig (
120
- participant_ids = comm_ranks , device_ids = device_ids )
121
- args = {
122
- "executor_config" : executor_config ,
123
- "checkpoint_dir" : executor_config .hf_model_dir ,
124
- }
111
+ def _create_py_executor ():
112
+ args = {}
125
113
assert hasattr (
126
- executor_config , "backend"
127
- ), "executor_config should be with backend in _create_py_executor"
128
- if executor_config .backend == "pytorch" :
114
+ self . llm_args , "backend"
115
+ ), "llm_args should be with backend in _create_py_executor"
116
+ if self . llm_args .backend == "pytorch" :
129
117
from tensorrt_llm ._torch .pyexecutor .py_executor_creator import \
130
118
create_py_executor
131
119
create_executor = create_py_executor
120
+ args ["llm_args" ] = self .llm_args
121
+ args ["checkpoint_dir" ] = hf_model_dir
122
+ args ["tokenizer" ] = tokenizer
132
123
args ["lora_config" ] = lora_config
133
- args [
134
- "garbage_collection_gen0_threshold" ] = llm_args .garbage_collection_gen0_threshold
135
124
args ["kv_connector_config" ] = kv_connector_config
136
- elif executor_config .backend == "_autodeploy" :
125
+ args [
126
+ "logits_post_processor_config" ] = tllm .LogitsPostProcessorConfig (
127
+ processor_batched = batched_logits_processor ,
128
+ replicate = False )
129
+ comm_ranks , device_ids = _get_comm_ranks_device_id ()
130
+ args ["parallel_config" ] = tllm .ParallelConfig (
131
+ participant_ids = comm_ranks , device_ids = device_ids )
132
+ elif self .llm_args .backend == "_autodeploy" :
133
+ from tensorrt_llm ._torch .auto_deploy .llm_args import \
134
+ LlmArgs as ADLlmArgs
137
135
from tensorrt_llm ._torch .auto_deploy .shim .ad_executor import \
138
136
create_autodeploy_executor
139
137
create_executor = create_autodeploy_executor
138
+ assert isinstance (self .llm_args , ADLlmArgs )
139
+ args ["ad_config" ] = self .llm_args .get_pytorch_backend_config ()
140
140
else :
141
141
raise ValueError (
142
- f"Unsupported backend config: { executor_config .backend } " )
143
- return create_executor (** args )
142
+ f"Unsupported backend config: { self .llm_args .backend } " )
143
+
144
+ # Define additional attributes that can be used later, such as in _deduce_max_tokens
145
+ self .mapping = self .llm_args .parallel_config .to_mapping ()
146
+ self .checkpoint_loader = None
147
+ if self .llm_args .backend == "pytorch" :
148
+ from tensorrt_llm ._torch .pyexecutor .config import \
149
+ _construct_checkpoint_loader
150
+ self .checkpoint_loader = _construct_checkpoint_loader (
151
+ self .llm_args .backend , self .llm_args .checkpoint_loader ,
152
+ self .llm_args .checkpoint_format )
153
+
154
+ _executor = create_executor (** args )
155
+ self .max_seq_len = self .llm_args .max_seq_len
156
+ if _executor .max_seq_len is not None :
157
+ # max_seq_len might be updated by model engine as in create_py_executor
158
+ self .max_seq_len = _executor .max_seq_len
159
+ return _executor
144
160
145
161
def _create_engine (executor_config ):
146
162
if executor_config is None :
@@ -164,8 +180,7 @@ def _create_engine(executor_config):
164
180
executor_config )
165
181
166
182
self .engine = _create_py_executor (
167
- executor_config ) if llm_args is not None else _create_engine (
168
- executor_config )
183
+ ) if self .llm_args is not None else _create_engine (executor_config )
169
184
170
185
self ._lora_manager : Optional [LoraManager ] = None
171
186
self ._prompt_adapter_manager : Optional [PromptAdapterManager ] = None
@@ -188,8 +203,9 @@ def _create_engine(executor_config):
188
203
if engine_config .build_config .max_prompt_embedding_table_size > 0 :
189
204
self ._prompt_adapter_manager = PromptAdapterManager ()
190
205
191
- if getattr (self ._executor_config , "backend" ,
192
- "" ) == "pytorch" and lora_config is not None :
206
+ if self .llm_args and getattr (
207
+ self .llm_args , "backend" ,
208
+ "" ) == "pytorch" and lora_config is not None :
193
209
from tensorrt_llm ._torch .pyexecutor .resource_manager import \
194
210
ResourceManagerType
195
211
peft_cache_manager = self .engine .resource_manager .resource_managers .get (
@@ -471,26 +487,43 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
471
487
assert request .id is not None
472
488
473
489
def _deduce_max_tokens (request : GenerationRequest ,
474
- executor_config : tllm .ExecutorConfig ) -> int :
490
+ executor_config : tllm .ExecutorConfig ,
491
+ llm_args : Optional [BaseLlmArgs ] = None ) -> int :
475
492
# deduce max_tokens when it's not set by user
476
493
max_tokens = request .sampling_params .max_tokens
477
494
query_token_len = len (
478
495
request .query_token_ids ) if request .query_token_ids else 0
479
- cp_size = 1 if (not hasattr (executor_config , "mapping" )
480
- or executor_config .mapping .cp_size
481
- is None ) else executor_config .mapping .cp_size
482
- if not hasattr (executor_config , "max_seq_len" ):
496
+
497
+ cp_size = 1
498
+ max_seq_len = None
499
+ if llm_args is not None :
500
+ # deduce max_tokens by llm args
501
+ assert executor_config is None , "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined."
502
+ if hasattr (self ,
503
+ "mapping" ) and self .mapping .cp_size is not None :
504
+ cp_size = self .mapping .cp_size
505
+ max_seq_len = getattr (self , "max_seq_len" , None )
506
+ else :
507
+ # deduce max_tokens by executor config
508
+ if hasattr (executor_config , "mapping"
509
+ ) and executor_config .mapping .cp_size is not None :
510
+ cp_size = executor_config .mapping .cp_size
511
+ max_seq_len = getattr (executor_config , "max_seq_len" , None )
512
+ if max_seq_len is None :
483
513
logger .warning ("`default_max_tokens` cannot be deduced" )
484
514
if max_tokens is None :
485
515
raise ValueError (
486
516
"`max_tokens` must be set when `default_max_tokens` cannot be deduced"
487
517
)
518
+ else :
519
+ # use max_tokens if can't deduce default_max_tokens
520
+ return max_tokens
488
521
splited_prompt_len = int (len (prompt_token_ids ) / cp_size )
489
- default_max_tokens = executor_config . max_seq_len - splited_prompt_len - query_token_len
522
+ default_max_tokens = max_seq_len - splited_prompt_len - query_token_len
490
523
if default_max_tokens <= 0 :
491
524
logger .warning (
492
525
f"`default_max_tokens` ({ default_max_tokens } ) should be greater than 0, "
493
- f"`default_max_tokens` ({ default_max_tokens } ) = max_seq_len ({ executor_config . max_seq_len } )"
526
+ f"`default_max_tokens` ({ default_max_tokens } ) = max_seq_len ({ max_seq_len } )"
494
527
f" - `splited_prompt_len` ({ splited_prompt_len } ) - `query_token_len` ({ query_token_len } )"
495
528
)
496
529
if max_tokens is None :
@@ -512,7 +545,8 @@ def _deduce_max_tokens(request: GenerationRequest,
512
545
executor_request = tllm .Request (
513
546
client_id = request .id ,
514
547
input_token_ids = prompt_token_ids ,
515
- max_tokens = _deduce_max_tokens (request , self ._executor_config ),
548
+ max_tokens = _deduce_max_tokens (request , self ._executor_config ,
549
+ self .llm_args ),
516
550
streaming = request .streaming ,
517
551
sampling_config = request .sampling_params ._get_sampling_config (),
518
552
end_id = - 1 if request .sampling_params .ignore_eos else
@@ -638,11 +672,19 @@ def shutdown(self):
638
672
self .engine .shutdown ()
639
673
self .engine = None
640
674
641
- if hasattr (
642
- self ._executor_config , "checkpoint_loader"
643
- ) and self ._executor_config .checkpoint_loader is not None :
644
- self ._executor_config .checkpoint_loader .cleanup ()
645
- self ._executor_config .checkpoint_loader = None
675
+ if self .llm_args is not None :
676
+ assert self ._executor_config is None , "An empty executor_config is expected in shutdown when LLM arguments are defined."
677
+ if (self .llm_args .backend == "pytorch"
678
+ and hasattr (self , "checkpoint_loader" )
679
+ and self .checkpoint_loader is not None ):
680
+ self .checkpoint_loader .cleanup ()
681
+ self .checkpoint_loader = None
682
+ else :
683
+ if hasattr (
684
+ self ._executor_config , "checkpoint_loader"
685
+ ) and self ._executor_config .checkpoint_loader is not None :
686
+ self ._executor_config .checkpoint_loader .cleanup ()
687
+ self ._executor_config .checkpoint_loader = None
646
688
647
689
# Check if there are any errors from the threads before shutdown.
648
690
self ._handle_background_error ()
@@ -689,7 +731,7 @@ def worker_main(
689
731
kv_connector_config : Optional [KvCacheConnectorConfig ] = None ,
690
732
hf_model_dir : Optional [Path ] = None ,
691
733
tokenizer : Optional [TokenizerBase ] = None ,
692
- llm_args : Optional [TorchLlmArgs ] = None ,
734
+ llm_args : Optional [BaseLlmArgs ] = None ,
693
735
) -> None :
694
736
mpi_comm ().barrier ()
695
737
print_colored_debug (f"Worker { mpi_rank ()} entering worker_main...\n " ,
0 commit comments