From f33134f15065e7bfa73a5226e3f28d0506175809 Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Fri, 25 Jul 2025 15:36:59 -0700 Subject: [PATCH 01/11] SequenceInfo arguments revisited Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- .../custom_ops/attention_interface.py | 352 ++++++++++++------ .../_torch/auto_deploy/models/factory.py | 15 +- .../_torch/auto_deploy/shim/ad_executor.py | 18 +- .../_torch/auto_deploy/shim/demollm.py | 26 +- .../transformations/library/kvcache.py | 14 +- .../unit/singlegpu/shim/test_engine.py | 2 - .../transformations/library/test_kv_cache.py | 36 +- 7 files changed, 296 insertions(+), 167 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 13c91652bff..cceeef4d0e9 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -10,14 +10,17 @@ """ from abc import ABC, abstractmethod -from dataclasses import dataclass, field, fields -from typing import Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Type, Union +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Type, Union import torch from torch._ops import OpOverloadPacket from torch.export import Dim from torch.fx import Node +DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension +DynamicShapeCallback = Callable[[], DynamicShape] + @dataclass class CacheConfig: @@ -88,17 +91,6 @@ class SequenceInfo: # then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens). max_num_tokens: Optional[int] = None - ## [UPDATE WITH CARE] TENSOR FIELDS THAT WILL BE PASSED TO PREPARE_METADATA OP ################# - # input_ids MUST ALWAYS BE THE FIRST FIELD - input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.int)) - position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.long)) - - seq_len: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int)) - input_pos: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int)) - cache_loc: torch.Tensor = field(default_factory=lambda: torch.arange(1, dtype=torch.int)) - pages_per_seq: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int)) - ################################################################################################ - ## PRIVATE FIELDS ############################################################################## _sequence_lengths: List[int] = field(default_factory=list) _num_pages: int = 1 @@ -122,23 +114,39 @@ def __post_init__(self): self.max_batch_size, (total_tokens) // self.page_size + (total_tokens % self.page_size > 0), ) + # sanity check + assert self.num_pages >= self.max_batch_size, "num_pages can't be less than max_batch_size" + + # keep a list-like object of sequence lengths for simplicity as well + self._sequence_lengths = [0] * self.max_batch_size + + # indicator if extra args are activated that are needed for cached attention backends + self._is_cached_attn = False + + ### TENSOR FIELDS/ARGS FOR UNCACHED AND CACHED ATTENTION ################################### + # UNCACHED TENSOR FIELDS self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int) self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long) + self._uncached_arg_names = ["input_ids", "position_ids"] + + # CACHED TENSOR FIELDS (for cached attention backends) self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int) self.input_pos = torch.empty_like(self.seq_len) self.cache_loc = torch.empty(self.num_pages, dtype=torch.int) self.pages_per_seq = torch.empty_like(self.seq_len) - assert self.num_pages >= self.max_batch_size, ( - "num_pages must be greater than max_batch_size" - ) - # dynamic shape descriptors for tensor args - self._dynamic_shapes: Optional[Tuple[Dict[str, Dim]]] = None + self._cached_arg_names = ["seq_len", "input_pos", "cache_loc", "pages_per_seq"] - # keep a list-like object of sequence lengths for simplicity as well - self._sequence_lengths = [0] * self.max_batch_size + # DYNAMIC SHAPES + # --> initialized lazily since Dim is not picklable for multi-processing + self._uncached_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None + self._cached_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None + ############################################################################################ - # indicator if extra args are activated that are needed for cached attention backends - self._is_cached_attn = False + ### EXTRA ARGS ############################################################################# + self._extra_args: Dict[str, torch.Tensor] = {} + self._extra_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None + self._extra_dynamic_shapes_callbacks: Dict[str, DynamicShapeCallback] = {} + ############################################################################################ # call reset once to initialize the tensors self.reset() @@ -147,52 +155,78 @@ def __post_init__(self): def device(self) -> torch.device: return self.input_pos.device + def _named_args( + self, include_extra_args: bool = True, include_cached_args: bool = True + ) -> Dict[str, torch.Tensor]: + args: Dict[str, torch.Tensor] = {} + for name in self._uncached_arg_names: + args[name] = getattr(self, name) + + if include_extra_args: + args.update(self._extra_args) + + if include_cached_args: + for name in self._cached_arg_names: + args[name] = getattr(self, name) + + return args + @property - def args(self) -> Tuple[torch.Tensor, ...]: - args = [] - for f in fields(self): - val = getattr(self, f.name) - if isinstance(val, torch.Tensor): - args.append(val) - if len(args) >= self._num_uncached_attn_args and not self._is_cached_attn: - break - return tuple(args) + def named_args(self) -> Dict[str, torch.Tensor]: + """Return a dictionary of named arguments.""" + return self._named_args(include_extra_args=True, include_cached_args=self._is_cached_attn) @property - def _num_uncached_attn_args(self) -> int: - """Return the number of original graph arguments expected by the model.""" - return 2 + def named_standard_args(self) -> Dict[str, torch.Tensor]: + """Return a dictionary of named standard arguments.""" + return self._named_args(include_extra_args=False, include_cached_args=self._is_cached_attn) @property - def _cached_attn_arg_names(self) -> List[str]: - """Return extra arg names for the prepare_metadata op beyond input_ids and position_ids. + def args(self) -> Tuple[torch.Tensor, ...]: + """Return a tuple of arguments.""" + return tuple(self.named_args.values()) - These extra args are needed once we switch from regular attention to inserting cached - attention ops in the model. - """ - return [f.name for f in fields(self) if isinstance(getattr(self, f.name), torch.Tensor)][ - self._num_uncached_attn_args : - ] + @property + def extra_args_for_prepare_metadata(self) -> Tuple: + """Return a tuple of extra (const, non-tensor) arguments for the prepare_metadata op.""" + return (self.page_size,) @property - def dynamic_shapes(self) -> Tuple[Dict[str, Dim]]: + def named_dynamic_shapes(self) -> Dict[str, Dict[str, Dim]]: """Return dynamic shapes of sequence info tensors. NOTE: will be lazily initialized since the Dim object is not picklable for multi-processing. """ - if self._dynamic_shapes is None: - # set up shape for input_ids and position_ids - dynamic_shapes = ({}, {}) + # lazy initialization of dynamic shapes with Dim objects + if self._uncached_dynamic_shapes is None: + # set up shape for uncached args (same for all, i.e., batch_size and seq_len) + bs_seq_len_shape: DynamicShape = {} if self.max_batch_size > 1: - dynamic_shapes[0][0] = Dim("batch_size", max=self.max_batch_size) - dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len) - # set up shape for position_ids (same as input_ids) - dynamic_shapes[1].update(dynamic_shapes[0]) - # set up shape for extra args - if self._is_cached_attn: - dynamic_shapes += ({},) * len(self._cached_attn_arg_names) - self._dynamic_shapes = dynamic_shapes - return self._dynamic_shapes + bs_seq_len_shape[0] = Dim("batch_size", max=self.max_batch_size) + bs_seq_len_shape[1] = Dim("seq_len", max=self.max_seq_len) + self._uncached_dynamic_shapes = {k: bs_seq_len_shape for k in self._uncached_arg_names} + + named_dynamic_shapes = self._uncached_dynamic_shapes.copy() + + # add dynamic shapes for extra args + if self._extra_dynamic_shapes is None: + self._extra_dynamic_shapes = { + k: callback() for k, callback in self._extra_dynamic_shapes_callbacks.items() + } + named_dynamic_shapes.update(self._extra_dynamic_shapes) + + # fixed shape for remaining cached attention args + if self._is_cached_attn: + if self._cached_dynamic_shapes is None: + self._cached_dynamic_shapes = {k: {} for k in self._cached_arg_names} + named_dynamic_shapes.update(self._cached_dynamic_shapes) + + return named_dynamic_shapes + + @property + def dynamic_shapes(self) -> Tuple[DynamicShape, ...]: + """Return dynamic shapes of sequence info tensors.""" + return tuple(self.named_dynamic_shapes.values()) @property def num_sequences(self) -> int: @@ -305,26 +339,15 @@ def switch_to_cached_attn_inputs(self) -> List[str]: """ assert not self._is_cached_attn, "Cached+flattened attention already activated" self._is_cached_attn = True - return self._cached_attn_arg_names + return self._cached_arg_names.copy() def to(self, *args, **kwargs) -> None: - for f in fields(self): - val = getattr(self, f.name) - if isinstance(val, torch.Tensor): - setattr(self, f.name, val.to(*args, **kwargs)) - - def sync(self, other: "SequenceInfo") -> None: - for f in fields(self): - val = getattr(self, f.name) - val_other = getattr(other, f.name) - if f.name in ["input_ids", "position_ids"]: - setattr(self, f.name, val_other.to(self.device)) - elif f.name == "_sequence_lengths": - self._sequence_lengths = val_other - elif isinstance(val, torch.Tensor): - val[: len(val_other)] = val_other.to(self.device) - else: - assert val == val_other, f"Field {f.name} mismatch: {val} != {val_other}." + for k in self._uncached_arg_names + self._cached_arg_names: + setattr(self, k, getattr(self, k).to(*args, **kwargs)) + + for k, v in self._extra_args.items(): + if isinstance(v, torch.Tensor): + self._extra_args[k] = v.to(*args, **kwargs) def reset(self) -> None: """Reset the sequence information. @@ -354,12 +377,7 @@ def set_example_sequence(self) -> None: ) self.nest_sequences(input_ids) - # unflatten if we are not yet using cached+flattened attention - if not self._is_cached_attn: - self.input_ids = self.input_ids.view(bs, seq_len) - self.position_ids = self.position_ids.view(bs, seq_len) - - def _set_max_num_tokens_sample(self) -> None: + def set_max_num_tokens_sample(self) -> None: """Set an example sequence with max_num_tokens.""" self.reset() seq_len = self.max_num_tokens // self.max_batch_size @@ -396,67 +414,163 @@ def _update_position_ids(self) -> None: else: self.position_ids = self.position_ids.view(1, -1) - def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None: + def _generate_position_ids(self) -> torch.Tensor: + """Generate position ids from current input_pos and sequence lengths.""" + position_ids_list = [ + num + for in_pos, seq_len in zip(self.input_positions, self.sequence_lengths) + for num in range(in_pos, in_pos + seq_len) + ] + return torch.tensor(position_ids_list, dtype=torch.long).to(self.device) + + def _update_input_pos(self, seq_len: Union[torch.Tensor, List[int], int]) -> None: + """Update the starting position for each sequence in the cache. + + If ``reset=True`, ``input_pos`` will be reset to zero before updating. + """ + if not isinstance(seq_len, torch.Tensor): + seq_len = torch.tensor(seq_len, dtype=torch.int) + bs = len(seq_len) if seq_len.dim() > 0 else self.max_batch_size + self.input_pos[:bs] = seq_len.to(self.device) + + def _assign_pages_per_seq(self, page_assignments: Sequence[Sequence[int]]) -> None: + """Set the cache location and pages_per_seq tensors from page assignments.""" + assert len(page_assignments) == self.num_sequences + cache_loc_flat = torch.tensor( + [p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int + ) + self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True) + + pages_per_seq = torch.tensor([len(p) for p in page_assignments], dtype=torch.int) + self.pages_per_seq[: len(pages_per_seq)].copy_(pages_per_seq, non_blocking=True) + + @staticmethod + def _flatten(nested_seqs: Sequence[Sequence[int]]) -> List[int]: + return [ + val + for lst in nested_seqs + for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst) + ] + + def _shape_for_forward(self, tnsr: torch.Tensor) -> torch.Tensor: + """Shape the tensor for the forward pass based on the current attention mode. + + Args: + tnsr: The tensor to shape assumed to be in shape [batch_size*seq_len, ...] + + Returns: + The shaped tensor flattened or unflattened based on the current attention mode. + """ + # check if we are still running uncached attention in which case we are also still + # operate on unflattened tensors with explicit [batch_size, seq_len, ...] shape + if not self._is_cached_attn: + bs = len(self.sequence_lengths) + sl = self.sequence_lengths[0] + return tnsr.view(bs, sl, *tnsr.shape[2:]) + + # use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len] + if self.is_generate: + return tnsr.view(-1, 1, *tnsr.shape[1:]) + else: + return tnsr.view(1, -1, *tnsr.shape[1:]) + + def nest_sequences( + self, + input_ids: Sequence[Sequence[int]], + position_ids: Optional[Sequence[Sequence[int]]] = None, + input_pos: Optional[Union[torch.Tensor, Sequence[int], int]] = None, + page_assignments: Optional[Sequence[Sequence[int]]] = None, + ) -> None: """Create and store a flattened list of input_ids from the provided list of sequences. - This i/f will also update any relevant sequence information. + Args: + input_ids: List of sequences of input_ids. + position_ids: List of sequences of position_ids for each token. + input_pos: Absolute starting position in the cache for each sequence. + page_assignments: List of sequences of page assignments for each sequence. + + This i/f will ensure that all sequence info args are updated accordingly. """ + # set new sequence lengths seq_lens = [len(ids) for ids in input_ids] self.seq_len.zero_() self.seq_len[: len(seq_lens)].copy_(torch.tensor(seq_lens), non_blocking=True) + self._sequence_lengths = seq_lens + # We'll preserve the dtype of the input_ids tensor if it is a tensor, otherwise we'll use int dtype = input_ids.dtype if isinstance(input_ids, torch.Tensor) else torch.int - # set new input_ids as new tensor from flattened input_ids - ids_list = [ - val - for lst in input_ids - for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst) - ] - self.input_ids = torch.tensor(ids_list, dtype=dtype).to(self.device) - # set derivative properties - self._sequence_lengths = seq_lens + # set new input_ids as new tensor from flattened input_ids + self.input_ids = torch.tensor(self._flatten(input_ids), dtype=dtype).to(self.device) + self.input_ids = self._shape_for_forward(self.input_ids) - # use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len] - if self.is_generate: - self.input_ids = self.input_ids.view(-1, 1, *self.input_ids.shape[1:]) + # check for position_ids/input_pos update + assert position_ids is None or input_pos is None, ( + "Cannot provide both position_ids and input_pos" + ) + # check for updated input_pos + if input_pos is not None: + self._update_input_pos(input_pos) + + # check for updated position_ids + if position_ids is None: + # none provided,simple update position_ids based on new sequence lengths and + # current input_pos assuming that input_pos is the starting position id for each + # sequence and position_ids are consecutive. + self.position_ids = self._generate_position_ids() + elif not isinstance(position_ids, torch.Tensor): + # nest position_ids to be consistent with input_ids + seq_lens_p = [len(ids) for ids in position_ids] + assert len(seq_lens_p) == len(seq_lens), f"{seq_lens_p=} != {seq_lens=}" + position_ids_flat = self._flatten(position_ids) + self.position_ids = torch.tensor( + position_ids_flat, dtype=torch.long, device=self.device + ) else: - self.input_ids = self.input_ids.view(1, -1, *self.input_ids.shape[1:]) + self.position_ids = position_ids - # update position_ids - self._update_position_ids() + # final shape for position_ids + self.position_ids = self._shape_for_forward(self.position_ids) + + # sanity check on final shape of position_ids and input_ids + assert self.position_ids.shape[:2] == self.input_ids.shape[:2], ( + f"{self.position_ids.shape[:2]=} != {self.input_ids.shape[:2]=}" + ) + + # check for updated page_assignments + if page_assignments is not None: + self._assign_pages_per_seq(page_assignments) def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]: t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0) return list(torch.split(t_squeezed, self.sequence_lengths)) - def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool = False) -> None: - """Update the starting position for each sequence in the cache. - - If ``reset=True`, ``input_pos`` will be reset to zero before updating. + def add_extra_arg( + self, + name: str, + value: torch.Tensor, + dynamic_shape_callback: Optional[DynamicShapeCallback] = None, + ) -> None: + """Add an extra argument to the sequence info object. + + Args: + name: The name of the extra argument. + value: Example input value of the extra argument. + dynamic_shape_callback: The callback to get the dynamic shape of the extra argument. + + Note that the extra argument is expected to be a tensor. """ - if not isinstance(seq_len, torch.Tensor): - seq_len = torch.tensor(seq_len, dtype=torch.int) - bs = len(seq_len) if seq_len.dim() > 0 else self.max_batch_size - - if reset: - self.input_pos[:bs] = seq_len.to(self.device) + self._extra_args[name] = value.to(self.device) + if dynamic_shape_callback is None: + self._extra_dynamic_shapes_callbacks[name] = lambda: {} else: - self.input_pos[:bs] += seq_len.to(self.device) + self._extra_dynamic_shapes_callbacks[name] = dynamic_shape_callback - # update position_ids - self._update_position_ids() - - def assign_cache_loc(self, page_assignments: Sequence[Sequence[int]]) -> None: - """Set the cache location and pages_per_seq tensors from page assignments.""" - cache_loc_flat = torch.tensor( - [p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int - ) - self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True) - - pages_per_seq = torch.tensor([len(p) for p in page_assignments], dtype=torch.int) - self.pages_per_seq[: len(pages_per_seq)].copy_(pages_per_seq, non_blocking=True) + def set_extra_arg(self, name: str, value: torch.Tensor) -> None: + """Set an extra argument to the sequence info.""" + # TODO (lucaslie): assume fixed shape for now + self._extra_args[name].copy_(value.to(self.device), non_blocking=True) Constant = Union[int, float, str, None] diff --git a/tensorrt_llm/_torch/auto_deploy/models/factory.py b/tensorrt_llm/_torch/auto_deploy/models/factory.py index 42a30402537..d15a07b2071 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/factory.py +++ b/tensorrt_llm/_torch/auto_deploy/models/factory.py @@ -2,13 +2,13 @@ import copy from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, Optional, Tuple, Type import torch import torch.nn as nn from torch._prims_common import DeviceLikeType -from ..custom_ops.attention_interface import CacheConfig +from ..custom_ops.attention_interface import CacheConfig, DynamicShapeCallback from ..utils.logger import ad_logger @@ -206,6 +206,17 @@ def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType): device: The device to load the model on. """ + def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, DynamicShapeCallback]]: + """Return a dictionary of extra inputs for the model. + + Returns: + A dictionary of extra inputs for the model where the key corresponds to the argument + name and the value corresponds to a tuple of (example_input, dynamic_shape_callback). + The dynamic shape callback is a function that returns the dynamic shape of the extra + input. + """ + return {} + class ModelFactoryRegistry: _registry: Dict[str, Type[ModelFactory]] = {} diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 7f759d6796d..353774c755a 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -102,16 +102,24 @@ def build_from_config(cls, ad_config: AutoDeployConfig): max_num_tokens=max_num_tokens, ) + # get factory + factory = ad_config.create_factory() + # update device to contain the current default device if it's in cuda device = torch.device(ad_config.device) if device.type == "cuda" and device.index is None: device = torch.device(f"cuda:{torch.cuda.current_device()}") device = str(device) + # pass in extra arguments defined by the model factory + for name, (example_input, dynamic_shape_callback) in factory.get_extra_inputs().items(): + seq_info.add_extra_arg(name, example_input, dynamic_shape_callback) + + # TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__, + # ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm. + # construct inference optimizer - build_and_optimize = InferenceOptimizer( - factory=ad_config.create_factory(), ad_config=ad_config - ) + build_and_optimize = InferenceOptimizer(factory=factory, ad_config=ad_config) # construct engine return cls(build_and_optimize, seq_info, device, max_beam_width) @@ -210,9 +218,7 @@ def _prepare_inputs( # update the sequence info object now si = self.cache_seq_interface.info - si.nest_sequences(input_ids) - si.update_pos(input_pos, reset=True) - si.assign_cache_loc(page_assignments) + si.nest_sequences(input_ids, input_pos=input_pos, page_assignments=page_assignments) return last_logit_only def _compute_logits(self) -> List[torch.Tensor]: diff --git a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py index c29cb5fbd7e..4f3c35a21eb 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py @@ -45,7 +45,7 @@ def stop(self): self.queue.close() self.queue.join_thread() - def _assign_pages(self) -> List[List[int]]: + def _assign_pages(self, total_lens: List[int]) -> List[List[int]]: """A simple heuristic to assign pages based on current sequence info. In a nutshell, we will look at the following information to update the page assignments: @@ -67,7 +67,6 @@ def _assign_pages(self) -> List[List[int]]: unassigned page if needed. """ si = self.cache_seq_interface.info - total_lens = [s_l + i_p for s_l, i_p in zip(si.sequence_lengths, si.input_positions)] page_assignments = si.page_assignments free_pages = set(range(si.num_pages)) - {i for pages in page_assignments for i in pages} @@ -76,7 +75,7 @@ def _assign_pages(self) -> List[List[int]]: extra_tokens = t_l - len(pages) * si.page_size num_extra_pages = (extra_tokens // si.page_size) + (extra_tokens > 0) updated_assignments.append(pages + [free_pages.pop() for _ in range(num_extra_pages)]) - si.assign_cache_loc(updated_assignments) + return updated_assignments def generate_tokens_batched( self, requests: List[GenerationRequest] @@ -94,7 +93,11 @@ def generate_tokens_batched( # set up sequence info object sequence_info = self.cache_seq_interface.info sequence_info.reset() - sequence_info.nest_sequences([r.prompt_token_ids for r in requests]) + total_lens = [len(r.prompt_token_ids) for r in requests] + sequence_info.nest_sequences( + input_ids=[r.prompt_token_ids for r in requests], + page_assignments=self._assign_pages(total_lens), + ) # setup objects we want to track for the output batch_size = sequence_info.num_sequences @@ -105,18 +108,21 @@ def generate_tokens_batched( context_logits: Optional[List[torch.Tensor]] = None def _generate_single_step(idx: int): - # assign pages - self._assign_pages() - - # get the logits and then last token logits in each sequence ([b, 1, vocab_size]) logits = self._compute_logits() logits_last = torch.stack([l_one_seq[-1] for l_one_seq in logits]).float().unsqueeze(1) token_ids, _ = self._decode_tokens(logits_last, sampling_params) # [b,1] # update sequence info accordingly for next step - sequence_info.update_pos(sequence_info.sequence_lengths) - sequence_info.nest_sequences(token_ids) + input_pos_next = sequence_info.input_positions + seq_lens_current = sequence_info.sequence_lengths + input_pos_next = [ip + sl for ip, sl in zip(input_pos_next, seq_lens_current)] + total_lens_next = [ip + len(t_ids) for ip, t_ids in zip(input_pos_next, token_ids)] + sequence_info.nest_sequences( + token_ids, + input_pos=input_pos_next, + page_assignments=self._assign_pages(total_lens_next), + ) # nest new tokens and run stop check for b, (new_tokens_b, new_id) in enumerate(zip(new_tokens, token_ids)): diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py index 618c8108f84..6fd1d2b7b7f 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py @@ -26,9 +26,6 @@ def update_in_out_nodes(egm: GraphModule, cm: CachedSequenceInterface) -> None: # loop through nodes to get input, output, and get_attr nodes input_nodes, output_nodes = get_all_input_output_nodes(egm.graph) - # we only expect one input node - assert len(input_nodes) == 2, "Expected exactly two input nodes (input_ids, position_ids)." - # NOTE: for now, we wanna make sure we *only* return the final output and no hidden states. # Later on, we can revisit how to support returning hidden states. assert len(output_nodes) == 1, "Expected exactly one output node!" @@ -73,16 +70,17 @@ def insert_cached_attention( # retrieve input nodes input_nodes, _ = get_all_input_output_nodes(egm.graph) + input_nodes_mapping = {n.target: n for n in input_nodes} + + # filtered and sorted for SequenceInfo arguments (input_ids, position_ids, etc.) + input_nodes_from_info = [input_nodes_mapping[k] for k in cm.info.named_standard_args.keys()] # insert metadata computation and extract each argument as a node get_metadata, num_metadata = attn_descriptor.get_prepare_metadata_op() with graph.inserting_before(input_nodes[-1].next): ret_node = graph.call_function( get_metadata, - args=( - *input_nodes, - cm.info.page_size, - ), + args=(*input_nodes_from_info, *cm.info.extra_args_for_prepare_metadata), ) metadata_nodes = [ graph.call_function(operator.getitem, args=(ret_node, idx)) @@ -162,7 +160,7 @@ def _get_mem_info_in_mb(): try: # Let's run a forward pass to get the memory usage - cm.info._set_max_num_tokens_sample() + cm.info.set_max_num_tokens_sample() free_mem_pre, _ = _get_mem_info_in_mb() ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}") diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py index e9d7acd7dc3..472bc71f1eb 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py @@ -71,7 +71,6 @@ def test_engine(engine_cls: Type[ADEngine], attn_backend: str, attn_page_size: i input_ids = [torch.tensor([0, 1, 2], device=device)] sequence_info.reset() sequence_info.nest_sequences(input_ids) - engine.cache_seq_interface.info.sync(sequence_info) logits = engine._compute_logits() logits = torch.stack(logits) assert logits is not None, "Logits are None" @@ -106,7 +105,6 @@ def test_demo_engine_sampling(attn_page_size: int): input_ids = [torch.tensor([1, 2, 3, 4], device=device)] sequence_info.reset() sequence_info.nest_sequences(input_ids) - engine.cache_seq_interface.info.sync(sequence_info) logits = engine._compute_logits() logits = torch.stack(logits) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py index 876eba196cc..4f1015a1268 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py @@ -38,17 +38,19 @@ def __init__( self.num_key_value_groups = None @torch.no_grad() - def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, input_ids: torch.Tensor, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: """ Forward pass with input tokens and optional position ids. position_ids parameter added to match expected interface in kvcache.py """ - b, s, _ = x.shape + b, s, _ = input_ids.shape # Project input to q, k, v representations - q = self.q_proj(x) # [b, s, n*h_d] - k = self.k_proj(x) # [b, s, n_kv*h_d] - v = self.v_proj(x) # [b, s, n_kv*h_d] + q = self.q_proj(input_ids) # [b, s, n*h_d] + k = self.k_proj(input_ids) # [b, s, n_kv*h_d] + v = self.v_proj(input_ids) # [b, s, n_kv*h_d] # Reshape to [b, s, n, h_d] q = q.view(b, s, self.num_heads, self.head_dim) @@ -185,9 +187,9 @@ def test_sdpa_with_kv_cache(dtype, attn_descriptor, gqa_config): cm.initialize_caches() # Helper function to call the model with proper sequence nesting - def _call_and_unnest(x): + def _call_and_unnest(x, input_pos): # Use nest_sequences to properly set input_ids and automatically update position_ids - cm.info.nest_sequences(x) + cm.info.nest_sequences(x, input_pos=input_pos) # Use the cm.args as is - it already contains the correct position_ids y = gm(*cm.args) @@ -197,31 +199,25 @@ def _call_and_unnest(x): # Test 1: Regular inference (all tokens at once) cm.info.reset() - y_no_cache = _call_and_unnest(x) + y_no_cache = _call_and_unnest(x, 0) assert all_close(y_model, y_no_cache, atol=atol, rtol=rtol) # Test 2: Autoregressive inference with KV cache cm.info.reset() y_with_cache = torch.empty_like(y_model) - for i in range(x.shape[1]): + for i_p in range(x.shape[1]): # Just pass the current token - y_with_cache[:, i : i + 1] = _call_and_unnest(x[:, i : i + 1]) - # Update position for next token - cm.info.update_pos(1) # This automatically updates position_ids too + y_with_cache[:, i_p : i_p + 1] = _call_and_unnest(x[:, i_p : i_p + 1], i_p) assert all_close(y_model, y_with_cache, atol=atol, rtol=rtol) # Test 3: Cache continuation after random tokens - cm.info.update_pos(-num_reset_steps) # Rewind position - for i in range(num_random_steps): - _call_and_unnest(torch.rand_like(x[:, :1])) - cm.info.update_pos(1) + for i_p in range(x.shape[1] - num_reset_steps, x.shape[1] - num_reset_steps + num_random_steps): + _call_and_unnest(torch.rand_like(x[:, :1]), i_p) # Continue inference from previous context cm.info.reset() - cm.info.update_pos(x.shape[1] - num_reset_steps) - for i in range(x.shape[1] - num_reset_steps, x.shape[1]): - y_with_cache[:, i : i + 1] = _call_and_unnest(x[:, i : i + 1]) - cm.info.update_pos(1) + for i_p in range(x.shape[1] - num_reset_steps, x.shape[1]): + y_with_cache[:, i_p : i_p + 1] = _call_and_unnest(x[:, i_p : i_p + 1], i_p) assert all_close(y_model, y_with_cache, atol=atol, rtol=rtol) # Test 4: Exportability of the transformed model From 92515a7837d497b391e734477dbecb3a6cfe1d22 Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Wed, 30 Jul 2025 09:31:12 -0700 Subject: [PATCH 02/11] more VLM work Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- examples/auto_deploy/.gitignore | 2 + examples/auto_deploy/build_and_run_ad.py | 48 ++++- .../custom_ops/attention_interface.py | 129 +++++++++--- .../_torch/auto_deploy/models/factory.py | 10 +- tensorrt_llm/_torch/auto_deploy/models/hf.py | 63 +++++- .../auto_deploy/models/patches/llama4.py | 185 ++++++++++++++++++ .../_torch/auto_deploy/shim/ad_executor.py | 28 ++- .../transform/library/export_to_gm.py | 23 ++- tensorrt_llm/executor/worker.py | 2 +- tensorrt_llm/llmapi/llm.py | 8 + .../integration/test_llama4_vlm_export.py | 5 +- 11 files changed, 453 insertions(+), 50 deletions(-) create mode 100644 tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py diff --git a/examples/auto_deploy/.gitignore b/examples/auto_deploy/.gitignore index 8e28b4431da..9836a37fc88 100644 --- a/examples/auto_deploy/.gitignore +++ b/examples/auto_deploy/.gitignore @@ -2,3 +2,5 @@ !.vscode benchmark_results.json *.png +# ignore config files that users might put here for debugging +*.yaml diff --git a/examples/auto_deploy/build_and_run_ad.py b/examples/auto_deploy/build_and_run_ad.py index 35879834db0..9c1f8d951d6 100644 --- a/examples/auto_deploy/build_and_run_ad.py +++ b/examples/auto_deploy/build_and_run_ad.py @@ -237,10 +237,56 @@ def main(config: Optional[ExperimentConfig] = None): llm = build_llm_from_config(config) + # just run config.prompt.queries with our special token sequence including special image tokens + # fmt: off + input_ids = [[ + 200000, 200005, 1556, 200006, 368, 200080, 200090, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200081, 200080, + 200090, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200081, 51212, 1780, 650, 2556, 310, 290, 1472, + 8392, 341, 1357, 13492, 26, 200008, 200005, 140680, 200006, + 368 + ] for _ in range(2)] + # fmt: on + # prompt the model and print its output ad_logger.info("Running example prompts...") + + # now let's try piping through multimodal data + outs = llm.generate( - config.prompt.queries, + input_ids, + # config.prompt.queries, sampling_params=SamplingParams(**config.prompt.sp_kwargs), ) results = {"prompts_and_outputs": print_outputs(outs)} diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index cceeef4d0e9..0916df03b0e 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -37,7 +37,7 @@ class SequenceInfo: between arguments that are originally part of the model/graph and arguments that are needed for the attention operator when we switch to cached+flattened attention. - # ORIGINAL MODEL ARGUMENTS ##################################################################### + ### ORIGINAL MODEL ARGUMENTS ################################################################### - input_ids: [id_0, ..., id_{s_total-1}] flattened sequence of [b, 1] or [1, s_total]. We use [b, 1] to denote generate-only batches. - position_ids: [pos_0, ..., pos_{s_total-1}] @@ -47,7 +47,18 @@ class SequenceInfo: NOTE: ``input_ids`` and ``position_ids`` are initially expected to be of shape [b, seq_len] before we switch to cached+flattened attention. - # EXTRA ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ############## + ### EXTRA ARGUMENTS PROVIDED TO THE INTERFACE ################################################## + Those are extra arguments that can be provided to the interface and they are stored as follows: + - _extra_args: dictionary of extra arguments with currently active values. + - _extra_example_inputs: dictionary of example inputs to the extra arguments. + - _extra_none_inputs: dictionary of none inputs to the extra arguments. + NOTE: we assume that extra arguments are *optional* arguments to the model. However, we + cannot represent them via `None` since fx graphs require a fixed input type. Instead, + we require a special placeholder tensor to represent the `None` input. + - _extra_dynamic_shapes_callbacks: dictionary of callbacks to initialize the dynamic shapes of + the extra arguments. + + ### CACHE ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ############ - seq_len: [s_0, s_1, ..., s_{b-1}] such that s_total = sum(s_i) Describes how long each sequence is. For example, input_ids[:s_0] will correspond to sequence 0 in the batch and input_ids[s_0:s_1] will @@ -128,6 +139,14 @@ def __post_init__(self): self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int) self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long) self._uncached_arg_names = ["input_ids", "position_ids"] + self._uncached_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None + + # EXTRA TENSOR FIELDS + self._extra_args: Dict[str, torch.Tensor] = {} + self._extra_example_inputs: Dict[str, torch.Tensor] = {} + self._extra_none_inputs: Dict[str, torch.Tensor] = {} + self._extra_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None + self._extra_dynamic_shapes_callbacks: Dict[str, DynamicShapeCallback] = {} # CACHED TENSOR FIELDS (for cached attention backends) self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int) @@ -135,19 +154,9 @@ def __post_init__(self): self.cache_loc = torch.empty(self.num_pages, dtype=torch.int) self.pages_per_seq = torch.empty_like(self.seq_len) self._cached_arg_names = ["seq_len", "input_pos", "cache_loc", "pages_per_seq"] - - # DYNAMIC SHAPES - # --> initialized lazily since Dim is not picklable for multi-processing - self._uncached_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None self._cached_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None ############################################################################################ - ### EXTRA ARGS ############################################################################# - self._extra_args: Dict[str, torch.Tensor] = {} - self._extra_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None - self._extra_dynamic_shapes_callbacks: Dict[str, DynamicShapeCallback] = {} - ############################################################################################ - # call reset once to initialize the tensors self.reset() @@ -345,9 +354,13 @@ def to(self, *args, **kwargs) -> None: for k in self._uncached_arg_names + self._cached_arg_names: setattr(self, k, getattr(self, k).to(*args, **kwargs)) - for k, v in self._extra_args.items(): - if isinstance(v, torch.Tensor): - self._extra_args[k] = v.to(*args, **kwargs) + def _move_dict(d: Dict[str, torch.Tensor]) -> None: + for k, v in d.items(): + d[k] = v.to(*args, **kwargs) + + _move_dict(self._extra_args) + _move_dict(self._extra_example_inputs) + _move_dict(self._extra_none_inputs) def reset(self) -> None: """Reset the sequence information. @@ -369,16 +382,63 @@ def set_example_sequence(self) -> None: """Set an example sequence useful for testing and export purposes.""" self.reset() bs, seq_len = min(2, self.max_batch_size), min(4, self.max_seq_len) - input_ids = torch.ones( + input_ids = torch.ones( # noqa bs, seq_len, dtype=torch.int, device=self.device, ) - self.nest_sequences(input_ids) + + # TODO (lucaslie): seems we have hit a road block using generic example inputs for export + # with VLMs. We need to probably switch to having the factory provide an example input that + # is then being tokenized inside the factory. + # WHY: for VLMs we need to hit these special tokens representing images. No way we can do + # that with a generic example input. + # fmt: off + input_ids2 = [[ + 200000, 200005, 1556, 200006, 368, 200080, 200090, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200081, 200080, + 200090, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200081, 74777, 290, 5326, 43, 200008, 200005, 140680, + 200006, 368 + ] for _ in range(2)] + # fmt: on + + self.nest_sequences(input_ids2, **self._extra_example_inputs) def set_max_num_tokens_sample(self) -> None: """Set an example sequence with max_num_tokens.""" + # TODO: understand what this implies for extra arguments self.reset() seq_len = self.max_num_tokens // self.max_batch_size input_ids = torch.ones( @@ -480,6 +540,7 @@ def nest_sequences( position_ids: Optional[Sequence[Sequence[int]]] = None, input_pos: Optional[Union[torch.Tensor, Sequence[int], int]] = None, page_assignments: Optional[Sequence[Sequence[int]]] = None, + **extra_args: Dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]], ) -> None: """Create and store a flattened list of input_ids from the provided list of sequences. @@ -488,6 +549,7 @@ def nest_sequences( position_ids: List of sequences of position_ids for each token. input_pos: Absolute starting position in the cache for each sequence. page_assignments: List of sequences of page assignments for each sequence. + extra_args: Extra arguments to be stored in the interface. This i/f will ensure that all sequence info args are updated accordingly. """ @@ -542,6 +604,21 @@ def nest_sequences( if page_assignments is not None: self._assign_pages_per_seq(page_assignments) + # go through all extra arguments and update them + for name, none_input in self._extra_none_inputs.items(): + if name in extra_args: + arg = extra_args.pop(name) + if not isinstance(arg, torch.Tensor): + if len(arg) > 1: + arg = torch.cat(arg) + else: + arg = arg[0] + self._extra_args[name] = arg.to(self.device) + else: + self._extra_args[name] = none_input + + assert not extra_args, f"Extra arguments {extra_args.keys()} not found" + def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]: t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0) return list(torch.split(t_squeezed, self.sequence_lengths)) @@ -549,29 +626,31 @@ def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]: def add_extra_arg( self, name: str, - value: torch.Tensor, + example_input: torch.Tensor, + none_input: torch.Tensor, dynamic_shape_callback: Optional[DynamicShapeCallback] = None, ) -> None: """Add an extra argument to the sequence info object. Args: name: The name of the extra argument. - value: Example input value of the extra argument. + example_input: Example input value of the extra argument. + none_input: None input value of the extra argument. dynamic_shape_callback: The callback to get the dynamic shape of the extra argument. Note that the extra argument is expected to be a tensor. """ - self._extra_args[name] = value.to(self.device) + assert name not in self._named_args().keys(), f"Extra argument {name} already exists" + + self._extra_args[name] = example_input.to(self.device) + self._extra_example_inputs[name] = example_input.to(self.device) + self._extra_none_inputs[name] = none_input.to(self.device) + if dynamic_shape_callback is None: self._extra_dynamic_shapes_callbacks[name] = lambda: {} else: self._extra_dynamic_shapes_callbacks[name] = dynamic_shape_callback - def set_extra_arg(self, name: str, value: torch.Tensor) -> None: - """Set an extra argument to the sequence info.""" - # TODO (lucaslie): assume fixed shape for now - self._extra_args[name].copy_(value.to(self.device), non_blocking=True) - Constant = Union[int, float, str, None] diff --git a/tensorrt_llm/_torch/auto_deploy/models/factory.py b/tensorrt_llm/_torch/auto_deploy/models/factory.py index d15a07b2071..ea36f3c9d49 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/factory.py +++ b/tensorrt_llm/_torch/auto_deploy/models/factory.py @@ -206,14 +206,16 @@ def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType): device: The device to load the model on. """ - def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, DynamicShapeCallback]]: + def get_extra_inputs( + self, + ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor, DynamicShapeCallback]]: """Return a dictionary of extra inputs for the model. Returns: A dictionary of extra inputs for the model where the key corresponds to the argument - name and the value corresponds to a tuple of (example_input, dynamic_shape_callback). - The dynamic shape callback is a function that returns the dynamic shape of the extra - input. + name and the value corresponds to a tuple of (example_input, none_input, + dynamic_shape_callback). The dynamic shape callback is a function that returns the + dynamic shape of the extra input. """ return {} diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index fc37c1e557a..6d2904142eb 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -4,7 +4,7 @@ import os import types from contextlib import contextmanager, nullcontext -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -17,6 +17,7 @@ AutoConfig, AutoModelForCausalLM, AutoModelForImageTextToText, + AutoProcessor, AutoTokenizer, PretrainedConfig, ) @@ -27,7 +28,7 @@ WEIGHTS_NAME, ) -from ..custom_ops.attention_interface import CacheConfig +from ..custom_ops.attention_interface import CacheConfig, Dim, DynamicShapeCallback from ..utils._config import deep_merge_dicts from ..utils.logger import ad_logger from .factory import ModelFactory, ModelFactoryRegistry @@ -366,3 +367,61 @@ def _get_max_position_embeddings_config(self) -> Dict[str, Any]: @property def automodel_from_config(self): return AutoModelForImageTextToText.from_config + + @property + def autotokenizer_from_pretrained(self): + return AutoTokenizer.from_pretrained + return AutoProcessor.from_pretrained + + @staticmethod + def _simple_forward( + model: nn.Module, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + pixel_values: torch.Tensor, + ): + """A simple forward pass for the model to functionalize the args. + + This follows the standard function signature as expected by factory.py. + """ + return type(model).forward( + model, + input_ids=input_ids, + position_ids=position_ids, + pixel_values=pixel_values, + ) + + def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, DynamicShapeCallback]]: + """Return a dictionary of extra inputs for the model. + + Returns: + A dictionary of extra inputs for the model where the key corresponds to the argument + name and the value corresponds to a tuple of (example_input, dynamic_shape_callback). + The dynamic shape callback is a function that returns the dynamic shape of the extra + input. + """ + + def _get_dynamic_shape(): + # return {} + return { + # TODO (lucaslie): how to set default values for dynamic shapes? + 0: Dim("img_batch_size", max=10), + 2: Dim("img_height", min=32, max=512), + 3: Dim("img_width", min=32, max=512), + } + + # TODO (lucaslie): try with both zero tensor and random tensor to activate/deactivate the + # vision branch + pixel_values = torch.ones(4, 3, 336, 336).to(torch.bfloat16) + pixel_values[1] = torch.zeros(3, 336, 336).to(torch.bfloat16) + pixel_values[3] = torch.zeros(3, 336, 336).to(torch.bfloat16) + none_pixel_values = torch.zeros(0, 3, 336, 336).to(torch.bfloat16) + return { + "pixel_values": ( + # TODO: figure out how to automatically setdtype?? --> maybe just comes from the + # InputProcessor as well?? + pixel_values, + none_pixel_values, + _get_dynamic_shape, + ), + } diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py b/tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py new file mode 100644 index 00000000000..eeb931bec7c --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py @@ -0,0 +1,185 @@ +"""A patch to handle vision branch in Llama4ForConditionalGeneration.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers import Llama4ForConditionalGeneration +from transformers.models.llama4.modeling_llama4 import Llama4CausalLMOutputWithPast + +from ...export.interface import BaseExportPatch, ExportPatchRegistry + + +# Copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama4.py#L1651 +# With some modifications that won't affect current execution logic: +# 1. Vison branch managed by torch.cond to enable both text-only and text+image input during runtime. +# 2. Input arg `image_sizes` are set to none +# as the input to torch.cond true/false branch needs fixed argument type during export +# 3. Do not return `image_hidden_states` as it is calculated inside the vision branch +# and invisible to the function outside. +def _forward_with_cond( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + image_sizes: torch.Tensor = None, # image_sizes set as None + **lm_kwargs, +) -> Union[Tuple, Llama4CausalLMOutputWithPast]: + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer + if vision_feature_layer is not None + else self.config.vision_config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + def _vision_branch(inputs_embeds, pixel_values, input_ids): + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + image_sizes=None, + ) + original_inputs_embeds_shape = inputs_embeds.shape + + vision_flat = image_features.view(-1, image_features.size(-1)) + projected_vision_flat = self.multi_modal_projector(vision_flat) + + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + final_mask = special_image_mask.to(inputs_embeds.device) + inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) + + final_mask_1d = final_mask[..., 0].reshape(-1) + # num_tokens_to_fill = final_mask_1d.sum() + + # This condition statement breaks torch.export: + # TODO: sanity check on the inputs for this + # if num_tokens_to_fill != projected_vision_flat.size(0): + # raise ValueError( + # f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, " + # f"but multi_modal_projector returned {projected_vision_flat.size(0)}" + # ) + + expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1)) + inputs_embeds.masked_scatter_(expanded_mask, projected_vision_flat) + + return inputs_embeds.view(original_inputs_embeds_shape) + + def _no_vision_branch(inputs_embeds, pixel_values, input_ids): + return inputs_embeds + + # decide by whether there is any non-zero pixel_values + has_image: torch.Tensor = torch.any(pixel_values != 0) + + inputs_embeds = torch.cond( + has_image, + _vision_branch, + _no_vision_branch, + (inputs_embeds, pixel_values, input_ids), + ) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + if attention_mask is not None: + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) + shift_logits = logits[..., :-1, :][ + shift_attention_mask.to(logits.device) != 0 + ].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1).to(shift_logits.device), + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Llama4CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=None, # skip outputting this for simplicity + ) + + +@ExportPatchRegistry.register("hf_llama4_vision") +class Llama4VisionPatch(BaseExportPatch): + """Patch for Llama4ForConditionalGeneration to make it compatible with torch.export. + + This patch replaces the forward method of Llama4ForConditionalGeneration with + a version that uses the torch.cond to handle the optional vision branch. + """ + + def _apply_patch(self): + """Apply the Llama4 vision patch.""" + # Store original forward method + self.original_values["Llama4ForConditionalGeneration.forward"] = ( + Llama4ForConditionalGeneration.forward + ) + + # Apply patch by replacing the forward method + Llama4ForConditionalGeneration.forward = _forward_with_cond + + def _revert_patch(self): + """Revert the Llama4 vision patch.""" + # Restore original forward method + Llama4ForConditionalGeneration.forward = self.original_values[ + "Llama4ForConditionalGeneration.forward" + ] diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 353774c755a..156ab5f9c96 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -1,6 +1,6 @@ -from itertools import chain +from collections import defaultdict from types import SimpleNamespace -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch from torch._prims_common import DeviceLikeType @@ -112,8 +112,8 @@ def build_from_config(cls, ad_config: AutoDeployConfig): device = str(device) # pass in extra arguments defined by the model factory - for name, (example_input, dynamic_shape_callback) in factory.get_extra_inputs().items(): - seq_info.add_extra_arg(name, example_input, dynamic_shape_callback) + for name, args in factory.get_extra_inputs().items(): + seq_info.add_extra_arg(name, *args) # TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__, # ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm. @@ -184,6 +184,7 @@ def _prepare_inputs( input_pos: List[int] = [] last_logit_only: List[bool] = [] page_assignments: List[List[int]] = [] + extra_args: Dict[str, List[torch.Tensor]] = defaultdict(list) # look at context requests first for request in context_requests: @@ -194,6 +195,15 @@ def _prepare_inputs( request.py_batch_idx = request.seq_slot last_logit_only.append(True) + # get cache indices + cache_indices = kv_cache_manager.get_cache_indices(request) + page_assignments.append(cache_indices) + + # store extra arguments + if request.py_multimodal_data is not None: + for k, v in request.py_multimodal_data.items(): + extra_args[k].append(v) + # look at generate requests next # TODO: we should also handle extend requests (for speculative decoding) here for request in gen_requests: @@ -210,15 +220,17 @@ def _prepare_inputs( # return all logits last_logit_only.append(False) - # extract cache information for all requests - for request in chain(context_requests, gen_requests): # get cache indices cache_indices = kv_cache_manager.get_cache_indices(request) page_assignments.append(cache_indices) # update the sequence info object now - si = self.cache_seq_interface.info - si.nest_sequences(input_ids, input_pos=input_pos, page_assignments=page_assignments) + self.cache_seq_interface.info.nest_sequences( + input_ids, + input_pos=input_pos, + page_assignments=page_assignments, + **extra_args, + ) return last_logit_only def _compute_logits(self) -> List[torch.Tensor]: diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py index bbe72650b4e..8249ea04849 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py @@ -56,14 +56,21 @@ def _apply( cm.info.set_example_sequence() # export the model to a graph module - gm = torch_export_to_gm( - model, - args=cm.args, - dynamic_shapes=cm.dynamic_shapes, - clone=self.config.clone_state_dict, - strict=self.config.strict, - patch_list=self.config.patch_list, - ) + + # TODO: revert. this is just a hack to run export with debugger + # torch.export don't always work together nicely. But I can set a breakpoin here + # and then manually create the gm in the debugger console and then continue + if False: + gm = None + else: + gm = torch_export_to_gm( + model, + args=cm.args, + dynamic_shapes=cm.dynamic_shapes, + clone=self.config.clone_state_dict, + strict=self.config.strict, + patch_list=self.config.patch_list, + ) # this is a clean graph by definition since it was just exported info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index aa793d30ea6..31882b46491 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -80,7 +80,7 @@ def __init__( self) # TODO: make it weakref self._executor_config = executor_config self._is_pytorch_backend = getattr(self._executor_config, "backend", - None) == "pytorch" + None) in ["pytorch", "_autodeploy"] if global_mpi_size() > 1: logger.set_rank(self.global_rank) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 5b440e8b90e..2751b656979 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -398,6 +398,14 @@ def generate_async( f"The inputs must be type str or list of int, but got {type(inputs)}" ) + import torch + pixel_values = torch.zeros(2, 3, 336, 336).to(torch.bfloat16) + pixel_values[0] += 0.0078 + pixel_values[1] -= 0.4961 + multimodal_params = MultimodalParams(multimodal_data={ + "pixel_values": pixel_values, + }) + self._check_arguments( len(prompt_token_ids), len(query_token_ids) if query_token_ids is not None else 0, diff --git a/tests/unittest/_torch/auto_deploy/integration/test_llama4_vlm_export.py b/tests/unittest/_torch/auto_deploy/integration/test_llama4_vlm_export.py index 596b7ff50dc..aa290491c41 100644 --- a/tests/unittest/_torch/auto_deploy/integration/test_llama4_vlm_export.py +++ b/tests/unittest/_torch/auto_deploy/integration/test_llama4_vlm_export.py @@ -190,7 +190,10 @@ def test_build_run_llama4_vlm(): "content": [ {"type": "image", "image": img1}, {"type": "image", "image": img2}, - {"type": "text", "text": "What's the difference?"}, + { + "type": "text", + "text": "Describe what you see in the two images and their differences.", + }, ], }, ] From e6e0ca2397a0b0d993cfbd2934058282b11da23c Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Wed, 30 Jul 2025 11:42:17 -0700 Subject: [PATCH 03/11] example inputs optionally via factory Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- .../custom_ops/attention_interface.py | 77 ++++--------------- .../_torch/auto_deploy/models/factory.py | 26 +++++-- tensorrt_llm/_torch/auto_deploy/models/hf.py | 63 +++++++++++++-- .../_torch/auto_deploy/shim/ad_executor.py | 4 +- .../transform/library/export_to_gm.py | 2 +- 5 files changed, 96 insertions(+), 76 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 0916df03b0e..4ac926d233e 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -50,7 +50,6 @@ class SequenceInfo: ### EXTRA ARGUMENTS PROVIDED TO THE INTERFACE ################################################## Those are extra arguments that can be provided to the interface and they are stored as follows: - _extra_args: dictionary of extra arguments with currently active values. - - _extra_example_inputs: dictionary of example inputs to the extra arguments. - _extra_none_inputs: dictionary of none inputs to the extra arguments. NOTE: we assume that extra arguments are *optional* arguments to the model. However, we cannot represent them via `None` since fx graphs require a fixed input type. Instead, @@ -143,7 +142,6 @@ def __post_init__(self): # EXTRA TENSOR FIELDS self._extra_args: Dict[str, torch.Tensor] = {} - self._extra_example_inputs: Dict[str, torch.Tensor] = {} self._extra_none_inputs: Dict[str, torch.Tensor] = {} self._extra_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None self._extra_dynamic_shapes_callbacks: Dict[str, DynamicShapeCallback] = {} @@ -359,7 +357,6 @@ def _move_dict(d: Dict[str, torch.Tensor]) -> None: d[k] = v.to(*args, **kwargs) _move_dict(self._extra_args) - _move_dict(self._extra_example_inputs) _move_dict(self._extra_none_inputs) def reset(self) -> None: @@ -378,63 +375,26 @@ def reset(self) -> None: self.cache_loc[:] = torch.arange(self.num_pages, dtype=torch.int, device=self.device) self.pages_per_seq.fill_(1) - def set_example_sequence(self) -> None: + def set_example_sequence(self, input_ids: Optional[torch.Tensor] = None, **kwargs) -> None: """Set an example sequence useful for testing and export purposes.""" self.reset() - bs, seq_len = min(2, self.max_batch_size), min(4, self.max_seq_len) - input_ids = torch.ones( # noqa - bs, - seq_len, - dtype=torch.int, - device=self.device, + + # use a best guess default for input_ids if not provided + if input_ids is None: + bs, seq_len = min(2, self.max_batch_size), min(4, self.max_seq_len) + input_ids = torch.ones( + bs, + seq_len, + dtype=torch.int, + device=self.device, + ) + + # make sure that all extra arguments are provided + assert self._extra_args.keys() <= kwargs.keys(), ( + f"Missing extra args: {self._extra_args.keys() - kwargs.keys()}" ) - # TODO (lucaslie): seems we have hit a road block using generic example inputs for export - # with VLMs. We need to probably switch to having the factory provide an example input that - # is then being tokenized inside the factory. - # WHY: for VLMs we need to hit these special tokens representing images. No way we can do - # that with a generic example input. - # fmt: off - input_ids2 = [[ - 200000, 200005, 1556, 200006, 368, 200080, 200090, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200081, 200080, - 200090, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200081, 74777, 290, 5326, 43, 200008, 200005, 140680, - 200006, 368 - ] for _ in range(2)] - # fmt: on - - self.nest_sequences(input_ids2, **self._extra_example_inputs) + self.nest_sequences(input_ids, **kwargs) def set_max_num_tokens_sample(self) -> None: """Set an example sequence with max_num_tokens.""" @@ -626,7 +586,6 @@ def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]: def add_extra_arg( self, name: str, - example_input: torch.Tensor, none_input: torch.Tensor, dynamic_shape_callback: Optional[DynamicShapeCallback] = None, ) -> None: @@ -634,7 +593,6 @@ def add_extra_arg( Args: name: The name of the extra argument. - example_input: Example input value of the extra argument. none_input: None input value of the extra argument. dynamic_shape_callback: The callback to get the dynamic shape of the extra argument. @@ -642,8 +600,7 @@ def add_extra_arg( """ assert name not in self._named_args().keys(), f"Extra argument {name} already exists" - self._extra_args[name] = example_input.to(self.device) - self._extra_example_inputs[name] = example_input.to(self.device) + self._extra_args[name] = none_input.to(self.device) self._extra_none_inputs[name] = none_input.to(self.device) if dynamic_shape_callback is None: diff --git a/tensorrt_llm/_torch/auto_deploy/models/factory.py b/tensorrt_llm/_torch/auto_deploy/models/factory.py index ea36f3c9d49..fd76412f6f5 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/factory.py +++ b/tensorrt_llm/_torch/auto_deploy/models/factory.py @@ -206,16 +206,30 @@ def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType): device: The device to load the model on. """ - def get_extra_inputs( - self, - ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor, DynamicShapeCallback]]: + def get_example_inputs(self) -> Dict[str, torch.Tensor]: + """Return a dictionary of example inputs for the model. + + This function can be overwritten by a factory when it requires a specific example input to + in order to run through export. + + Returns: + A dictionary of example inputs for the model where the key corresponds to the argument + name and the value corresponds to the example input. + """ + return {} + + def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, DynamicShapeCallback]]: """Return a dictionary of extra inputs for the model. Returns: A dictionary of extra inputs for the model where the key corresponds to the argument - name and the value corresponds to a tuple of (example_input, none_input, - dynamic_shape_callback). The dynamic shape callback is a function that returns the - dynamic shape of the extra input. + name and the value corresponds to a tuple of (none_input, dynamic_shape_callback): + - `none_input`: The none input value of the extra input indicating the tensor + value corresponding to the equivalent of the None input. `None` is not supported + as we require the input to be a tensor. Hence, this none_input acts as a + placeholder for the None input. + - `dynamic_shape_callback`: A function that returns the dynamic shape of the extra + input. """ return {} diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index 6d2904142eb..1f544e9c28c 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -391,6 +391,55 @@ def _simple_forward( pixel_values=pixel_values, ) + def get_example_inputs(self) -> Dict[str, torch.Tensor]: + """Return a dictionary of example inputs for the model.""" + # fmt: off + input_ids = [[ + 200000, 200005, 1556, 200006, 368, 200080, 200090, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200081, 200080, + 200090, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, + 200092, 200081, 74777, 290, 5326, 43, 200008, 200005, 140680, + 200006, 368 + ] for _ in range(2)] + # fmt: on + pixel_values = torch.ones(4, 3, 336, 336) + pixel_values[1] = torch.zeros(3, 336, 336) + pixel_values[3] = torch.zeros(3, 336, 336) + return { + "input_ids": input_ids, + "pixel_values": pixel_values.to(torch.bfloat16), + } + def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, DynamicShapeCallback]]: """Return a dictionary of extra inputs for the model. @@ -406,21 +455,21 @@ def _get_dynamic_shape(): return { # TODO (lucaslie): how to set default values for dynamic shapes? 0: Dim("img_batch_size", max=10), - 2: Dim("img_height", min=32, max=512), - 3: Dim("img_width", min=32, max=512), + 2: Dim("img_height", min=32, max=2048), + 3: Dim("img_width", min=32, max=2048), } # TODO (lucaslie): try with both zero tensor and random tensor to activate/deactivate the # vision branch - pixel_values = torch.ones(4, 3, 336, 336).to(torch.bfloat16) - pixel_values[1] = torch.zeros(3, 336, 336).to(torch.bfloat16) - pixel_values[3] = torch.zeros(3, 336, 336).to(torch.bfloat16) - none_pixel_values = torch.zeros(0, 3, 336, 336).to(torch.bfloat16) + # pixel_values = torch.ones(4, 3, 336, 336).to(torch.bfloat16) + # pixel_values[1] = torch.zeros(3, 336, 336).to(torch.bfloat16) + # pixel_values[3] = torch.zeros(3, 336, 336).to(torch.bfloat16) + none_pixel_values = torch.zeros(0, 3, 336, 336) return { "pixel_values": ( # TODO: figure out how to automatically setdtype?? --> maybe just comes from the # InputProcessor as well?? - pixel_values, + # pixel_values, none_pixel_values, _get_dynamic_shape, ), diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 156ab5f9c96..4598056c7bb 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -112,8 +112,8 @@ def build_from_config(cls, ad_config: AutoDeployConfig): device = str(device) # pass in extra arguments defined by the model factory - for name, args in factory.get_extra_inputs().items(): - seq_info.add_extra_arg(name, *args) + for name, (none_input, dynamic_shape_callback) in factory.get_extra_inputs().items(): + seq_info.add_extra_arg(name, none_input, dynamic_shape_callback) # TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__, # ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm. diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py index 8249ea04849..3412f67bb19 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py @@ -53,7 +53,7 @@ def _apply( model = gm.get_submodule("factory_model") # set the example sequence - cm.info.set_example_sequence() + cm.info.set_example_sequence(**factory.get_example_inputs()) # export the model to a graph module From b64ad5221ef6f9cab9619973545834d392639393 Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Wed, 30 Jul 2025 12:24:19 -0700 Subject: [PATCH 04/11] use processor to set example input Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/models/hf.py | 46 ++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index 1f544e9c28c..3a6f30fadb3 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -12,6 +12,7 @@ from accelerate.utils import modeling from huggingface_hub import HfApi, snapshot_download from huggingface_hub.utils import HFValidationError, filter_repo_objects, validate_repo_id +from PIL import Image from torch._prims_common import DeviceLikeType from transformers import ( AutoConfig, @@ -393,6 +394,51 @@ def _simple_forward( def get_example_inputs(self) -> Dict[str, torch.Tensor]: """Return a dictionary of example inputs for the model.""" + + def _prep_seq(text, img1, img2): + return [ + { + "role": "user", + "content": [ + {"type": "image", "image": img1}, + {"type": "image", "image": img2}, + { + "type": "text", + "text": text, + }, + ], + }, + ] + + # Create a batch of conversations (batch_size = 2) + batch_messages = [ + _prep_seq( + "Describe what you see in the two images and their differences.", + Image.new("RGB", (16, 16), color=(128, 128, 128)), + Image.new("RGB", (16, 16), color=(64, 64, 64)), + ), + _prep_seq( + "What are the main differences between these two images?", + Image.new("RGB", (16, 16), color=(255, 0, 0)), + Image.new("RGB", (16, 16), color=(0, 255, 0)), + ), + ] + + processor = AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs) + inputs = processor.apply_chat_template( + batch_messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=True, + ) + + return { + "input_ids": inputs["input_ids"], + "pixel_values": inputs["pixel_values"], + } + # fmt: off input_ids = [[ 200000, 200005, 1556, 200006, 368, 200080, 200090, 200092, 200092, From 6eff0dae0bf489a3740df66d795dc64f665117db Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Wed, 30 Jul 2025 19:16:02 -0700 Subject: [PATCH 05/11] full scale hf chat template Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- examples/auto_deploy/build_and_run_ad.py | 98 ++++++++----------- .../custom_ops/attention_interface.py | 1 + tensorrt_llm/_torch/auto_deploy/llm.py | 93 +++++++++++++++--- .../_torch/auto_deploy/models/factory.py | 9 ++ tensorrt_llm/_torch/auto_deploy/models/hf.py | 93 ++++-------------- .../_torch/auto_deploy/shim/demollm.py | 4 +- tensorrt_llm/llmapi/llm.py | 8 -- 7 files changed, 154 insertions(+), 152 deletions(-) diff --git a/examples/auto_deploy/build_and_run_ad.py b/examples/auto_deploy/build_and_run_ad.py index 9c1f8d951d6..6ce5f4e692e 100644 --- a/examples/auto_deploy/build_and_run_ad.py +++ b/examples/auto_deploy/build_and_run_ad.py @@ -26,6 +26,9 @@ # Global torch config, set the torch compile cache to fix up to llama 405B torch._dynamo.config.cache_size_limit = 20 +# simple string, TRT-LLM style text-only prompt or full-scale HF message template +PromptInput = Union[str, Dict, List[Dict]] + class PromptConfig(BaseModel): """Prompt configuration. @@ -35,17 +38,27 @@ class PromptConfig(BaseModel): """ batch_size: int = Field(default=2, description="Number of queries") - queries: Union[str, List[str]] = Field( + queries: Union[PromptInput, List[PromptInput]] = Field( default_factory=lambda: [ + # OPTION 1: simple text prompt "How big is the universe? ", - "In simple words and in a single sentence, explain the concept of gravity: ", - "How to fix slicing in golf? ", - "Where is the capital of Iceland? ", - "How big is the universe? ", - "In simple words and in a single sentence, explain the concept of gravity: ", - "How to fix slicing in golf? ", - "Where is the capital of Iceland? ", - ] + # OPTION 2: wrapped text prompt for TRT-LLM + {"prompt": "In simple words and a single sentence, explain the concept of gravity: "}, + # OPTION 3: a full-scale HF message template (this one works for text-only models!) + # Learn more about chat templates: https://huggingface.co/docs/transformers/en/chat_templating + # and multi-modal templates: https://huggingface.co/docs/transformers/en/chat_templating_multimodal + [ + { + "role": "user", + "content": "How to fix slicing in golf?", + } + ], + # More prompts... + {"prompt": "Where is the capital of Iceland? "}, + ], + description="Example queries to prompt the model with. We support both TRT-LLM text-only " + "queries via the 'prompt' key and full-scale HF message template called via " + "apply_chat_template.", ) sp_kwargs: Dict[str, Any] = Field( default_factory=lambda: {"max_tokens": 100, "top_k": 200, "temperature": 1.0}, @@ -59,10 +72,28 @@ def model_post_init(self, __context: Any): NOTE (lucaslie): has to be done with model_post_init to ensure it's always run. field validators are only run if a value is provided. """ - queries = [self.queries] if isinstance(self.queries, str) else self.queries + queries = self.queries if isinstance(self.queries, list) else [self.queries] batch_size = self.batch_size queries = queries * (batch_size // len(queries) + 1) - self.queries = queries[:batch_size] + queries = queries[:batch_size] + + # now let's standardize the queries for the LLM api to understand them + queries_processed = [] + for query in queries: + if isinstance(query, str): + queries_processed.append({"prompt": query}) + elif isinstance(query, dict): + queries_processed.append(query) + elif isinstance(query, list): + queries_processed.append( + { + "prompt": "Fake prompt. Check out messages field for the HF chat template.", + "messages": query, # contains the actual HF chat template + } + ) + else: + raise ValueError(f"Invalid query type: {type(query)}") + self.queries = queries_processed @field_validator("sp_kwargs", mode="after") @classmethod @@ -237,56 +268,13 @@ def main(config: Optional[ExperimentConfig] = None): llm = build_llm_from_config(config) - # just run config.prompt.queries with our special token sequence including special image tokens - # fmt: off - input_ids = [[ - 200000, 200005, 1556, 200006, 368, 200080, 200090, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200081, 200080, - 200090, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200081, 51212, 1780, 650, 2556, 310, 290, 1472, - 8392, 341, 1357, 13492, 26, 200008, 200005, 140680, 200006, - 368 - ] for _ in range(2)] - # fmt: on - # prompt the model and print its output ad_logger.info("Running example prompts...") # now let's try piping through multimodal data outs = llm.generate( - input_ids, - # config.prompt.queries, + config.prompt.queries, sampling_params=SamplingParams(**config.prompt.sp_kwargs), ) results = {"prompts_and_outputs": print_outputs(outs)} diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 4ac926d233e..68dfd9da4ee 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -577,6 +577,7 @@ def nest_sequences( else: self._extra_args[name] = none_input + # TODO (lucaslie): how strict do we wanna be here? Should we just warn/ignore instead? assert not extra_args, f"Extra arguments {extra_args.keys()} not found" def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]: diff --git a/tensorrt_llm/_torch/auto_deploy/llm.py b/tensorrt_llm/_torch/auto_deploy/llm.py index 999a024fb38..a1fd8c27135 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm.py +++ b/tensorrt_llm/_torch/auto_deploy/llm.py @@ -1,19 +1,83 @@ import types -from typing import List, Optional +from typing import Any, Dict, List, Optional, Tuple from ...executor.result import CompletionOutput -from ...inputs.registry import create_input_processor +from ...inputs.registry import DefaultInputProcessor, ExtraProcessedInputs from ...llmapi.llm import RequestOutput, _TorchLLM -from ...llmapi.tokenizer import TokenizerBase, tokenizer_factory +from ...llmapi.tokenizer import TokenizerBase, TransformersTokenizer, tokenizer_factory +from ...sampling_params import SamplingParams from .distributed import common as dist_ad from .llm_args import LlmArgs +from .models.factory import ModelFactory from .shim.demollm import DemoGenerationExecutor +class ADInputProcessor(DefaultInputProcessor): + """Input processor for AutoDeploy backend. + + This is a wrapper to either support standard TRT-LLM text-only input processing or use HF's + message chat template system to process multimodal inputs. + """ + + def __init__(self, tokenizer: TokenizerBase, processor: Optional[Any] = None): + super().__init__(None, None, tokenizer) + # NOTE: HF's tokenizer/processor that has the apply_chat_template method + self.processor = processor or tokenizer.tokenizer + + def __call__( + self, inputs: Dict[str, Any], sampling_params: SamplingParams + ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + # construct kwargs to reflect DefaultInputProcessor + kwargs = { + "add_special_tokens": sampling_params.add_special_tokens, + } + if sampling_params.truncate_prompt_tokens is not None: + kwargs = { + "truncation": True, + "max_length": sampling_params.truncate_prompt_tokens, + } + # check for messages field and if yes, use the apply_chat_template method + if "messages" in inputs: + # TODO: we don't really need this but it makes for a good sanity check. Consider + # removing this in the future if we need to speed things up. + prompt = self.processor.apply_chat_template( + inputs["messages"], + add_generation_prompt=True, + tokenize=False, + ) + inputs["prompt"] = prompt + + all_args = self.processor.apply_chat_template( + inputs["messages"], + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=False, # there shouldn't be a need for padding ever... + return_attention_mask=False, + **kwargs, + ) + # TODO: is there a more reliable way to avoid the attention_mask here? + all_args.pop("attention_mask", None) + + # TODO: can we avoid the extra tolist() here eventually? + token_ids = all_args.pop("input_ids") + assert token_ids.shape[0] == 1, "messages should be unbatched at this point." + return token_ids[0].tolist(), {"multimodal_data": all_args} if all_args else None + else: + token_ids = self.tokenizer.encode(inputs["prompt"], **kwargs) + return token_ids, None + + class LLM(_TorchLLM): """LLM class is the main class for running an LLM model using AutoDeploy backend.""" args: LlmArgs + _factory: ModelFactory + + @property + def factory(self) -> ModelFactory: + return self._factory def __init__(self, *args, **kwargs): kwargs["backend"] = "_autodeploy" @@ -23,16 +87,14 @@ def _try_load_tokenizer(self) -> Optional[TokenizerBase]: if self.args.skip_tokenizer_init: return None - factory = self.args.create_factory() - return tokenizer_factory(factory.init_tokenizer()) + return tokenizer_factory(self._factory.init_tokenizer()) def _validate_args_for_torch_backend(self, kwargs: dict) -> None: """We don't need to validate args for AutoDeploy backend for now.""" pass - def _prefetch_model(self): - """Prefetch the model for the LLM.""" - self.args.create_factory().prefetch_checkpoint() + def _create_input_processor(self) -> ADInputProcessor: + return ADInputProcessor(self.tokenizer, self._factory.init_processor()) def _build_model(self): """Build the model for the LLM. @@ -40,13 +102,21 @@ def _build_model(self): This is a wrapper around the regular build model method that prefetches the model with the factory. """ + # create and store a factory + self._factory = self.args.create_factory() + # prefetch model with factory - self._prefetch_model() + self._factory.prefetch_checkpoint() # NOTE (lucaslie): do regular build model, we bypass the regular LLM CachedModelLoader in # _autodeploy backend. super()._build_model() + # now correct input processor + assert isinstance(self.input_processor, DefaultInputProcessor) + assert isinstance(self.tokenizer, TransformersTokenizer) + self.input_processor = self._create_input_processor() + class DemoLLM(LLM): """A simple LLM class to demo the LLM interface while debugging the e2e workflow. @@ -61,9 +131,10 @@ def __init__(self, **kwargs): self.runtime_context = None # prefetch model and load tokenizer - self._prefetch_model() + self._factory = self.args.create_factory() + self._factory.prefetch_checkpoint() self._tokenizer = self._try_load_tokenizer() - self.input_processor = create_input_processor(None, self.tokenizer) + self.input_processor = self._create_input_processor() # construct demo executor + engine self._executor = DemoGenerationExecutor( diff --git a/tensorrt_llm/_torch/auto_deploy/models/factory.py b/tensorrt_llm/_torch/auto_deploy/models/factory.py index fd76412f6f5..11c676bd935 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/factory.py +++ b/tensorrt_llm/_torch/auto_deploy/models/factory.py @@ -113,6 +113,15 @@ def init_tokenizer(self) -> Optional[Any]: """ return None + def init_processor(self) -> Optional[Any]: + """Initialize the (multi-modal) processor for the model. + + Returns: + The initialized processor for the model. If the processor is not available, then this + method should return None. + """ + return None + def prefetch_checkpoint(self, force: bool = False): """Try or skip prefetching the checkpoint for the model and tokenizer. diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index 3a6f30fadb3..b0f8b645105 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -110,10 +110,6 @@ def __init__(self, *args, **kwargs): def autoconfig_from_pretrained(self): return AutoConfig.from_pretrained - @property - def autotokenizer_from_pretrained(self): - return AutoTokenizer.from_pretrained - # TODO (@lucaslie): Do we ever want to switch to from_pretrained? @property def automodel_from_config(self): @@ -202,7 +198,7 @@ def init_tokenizer(self) -> Optional[Any]: """Initialize the tokenizer—either a custom name or the model's default.""" if self.tokenizer is None: return None - return self.autotokenizer_from_pretrained(self.tokenizer, **self.tokenizer_kwargs) + return AutoTokenizer.from_pretrained(self.tokenizer, **self.tokenizer_kwargs) @staticmethod def _get_ignore_patterns(repo_id: str, skip_prefetch_weights: bool) -> List[str]: @@ -369,10 +365,18 @@ def _get_max_position_embeddings_config(self) -> Dict[str, Any]: def automodel_from_config(self): return AutoModelForImageTextToText.from_config - @property - def autotokenizer_from_pretrained(self): - return AutoTokenizer.from_pretrained - return AutoProcessor.from_pretrained + def init_tokenizer(self) -> Optional[Any]: + """Initialize the tokenizer—either a custom name or the model's default.""" + processor = self.init_processor() + if processor is None: + return None + return processor.tokenizer + + def init_processor(self) -> Optional[Any]: + """Initialize the processor for the model.""" + if self.tokenizer is None: + return None + return AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs) @staticmethod def _simple_forward( @@ -402,12 +406,9 @@ def _prep_seq(text, img1, img2): "content": [ {"type": "image", "image": img1}, {"type": "image", "image": img2}, - { - "type": "text", - "text": text, - }, + {"type": "text", "text": text}, ], - }, + } ] # Create a batch of conversations (batch_size = 2) @@ -432,6 +433,7 @@ def _prep_seq(text, img1, img2): return_dict=True, return_tensors="pt", padding=True, + return_attention_mask=False, ) return { @@ -439,53 +441,6 @@ def _prep_seq(text, img1, img2): "pixel_values": inputs["pixel_values"], } - # fmt: off - input_ids = [[ - 200000, 200005, 1556, 200006, 368, 200080, 200090, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200081, 200080, - 200090, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, 200092, - 200092, 200081, 74777, 290, 5326, 43, 200008, 200005, 140680, - 200006, 368 - ] for _ in range(2)] - # fmt: on - pixel_values = torch.ones(4, 3, 336, 336) - pixel_values[1] = torch.zeros(3, 336, 336) - pixel_values[3] = torch.zeros(3, 336, 336) - return { - "input_ids": input_ids, - "pixel_values": pixel_values.to(torch.bfloat16), - } - def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, DynamicShapeCallback]]: """Return a dictionary of extra inputs for the model. @@ -497,7 +452,6 @@ def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, DynamicShapeCallback """ def _get_dynamic_shape(): - # return {} return { # TODO (lucaslie): how to set default values for dynamic shapes? 0: Dim("img_batch_size", max=10), @@ -505,18 +459,5 @@ def _get_dynamic_shape(): 3: Dim("img_width", min=32, max=2048), } - # TODO (lucaslie): try with both zero tensor and random tensor to activate/deactivate the - # vision branch - # pixel_values = torch.ones(4, 3, 336, 336).to(torch.bfloat16) - # pixel_values[1] = torch.zeros(3, 336, 336).to(torch.bfloat16) - # pixel_values[3] = torch.zeros(3, 336, 336).to(torch.bfloat16) none_pixel_values = torch.zeros(0, 3, 336, 336) - return { - "pixel_values": ( - # TODO: figure out how to automatically setdtype?? --> maybe just comes from the - # InputProcessor as well?? - # pixel_values, - none_pixel_values, - _get_dynamic_shape, - ), - } + return {"pixel_values": (none_pixel_values, _get_dynamic_shape)} diff --git a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py index 4f3c35a21eb..7b2dc2c8606 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py @@ -90,7 +90,7 @@ def generate_tokens_batched( ) assert sampling_params.best_of == 1, "Best-of is not supported." - # set up sequence info object + # set up sequence info object for decode phase sequence_info = self.cache_seq_interface.info sequence_info.reset() total_lens = [len(r.prompt_token_ids) for r in requests] @@ -113,7 +113,7 @@ def _generate_single_step(idx: int): token_ids, _ = self._decode_tokens(logits_last, sampling_params) # [b,1] - # update sequence info accordingly for next step + # update sequence info accordingly for next step (generate phase) input_pos_next = sequence_info.input_positions seq_lens_current = sequence_info.sequence_lengths input_pos_next = [ip + sl for ip, sl in zip(input_pos_next, seq_lens_current)] diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 2751b656979..5b440e8b90e 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -398,14 +398,6 @@ def generate_async( f"The inputs must be type str or list of int, but got {type(inputs)}" ) - import torch - pixel_values = torch.zeros(2, 3, 336, 336).to(torch.bfloat16) - pixel_values[0] += 0.0078 - pixel_values[1] -= 0.4961 - multimodal_params = MultimodalParams(multimodal_data={ - "pixel_values": pixel_values, - }) - self._check_arguments( len(prompt_token_ids), len(query_token_ids) if query_token_ids is not None else 0, From a325fa1e857250039d19e4ebda6bebc629d1c6eb Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Wed, 30 Jul 2025 20:10:45 -0700 Subject: [PATCH 06/11] demollm support for multi-modal input Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- .../_torch/auto_deploy/shim/demollm.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py index 7b2dc2c8606..4e0e1911390 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py @@ -1,8 +1,9 @@ """A demo LLM api to for debugging and testing purposes of e2e workflows.""" import gc +from collections import defaultdict from queue import Empty -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.multiprocessing as mp @@ -93,10 +94,22 @@ def generate_tokens_batched( # set up sequence info object for decode phase sequence_info = self.cache_seq_interface.info sequence_info.reset() - total_lens = [len(r.prompt_token_ids) for r in requests] + + input_ids = [] + total_lens = [] + extra_args: Dict[str, List[torch.Tensor]] = defaultdict(list) + + for request in requests: + total_lens.append(len(request.prompt_token_ids)) + input_ids.append(request.prompt_token_ids) + if request.multimodal_params is not None: + for k, v in request.multimodal_params.multimodal_data.items(): + extra_args[k].append(v) + sequence_info.nest_sequences( - input_ids=[r.prompt_token_ids for r in requests], + input_ids=input_ids, page_assignments=self._assign_pages(total_lens), + **extra_args, ) # setup objects we want to track for the output From e5cbfcc6e9f0ad140863659ef67afaacdbabb5f2 Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Wed, 30 Jul 2025 20:35:04 -0700 Subject: [PATCH 07/11] fix demollm for world_size >=1 Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/shim/demollm.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py index 4e0e1911390..028d0ac5429 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py @@ -11,6 +11,7 @@ from ....executor import GenerationExecutor from ....executor.request import GenerationRequest from ....executor.result import CompletionOutput, GenerationResult +from ....inputs.multimodal import MultimodalParams from ....sampling_params import SamplingParams from ...pyexecutor.sampler import greedy_search_sampling_batch, top_k_sampling_batch from ..distributed import common as dist_ad @@ -35,8 +36,11 @@ def __init__(self, *args, **kwargs) -> None: self.queue = mp.Queue() @torch.inference_mode() - def __call__(self, requests: GenerationRequest) -> mp.Queue: + def __call__( + self, requests: GenerationRequest, multimodal_params: Optional[MultimodalParams] + ) -> mp.Queue: """Generate tokens and put the results in a queue and return the queue.""" + requests.multimodal_params = multimodal_params output = self.generate_tokens_batched([requests])[0] self.queue.put(output) return self.queue @@ -274,6 +278,7 @@ def _run_engine( def _unpack(inputs) -> GenerationRequest: args, kwargs = inputs # unpack the inputs request: GenerationRequest = args[0] + request.multimodal_params: Optional[MultimodalParams] = args[1] return request engine = DemoEngine.build_from_config(**engine_kwargs) @@ -328,8 +333,11 @@ def submit(self, request: GenerationRequest) -> GenerationResult: request.set_id(client_id) # submit request to our demo engine and store results + # NOTE: when returning from this function, the reference request.multimodal_params will + # be cleared immediately. So we pass it in explicitly to maintain a reference even when + # requests get submitted asynchronously. result = GenerationResult(request) - result.queue = self.engine_executor(request) + result.queue = self.engine_executor(request, request.multimodal_params) return result From 97d661644662a6cabcf9ffb9bdc87a30d9873bd2 Mon Sep 17 00:00:00 2001 From: Chenghao Zhang Date: Fri, 8 Aug 2025 13:41:10 -0700 Subject: [PATCH 08/11] Qwen2.5 VL Bringup Signed-off-by: Chenghao Zhang --- .../_torch/auto_deploy/custom_ops/__init__.py | 1 + .../_torch/auto_deploy/custom_ops/qwen_ops.py | 273 ++++++++++++++++++ tensorrt_llm/_torch/auto_deploy/models/hf.py | 44 ++- .../auto_deploy/models/patches/qwen2_5_vl.py | 153 ++++++++++ .../auto_deploy/transformations/transform.py | 4 + 5 files changed, 467 insertions(+), 8 deletions(-) create mode 100644 tensorrt_llm/_torch/auto_deploy/custom_ops/qwen_ops.py create mode 100644 tensorrt_llm/_torch/auto_deploy/models/patches/qwen2_5_vl.py diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py index 23a80b94d74..a23265a42f8 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py @@ -7,6 +7,7 @@ from .linear import * from .mla import * from .quant import * +from .qwen_ops import * from .rms_norm import * from .torch_attention import * from .torch_backend_attention import * diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/qwen_ops.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/qwen_ops.py new file mode 100644 index 00000000000..2afc11d2768 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/qwen_ops.py @@ -0,0 +1,273 @@ +"""Custom ops required for Qwen2.5-VL vision model export.""" + +from typing import Tuple + +import torch + + +@torch.library.custom_op( + "auto_deploy::qwen_vision_data_dependent_ops", mutates_args=(), device_types=["cuda", "cpu"] +) +def qwen_vision_data_dependent_ops( + grid_thw: torch.Tensor, + hidden_states: torch.Tensor, + spatial_merge_size: int, + window_size: int, + patch_size: int, + spatial_merge_unit: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Custom op that encapsulates all data-dependent operations for Qwen2.5-VL vision model. + + Args: + grid_thw: Grid dimensions [num_images, 3] where each row is [t, h, w] + hidden_states: Hidden states after patch embedding + spatial_merge_size: Spatial merge size for the vision model + window_size: Window size for windowed attention + patch_size: Vision transformer patch size + spatial_merge_unit: Spatial merge unit + + Returns: + processed_hidden_states: Hidden states after window indexing + pos_emb_cos: Cosine part of position embeddings + pos_emb_sin: Sine part of position embeddings + cu_window_seqlens: Cumulative window sequence lengths + cu_seqlens: Cumulative sequence lengths for full attention + reverse_indices: Indices to reverse the window ordering (for final step) + """ + device = grid_thw.device + + # === ROT_POS_EMB CALCULATION === + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // spatial_merge_size, + spatial_merge_size, + w // spatial_merge_size, + spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // spatial_merge_size, + spatial_merge_size, + w // spatial_merge_size, + spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + + # Rotary embedding calculation (matching original implementation) + dim = 40 # head_dim // 2 for Qwen2.5-VL + theta = 10000.0 + # Match original implementation exactly - no explicit device specification + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + # Move to device after calculation to match the flow + inv_freq = inv_freq.to(device) + torch.save(inv_freq, "inv_freq_patched.pt") + seq = torch.arange(max_grid_size, device=device, dtype=torch.float) + freqs = torch.outer(seq, inv_freq) + rotary_pos_emb_full = freqs + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + + # === GET_WINDOW_INDEX CALCULATION === + window_index = [] + cu_window_seqlens = [0] + window_index_id = 0 + vit_merger_window_size = window_size // spatial_merge_size // patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h.item() // spatial_merge_size, + grid_w.item() // spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w + ) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = torch.nn.functional.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + + window_index = torch.cat(window_index, dim=0) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + # === CU_SEQLENS CALCULATION === + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + + # === REVERSE_INDICES CALCULATION === + reverse_indices = torch.argsort(window_index) + + # === ADVANCED INDEXING OPERATIONS === + # Process hidden_states with window indexing + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] # ADVANCED INDEXING + hidden_states = hidden_states.reshape(seq_len, -1) + + # Process rotary_pos_emb with window indexing + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] # ADVANCED INDEXING + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + + pos_emb_cos = emb.cos() + pos_emb_sin = emb.sin() + + return hidden_states, pos_emb_cos, pos_emb_sin, cu_window_seqlens, cu_seqlens, reverse_indices + + +@qwen_vision_data_dependent_ops.register_fake +def qwen_vision_data_dependent_ops_fake( + grid_thw: torch.Tensor, + hidden_states: torch.Tensor, + spatial_merge_size: int, + window_size: int, + patch_size: int, + spatial_merge_unit: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Fake implementation for symbolic tracing. + Returns tensors with correct symbolic shapes but dummy values. + """ + device = grid_thw.device + dtype = grid_thw.dtype + + # Calculate total sequence length from hidden_states + seq_len = hidden_states.shape[0] + + # Fake processed hidden_states: same shape as input + processed_hidden_states = torch.zeros_like(hidden_states) + + # Fake position embeddings: separate cos and sin tensors with doubled embedding dim + emb_dim = 80 # head_dim for Qwen2.5-VL + pos_emb_cos = torch.zeros(seq_len, emb_dim, device=device, dtype=hidden_states.dtype) + pos_emb_sin = torch.zeros(seq_len, emb_dim, device=device, dtype=hidden_states.dtype) + + # Fake cu_window_seqlens: varies based on windowing, but approximate + num_windows = (seq_len // spatial_merge_unit // 16) + 2 # rough estimate + cu_window_seqlens = torch.arange(num_windows + 1, device=device, dtype=dtype) * 16 + + # Fake cu_seqlens: [num_images + 1] + cu_seqlens = torch.cumsum(grid_thw[:, 1] * grid_thw[:, 2] * grid_thw[:, 0], dim=0) + cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + cu_seqlens = cu_seqlens.to(dtype=dtype) + + # Fake reverse_indices: identity mapping for fake implementation + reverse_indices = torch.arange(seq_len, device=device, dtype=torch.long) + + return ( + processed_hidden_states, + pos_emb_cos, + pos_emb_sin, + cu_window_seqlens, + cu_seqlens, + reverse_indices, + ) + + +@torch.library.custom_op( + "auto_deploy::qwen_prepare_attention_mask", mutates_args=(), device_types=["cuda", "cpu"] +) +def qwen_prepare_attention_mask( + hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, attn_implementation: str +) -> torch.Tensor: + """ + Custom op for _prepare_attention_mask to handle data-dependent operations. + + Based on Qwen2_5_VisionTransformerPretrainedModel._prepare_attention_mask + + Returns a special marker tensor (empty tensor) when attention_mask should be None. + """ + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + if attn_implementation == "flash_attention_2": + # Return empty tensor as marker for None + return torch.empty(0, device=hidden_states.device, dtype=hidden_states.dtype) + + seq_length = hidden_states.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(hidden_states.dtype).min, + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[ + ..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i] + ] = 0 + return attention_mask + + +@qwen_prepare_attention_mask.register_fake +def qwen_prepare_attention_mask_fake( + hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, attn_implementation: str +) -> torch.Tensor: + """Fake implementation for symbolic tracing.""" + if attn_implementation == "flash_attention_2": + # Return empty tensor as marker for None + return torch.empty(0, device=hidden_states.device, dtype=hidden_states.dtype) + + seq_length = hidden_states.shape[0] + attention_mask = torch.zeros( + [1, 1, seq_length, seq_length], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + return attention_mask + + +@torch.library.custom_op( + "auto_deploy::qwen_reverse_indexing", mutates_args=(), device_types=["cuda", "cpu"] +) +def qwen_reverse_indexing( + hidden_states: torch.Tensor, reverse_indices: torch.Tensor +) -> torch.Tensor: + """ + Custom op for reverse indexing operation to handle advanced indexing with symbolic indices. + """ + return hidden_states[reverse_indices, :] + + +@qwen_reverse_indexing.register_fake +def qwen_reverse_indexing_fake( + hidden_states: torch.Tensor, reverse_indices: torch.Tensor +) -> torch.Tensor: + """Fake implementation for symbolic tracing.""" + return torch.zeros_like(hidden_states) diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index b0f8b645105..96737d46f6e 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -384,16 +384,21 @@ def _simple_forward( input_ids: torch.Tensor, position_ids: torch.Tensor, pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor = None, + attention_mask: torch.Tensor = None, ): """A simple forward pass for the model to functionalize the args. This follows the standard function signature as expected by factory.py. """ + attention_mask = torch.ones_like(input_ids) return type(model).forward( model, input_ids=input_ids, position_ids=position_ids, pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + attention_mask=attention_mask, ) def get_example_inputs(self) -> Dict[str, torch.Tensor]: @@ -438,7 +443,12 @@ def _prep_seq(text, img1, img2): return { "input_ids": inputs["input_ids"], - "pixel_values": inputs["pixel_values"], + "pixel_values": torch.zeros( + 14308, 1176 + ), # Example shape for export (will be dynamic at runtime) + "image_grid_thw": torch.tensor( + [[1, 98, 146]], dtype=torch.long + ), # Example grid for export (will be dynamic at runtime) } def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, DynamicShapeCallback]]: @@ -451,13 +461,31 @@ def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, DynamicShapeCallback input. """ - def _get_dynamic_shape(): + # Use dynamic shapes to handle variable-sized multimodal inputs + # This allows different image sizes and batch sizes to work properly + + # Example tensors with reasonable default sizes + pixel_values_tensor = torch.zeros(14308, 1176) # Example size for export + image_grid_thw_tensor = torch.tensor([[1, 98, 146]], dtype=torch.long) # Example grid size + + # Define dynamic shapes based on PyTorch's constraint analysis + # PyTorch has determined the exact mathematical constraints from the model + + def pixel_values_dynamic_shape(): + # PyTorch constraint analysis shows num_patches must be divisible by 4 + # Use the suggested approach: num_patches = 4 * _num_patches + _num_patches = Dim("_num_patches", min=1, max=25000) + num_patches = 4 * _num_patches return { - # TODO (lucaslie): how to set default values for dynamic shapes? - 0: Dim("img_batch_size", max=10), - 2: Dim("img_height", min=32, max=2048), - 3: Dim("img_width", min=32, max=2048), + 0: num_patches, # Number of image patches (must be multiple of 4) } - none_pixel_values = torch.zeros(0, 3, 336, 336) - return {"pixel_values": (none_pixel_values, _get_dynamic_shape)} + def image_grid_thw_dynamic_shape(): + # TODO: add dynamic shape for image_grid_thw, + # the pytorch returned error when I change the first batch dim to dynamic + return {} + + return { + "pixel_values": (pixel_values_tensor, pixel_values_dynamic_shape), + "image_grid_thw": (image_grid_thw_tensor, image_grid_thw_dynamic_shape), + } diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/qwen2_5_vl.py b/tensorrt_llm/_torch/auto_deploy/models/patches/qwen2_5_vl.py new file mode 100644 index 00000000000..597d913ccca --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/qwen2_5_vl.py @@ -0,0 +1,153 @@ +"""Patches for Qwen2.5-VL model to make it compatible with torch.export.""" + +import torch + +from ...export.interface import BaseExportPatch, ExportPatchRegistry + + +def _patched_vision_forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs): + """ + Patched forward method for Qwen2.5-VL vision transformer. + + This patch moves ALL data-dependent operations into custom ops to make + the method fully compatible with torch.export's symbolic tracing. + """ + # Original patch_embed processing (no data dependencies) + hidden_states = self.patch_embed(hidden_states) + + # Use custom op to handle ALL data-dependent operations including advanced indexing + hidden_states, pos_emb_cos, pos_emb_sin, cu_window_seqlens, cu_seqlens, reverse_indices = ( + torch.ops.auto_deploy.qwen_vision_data_dependent_ops( + grid_thw, + hidden_states, + self.spatial_merge_size, + self.window_size, + self.patch_size, + self.spatial_merge_unit, + ) + ) + + # Create position_embeddings tuple from separate tensors + position_embeddings = (pos_emb_cos, pos_emb_sin) + + # Process through attention blocks (using custom op for attention mask) + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + + # Use custom op for _prepare_attention_mask (handles data-dependent operations) + attention_mask_tensor = torch.ops.auto_deploy.qwen_prepare_attention_mask( + hidden_states, cu_seqlens_now, self.config._attn_implementation + ) + + # Convert empty tensor marker back to None + attention_mask = None if attention_mask_tensor.numel() == 0 else attention_mask_tensor + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + **kwargs, + ) + # Final merger (no data dependencies) + hidden_states = self.merger(hidden_states) + + # Use custom op for reverse indexing (handles data-dependent operations) + hidden_states = torch.ops.auto_deploy.qwen_reverse_indexing(hidden_states, reverse_indices) + + return hidden_states + + +def _patched_rope_forward(self, x, position_ids): + """ + Patched forward method for Qwen2_5_VLRotaryEmbedding to handle 'meta' device during torch.export. + + This patch fixes the device_type issue when using torch.export where device becomes 'meta'. + """ + # In contrast to other models, Qwen2_5_VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = ( + self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + ) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + # Fix device type handling for torch.export (where device can be 'meta') + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type not in ["mps", "meta"] + else "cpu" + ) + + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def _patched_create_causal_mask(**kwargs): + """ + Patched create_causal_mask function that returns None to avoid issues during model export/execution. + + This is a temporary workaround for compatibility issues with create_causal_mask during + TensorRT-LLM model execution. + """ + return None + + +@ExportPatchRegistry.register("qwen2_5_vl_vision") +class Qwen2_5_VLVisionPatch(BaseExportPatch): + """ + Patch for Qwen2.5-VL model to make it compatible with torch.export. + + This patch applies fixes for: + 1. Vision transformer forward method (using custom ops for data-dependent operations) + 2. Rotary embedding forward method (handling 'meta' device during export) + 3. create_causal_mask function (returns None to avoid execution issues) + """ + + def _apply_patch(self): + import transformers.models.qwen2_5_vl.modeling_qwen2_5_vl as qwen_modeling + from transformers.masking_utils import create_causal_mask + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLRotaryEmbedding, + ) + + # Store original methods + self.original_values["vision_forward"] = Qwen2_5_VisionTransformerPretrainedModel.forward + self.original_values["rope_forward"] = Qwen2_5_VLRotaryEmbedding.forward + self.original_values["create_causal_mask"] = create_causal_mask + self.original_values["qwen_create_causal_mask"] = qwen_modeling.create_causal_mask + + # Apply patches + Qwen2_5_VisionTransformerPretrainedModel.forward = _patched_vision_forward + Qwen2_5_VLRotaryEmbedding.forward = _patched_rope_forward + + # Patch the create_causal_mask function in both the masking_utils module + # and the locally imported reference in the qwen2_5_vl modeling module + import transformers.masking_utils + + transformers.masking_utils.create_causal_mask = _patched_create_causal_mask + qwen_modeling.create_causal_mask = _patched_create_causal_mask + + def _revert_patch(self): + import transformers.models.qwen2_5_vl.modeling_qwen2_5_vl as qwen_modeling + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLRotaryEmbedding, + ) + + # Restore original methods + Qwen2_5_VisionTransformerPretrainedModel.forward = self.original_values["vision_forward"] + Qwen2_5_VLRotaryEmbedding.forward = self.original_values["rope_forward"] + + # Restore original create_causal_mask function in both locations + import transformers.masking_utils + + transformers.masking_utils.create_causal_mask = self.original_values["create_causal_mask"] + qwen_modeling.create_causal_mask = self.original_values["qwen_create_causal_mask"] diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py index ed247753f83..18b8c4679ac 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py @@ -68,6 +68,8 @@ def __call__(self, cm: CachedSequenceInterface) -> nn.Module: ############################################################################################ # RUN PATTERN MATCHER TRANSFORMATIONS TO STANDARDIZE GRAPH REPRESENTATION ############################################################################################ + # Create both forward and backward visualizations of the initial graph + # visualize_model(egm, filename="initial_graph.svg", max_nodes=1000) # Match MoE pattern match_moe_pattern(egm) @@ -205,6 +207,8 @@ def __call__(self, cm: CachedSequenceInterface) -> nn.Module: dynamic_shapes=cm.dynamic_shapes, compiler_kwargs=compiler_kwargs, ) + # visualize_model(egm_compiled, max_nodes=1000) + cm.info.reset() torch.cuda.empty_cache() From 42597a3d67d5f5e16ece195fcb27fb0407c3dc4d Mon Sep 17 00:00:00 2001 From: Chenghao Zhang Date: Fri, 8 Aug 2025 16:24:03 -0700 Subject: [PATCH 09/11] Hack the hf.py to use the new keys for weights Signed-off-by: Chenghao Zhang --- tensorrt_llm/_torch/auto_deploy/models/hf.py | 62 ++++++++++++++++---- 1 file changed, 52 insertions(+), 10 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index 96737d46f6e..b1791ac1621 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn -from accelerate import init_empty_weights, load_checkpoint_in_model +from accelerate import init_empty_weights from accelerate.utils import modeling from huggingface_hub import HfApi, snapshot_download from huggingface_hub.utils import HFValidationError, filter_repo_objects, validate_repo_id @@ -310,16 +310,58 @@ def _prefetch_checkpoint(self, model_name_or_path: str, skip_prefetch_weights: b def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType): """Load the checkpoint into the model.""" # identify the most relevant checkpoint file - ckpt_file = self._get_checkpoint_file(self.model) + index_json_path = self._get_checkpoint_file(self.model) + checkpoint_dir = os.path.dirname(index_json_path) + + # 2. Manually load and process the index file to map shards. + with open(index_json_path, "r") as f: + index_data = json.load(f) + + weight_map = index_data["weight_map"] + from collections import defaultdict + + from safetensors import safe_open + + # Invert the map to group tensor names by the file they are in. + shards = defaultdict(list) + for tensor_name, shard_file in weight_map.items(): + shards[shard_file].append(tensor_name) + + # 3. Load all tensors from their respective shards into one dictionary. + state_dict = {} + for shard_file, tensor_names in shards.items(): + shard_path = os.path.join(checkpoint_dir, shard_file) + # Use safe_open to efficiently load only the needed tensors from each shard. + with safe_open(shard_path, framework="pt", device="cpu") as f: + for tensor_name in tensor_names: + state_dict[tensor_name] = f.get_tensor(tensor_name) + + # 4. Perform the key remapping on the now fully assembled state_dict. + conversion_mapping = { + "^visual": "model.visual", + r"^model(?!\.(language_model|visual))": "model.language_model", + } + import re + + keys_to_process = list(state_dict.keys()) + for key in keys_to_process: + new_key = key + for pattern, replacement in conversion_mapping.items(): + new_key = re.sub(pattern, replacement, new_key) + + if new_key != key: + state_dict[new_key] = state_dict.pop(key) + # reuse the load checkpoint utility from accelerate - with hf_load_state_dict_with_device(device): - # Set `full_state_dict=False` to skip Accelerate's FSDP weight sync logic. - # Internally, load_checkpoint_in_model → set_model_state_dict → _load_model_state_dict, - # which collects local model params, syncs weights from checkpoint, and applies them via - # model.load_state_dict. - # This sync step can interfere with load_hooks by mixing raw checkpoint weights and - # model-transformed weights,leading to unexpected key mismatches or format issues. - load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False) + # with hf_load_state_dict_with_device(device): + # # Set `full_state_dict=False` to skip Accelerate's FSDP weight sync logic. + # # Internally, load_checkpoint_in_model → set_model_state_dict → _load_model_state_dict, + # # which collects local model params, syncs weights from checkpoint, and applies them via + # # model.load_state_dict. + # # This sync step can interfere with load_hooks by mixing raw checkpoint weights and + # # model-transformed weights,leading to unexpected key mismatches or format issues. + # load_checkpoint_in_model(model, checkpoint=state_dict, full_state_dict=False) + model.load_state_dict(state_dict) def _load_quantization_config(self, fetched_dir: str): """Load the quantization config from the model directory if not done already.""" From 964981db8659a9cb3da8478dbb70cf816efd5fe1 Mon Sep 17 00:00:00 2001 From: Chenghao Zhang Date: Mon, 11 Aug 2025 14:10:08 -0700 Subject: [PATCH 10/11] Support dynamic shapes for grid_thw Signed-off-by: Chenghao Zhang --- tensorrt_llm/_torch/auto_deploy/models/hf.py | 8 +- .../auto_deploy/models/patches/qwen2_5_vl.py | 124 ++++++++++++++++++ 2 files changed, 128 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index b1791ac1621..8f4a958f734 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -489,7 +489,7 @@ def _prep_seq(text, img1, img2): 14308, 1176 ), # Example shape for export (will be dynamic at runtime) "image_grid_thw": torch.tensor( - [[1, 98, 146]], dtype=torch.long + [[1, 98, 146], [1, 98, 146]], dtype=torch.long ), # Example grid for export (will be dynamic at runtime) } @@ -523,9 +523,9 @@ def pixel_values_dynamic_shape(): } def image_grid_thw_dynamic_shape(): - # TODO: add dynamic shape for image_grid_thw, - # the pytorch returned error when I change the first batch dim to dynamic - return {} + return { + 0: Dim("_num_images", min=1, max=10), + } return { "pixel_values": (pixel_values_tensor, pixel_values_dynamic_shape), diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/qwen2_5_vl.py b/tensorrt_llm/_torch/auto_deploy/models/patches/qwen2_5_vl.py index 597d913ccca..c09abeeabbd 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/qwen2_5_vl.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/qwen2_5_vl.py @@ -1,5 +1,7 @@ """Patches for Qwen2.5-VL model to make it compatible with torch.export.""" +from typing import Optional + import torch from ...export.interface import BaseExportPatch, ExportPatchRegistry @@ -99,6 +101,118 @@ def _patched_create_causal_mask(**kwargs): return None +def _patched_get_image_features_flat( + self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None +): + """ + WAR: Return flat image features directly; avoid Python list/split that specializes num_images. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + return image_embeds + + +def _patched_model_forward_export_war( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + **kwargs, +): + """ + WAR forward: (a) do not call get_rope_index; synthesize tensor position_ids if None, + (b) consume flat image features directly (no split/cat). + """ + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModelOutputWithPast + from transformers.utils import is_torchdynamo_compiling + + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum() + n_image_features = image_embeds.shape[0] + if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match tokens:{n_image_tokens}, features:{n_image_features}" + ) + mask = input_ids == self.config.image_token_id + image_mask = mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + # Keep videos flat as well + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum() + n_video_features = video_embeds.shape[0] + if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match tokens:{n_video_tokens}, features:{n_video_features}" + ) + mask = input_ids == self.config.video_token_id + video_mask = mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + # Bypass get_rope_index: synthesize simple tensor position_ids if None + if position_ids is None: + bsz, seqlen = inputs_embeds.shape[0], inputs_embeds.shape[1] + base = torch.arange(seqlen, device=inputs_embeds.device, dtype=inputs_embeds.dtype) + position_ids = base.view(1, 1, -1).expand(3, bsz, -1) + # keep rope_deltas as zeros per batch + self.rope_deltas = torch.zeros( + (bsz, 1), device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + output = Qwen2_5_VLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + return output if return_dict else output.to_tuple() + + @ExportPatchRegistry.register("qwen2_5_vl_vision") class Qwen2_5_VLVisionPatch(BaseExportPatch): """ @@ -108,6 +222,7 @@ class Qwen2_5_VLVisionPatch(BaseExportPatch): 1. Vision transformer forward method (using custom ops for data-dependent operations) 2. Rotary embedding forward method (handling 'meta' device during export) 3. create_causal_mask function (returns None to avoid execution issues) + 4. WAR for multimodal export: flat image features and synthetic tensor position_ids """ def _apply_patch(self): @@ -115,6 +230,7 @@ def _apply_patch(self): from transformers.masking_utils import create_causal_mask from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLModel, Qwen2_5_VLRotaryEmbedding, ) @@ -123,10 +239,15 @@ def _apply_patch(self): self.original_values["rope_forward"] = Qwen2_5_VLRotaryEmbedding.forward self.original_values["create_causal_mask"] = create_causal_mask self.original_values["qwen_create_causal_mask"] = qwen_modeling.create_causal_mask + # Store originals for model-level methods + self.original_values["model_get_image_features"] = Qwen2_5_VLModel.get_image_features + self.original_values["model_forward"] = Qwen2_5_VLModel.forward # Apply patches Qwen2_5_VisionTransformerPretrainedModel.forward = _patched_vision_forward Qwen2_5_VLRotaryEmbedding.forward = _patched_rope_forward + Qwen2_5_VLModel.get_image_features = _patched_get_image_features_flat + Qwen2_5_VLModel.forward = _patched_model_forward_export_war # Patch the create_causal_mask function in both the masking_utils module # and the locally imported reference in the qwen2_5_vl modeling module @@ -139,12 +260,15 @@ def _revert_patch(self): import transformers.models.qwen2_5_vl.modeling_qwen2_5_vl as qwen_modeling from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLModel, Qwen2_5_VLRotaryEmbedding, ) # Restore original methods Qwen2_5_VisionTransformerPretrainedModel.forward = self.original_values["vision_forward"] Qwen2_5_VLRotaryEmbedding.forward = self.original_values["rope_forward"] + Qwen2_5_VLModel.get_image_features = self.original_values["model_get_image_features"] + Qwen2_5_VLModel.forward = self.original_values["model_forward"] # Restore original create_causal_mask function in both locations import transformers.masking_utils From d76647177cc4fb8628db7640cc7a43b54bab7e44 Mon Sep 17 00:00:00 2001 From: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Date: Thu, 7 Aug 2025 21:34:28 -0700 Subject: [PATCH 11/11] cherry-pick skip pattern Signed-off-by: Chenghao Zhang --- .../transform/library/export_to_gm.py | 1 - .../transformations/library/attention.py | 16 ++++++++----- .../transformations/library/kvcache.py | 7 +++--- .../transformations/library/rope.py | 9 +++++--- .../auto_deploy/transformations/transform.py | 23 +++++++++++++++---- .../_torch/auto_deploy/utils/node_utils.py | 9 ++++++++ .../auto_deploy/utils/pattern_matcher.py | 1 + 7 files changed, 49 insertions(+), 17 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py index 3412f67bb19..2868fbabe84 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py @@ -71,7 +71,6 @@ def _apply( strict=self.config.strict, patch_list=self.config.patch_list, ) - # this is a clean graph by definition since it was just exported info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py index 6d024aaadd1..cc998d2fbd0 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py @@ -10,7 +10,7 @@ from ...custom_ops.attention_interface import AttentionDescriptor from ...utils.logger import ad_logger from ...utils.node_utils import is_op -from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern +from ...utils.pattern_matcher import ADPatternMatcherPass, Match, register_ad_pattern from .._graph import canonicalize_graph, lift_to_meta @@ -24,14 +24,13 @@ def _apply_pattern( patterns = ADPatternMatcherPass() register_fn(patterns) num_matches = patterns.apply(gm.graph) - if num_matches > 0: with lift_to_meta(gm) if shape_prop else nullcontext(): canonicalize_graph(gm, shape_prop=shape_prop) ad_logger.info(f"Found and matched {num_matches} {pattern_name} pattern(s)") -def match_repeat_kv(gm: GraphModule) -> None: +def match_repeat_kv(gm: GraphModule, accept_match_fn: Callable[[Match], bool]) -> None: """ Match and replace the repeat_kv pattern with torch.ops.auto_deploy.torch_attention_repeat_kv. """ @@ -51,24 +50,25 @@ def register_repeat_kv(patterns: ADPatternMatcherPass): torch.ops.aten.expand.default: (int,), }, scalar_workaround={"n_rep": dummy_args[1]}, + extra_check=accept_match_fn, ) _apply_pattern(gm, "Repeat KV", register_repeat_kv, shape_prop=True) -def match_eager_attention(gm: GraphModule) -> None: +def match_eager_attention(gm: GraphModule, accept_match_fn: Callable[[Match], bool]) -> None: """ Match and replace the eager attention pattern with torch.ops.auto_deploy.torch_attention_sdpa. """ def register_eager_attention(patterns: ADPatternMatcherPass): for pattern_config in _get_sfdp_patterns(): - register_ad_pattern(**pattern_config, patterns=patterns) + register_ad_pattern(**pattern_config, patterns=patterns, extra_check=accept_match_fn) _apply_pattern(gm, "Eager Attention", register_eager_attention) -def match_grouped_attention(gm: GraphModule) -> None: +def match_grouped_attention(gm: GraphModule, accept_match_fn: Callable[[Match], bool]) -> None: """ Match and replace the grouped attention pattern with torch.ops.auto_deploy.torch_attention_grouped_sdpa. @@ -92,6 +92,7 @@ def register_grouped_attention(patterns: ADPatternMatcherPass): patterns=patterns, dummy_args=dummy_args_1, scalar_workaround={"scale": scale, "dropout_p": dropout, "n_rep": n_rep}, + extra_check=accept_match_fn, ) register_ad_pattern( search_fn=_grouped_attn_pattern_2, @@ -102,6 +103,7 @@ def register_grouped_attention(patterns: ADPatternMatcherPass): "scale": scale, "dropout_p": dropout, }, + extra_check=accept_match_fn, ) register_ad_pattern( search_fn=_grouped_attn_pattern_3, @@ -109,6 +111,7 @@ def register_grouped_attention(patterns: ADPatternMatcherPass): patterns=patterns, dummy_args=dummy_args_1, scalar_workaround={"scale": scale, "dropout_p": dropout, "n_rep": n_rep}, + extra_check=accept_match_fn, ) register_ad_pattern( search_fn=_grouped_attn_pattern_4, @@ -119,6 +122,7 @@ def register_grouped_attention(patterns: ADPatternMatcherPass): "scale": scale, "dropout_p": dropout, }, + extra_check=accept_match_fn, ) _apply_pattern(gm, "Grouped Attention", register_grouped_attention) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py index 6fd1d2b7b7f..bf7adfe94e3 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py @@ -28,11 +28,12 @@ def update_in_out_nodes(egm: GraphModule, cm: CachedSequenceInterface) -> None: # NOTE: for now, we wanna make sure we *only* return the final output and no hidden states. # Later on, we can revisit how to support returning hidden states. - assert len(output_nodes) == 1, "Expected exactly one output node!" - assert len(output_nodes[0].all_input_nodes) == 1, "Expected to only return final tensor output!" - ad_logger.info(f"Found {len(input_nodes)} input nodes and {len(output_nodes)} output nodes") + assert len(output_nodes) == 1, "Expected exactly one output node!" + # the following assert feels too restrictive.. + # assert len(output_nodes[0].all_input_nodes) == 1, "Expected to only return final tensor output!" + # Activate and add extra argument nodes new_args = cm.info.switch_to_cached_attn_inputs() for name in new_args: diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py index 65e7f7f614c..ecaf84aae2e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py @@ -47,7 +47,7 @@ def apply_rotary_emb( import operator from collections import defaultdict -from typing import Any, DefaultDict, Dict, Optional, Sequence +from typing import Any, Callable, DefaultDict, Dict, Optional, Sequence import torch from torch.fx import GraphModule, Node @@ -119,7 +119,7 @@ def _explicit_not_interleaved(match: Match) -> bool: return not any(isinstance(n, Node) and _match_input_interleave_pattern(n) for n in (q, k)) -def match_rope_pattern(gm: GraphModule) -> int: +def match_rope_pattern(gm: GraphModule, accept_match_fn: Callable[[Match], bool]) -> int: graph = gm.graph patterns = ADPatternMatcherPass() @@ -154,7 +154,7 @@ def match_rope_pattern(gm: GraphModule) -> int: dummy_args=dummy_explicit, op_ignore_types={torch.ops.aten.slice.Tensor: (int,)}, scalar_workaround={"unsqueeze_dim": 1}, - extra_check=_explicit_not_interleaved, + extra_check=lambda match: _explicit_not_interleaved(match) and accept_match_fn(match), ) register_ad_pattern( search_fn=_interleaved_rope_pattern, @@ -167,6 +167,7 @@ def match_rope_pattern(gm: GraphModule) -> int: torch.ops.aten.view.default: (int,), }, scalar_workaround={"unsqueeze_dim": 1}, + extra_check=accept_match_fn, ) register_ad_pattern( search_fn=_complex_rope_pattern, @@ -177,6 +178,7 @@ def match_rope_pattern(gm: GraphModule) -> int: torch.ops.aten.reshape.default: (int,), }, scalar_workaround={"unsqueeze_dim": 1}, + extra_check=accept_match_fn, ) register_ad_pattern( search_fn=_complex_rope_pattern, @@ -187,6 +189,7 @@ def match_rope_pattern(gm: GraphModule) -> int: torch.ops.aten.reshape.default: (int,), }, scalar_workaround={"unsqueeze_dim": 1}, + extra_check=accept_match_fn, ) num_matches = patterns.apply(graph) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py index 18b8c4679ac..72f12dfd39d 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py @@ -1,10 +1,14 @@ """High-level entrypoint to transform a model into an efficient inference model.""" import gc +from functools import partial import torch import torch.nn as nn +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_node_in_module +from tensorrt_llm._torch.auto_deploy.utils.pattern_matcher import Match + from ..compile import compile_and_capture from ..custom_ops.attention_interface import AttentionRegistry from ..distributed import common as dist_ad @@ -38,6 +42,17 @@ ) +def accept_match_fn(match: Match, name_to_skip: str) -> bool: + """ + Accept a match if it does not contain a node with the given name scope. + """ + for node in match.nodes: + if is_node_in_module(node, name_to_skip): + ad_logger.info(f"REJECTING MATCH: {match}") + return False + return True + + class InferenceOptimizer: def __init__(self, factory: ModelFactory, ad_config: AutoDeployConfig): self.factory = factory @@ -75,19 +90,19 @@ def __call__(self, cm: CachedSequenceInterface) -> nn.Module: match_moe_pattern(egm) # Match repeat_kv pattern - match_repeat_kv(egm) + match_repeat_kv(egm, partial(accept_match_fn, name_to_skip="visual")) # Match eager attention pattern - match_eager_attention(egm) + match_eager_attention(egm, partial(accept_match_fn, name_to_skip="visual")) # Match grouped attention pattern - match_grouped_attention(egm) + match_grouped_attention(egm, partial(accept_match_fn, name_to_skip="visual")) # Match attention layout expected by our backend match_attention_layout(egm, AttentionRegistry.get(self.ad_config.attn_backend)) # Match rope - match_rope_pattern(egm) + match_rope_pattern(egm, partial(accept_match_fn, name_to_skip="visual")) # Match RoPE layout expected by our backend match_rope_layout( diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 48f06c70e60..49d6e15cf32 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -422,3 +422,12 @@ def _get(name): raise RuntimeError(f"Could not find a value for '{name}' on op {op}") return [_get(n) for n in arg_names] + + +def is_node_in_module(node: Node, module_name: str) -> bool: + """Check if the node is in the given module.""" + try: + nn_module_list = list(node.meta["nn_module_stack"].keys())[-1] + return module_name in nn_module_list + except Exception: + return False diff --git a/tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py b/tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py index 00b535dec61..c4ad93b3d87 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py @@ -83,6 +83,7 @@ def _trace_to_gm(fn: Callable, args: Sequence[torch.Tensor]) -> GraphModule: Exports a function or Module into a GraphModule via torch_export_to_gm. """ module = fn if isinstance(fn, torch.nn.Module) else _WrapperModule(fn) + return torch_export_to_gm(module, tuple(args))