Skip to content

Commit 92515a7

Browse files
committed
more VLM work
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent f33134f commit 92515a7

File tree

11 files changed

+453
-50
lines changed

11 files changed

+453
-50
lines changed

examples/auto_deploy/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
!.vscode
33
benchmark_results.json
44
*.png
5+
# ignore config files that users might put here for debugging
6+
*.yaml

examples/auto_deploy/build_and_run_ad.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,56 @@ def main(config: Optional[ExperimentConfig] = None):
237237

238238
llm = build_llm_from_config(config)
239239

240+
# just run config.prompt.queries with our special token sequence including special image tokens
241+
# fmt: off
242+
input_ids = [[
243+
200000, 200005, 1556, 200006, 368, 200080, 200090, 200092, 200092,
244+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
245+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
246+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
247+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
248+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
249+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
250+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
251+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
252+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
253+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
254+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
255+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
256+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
257+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
258+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
259+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200081, 200080,
260+
200090, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
261+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
262+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
263+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
264+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
265+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
266+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
267+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
268+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
269+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
270+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
271+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
272+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
273+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
274+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
275+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
276+
200092, 200081, 51212, 1780, 650, 2556, 310, 290, 1472,
277+
8392, 341, 1357, 13492, 26, 200008, 200005, 140680, 200006,
278+
368
279+
] for _ in range(2)]
280+
# fmt: on
281+
240282
# prompt the model and print its output
241283
ad_logger.info("Running example prompts...")
284+
285+
# now let's try piping through multimodal data
286+
242287
outs = llm.generate(
243-
config.prompt.queries,
288+
input_ids,
289+
# config.prompt.queries,
244290
sampling_params=SamplingParams(**config.prompt.sp_kwargs),
245291
)
246292
results = {"prompts_and_outputs": print_outputs(outs)}

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 104 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class SequenceInfo:
3737
between arguments that are originally part of the model/graph and arguments that are needed for
3838
the attention operator when we switch to cached+flattened attention.
3939
40-
# ORIGINAL MODEL ARGUMENTS #####################################################################
40+
### ORIGINAL MODEL ARGUMENTS ###################################################################
4141
- input_ids: [id_0, ..., id_{s_total-1}]
4242
flattened sequence of [b, 1] or [1, s_total]. We use [b, 1] to denote generate-only batches.
4343
- position_ids: [pos_0, ..., pos_{s_total-1}]
@@ -47,7 +47,18 @@ class SequenceInfo:
4747
NOTE: ``input_ids`` and ``position_ids`` are initially expected to be of shape [b, seq_len]
4848
before we switch to cached+flattened attention.
4949
50-
# EXTRA ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ##############
50+
### EXTRA ARGUMENTS PROVIDED TO THE INTERFACE ##################################################
51+
Those are extra arguments that can be provided to the interface and they are stored as follows:
52+
- _extra_args: dictionary of extra arguments with currently active values.
53+
- _extra_example_inputs: dictionary of example inputs to the extra arguments.
54+
- _extra_none_inputs: dictionary of none inputs to the extra arguments.
55+
NOTE: we assume that extra arguments are *optional* arguments to the model. However, we
56+
cannot represent them via `None` since fx graphs require a fixed input type. Instead,
57+
we require a special placeholder tensor to represent the `None` input.
58+
- _extra_dynamic_shapes_callbacks: dictionary of callbacks to initialize the dynamic shapes of
59+
the extra arguments.
60+
61+
### CACHE ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ############
5162
- seq_len: [s_0, s_1, ..., s_{b-1}] such that s_total = sum(s_i)
5263
Describes how long each sequence is. For example,
5364
input_ids[:s_0] will correspond to sequence 0 in the batch and input_ids[s_0:s_1] will
@@ -128,26 +139,24 @@ def __post_init__(self):
128139
self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int)
129140
self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long)
130141
self._uncached_arg_names = ["input_ids", "position_ids"]
142+
self._uncached_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None
143+
144+
# EXTRA TENSOR FIELDS
145+
self._extra_args: Dict[str, torch.Tensor] = {}
146+
self._extra_example_inputs: Dict[str, torch.Tensor] = {}
147+
self._extra_none_inputs: Dict[str, torch.Tensor] = {}
148+
self._extra_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None
149+
self._extra_dynamic_shapes_callbacks: Dict[str, DynamicShapeCallback] = {}
131150

