Skip to content

Commit 34ffebd

Browse files
committed
wip for vlm
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent ae89163 commit 34ffebd

File tree

24 files changed

+1014
-560
lines changed

24 files changed

+1014
-560
lines changed

examples/auto_deploy/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
!.vscode
33
benchmark_results.json
44
*.png
5+
# ignore config files that users might put here for debugging
6+
*.yaml

examples/auto_deploy/build_and_run_ad.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
# Global torch config, set the torch compile cache to fix up to llama 405B
2727
torch._dynamo.config.cache_size_limit = 20
2828

29+
# simple string, TRT-LLM style text-only prompt or full-scale HF message template
30+
PromptInput = Union[str, Dict, List[Dict]]
31+
2932

3033
class PromptConfig(BaseModel):
3134
"""Prompt configuration.
@@ -35,13 +38,27 @@ class PromptConfig(BaseModel):
3538
"""
3639

3740
batch_size: int = Field(default=2, description="Number of queries")
38-
queries: Union[str, List[str]] = Field(
41+
queries: Union[PromptInput, List[PromptInput]] = Field(
3942
default_factory=lambda: [
43+
# OPTION 1: simple text prompt
4044
"How big is the universe? ",
41-
"In simple words and in a single sentence, explain the concept of gravity: ",
42-
"How to fix slicing in golf? ",
43-
"Where is the capital of Iceland? ",
44-
]
45+
# OPTION 2: wrapped text prompt for TRT-LLM
46+
{"prompt": "In simple words and a single sentence, explain the concept of gravity: "},
47+
# OPTION 3: a full-scale HF message template (this one works for text-only models!)
48+
# Learn more about chat templates: https://huggingface.co/docs/transformers/en/chat_templating
49+
# and multi-modal templates: https://huggingface.co/docs/transformers/en/chat_templating_multimodal
50+
[
51+
{
52+
"role": "user",
53+
"content": "How to fix slicing in golf?",
54+
}
55+
],
56+
# More prompts...
57+
{"prompt": "Where is the capital of Iceland? "},
58+
],
59+
description="Example queries to prompt the model with. We support both TRT-LLM text-only "
60+
"queries via the 'prompt' key and full-scale HF message template called via "
61+
"apply_chat_template.",
4562
)
4663
sp_kwargs: Dict[str, Any] = Field(
4764
default_factory=lambda: {"max_tokens": 100, "top_k": 200, "temperature": 1.0},
@@ -55,10 +72,28 @@ def model_post_init(self, __context: Any):
5572
NOTE (lucaslie): has to be done with model_post_init to ensure it's always run. field
5673
validators are only run if a value is provided.
5774
"""
58-
queries = [self.queries] if isinstance(self.queries, str) else self.queries
75+
queries = self.queries if isinstance(self.queries, list) else [self.queries]
5976
batch_size = self.batch_size
6077
queries = queries * (batch_size // len(queries) + 1)
61-
self.queries = queries[:batch_size]
78+
queries = queries[:batch_size]
79+
80+
# now let's standardize the queries for the LLM api to understand them
81+
queries_processed = []
82+
for query in queries:
83+
if isinstance(query, str):
84+
queries_processed.append({"prompt": query})
85+
elif isinstance(query, dict):
86+
queries_processed.append(query)
87+
elif isinstance(query, list):
88+
queries_processed.append(
89+
{
90+
"prompt": "Fake prompt. Check out messages field for the HF chat template.",
91+
"messages": query, # contains the actual HF chat template
92+
}
93+
)
94+
else:
95+
raise ValueError(f"Invalid query type: {type(query)}")
96+
self.queries = queries_processed
6297

6398
@field_validator("sp_kwargs", mode="after")
6499
@classmethod

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 401 additions & 317 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def scaled_dot_product_attention(
6363
dropout_p: float = 0.0,
6464
is_causal: bool = False,
6565
scale: Optional[float] = None,
66+
enable_gqa: bool = False,
6667
) -> torch.Tensor:
6768
"""A carbon copy of torch.nn.functional.scaled_dot_product_attention as custom op.
6869
@@ -78,12 +79,13 @@ def scaled_dot_product_attention(
7879
dropout_p=dropout_p,
7980
is_causal=is_causal,
8081
scale=scale,
82+
enable_gqa=enable_gqa,
8183
)
8284

8385

8486
@scaled_dot_product_attention.register_fake
8587
def scaled_dot_product_attention_fake(
86-
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
88+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False
8789
):
8890
"""Fake implementation of scaled_dot_product_attention."""
8991
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()

tensorrt_llm/_torch/auto_deploy/export/export.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919
from ..utils.logger import ad_logger
2020
from ..utils.node_utils import is_op
21-
from .interface import ExportPatchRegistry, apply_export_patches
21+
from .interface import apply_export_patches
2222

2323
try:
2424
from modelopt.torch.quantization.utils import export_torch_mode as torch_export_context
@@ -229,20 +229,9 @@ def torch_export_to_gm(
229229
patch_list: Optional list of patch names to apply with default settings.
230230
Cannot be used together with patch_configs.
231231
"""
232-
# Validate that both patch_configs and patch_list are not provided simultaneously
233-
if patch_configs is not None and patch_list is not None:
234-
raise ValueError("Cannot specify both patch_configs and patch_list. Use only one.")
235-
236-
# Handle patch configuration
237-
if patch_list is not None:
238-
# Convert patch_list to patch_configs format
239-
patch_configs = {patch_name: {} for patch_name in patch_list}
240-
elif patch_configs is None:
241-
# Default patch configurations - apply all registered patches with default settings
242-
patch_configs = {patch_name: {} for patch_name in ExportPatchRegistry.list_patches()}
243232

244233
# run export with patches and lifted to meta
245-
with apply_export_patches(patch_configs), lift_to_meta(model) as state_dict:
234+
with apply_export_patches(patch_configs, patch_list), lift_to_meta(model) as state_dict:
246235
# clean up args, kwargs and move to correct device
247236
args, kwargs = tree_to((args, kwargs or {}), device="meta")
248237

tensorrt_llm/_torch/auto_deploy/export/interface.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from abc import ABC, abstractmethod
77
from contextlib import contextmanager
8-
from typing import Any, Callable, Dict, List, Type, Union, final
8+
from typing import Any, Callable, Dict, List, Optional, Type, Union, final
99

1010
from pydantic import BaseModel, Field
1111

@@ -183,6 +183,8 @@ def inner(patch_cls: Type[BaseExportPatch]) -> Type[BaseExportPatch]:
183183
@classmethod
184184
def get(cls, name: str) -> Type[BaseExportPatch]:
185185
"""Get a patch class by name."""
186+
if not cls.has(name):
187+
raise ValueError(f"Unknown patch: {name}")
186188
return cls._registry[name]
187189

188190
@classmethod
@@ -212,20 +214,29 @@ def list_patches(cls) -> List[str]:
212214

213215

214216
@contextmanager
215-
def apply_export_patches(patch_configs: Dict[str, Union[ExportPatchConfig, Dict[str, Any]]]):
217+
def apply_export_patches(
218+
patch_configs: Optional[Dict[str, Union[ExportPatchConfig, Dict[str, Any]]]] = None,
219+
patch_list: Optional[List[str]] = None,
220+
):
216221
"""Context manager to apply multiple patches.
217222
218223
Args:
219224
patch_configs: Dict mapping patch names to their configurations.
220225
"""
221-
patches = []
226+
# Validate that both patch_configs and patch_list are not provided simultaneously
227+
if patch_configs is not None and patch_list is not None:
228+
raise ValueError("Cannot specify both patch_configs and patch_list. Use only one.")
229+
230+
# Handle patch configuration
231+
if patch_list is not None:
232+
# Convert patch_list to patch_configs format
233+
patch_configs = {patch_name: {} for patch_name in patch_list}
234+
elif patch_configs is None:
235+
# Default patch configurations - apply all registered patches with default settings
236+
patch_configs = {patch_name: {} for patch_name in ExportPatchRegistry.list_patches()}
222237

223238
# Create patch instances
224-
for name, config in patch_configs.items():
225-
if not ExportPatchRegistry.has(name):
226-
raise ValueError(f"Unknown patch: {name}")
227-
patch = ExportPatchRegistry.create_patch(name, config)
228-
patches.append(patch)
239+
patches = [ExportPatchRegistry.create_patch(k, conf) for k, conf in patch_configs.items()]
229240

230241
# Apply patches using nested context managers
231242
if not patches:

tensorrt_llm/_torch/auto_deploy/llm.py

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,92 @@
11
import types
2-
from typing import List, Optional
2+
from typing import Any, Dict, List, Optional, Tuple
33

44
from ...executor.result import CompletionOutput
5-
from ...inputs.registry import create_input_processor
5+
from ...inputs.registry import DefaultInputProcessor, ExtraProcessedInputs
66
from ...llmapi.llm import RequestOutput, _TorchLLM
7-
from ...llmapi.tokenizer import TokenizerBase, tokenizer_factory
7+
from ...llmapi.tokenizer import TokenizerBase, TransformersTokenizer, tokenizer_factory
8+
from ...sampling_params import SamplingParams
89
from .distributed import common as dist_ad
910
from .llm_args import LlmArgs
11+
from .models.factory import ModelFactory
1012
from .shim.demollm import DemoGenerationExecutor
1113

1214

15+
class ADInputProcessor(DefaultInputProcessor):
16+
"""Input processor for AutoDeploy backend.
17+
18+
This is a wrapper to either support standard TRT-LLM text-only input processing or use HF's
19+
message chat template system to process multimodal inputs.
20+
"""
21+
22+
def __init__(self, tokenizer: Optional[TokenizerBase], processor: Optional[Any] = None):
23+
super().__init__(None, None, tokenizer)
24+
# NOTE: HF's tokenizer/processor that has the apply_chat_template method
25+
self.processor = processor or getattr(tokenizer, "tokenizer", None)
26+
27+
def __call__(
28+
self, inputs: Dict[str, Any], sampling_params: SamplingParams
29+
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
30+
if self.processor is None:
31+
raise ValueError("processor is required to tokenize inputs")
32+
33+
# construct kwargs to reflect DefaultInputProcessor
34+
kwargs = {
35+
"add_special_tokens": sampling_params.add_special_tokens,
36+
}
37+
if sampling_params.truncate_prompt_tokens is not None:
38+
kwargs = {
39+
"truncation": True,
40+
"max_length": sampling_params.truncate_prompt_tokens,
41+
}
42+
# check for messages field and if yes, use the apply_chat_template method
43+
if "messages" in inputs:
44+
# TODO: we don't really need this but it makes for a good sanity check. Consider
45+
# removing this in the future if we need to speed things up.
46+
prompt = self.processor.apply_chat_template(
47+
inputs["messages"],
48+
add_generation_prompt=True,
49+
tokenize=False,
50+
)
51+
inputs["prompt"] = prompt
52+
53+
all_args = self.processor.apply_chat_template(
54+
inputs["messages"],
55+
add_generation_prompt=True,
56+
tokenize=True,
57+
return_dict=True,
58+
return_tensors="pt",
59+
padding=False, # there shouldn't be a need for padding ever...
60+
return_attention_mask=False,
61+
**kwargs,
62+
)
63+
# TODO: is there a more reliable way to avoid the attention_mask here?
64+
all_args.pop("attention_mask", None)
65+
66+
# TODO: can we avoid the extra tolist() here eventually?
67+
token_ids = all_args.pop("input_ids")
68+
assert token_ids.shape[0] == 1, "messages should be unbatched at this point."
69+
if all_args:
70+
extra_processed_inputs = {"multimodal_data": all_args}
71+
else:
72+
extra_processed_inputs = None
73+
return token_ids[0].tolist(), extra_processed_inputs
74+
else:
75+
token_ids = self.tokenizer.encode(inputs["prompt"], **kwargs)
76+
return token_ids, None
77+
78+
1379
class LLM(_TorchLLM):
1480
"""LLM class is the main class for running an LLM model using AutoDeploy backend."""
1581

1682
args: LlmArgs
83+
_factory: ModelFactory
84+
85+
@property
86+
def factory(self) -> ModelFactory:
87+
if not getattr(self, "_factory", None):
88+
self._factory = self.args.create_factory()
89+
return self._factory
1790

1891
def __init__(self, *args, **kwargs):
1992
kwargs["backend"] = "_autodeploy"
@@ -23,16 +96,18 @@ def _try_load_tokenizer(self) -> Optional[TokenizerBase]:
2396
if self.args.skip_tokenizer_init:
2497
return None
2598

26-
factory = self.args.create_factory()
27-
return tokenizer_factory(factory.init_tokenizer())
99+
return tokenizer_factory(self.factory.init_tokenizer())
28100

29101
def _validate_args_for_torch_backend(self, kwargs: dict) -> None:
30102
"""We don't need to validate args for AutoDeploy backend for now."""
31103
pass
32104

105+
def _create_input_processor(self) -> ADInputProcessor:
106+
return ADInputProcessor(self.tokenizer, self.factory.init_processor())
107+
33108
def _prefetch_model(self):
34109
"""Prefetch the model for the LLM."""
35-
self.args.create_factory().prefetch_checkpoint()
110+
self.factory.prefetch_checkpoint()
36111

37112
def _build_model(self):
38113
"""Build the model for the LLM.
@@ -47,6 +122,11 @@ def _build_model(self):
47122
# _autodeploy backend.
48123
super()._build_model()
49124

125+
# now correct input processor
126+
assert isinstance(self.input_processor, DefaultInputProcessor)
127+
assert self.tokenizer is None or isinstance(self.tokenizer, TransformersTokenizer)
128+
self.input_processor = self._create_input_processor()
129+
50130

51131
class DemoLLM(LLM):
52132
"""A simple LLM class to demo the LLM interface while debugging the e2e workflow.
@@ -63,7 +143,7 @@ def __init__(self, **kwargs):
63143
# prefetch model and load tokenizer
64144
self._prefetch_model()
65145
self._tokenizer = self._try_load_tokenizer()
66-
self.input_processor = create_input_processor(None, self.tokenizer)
146+
self.input_processor = self._create_input_processor()
67147

68148
# construct demo executor + engine
69149
self._executor = DemoGenerationExecutor(

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
5757
description="The path to the model checkpoint or the model name from the Hugging Face Hub."
5858
)
5959

60-
model_factory: Literal["AutoModelForCausalLM", "AutoModelForImageTextToText"] = Field(
60+
model_factory: str = Field(
6161
default="AutoModelForCausalLM",
6262
description="The model factory to use for loading the model.",
6363
)

0 commit comments

Comments
 (0)