18
18
mpi_comm , mpi_rank , nvtx_range_debug )
19
19
from ..bindings import executor as tllm
20
20
from ..builder import ConfigEncoder , Engine , EngineConfig
21
- from ..llmapi .llm_args import PybindMirror , TorchLlmArgs
21
+ from ..llmapi .llm_args import BaseLlmArgs , PybindMirror , TorchLlmArgs
22
22
from ..llmapi .mpi_session import set_mpi_session_cpp
23
23
from ..llmapi .tokenizer import TokenizerBase
24
24
from ..llmapi .tracer import VizTracer , global_tracer , set_global_tracer
@@ -63,7 +63,7 @@ def __init__(
63
63
lora_config : Optional [LoraConfig ] = None ,
64
64
hf_model_dir : Optional [Path ] = None ,
65
65
tokenizer : Optional [TokenizerBase ] = None ,
66
- llm_args : Optional [TorchLlmArgs ] = None ,
66
+ llm_args : Optional [BaseLlmArgs ] = None ,
67
67
) -> None :
68
68
postproc_config = postproc_worker_config or PostprocWorkerConfig ()
69
69
super ().__init__ (
@@ -102,39 +102,54 @@ def _get_comm_ranks_device_id():
102
102
device_ids = mpi_comm ().allgather (device_id )
103
103
return comm_ranks , device_ids
104
104
105
- def _create_py_executor (executor_config ):
106
- assert executor_config is None , "expect an empty executor_config is _create_py_executor"
107
- executor_config = llm_args .get_executor_config (
108
- hf_model_dir , tokenizer )
109
- # Persist so downstream code (e.g., default max_tokens deduction) has access
110
- self ._executor_config = executor_config
111
- executor_config .logits_post_processor_config = tllm .LogitsPostProcessorConfig (
112
- processor_batched = batched_logits_processor , replicate = False )
113
- comm_ranks , device_ids = _get_comm_ranks_device_id ()
114
- executor_config .parallel_config = tllm .ParallelConfig (
115
- participant_ids = comm_ranks , device_ids = device_ids )
116
- args = {
117
- "executor_config" : executor_config ,
118
- "checkpoint_dir" : executor_config .hf_model_dir ,
119
- }
105
+ def _create_py_executor ():
106
+ args = {}
120
107
assert hasattr (
121
- executor_config , "backend"
122
- ), "executor_config should be with backend in _create_py_executor"
123
- if executor_config .backend == "pytorch" :
108
+ self . llm_args , "backend"
109
+ ), "llm_args should be with backend in _create_py_executor"
110
+ if self . llm_args .backend == "pytorch" :
124
111
from tensorrt_llm ._torch .pyexecutor .py_executor_creator import \
125
112
create_py_executor
126
113
create_executor = create_py_executor
114
+ args ["llm_args" ] = self .llm_args
115
+ args ["checkpoint_dir" ] = hf_model_dir
116
+ args ["tokenizer" ] = tokenizer
127
117
args ["lora_config" ] = lora_config
128
118
args [
129
- "garbage_collection_gen0_threshold" ] = llm_args .garbage_collection_gen0_threshold
130
- elif executor_config .backend == "_autodeploy" :
119
+ "logits_post_processor_config" ] = tllm .LogitsPostProcessorConfig (
120
+ processor_batched = batched_logits_processor ,
121
+ replicate = False )
122
+ comm_ranks , device_ids = _get_comm_ranks_device_id ()
123
+ args ["parallel_config" ] = tllm .ParallelConfig (
124
+ participant_ids = comm_ranks , device_ids = device_ids )
125
+ elif self .llm_args .backend == "_autodeploy" :
126
+ from tensorrt_llm ._torch .auto_deploy .llm_args import \
127
+ LlmArgs as ADLlmArgs
131
128
from tensorrt_llm ._torch .auto_deploy .shim .ad_executor import \
132
129
create_autodeploy_executor
133
130
create_executor = create_autodeploy_executor
131
+ assert isinstance (self .llm_args , ADLlmArgs )
132
+ args ["ad_config" ] = self .llm_args .get_pytorch_backend_config ()
134
133
else :
135
134
raise ValueError (
136
- f"Unsupported backend config: { executor_config .backend } " )
137
- return create_executor (** args )
135
+ f"Unsupported backend config: { self .llm_args .backend } " )
136
+
137
+ # Define additional attributes that can be used later, such as in _deduce_max_tokens
138
+ self .mapping = self .llm_args .parallel_config .to_mapping ()
139
+ self .checkpoint_loader = None
140
+ if self .llm_args .backend == "pytorch" :
141
+ from tensorrt_llm ._torch .pyexecutor .config import \
142
+ _construct_checkpoint_loader
143
+ self .checkpoint_loader = _construct_checkpoint_loader (
144
+ self .llm_args .backend , self .llm_args .checkpoint_loader ,
145
+ self .llm_args .checkpoint_format )
146
+
147
+ _executor = create_executor (** args )
148
+ self .max_seq_len = self .llm_args .max_seq_len
149
+ if _executor .max_seq_len is not None :
150
+ # max_seq_len might be updated by model engine as in create_py_executor
151
+ self .max_seq_len = _executor .max_seq_len
152
+ return _executor
138
153
139
154
def _create_engine (executor_config ):
140
155
if executor_config is None :
@@ -158,8 +173,7 @@ def _create_engine(executor_config):
158
173
executor_config )
159
174
160
175
self .engine = _create_py_executor (
161
- executor_config ) if llm_args is not None else _create_engine (
162
- executor_config )
176
+ ) if self .llm_args is not None else _create_engine (executor_config )
163
177
164
178
self ._lora_manager : Optional [LoraManager ] = None
165
179
self ._prompt_adapter_manager : Optional [PromptAdapterManager ] = None
@@ -182,8 +196,9 @@ def _create_engine(executor_config):
182
196
if engine_config .build_config .max_prompt_embedding_table_size > 0 :
183
197
self ._prompt_adapter_manager = PromptAdapterManager ()
184
198
185
- if getattr (self ._executor_config , "backend" ,
186
- "" ) == "pytorch" and lora_config is not None :
199
+ if self .llm_args and getattr (
200
+ self .llm_args , "backend" ,
201
+ "" ) == "pytorch" and lora_config is not None :
187
202
from tensorrt_llm ._torch .pyexecutor .resource_manager import \
188
203
ResourceManagerType
189
204
peft_cache_manager = self .engine .resource_manager .resource_managers .get (
@@ -465,32 +480,51 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
465
480
assert request .id is not None
466
481
467
482
def _deduce_max_tokens (request : GenerationRequest ,
468
- executor_config : tllm .ExecutorConfig ) -> int :
483
+ executor_config : Optional [tllm .ExecutorConfig ],
484
+ llm_args : Optional [BaseLlmArgs ] = None ) -> int :
469
485
if request .sampling_params .max_tokens :
470
486
return request .sampling_params .max_tokens
471
487
# deduce max_tokens when it's not set by user
472
488
query_token_len = len (
473
489
request .query_token_ids ) if request .query_token_ids else 0
474
- cp_size = 1 if (not hasattr (executor_config , "mapping" )
475
- or executor_config .mapping .cp_size
476
- is None ) else executor_config .mapping .cp_size
477
- if not hasattr (executor_config , "max_seq_len" ):
478
- raise RuntimeError (
479
- "max_tokens for sampling is not set and cannot be deduced" )
490
+ cp_size = 1
491
+ max_seq_len = None
492
+ if llm_args is not None :
493
+ # deduce max_tokens by llm args
494
+ assert executor_config is None , "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined."
495
+ if hasattr (self ,
496
+ "mapping" ) and self .mapping .cp_size is not None :
497
+ cp_size = self .mapping .cp_size
498
+ if not hasattr (self , "max_seq_len" ):
499
+ raise RuntimeError (
500
+ "max_tokens for sampling is not set and cannot be deduced by llm args"
501
+ )
502
+ max_seq_len = self .max_seq_len
503
+ else :
504
+ # deduce max_tokens by executor config
505
+ if hasattr (executor_config , "mapping"
506
+ ) and executor_config .mapping .cp_size is not None :
507
+ cp_size = executor_config .mapping .cp_size
508
+ if not hasattr (executor_config , "max_seq_len" ):
509
+ raise RuntimeError (
510
+ "max_tokens for sampling is not set and cannot be deduced"
511
+ )
512
+ max_seq_len = executor_config .max_seq_len
480
513
splited_prompt_len = int (len (prompt_token_ids ) / cp_size )
481
- default_max_tokens = executor_config . max_seq_len - splited_prompt_len - query_token_len
514
+ default_max_tokens = max_seq_len - splited_prompt_len - query_token_len
482
515
if default_max_tokens < 0 :
483
516
raise ValueError (
484
517
f"Deduced max_tokens { default_max_tokens } is less than 0, because"
485
518
f"prompt length { splited_prompt_len } plus query length { query_token_len } "
486
- f"is larger than max_seq_len { executor_config . max_seq_len } " )
519
+ f"is larger than max_seq_len { max_seq_len } " )
487
520
return default_max_tokens
488
521
489
522
try :
490
523
executor_request = tllm .Request (
491
524
client_id = request .id ,
492
525
input_token_ids = prompt_token_ids ,
493
- max_tokens = _deduce_max_tokens (request , self ._executor_config ),
526
+ max_tokens = _deduce_max_tokens (request , self ._executor_config ,
527
+ self .llm_args ),
494
528
streaming = request .streaming ,
495
529
sampling_config = request .sampling_params ._get_sampling_config (),
496
530
end_id = - 1 if request .sampling_params .ignore_eos else
@@ -616,11 +650,19 @@ def shutdown(self):
616
650
self .engine .shutdown ()
617
651
self .engine = None
618
652
619
- if hasattr (
620
- self ._executor_config , "checkpoint_loader"
621
- ) and self ._executor_config .checkpoint_loader is not None :
622
- self ._executor_config .checkpoint_loader .cleanup ()
623
- self ._executor_config .checkpoint_loader = None
653
+ if self .llm_args is not None :
654
+ assert self ._executor_config is None , "An empty executor_config is expected in shutdown when LLM arguments are defined."
655
+ if (self .llm_args .backend == "pytorch"
656
+ and hasattr (self , "checkpoint_loader" )
657
+ and self .checkpoint_loader is not None ):
658
+ self .checkpoint_loader .cleanup ()
659
+ self .checkpoint_loader = None
660
+ else :
661
+ if hasattr (
662
+ self ._executor_config , "checkpoint_loader"
663
+ ) and self ._executor_config .checkpoint_loader is not None :
664
+ self ._executor_config .checkpoint_loader .cleanup ()
665
+ self ._executor_config .checkpoint_loader = None
624
666
625
667
# Check if there are any errors from the threads before shutdown.
626
668
self ._handle_background_error ()
@@ -666,7 +708,7 @@ def worker_main(
666
708
lora_config : Optional [LoraConfig ] = None ,
667
709
hf_model_dir : Optional [Path ] = None ,
668
710
tokenizer : Optional [TokenizerBase ] = None ,
669
- llm_args : Optional [TorchLlmArgs ] = None ,
711
+ llm_args : Optional [BaseLlmArgs ] = None ,
670
712
) -> None :
671
713
mpi_comm ().barrier ()
672
714
print_colored_debug (f"Worker { mpi_rank ()} entering worker_main...\n " ,
0 commit comments