132151
# CACHED TENSOR FIELDS (for cached attention backends)
133152
self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int)
134153
self.input_pos = torch.empty_like(self.seq_len)
135154
self.cache_loc = torch.empty(self.num_pages, dtype=torch.int)
136155
self.pages_per_seq = torch.empty_like(self.seq_len)
137156
self._cached_arg_names = ["seq_len", "input_pos", "cache_loc", "pages_per_seq"]
138-
139-
# DYNAMIC SHAPES
140-
# --> initialized lazily since Dim is not picklable for multi-processing
141-
self._uncached_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None
142157
self._cached_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None
143158
############################################################################################
144159

145-
### EXTRA ARGS #############################################################################
146-
self._extra_args: Dict[str, torch.Tensor] = {}
147-
self._extra_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None
148-
self._extra_dynamic_shapes_callbacks: Dict[str, DynamicShapeCallback] = {}
149-
############################################################################################
150-
151160
# call reset once to initialize the tensors
152161
self.reset()
153162

@@ -345,9 +354,13 @@ def to(self, *args, **kwargs) -> None:
345354
for k in self._uncached_arg_names + self._cached_arg_names:
346355
setattr(self, k, getattr(self, k).to(*args, **kwargs))
347356

348-
for k, v in self._extra_args.items():
349-
if isinstance(v, torch.Tensor):
350-
self._extra_args[k] = v.to(*args, **kwargs)
357+
def _move_dict(d: Dict[str, torch.Tensor]) -> None:
358+
for k, v in d.items():
359+
d[k] = v.to(*args, **kwargs)
360+
361+
_move_dict(self._extra_args)
362+
_move_dict(self._extra_example_inputs)
363+
_move_dict(self._extra_none_inputs)
351364

