1
1
import types
2
- from typing import List , Optional
2
+ from typing import Any , Dict , List , Optional , Tuple
3
3
4
4
from ...executor .result import CompletionOutput
5
- from ...inputs .registry import create_input_processor
5
+ from ...inputs .registry import DefaultInputProcessor , ExtraProcessedInputs
6
6
from ...llmapi .llm import RequestOutput , _TorchLLM
7
- from ...llmapi .tokenizer import TokenizerBase , tokenizer_factory
7
+ from ...llmapi .tokenizer import TokenizerBase , TransformersTokenizer , tokenizer_factory
8
+ from ...sampling_params import SamplingParams
8
9
from .distributed import common as dist_ad
9
10
from .llm_args import LlmArgs
11
+ from .models .factory import ModelFactory
10
12
from .shim .demollm import DemoGenerationExecutor
11
13
12
14
15
+ class ADInputProcessor (DefaultInputProcessor ):
16
+ """Input processor for AutoDeploy backend.
17
+
18
+ This is a wrapper to either support standard TRT-LLM text-only input processing or use HF's
19
+ message chat template system to process multimodal inputs.
20
+ """
21
+
22
+ def __init__ (self , tokenizer : Optional [TokenizerBase ], processor : Optional [Any ] = None ):
23
+ super ().__init__ (None , None , tokenizer )
24
+ # NOTE: HF's tokenizer/processor that has the apply_chat_template method
25
+ self .processor = processor or getattr (tokenizer , "tokenizer" , None )
26
+
27
+ def __call__ (
28
+ self , inputs : Dict [str , Any ], sampling_params : SamplingParams
29
+ ) -> Tuple [List [int ], Optional [ExtraProcessedInputs ]]:
30
+ if self .processor is None :
31
+ raise ValueError ("processor is required to tokenize inputs" )
32
+
33
+ # construct kwargs to reflect DefaultInputProcessor
34
+ kwargs = {
35
+ "add_special_tokens" : sampling_params .add_special_tokens ,
36
+ }
37
+ if sampling_params .truncate_prompt_tokens is not None :
38
+ kwargs = {
39
+ "truncation" : True ,
40
+ "max_length" : sampling_params .truncate_prompt_tokens ,
41
+ }
42
+ # check for messages field and if yes, use the apply_chat_template method
43
+ if "messages" in inputs :
44
+ # TODO: we don't really need this but it makes for a good sanity check. Consider
45
+ # removing this in the future if we need to speed things up.
46
+ prompt = self .processor .apply_chat_template (
47
+ inputs ["messages" ],
48
+ add_generation_prompt = True ,
49
+ tokenize = False ,
50
+ )
51
+ inputs ["prompt" ] = prompt
52
+
53
+ all_args = self .processor .apply_chat_template (
54
+ inputs ["messages" ],
55
+ add_generation_prompt = True ,
56
+ tokenize = True ,
57
+ return_dict = True ,
58
+ return_tensors = "pt" ,
59
+ padding = False , # there shouldn't be a need for padding ever...
60
+ return_attention_mask = False ,
61
+ ** kwargs ,
62
+ )
63
+ # TODO: is there a more reliable way to avoid the attention_mask here?
64
+ all_args .pop ("attention_mask" , None )
65
+
66
+ # TODO: can we avoid the extra tolist() here eventually?
67
+ token_ids = all_args .pop ("input_ids" )
68
+ assert token_ids .shape [0 ] == 1 , "messages should be unbatched at this point."
69
+ return token_ids [0 ].tolist (), {"multimodal_data" : all_args } if all_args else None
70
+ else :
71
+ token_ids = self .tokenizer .encode (inputs ["prompt" ], ** kwargs )
72
+ return token_ids , None
73
+
74
+
13
75
class LLM (_TorchLLM ):
14
76
"""LLM class is the main class for running an LLM model using AutoDeploy backend."""
15
77
16
78
args : LlmArgs
79
+ _factory : ModelFactory
80
+
81
+ @property
82
+ def factory (self ) -> ModelFactory :
83
+ if not getattr (self , "_factory" , None ):
84
+ self ._factory = self .args .create_factory ()
85
+ return self ._factory
17
86
18
87
def __init__ (self , * args , ** kwargs ):
19
88
kwargs ["backend" ] = "_autodeploy"
@@ -23,16 +92,18 @@ def _try_load_tokenizer(self) -> Optional[TokenizerBase]:
23
92
if self .args .skip_tokenizer_init :
24
93
return None
25
94
26
- factory = self .args .create_factory ()
27
- return tokenizer_factory (factory .init_tokenizer ())
95
+ return tokenizer_factory (self .factory .init_tokenizer ())
28
96
29
97
def _validate_args_for_torch_backend (self , kwargs : dict ) -> None :
30
98
"""We don't need to validate args for AutoDeploy backend for now."""
31
99
pass
32
100
101
+ def _create_input_processor (self ) -> ADInputProcessor :
102
+ return ADInputProcessor (self .tokenizer , self .factory .init_processor ())
103
+
33
104
def _prefetch_model (self ):
34
105
"""Prefetch the model for the LLM."""
35
- self .args . create_factory () .prefetch_checkpoint ()
106
+ self .factory .prefetch_checkpoint ()
36
107
37
108
def _build_model (self ):
38
109
"""Build the model for the LLM.
@@ -47,6 +118,11 @@ def _build_model(self):
47
118
# _autodeploy backend.
48
119
super ()._build_model ()
49
120
121
+ # now correct input processor
122
+ assert isinstance (self .input_processor , DefaultInputProcessor )
123
+ assert self .tokenizer is None or isinstance (self .tokenizer , TransformersTokenizer )
124
+ self .input_processor = self ._create_input_processor ()
125
+
50
126
51
127
class DemoLLM (LLM ):
52
128
"""A simple LLM class to demo the LLM interface while debugging the e2e workflow.
@@ -63,7 +139,7 @@ def __init__(self, **kwargs):
63
139
# prefetch model and load tokenizer
64
140
self ._prefetch_model ()
65
141
self ._tokenizer = self ._try_load_tokenizer ()
66
- self .input_processor = create_input_processor ( None , self .tokenizer )
142
+ self .input_processor = self ._create_input_processor ( )
67
143
68
144
# construct demo executor + engine
69
145
self ._executor = DemoGenerationExecutor (
0 commit comments