Skip to content

Commit 57e20fe

Browse files
committed
TODO: WIP: split a i/f changes
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent ce0b13e commit 57e20fe

File tree

20 files changed

+1012
-447
lines changed

20 files changed

+1012
-447
lines changed

examples/auto_deploy/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
!.vscode
33
benchmark_results.json
44
*.png
5+
# ignore config files that users might put here for debugging
6+
*.yaml

examples/auto_deploy/build_and_run_ad.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
# Global torch config, set the torch compile cache to fix up to llama 405B
2727
torch._dynamo.config.cache_size_limit = 20
2828

29+
# simple string, TRT-LLM style text-only prompt or full-scale HF message template
30+
PromptInput = Union[str, Dict, List[Dict]]
31+
2932

3033
class PromptConfig(BaseModel):
3134
"""Prompt configuration.
@@ -35,13 +38,27 @@ class PromptConfig(BaseModel):
3538
"""
3639

3740
batch_size: int = Field(default=2, description="Number of queries")
38-
queries: Union[str, List[str]] = Field(
41+
queries: Union[PromptInput, List[PromptInput]] = Field(
3942
default_factory=lambda: [
43+
# OPTION 1: simple text prompt
4044
"How big is the universe? ",
41-
"In simple words and in a single sentence, explain the concept of gravity: ",
42-
"How to fix slicing in golf? ",
43-
"Where is the capital of Iceland? ",
44-
]
45+
# OPTION 2: wrapped text prompt for TRT-LLM
46+
{"prompt": "In simple words and a single sentence, explain the concept of gravity: "},
47+
# OPTION 3: a full-scale HF message template (this one works for text-only models!)
48+
# Learn more about chat templates: https://huggingface.co/docs/transformers/en/chat_templating
49+
# and multi-modal templates: https://huggingface.co/docs/transformers/en/chat_templating_multimodal
50+
[
51+
{
52+
"role": "user",
53+
"content": "How to fix slicing in golf?",
54+
}
55+
],
56+
# More prompts...
57+
{"prompt": "Where is the capital of Iceland? "},
58+
],
59+
description="Example queries to prompt the model with. We support both TRT-LLM text-only "
60+
"queries via the 'prompt' key and full-scale HF message template called via "
61+
"apply_chat_template.",
4562
)
4663
sp_kwargs: Dict[str, Any] = Field(
4764
default_factory=lambda: {"max_tokens": 100, "top_k": 200, "temperature": 1.0},
@@ -55,10 +72,28 @@ def model_post_init(self, __context: Any):
5572
NOTE (lucaslie): has to be done with model_post_init to ensure it's always run. field
5673
validators are only run if a value is provided.
5774
"""
58-
queries = [self.queries] if isinstance(self.queries, str) else self.queries
75+
queries = self.queries if isinstance(self.queries, list) else [self.queries]
5976
batch_size = self.batch_size
6077
queries = queries * (batch_size // len(queries) + 1)
61-
self.queries = queries[:batch_size]
78+
queries = queries[:batch_size]
79+
80+
# now let's standardize the queries for the LLM api to understand them
81+
queries_processed = []
82+
for query in queries:
83+
if isinstance(query, str):
84+
queries_processed.append({"prompt": query})
85+
elif isinstance(query, dict):
86+
queries_processed.append(query)
87+
elif isinstance(query, list):
88+
queries_processed.append(
89+
{
90+
"prompt": "Fake prompt. Check out messages field for the HF chat template.",
91+
"messages": query, # contains the actual HF chat template
92+
}
93+
)
94+
else:
95+
raise ValueError(f"Invalid query type: {type(query)}")
96+
self.queries = queries_processed
6297

6398
@field_validator("sp_kwargs", mode="after")
6499
@classmethod

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 401 additions & 320 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def scaled_dot_product_attention(
6363
dropout_p: float = 0.0,
6464
is_causal: bool = False,
6565
scale: Optional[float] = None,
66+
enable_gqa: bool = False,
6667
) -> torch.Tensor:
6768
"""A carbon copy of torch.nn.functional.scaled_dot_product_attention as custom op.
6869
@@ -78,12 +79,13 @@ def scaled_dot_product_attention(
7879
dropout_p=dropout_p,
7980
is_causal=is_causal,
8081
scale=scale,
82+
enable_gqa=False,
8183
)
8284

8385

8486
@scaled_dot_product_attention.register_fake
8587
def scaled_dot_product_attention_fake(
86-
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
88+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False
8789
):
8890
"""Fake implementation of scaled_dot_product_attention."""
8991
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()

tensorrt_llm/_torch/auto_deploy/llm.py

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,88 @@
11
import types
2-
from typing import List, Optional
2+
from typing import Any, Dict, List, Optional, Tuple
33

44
from ...executor.result import CompletionOutput
5-
from ...inputs.registry import create_input_processor
5+
from ...inputs.registry import DefaultInputProcessor, ExtraProcessedInputs
66
from ...llmapi.llm import RequestOutput, _TorchLLM
7-
from ...llmapi.tokenizer import TokenizerBase, tokenizer_factory
7+
from ...llmapi.tokenizer import TokenizerBase, TransformersTokenizer, tokenizer_factory
8+
from ...sampling_params import SamplingParams
89
from .distributed import common as dist_ad
910
from .llm_args import LlmArgs
11+
from .models.factory import ModelFactory
1012
from .shim.demollm import DemoGenerationExecutor
1113

1214

15+
class ADInputProcessor(DefaultInputProcessor):
16+
"""Input processor for AutoDeploy backend.
17+
18+
This is a wrapper to either support standard TRT-LLM text-only input processing or use HF's
19+
message chat template system to process multimodal inputs.
20+
"""
21+
22+
def __init__(self, tokenizer: Optional[TokenizerBase], processor: Optional[Any] = None):
23+
super().__init__(None, None, tokenizer)
24+
# NOTE: HF's tokenizer/processor that has the apply_chat_template method
25+
self.processor = processor or getattr(tokenizer, "tokenizer", None)
26+
27+
def __call__(
28+
self, inputs: Dict[str, Any], sampling_params: SamplingParams
29+
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
30+
if self.processor is None:
31+
raise ValueError("processor is required to tokenize inputs")
32+
33+
# construct kwargs to reflect DefaultInputProcessor
34+
kwargs = {
35+
"add_special_tokens": sampling_params.add_special_tokens,
36+
}
37+
if sampling_params.truncate_prompt_tokens is not None:
38+
kwargs = {
39+
"truncation": True,
40+
"max_length": sampling_params.truncate_prompt_tokens,
41+
}
42+
# check for messages field and if yes, use the apply_chat_template method
43+
if "messages" in inputs:
44+
# TODO: we don't really need this but it makes for a good sanity check. Consider
45+
# removing this in the future if we need to speed things up.
46+
prompt = self.processor.apply_chat_template(
47+
inputs["messages"],
48+
add_generation_prompt=True,
49+
tokenize=False,
50+
)
51+
inputs["prompt"] = prompt
52+
53+
all_args = self.processor.apply_chat_template(
54+
inputs["messages"],
55+
add_generation_prompt=True,
56+
tokenize=True,
57+
return_dict=True,
58+
return_tensors="pt",
59+
padding=False, # there shouldn't be a need for padding ever...
60+
return_attention_mask=False,
61+
**kwargs,
62+
)
63+
# TODO: is there a more reliable way to avoid the attention_mask here?
64+
all_args.pop("attention_mask", None)
65+
66+
# TODO: can we avoid the extra tolist() here eventually?
67+
token_ids = all_args.pop("input_ids")
68+
assert token_ids.shape[0] == 1, "messages should be unbatched at this point."
69+
return token_ids[0].tolist(), {"multimodal_data": all_args} if all_args else None
70+
else:
71+
token_ids = self.tokenizer.encode(inputs["prompt"], **kwargs)
72+
return token_ids, None
73+
74+
1375
class LLM(_TorchLLM):
1476
"""LLM class is the main class for running an LLM model using AutoDeploy backend."""
1577

1678
args: LlmArgs
79+
_factory: ModelFactory
80+
81+
@property
82+
def factory(self) -> ModelFactory:
83+
if not getattr(self, "_factory", None):
84+
self._factory = self.args.create_factory()
85+
return self._factory
1786

1887
def __init__(self, *args, **kwargs):
1988
kwargs["backend"] = "_autodeploy"
@@ -23,16 +92,18 @@ def _try_load_tokenizer(self) -> Optional[TokenizerBase]:
2392
if self.args.skip_tokenizer_init:
2493
return None
2594

26-
factory = self.args.create_factory()
27-
return tokenizer_factory(factory.init_tokenizer())
95+
return tokenizer_factory(self.factory.init_tokenizer())
2896

2997
def _validate_args_for_torch_backend(self, kwargs: dict) -> None:
3098
"""We don't need to validate args for AutoDeploy backend for now."""
3199
pass
32100

101+
def _create_input_processor(self) -> ADInputProcessor:
102+
return ADInputProcessor(self.tokenizer, self.factory.init_processor())
103+
33104
def _prefetch_model(self):
34105
"""Prefetch the model for the LLM."""
35-
self.args.create_factory().prefetch_checkpoint()
106+
self.factory.prefetch_checkpoint()
36107

37108
def _build_model(self):
38109
"""Build the model for the LLM.
@@ -47,6 +118,11 @@ def _build_model(self):
47118
# _autodeploy backend.
48119
super()._build_model()
49120

121+
# now correct input processor
122+
assert isinstance(self.input_processor, DefaultInputProcessor)
123+
assert self.tokenizer is None or isinstance(self.tokenizer, TransformersTokenizer)
124+
self.input_processor = self._create_input_processor()
125+
50126

51127
class DemoLLM(LLM):
52128
"""A simple LLM class to demo the LLM interface while debugging the e2e workflow.
@@ -63,7 +139,7 @@ def __init__(self, **kwargs):
63139
# prefetch model and load tokenizer
64140
self._prefetch_model()
65141
self._tokenizer = self._try_load_tokenizer()
66-
self.input_processor = create_input_processor(None, self.tokenizer)
142+
self.input_processor = self._create_input_processor()
67143

68144
# construct demo executor + engine
69145
self._executor = DemoGenerationExecutor(

tensorrt_llm/_torch/auto_deploy/models/factory.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
import copy
44
from abc import ABC, abstractmethod
5-
from typing import Any, Callable, Dict, Optional, Type
5+
from typing import Any, Callable, Dict, Optional, Tuple, Type
66

77
import torch
88
import torch.nn as nn
99
from torch._prims_common import DeviceLikeType
1010

11-
from ..custom_ops.attention_interface import CacheConfig
11+
from ..custom_ops.attention_interface import CacheConfig, DynamicShapeCallback
1212
from ..utils.logger import ad_logger
1313

1414

@@ -113,6 +113,15 @@ def init_tokenizer(self) -> Optional[Any]:
113113
"""
114114
return None
115115

116+
def init_processor(self) -> Optional[Any]:
117+
"""Initialize the (multi-modal) processor for the model.
118+
119+
Returns:
120+
The initialized processor for the model. If the processor is not available, then this
121+
method should return None.
122+
"""
123+
return None
124+
116125
def prefetch_checkpoint(self, force: bool = False):
117126
"""Try or skip prefetching the checkpoint for the model and tokenizer.
118127
@@ -206,6 +215,33 @@ def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType):
206215
device: The device to load the model on.
207216
"""
208217

218+
def get_example_inputs(self) -> Dict[str, torch.Tensor]:
219+
"""Return a dictionary of example inputs for the model.
220+
221+
This function can be overwritten by a factory when it requires a specific example input to
222+
in order to run through export.
223+
224+
Returns:
225+
A dictionary of example inputs for the model where the key corresponds to the argument
226+
name and the value corresponds to the example input.
227+
"""
228+
return {}
229+
230+
def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, DynamicShapeCallback]]:
231+
"""Return a dictionary of extra inputs for the model.
232+
233+
Returns:
234+
A dictionary of extra inputs for the model where the key corresponds to the argument
235+
name and the value corresponds to a tuple of (none_input, dynamic_shape_callback):
236+
- `none_input`: The none input value of the extra input indicating the tensor
237+
value corresponding to the equivalent of the None input. `None` is not supported
238+
as we require the input to be a tensor. Hence, this none_input acts as a
239+
placeholder for the None input.
240+
- `dynamic_shape_callback`: A function that returns the dynamic shape of the extra
241+
input.
242+
"""
243+
return {}
244+
209245

210246
class ModelFactoryRegistry:
211247
_registry: Dict[str, Type[ModelFactory]] = {}

0 commit comments

Comments
 (0)