352365
def reset(self) -> None:
353366
"""Reset the sequence information.
@@ -369,16 +382,63 @@ def set_example_sequence(self) -> None:
369382
"""Set an example sequence useful for testing and export purposes."""
370383
self.reset()
371384
bs, seq_len = min(2, self.max_batch_size), min(4, self.max_seq_len)
372-
input_ids = torch.ones(
385+
input_ids = torch.ones( # noqa
373386
bs,
374387
seq_len,
375388
dtype=torch.int,
376389
device=self.device,
377390
)
378-
self.nest_sequences(input_ids)
391+
392+
# TODO (lucaslie): seems we have hit a road block using generic example inputs for export
393+
# with VLMs. We need to probably switch to having the factory provide an example input that
394+
# is then being tokenized inside the factory.
395+
# WHY: for VLMs we need to hit these special tokens representing images. No way we can do
396+
# that with a generic example input.
397+
# fmt: off
398+
input_ids2 = [[
399+
200000, 200005, 1556, 200006, 368, 200080, 200090, 200092, 200092,
400+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
401+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
402+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
403+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
404+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
405+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
406+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
407+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
408+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
409+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
410+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
411+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
412+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
413+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
414+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
415+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200081, 200080,
416+
200090, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
417+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
418+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
419+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
420+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
421+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
422+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
423+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
424+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
425+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
426+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
427+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
428+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
429+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
430+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
431+
200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092,
432+
200092, 200081, 74777, 290, 5326, 43, 200008, 200005, 140680,
433+
200006, 368
434+
] for _ in range(2)]
435+
# fmt: on
436+
437+
self.nest_sequences(input_ids2, **self._extra_example_inputs)
379438

380439
def set_max_num_tokens_sample(self) -> None:
381440
"""Set an example sequence with max_num_tokens."""
441+
# TODO: understand what this implies for extra arguments
382442
self.reset()
383443
seq_len = self.max_num_tokens // self.max_batch_size
384444
input_ids = torch.ones(
@@ -480,6 +540,7 @@ def nest_sequences(
480540
position_ids: Optional[Sequence[Sequence[int]]] = None,
481541
input_pos: Optional[Union[torch.Tensor, Sequence[int], int]] = None,
482542
page_assignments: Optional[Sequence[Sequence[int]]] = None,
543+
**extra_args: Dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]],
483544
) -> None:
484545
"""Create and store a flattened list of input_ids from the provided list of sequences.
485546
@@ -488,6 +549,7 @@ def nest_sequences(
488549
position_ids: List of sequences of position_ids for each token.
489550
input_pos: Absolute starting position in the cache for each sequence.
490551
page_assignments: List of sequences of page assignments for each sequence.
552+
extra_args: Extra arguments to be stored in the interface.
491553
492554
This i/f will ensure that all sequence info args are updated accordingly.
493555
"""
@@ -542,36 +604,53 @@ def nest_sequences(
542604
if page_assignments is not None:
543605
self._assign_pages_per_seq(page_assignments)
544606

607+
# go through all extra arguments and update them
608+
for name, none_input in self._extra_none_inputs.items():
609+
if name in extra_args:
610+
arg = extra_args.pop(name)
611+
if not isinstance(arg, torch.Tensor):
612+
if len(arg) > 1:
613+
arg = torch.cat(arg)
614+
else:
615+
arg = arg[0]
616+
self._extra_args[name] = arg.to(self.device)
617+
else:
618+
self._extra_args[name] = none_input
619+
620+
assert not extra_args, f"Extra arguments {extra_args.keys()} not found"
621+
545622
def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
546623
t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0)
547624
return list(torch.split(t_squeezed, self.sequence_lengths))
548625

549626
def add_extra_arg(
550627
self,
551628
name: str,
552-
value: torch.Tensor,
629+
example_input: torch.Tensor,
630+
none_input: torch.Tensor,
553631
dynamic_shape_callback: Optional[DynamicShapeCallback] = None,
554632
) -> None:
555633
"""Add an extra argument to the sequence info object.
556634
557635
Args:
558636
name: The name of the extra argument.
559-
value: Example input value of the extra argument.
637+
example_input: Example input value of the extra argument.
638+
none_input: None input value of the extra argument.
560639
dynamic_shape_callback: The callback to get the dynamic shape of the extra argument.
561640
562641
Note that the extra argument is expected to be a tensor.
563642
"""
564-
self._extra_args[name] = value.to(self.device)
643+
assert name not in self._named_args().keys(), f"Extra argument {name} already exists"
644+
645+
self._extra_args[name] = example_input.to(self.device)
646+
self._extra_example_inputs[name] = example_input.to(self.device)
647+
self._extra_none_inputs[name] = none_input.to(self.device)
648+
565649
if dynamic_shape_callback is None:
566650
self._extra_dynamic_shapes_callbacks[name] = lambda: {}
567651
else:
568652
self._extra_dynamic_shapes_callbacks[name] = dynamic_shape_callback
569653

570-
def set_extra_arg(self, name: str, value: torch.Tensor) -> None:
571-
"""Set an extra argument to the sequence info."""
572-
# TODO (lucaslie): assume fixed shape for now
573-
self._extra_args[name].copy_(value.to(self.device), non_blocking=True)
574-
575654

576655
Constant = Union[int, float, str, None]
577656

tensorrt_llm/_torch/auto_deploy/models/factory.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,16 @@ def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType):
206206
device: The device to load the model on.
207207
"""
208208

209-
def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, DynamicShapeCallback]]:
209+
def get_extra_inputs(
210+
self,
211+
) -> Dict[str, Tuple[torch.Tensor, torch.Tensor, DynamicShapeCallback]]:
210212
"""Return a dictionary of extra inputs for the model.
211213
212214
Returns:
213215
A dictionary of extra inputs for the model where the key corresponds to the argument
214-
name and the value corresponds to a tuple of (example_input, dynamic_shape_callback).
215-
The dynamic shape callback is a function that returns the dynamic shape of the extra
216-
input.
216+
name and the value corresponds to a tuple of (example_input, none_input,
217+
dynamic_shape_callback). The dynamic shape callback is a function that returns the
218+
dynamic shape of the extra input.
217219
"""
218220
return {}
219221

0 commit comments

Comments
 (0)