diff --git a/examples/auto_deploy/.gitignore b/examples/auto_deploy/.gitignore index 8e28b4431da..9836a37fc88 100644 --- a/examples/auto_deploy/.gitignore +++ b/examples/auto_deploy/.gitignore @@ -2,3 +2,5 @@ !.vscode benchmark_results.json *.png +# ignore config files that users might put here for debugging +*.yaml diff --git a/examples/auto_deploy/build_and_run_ad.py b/examples/auto_deploy/build_and_run_ad.py index 35879834db0..6ce5f4e692e 100644 --- a/examples/auto_deploy/build_and_run_ad.py +++ b/examples/auto_deploy/build_and_run_ad.py @@ -26,6 +26,9 @@ # Global torch config, set the torch compile cache to fix up to llama 405B torch._dynamo.config.cache_size_limit = 20 +# simple string, TRT-LLM style text-only prompt or full-scale HF message template +PromptInput = Union[str, Dict, List[Dict]] + class PromptConfig(BaseModel): """Prompt configuration. @@ -35,17 +38,27 @@ class PromptConfig(BaseModel): """ batch_size: int = Field(default=2, description="Number of queries") - queries: Union[str, List[str]] = Field( + queries: Union[PromptInput, List[PromptInput]] = Field( default_factory=lambda: [ + # OPTION 1: simple text prompt "How big is the universe? ", - "In simple words and in a single sentence, explain the concept of gravity: ", - "How to fix slicing in golf? ", - "Where is the capital of Iceland? ", - "How big is the universe? ", - "In simple words and in a single sentence, explain the concept of gravity: ", - "How to fix slicing in golf? ", - "Where is the capital of Iceland? ", - ] + # OPTION 2: wrapped text prompt for TRT-LLM + {"prompt": "In simple words and a single sentence, explain the concept of gravity: "}, + # OPTION 3: a full-scale HF message template (this one works for text-only models!) + # Learn more about chat templates: https://huggingface.co/docs/transformers/en/chat_templating + # and multi-modal templates: https://huggingface.co/docs/transformers/en/chat_templating_multimodal + [ + { + "role": "user", + "content": "How to fix slicing in golf?", + } + ], + # More prompts... + {"prompt": "Where is the capital of Iceland? "}, + ], + description="Example queries to prompt the model with. We support both TRT-LLM text-only " + "queries via the 'prompt' key and full-scale HF message template called via " + "apply_chat_template.", ) sp_kwargs: Dict[str, Any] = Field( default_factory=lambda: {"max_tokens": 100, "top_k": 200, "temperature": 1.0}, @@ -59,10 +72,28 @@ def model_post_init(self, __context: Any): NOTE (lucaslie): has to be done with model_post_init to ensure it's always run. field validators are only run if a value is provided. """ - queries = [self.queries] if isinstance(self.queries, str) else self.queries + queries = self.queries if isinstance(self.queries, list) else [self.queries] batch_size = self.batch_size queries = queries * (batch_size // len(queries) + 1) - self.queries = queries[:batch_size] + queries = queries[:batch_size] + + # now let's standardize the queries for the LLM api to understand them + queries_processed = [] + for query in queries: + if isinstance(query, str): + queries_processed.append({"prompt": query}) + elif isinstance(query, dict): + queries_processed.append(query) + elif isinstance(query, list): + queries_processed.append( + { + "prompt": "Fake prompt. Check out messages field for the HF chat template.", + "messages": query, # contains the actual HF chat template + } + ) + else: + raise ValueError(f"Invalid query type: {type(query)}") + self.queries = queries_processed @field_validator("sp_kwargs", mode="after") @classmethod @@ -239,6 +270,9 @@ def main(config: Optional[ExperimentConfig] = None): # prompt the model and print its output ad_logger.info("Running example prompts...") + + # now let's try piping through multimodal data + outs = llm.generate( config.prompt.queries, sampling_params=SamplingParams(**config.prompt.sp_kwargs), diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 13c91652bff..68dfd9da4ee 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -10,14 +10,17 @@ """ from abc import ABC, abstractmethod -from dataclasses import dataclass, field, fields -from typing import Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Type, Union +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Type, Union import torch from torch._ops import OpOverloadPacket from torch.export import Dim from torch.fx import Node +DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension +DynamicShapeCallback = Callable[[], DynamicShape] + @dataclass class CacheConfig: @@ -34,7 +37,7 @@ class SequenceInfo: between arguments that are originally part of the model/graph and arguments that are needed for the attention operator when we switch to cached+flattened attention. - # ORIGINAL MODEL ARGUMENTS ##################################################################### + ### ORIGINAL MODEL ARGUMENTS ################################################################### - input_ids: [id_0, ..., id_{s_total-1}] flattened sequence of [b, 1] or [1, s_total]. We use [b, 1] to denote generate-only batches. - position_ids: [pos_0, ..., pos_{s_total-1}] @@ -44,7 +47,17 @@ class SequenceInfo: NOTE: ``input_ids`` and ``position_ids`` are initially expected to be of shape [b, seq_len] before we switch to cached+flattened attention. - # EXTRA ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ############## + ### EXTRA ARGUMENTS PROVIDED TO THE INTERFACE ################################################## + Those are extra arguments that can be provided to the interface and they are stored as follows: + - _extra_args: dictionary of extra arguments with currently active values. + - _extra_none_inputs: dictionary of none inputs to the extra arguments. + NOTE: we assume that extra arguments are *optional* arguments to the model. However, we + cannot represent them via `None` since fx graphs require a fixed input type. Instead, + we require a special placeholder tensor to represent the `None` input. + - _extra_dynamic_shapes_callbacks: dictionary of callbacks to initialize the dynamic shapes of + the extra arguments. + + ### CACHE ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ############ - seq_len: [s_0, s_1, ..., s_{b-1}] such that s_total = sum(s_i) Describes how long each sequence is. For example, input_ids[:s_0] will correspond to sequence 0 in the batch and input_ids[s_0:s_1] will @@ -88,17 +101,6 @@ class SequenceInfo: # then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens). max_num_tokens: Optional[int] = None - ## [UPDATE WITH CARE] TENSOR FIELDS THAT WILL BE PASSED TO PREPARE_METADATA OP ################# - # input_ids MUST ALWAYS BE THE FIRST FIELD - input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.int)) - position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.long)) - - seq_len: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int)) - input_pos: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int)) - cache_loc: torch.Tensor = field(default_factory=lambda: torch.arange(1, dtype=torch.int)) - pages_per_seq: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int)) - ################################################################################################ - ## PRIVATE FIELDS ############################################################################## _sequence_lengths: List[int] = field(default_factory=list) _num_pages: int = 1 @@ -122,17 +124,8 @@ def __post_init__(self): self.max_batch_size, (total_tokens) // self.page_size + (total_tokens % self.page_size > 0), ) - self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int) - self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long) - self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int) - self.input_pos = torch.empty_like(self.seq_len) - self.cache_loc = torch.empty(self.num_pages, dtype=torch.int) - self.pages_per_seq = torch.empty_like(self.seq_len) - assert self.num_pages >= self.max_batch_size, ( - "num_pages must be greater than max_batch_size" - ) - # dynamic shape descriptors for tensor args - self._dynamic_shapes: Optional[Tuple[Dict[str, Dim]]] = None + # sanity check + assert self.num_pages >= self.max_batch_size, "num_pages can't be less than max_batch_size" # keep a list-like object of sequence lengths for simplicity as well self._sequence_lengths = [0] * self.max_batch_size @@ -140,6 +133,28 @@ def __post_init__(self): # indicator if extra args are activated that are needed for cached attention backends self._is_cached_attn = False + ### TENSOR FIELDS/ARGS FOR UNCACHED AND CACHED ATTENTION ################################### + # UNCACHED TENSOR FIELDS + self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int) + self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long) + self._uncached_arg_names = ["input_ids", "position_ids"] + self._uncached_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None + + # EXTRA TENSOR FIELDS + self._extra_args: Dict[str, torch.Tensor] = {} + self._extra_none_inputs: Dict[str, torch.Tensor] = {} + self._extra_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None + self._extra_dynamic_shapes_callbacks: Dict[str, DynamicShapeCallback] = {} + + # CACHED TENSOR FIELDS (for cached attention backends) + self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int) + self.input_pos = torch.empty_like(self.seq_len) + self.cache_loc = torch.empty(self.num_pages, dtype=torch.int) + self.pages_per_seq = torch.empty_like(self.seq_len) + self._cached_arg_names = ["seq_len", "input_pos", "cache_loc", "pages_per_seq"] + self._cached_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None + ############################################################################################ + # call reset once to initialize the tensors self.reset() @@ -147,52 +162,78 @@ def __post_init__(self): def device(self) -> torch.device: return self.input_pos.device + def _named_args( + self, include_extra_args: bool = True, include_cached_args: bool = True + ) -> Dict[str, torch.Tensor]: + args: Dict[str, torch.Tensor] = {} + for name in self._uncached_arg_names: + args[name] = getattr(self, name) + + if include_extra_args: + args.update(self._extra_args) + + if include_cached_args: + for name in self._cached_arg_names: + args[name] = getattr(self, name) + + return args + @property - def args(self) -> Tuple[torch.Tensor, ...]: - args = [] - for f in fields(self): - val = getattr(self, f.name) - if isinstance(val, torch.Tensor): - args.append(val) - if len(args) >= self._num_uncached_attn_args and not self._is_cached_attn: - break - return tuple(args) + def named_args(self) -> Dict[str, torch.Tensor]: + """Return a dictionary of named arguments.""" + return self._named_args(include_extra_args=True, include_cached_args=self._is_cached_attn) @property - def _num_uncached_attn_args(self) -> int: - """Return the number of original graph arguments expected by the model.""" - return 2 + def named_standard_args(self) -> Dict[str, torch.Tensor]: + """Return a dictionary of named standard arguments.""" + return self._named_args(include_extra_args=False, include_cached_args=self._is_cached_attn) @property - def _cached_attn_arg_names(self) -> List[str]: - """Return extra arg names for the prepare_metadata op beyond input_ids and position_ids. + def args(self) -> Tuple[torch.Tensor, ...]: + """Return a tuple of arguments.""" + return tuple(self.named_args.values()) - These extra args are needed once we switch from regular attention to inserting cached - attention ops in the model. - """ - return [f.name for f in fields(self) if isinstance(getattr(self, f.name), torch.Tensor)][ - self._num_uncached_attn_args : - ] + @property + def extra_args_for_prepare_metadata(self) -> Tuple: + """Return a tuple of extra (const, non-tensor) arguments for the prepare_metadata op.""" + return (self.page_size,) @property - def dynamic_shapes(self) -> Tuple[Dict[str, Dim]]: + def named_dynamic_shapes(self) -> Dict[str, Dict[str, Dim]]: """Return dynamic shapes of sequence info tensors. NOTE: will be lazily initialized since the Dim object is not picklable for multi-processing. """ - if self._dynamic_shapes is None: - # set up shape for input_ids and position_ids - dynamic_shapes = ({}, {}) + # lazy initialization of dynamic shapes with Dim objects + if self._uncached_dynamic_shapes is None: + # set up shape for uncached args (same for all, i.e., batch_size and seq_len) + bs_seq_len_shape: DynamicShape = {} if self.max_batch_size > 1: - dynamic_shapes[0][0] = Dim("batch_size", max=self.max_batch_size) - dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len) - # set up shape for position_ids (same as input_ids) - dynamic_shapes[1].update(dynamic_shapes[0]) - # set up shape for extra args - if self._is_cached_attn: - dynamic_shapes += ({},) * len(self._cached_attn_arg_names) - self._dynamic_shapes = dynamic_shapes - return self._dynamic_shapes + bs_seq_len_shape[0] = Dim("batch_size", max=self.max_batch_size) + bs_seq_len_shape[1] = Dim("seq_len", max=self.max_seq_len) + self._uncached_dynamic_shapes = {k: bs_seq_len_shape for k in self._uncached_arg_names} + + named_dynamic_shapes = self._uncached_dynamic_shapes.copy() + + # add dynamic shapes for extra args + if self._extra_dynamic_shapes is None: + self._extra_dynamic_shapes = { + k: callback() for k, callback in self._extra_dynamic_shapes_callbacks.items() + } + named_dynamic_shapes.update(self._extra_dynamic_shapes) + + # fixed shape for remaining cached attention args + if self._is_cached_attn: + if self._cached_dynamic_shapes is None: + self._cached_dynamic_shapes = {k: {} for k in self._cached_arg_names} + named_dynamic_shapes.update(self._cached_dynamic_shapes) + + return named_dynamic_shapes + + @property + def dynamic_shapes(self) -> Tuple[DynamicShape, ...]: + """Return dynamic shapes of sequence info tensors.""" + return tuple(self.named_dynamic_shapes.values()) @property def num_sequences(self) -> int: @@ -305,26 +346,18 @@ def switch_to_cached_attn_inputs(self) -> List[str]: """ assert not self._is_cached_attn, "Cached+flattened attention already activated" self._is_cached_attn = True - return self._cached_attn_arg_names + return self._cached_arg_names.copy() def to(self, *args, **kwargs) -> None: - for f in fields(self): - val = getattr(self, f.name) - if isinstance(val, torch.Tensor): - setattr(self, f.name, val.to(*args, **kwargs)) - - def sync(self, other: "SequenceInfo") -> None: - for f in fields(self): - val = getattr(self, f.name) - val_other = getattr(other, f.name) - if f.name in ["input_ids", "position_ids"]: - setattr(self, f.name, val_other.to(self.device)) - elif f.name == "_sequence_lengths": - self._sequence_lengths = val_other - elif isinstance(val, torch.Tensor): - val[: len(val_other)] = val_other.to(self.device) - else: - assert val == val_other, f"Field {f.name} mismatch: {val} != {val_other}." + for k in self._uncached_arg_names + self._cached_arg_names: + setattr(self, k, getattr(self, k).to(*args, **kwargs)) + + def _move_dict(d: Dict[str, torch.Tensor]) -> None: + for k, v in d.items(): + d[k] = v.to(*args, **kwargs) + + _move_dict(self._extra_args) + _move_dict(self._extra_none_inputs) def reset(self) -> None: """Reset the sequence information. @@ -342,25 +375,30 @@ def reset(self) -> None: self.cache_loc[:] = torch.arange(self.num_pages, dtype=torch.int, device=self.device) self.pages_per_seq.fill_(1) - def set_example_sequence(self) -> None: + def set_example_sequence(self, input_ids: Optional[torch.Tensor] = None, **kwargs) -> None: """Set an example sequence useful for testing and export purposes.""" self.reset() - bs, seq_len = min(2, self.max_batch_size), min(4, self.max_seq_len) - input_ids = torch.ones( - bs, - seq_len, - dtype=torch.int, - device=self.device, + + # use a best guess default for input_ids if not provided + if input_ids is None: + bs, seq_len = min(2, self.max_batch_size), min(4, self.max_seq_len) + input_ids = torch.ones( + bs, + seq_len, + dtype=torch.int, + device=self.device, + ) + + # make sure that all extra arguments are provided + assert self._extra_args.keys() <= kwargs.keys(), ( + f"Missing extra args: {self._extra_args.keys() - kwargs.keys()}" ) - self.nest_sequences(input_ids) - # unflatten if we are not yet using cached+flattened attention - if not self._is_cached_attn: - self.input_ids = self.input_ids.view(bs, seq_len) - self.position_ids = self.position_ids.view(bs, seq_len) + self.nest_sequences(input_ids, **kwargs) - def _set_max_num_tokens_sample(self) -> None: + def set_max_num_tokens_sample(self) -> None: """Set an example sequence with max_num_tokens.""" + # TODO: understand what this implies for extra arguments self.reset() seq_len = self.max_num_tokens // self.max_batch_size input_ids = torch.ones( @@ -396,67 +434,180 @@ def _update_position_ids(self) -> None: else: self.position_ids = self.position_ids.view(1, -1) - def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None: + def _generate_position_ids(self) -> torch.Tensor: + """Generate position ids from current input_pos and sequence lengths.""" + position_ids_list = [ + num + for in_pos, seq_len in zip(self.input_positions, self.sequence_lengths) + for num in range(in_pos, in_pos + seq_len) + ] + return torch.tensor(position_ids_list, dtype=torch.long).to(self.device) + + def _update_input_pos(self, seq_len: Union[torch.Tensor, List[int], int]) -> None: + """Update the starting position for each sequence in the cache. + + If ``reset=True`, ``input_pos`` will be reset to zero before updating. + """ + if not isinstance(seq_len, torch.Tensor): + seq_len = torch.tensor(seq_len, dtype=torch.int) + bs = len(seq_len) if seq_len.dim() > 0 else self.max_batch_size + self.input_pos[:bs] = seq_len.to(self.device) + + def _assign_pages_per_seq(self, page_assignments: Sequence[Sequence[int]]) -> None: + """Set the cache location and pages_per_seq tensors from page assignments.""" + assert len(page_assignments) == self.num_sequences + cache_loc_flat = torch.tensor( + [p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int + ) + self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True) + + pages_per_seq = torch.tensor([len(p) for p in page_assignments], dtype=torch.int) + self.pages_per_seq[: len(pages_per_seq)].copy_(pages_per_seq, non_blocking=True) + + @staticmethod + def _flatten(nested_seqs: Sequence[Sequence[int]]) -> List[int]: + return [ + val + for lst in nested_seqs + for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst) + ] + + def _shape_for_forward(self, tnsr: torch.Tensor) -> torch.Tensor: + """Shape the tensor for the forward pass based on the current attention mode. + + Args: + tnsr: The tensor to shape assumed to be in shape [batch_size*seq_len, ...] + + Returns: + The shaped tensor flattened or unflattened based on the current attention mode. + """ + # check if we are still running uncached attention in which case we are also still + # operate on unflattened tensors with explicit [batch_size, seq_len, ...] shape + if not self._is_cached_attn: + bs = len(self.sequence_lengths) + sl = self.sequence_lengths[0] + return tnsr.view(bs, sl, *tnsr.shape[2:]) + + # use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len] + if self.is_generate: + return tnsr.view(-1, 1, *tnsr.shape[1:]) + else: + return tnsr.view(1, -1, *tnsr.shape[1:]) + + def nest_sequences( + self, + input_ids: Sequence[Sequence[int]], + position_ids: Optional[Sequence[Sequence[int]]] = None, + input_pos: Optional[Union[torch.Tensor, Sequence[int], int]] = None, + page_assignments: Optional[Sequence[Sequence[int]]] = None, + **extra_args: Dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]], + ) -> None: """Create and store a flattened list of input_ids from the provided list of sequences. - This i/f will also update any relevant sequence information. + Args: + input_ids: List of sequences of input_ids. + position_ids: List of sequences of position_ids for each token. + input_pos: Absolute starting position in the cache for each sequence. + page_assignments: List of sequences of page assignments for each sequence. + extra_args: Extra arguments to be stored in the interface. + + This i/f will ensure that all sequence info args are updated accordingly. """ + # set new sequence lengths seq_lens = [len(ids) for ids in input_ids] self.seq_len.zero_() self.seq_len[: len(seq_lens)].copy_(torch.tensor(seq_lens), non_blocking=True) + self._sequence_lengths = seq_lens + # We'll preserve the dtype of the input_ids tensor if it is a tensor, otherwise we'll use int dtype = input_ids.dtype if isinstance(input_ids, torch.Tensor) else torch.int - # set new input_ids as new tensor from flattened input_ids - ids_list = [ - val - for lst in input_ids - for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst) - ] - self.input_ids = torch.tensor(ids_list, dtype=dtype).to(self.device) - # set derivative properties - self._sequence_lengths = seq_lens + # set new input_ids as new tensor from flattened input_ids + self.input_ids = torch.tensor(self._flatten(input_ids), dtype=dtype).to(self.device) + self.input_ids = self._shape_for_forward(self.input_ids) - # use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len] - if self.is_generate: - self.input_ids = self.input_ids.view(-1, 1, *self.input_ids.shape[1:]) + # check for position_ids/input_pos update + assert position_ids is None or input_pos is None, ( + "Cannot provide both position_ids and input_pos" + ) + # check for updated input_pos + if input_pos is not None: + self._update_input_pos(input_pos) + + # check for updated position_ids + if position_ids is None: + # none provided,simple update position_ids based on new sequence lengths and + # current input_pos assuming that input_pos is the starting position id for each + # sequence and position_ids are consecutive. + self.position_ids = self._generate_position_ids() + elif not isinstance(position_ids, torch.Tensor): + # nest position_ids to be consistent with input_ids + seq_lens_p = [len(ids) for ids in position_ids] + assert len(seq_lens_p) == len(seq_lens), f"{seq_lens_p=} != {seq_lens=}" + position_ids_flat = self._flatten(position_ids) + self.position_ids = torch.tensor( + position_ids_flat, dtype=torch.long, device=self.device + ) else: - self.input_ids = self.input_ids.view(1, -1, *self.input_ids.shape[1:]) + self.position_ids = position_ids - # update position_ids - self._update_position_ids() + # final shape for position_ids + self.position_ids = self._shape_for_forward(self.position_ids) + + # sanity check on final shape of position_ids and input_ids + assert self.position_ids.shape[:2] == self.input_ids.shape[:2], ( + f"{self.position_ids.shape[:2]=} != {self.input_ids.shape[:2]=}" + ) + + # check for updated page_assignments + if page_assignments is not None: + self._assign_pages_per_seq(page_assignments) + + # go through all extra arguments and update them + for name, none_input in self._extra_none_inputs.items(): + if name in extra_args: + arg = extra_args.pop(name) + if not isinstance(arg, torch.Tensor): + if len(arg) > 1: + arg = torch.cat(arg) + else: + arg = arg[0] + self._extra_args[name] = arg.to(self.device) + else: + self._extra_args[name] = none_input + + # TODO (lucaslie): how strict do we wanna be here? Should we just warn/ignore instead? + assert not extra_args, f"Extra arguments {extra_args.keys()} not found" def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]: t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0) return list(torch.split(t_squeezed, self.sequence_lengths)) - def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool = False) -> None: - """Update the starting position for each sequence in the cache. - - If ``reset=True`, ``input_pos`` will be reset to zero before updating. + def add_extra_arg( + self, + name: str, + none_input: torch.Tensor, + dynamic_shape_callback: Optional[DynamicShapeCallback] = None, + ) -> None: + """Add an extra argument to the sequence info object. + + Args: + name: The name of the extra argument. + none_input: None input value of the extra argument. + dynamic_shape_callback: The callback to get the dynamic shape of the extra argument. + + Note that the extra argument is expected to be a tensor. """ - if not isinstance(seq_len, torch.Tensor): - seq_len = torch.tensor(seq_len, dtype=torch.int) - bs = len(seq_len) if seq_len.dim() > 0 else self.max_batch_size + assert name not in self._named_args().keys(), f"Extra argument {name} already exists" - if reset: - self.input_pos[:bs] = seq_len.to(self.device) - else: - self.input_pos[:bs] += seq_len.to(self.device) - - # update position_ids - self._update_position_ids() + self._extra_args[name] = none_input.to(self.device) + self._extra_none_inputs[name] = none_input.to(self.device) - def assign_cache_loc(self, page_assignments: Sequence[Sequence[int]]) -> None: - """Set the cache location and pages_per_seq tensors from page assignments.""" - cache_loc_flat = torch.tensor( - [p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int - ) - self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True) - - pages_per_seq = torch.tensor([len(p) for p in page_assignments], dtype=torch.int) - self.pages_per_seq[: len(pages_per_seq)].copy_(pages_per_seq, non_blocking=True) + if dynamic_shape_callback is None: + self._extra_dynamic_shapes_callbacks[name] = lambda: {} + else: + self._extra_dynamic_shapes_callbacks[name] = dynamic_shape_callback Constant = Union[int, float, str, None] diff --git a/tensorrt_llm/_torch/auto_deploy/llm.py b/tensorrt_llm/_torch/auto_deploy/llm.py index 999a024fb38..a1fd8c27135 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm.py +++ b/tensorrt_llm/_torch/auto_deploy/llm.py @@ -1,19 +1,83 @@ import types -from typing import List, Optional +from typing import Any, Dict, List, Optional, Tuple from ...executor.result import CompletionOutput -from ...inputs.registry import create_input_processor +from ...inputs.registry import DefaultInputProcessor, ExtraProcessedInputs from ...llmapi.llm import RequestOutput, _TorchLLM -from ...llmapi.tokenizer import TokenizerBase, tokenizer_factory +from ...llmapi.tokenizer import TokenizerBase, TransformersTokenizer, tokenizer_factory +from ...sampling_params import SamplingParams from .distributed import common as dist_ad from .llm_args import LlmArgs +from .models.factory import ModelFactory from .shim.demollm import DemoGenerationExecutor +class ADInputProcessor(DefaultInputProcessor): + """Input processor for AutoDeploy backend. + + This is a wrapper to either support standard TRT-LLM text-only input processing or use HF's + message chat template system to process multimodal inputs. + """ + + def __init__(self, tokenizer: TokenizerBase, processor: Optional[Any] = None): + super().__init__(None, None, tokenizer) + # NOTE: HF's tokenizer/processor that has the apply_chat_template method + self.processor = processor or tokenizer.tokenizer + + def __call__( + self, inputs: Dict[str, Any], sampling_params: SamplingParams + ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + # construct kwargs to reflect DefaultInputProcessor + kwargs = { + "add_special_tokens": sampling_params.add_special_tokens, + } + if sampling_params.truncate_prompt_tokens is not None: + kwargs = { + "truncation": True, + "max_length": sampling_params.truncate_prompt_tokens, + } + # check for messages field and if yes, use the apply_chat_template method + if "messages" in inputs: + # TODO: we don't really need this but it makes for a good sanity check. Consider + # removing this in the future if we need to speed things up. + prompt = self.processor.apply_chat_template( + inputs["messages"], + add_generation_prompt=True, + tokenize=False, + ) + inputs["prompt"] = prompt + + all_args = self.processor.apply_chat_template( + inputs["messages"], + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=False, # there shouldn't be a need for padding ever... + return_attention_mask=False, + **kwargs, + ) + # TODO: is there a more reliable way to avoid the attention_mask here? + all_args.pop("attention_mask", None) + + # TODO: can we avoid the extra tolist() here eventually? + token_ids = all_args.pop("input_ids") + assert token_ids.shape[0] == 1, "messages should be unbatched at this point." + return token_ids[0].tolist(), {"multimodal_data": all_args} if all_args else None + else: + token_ids = self.tokenizer.encode(inputs["prompt"], **kwargs) + return token_ids, None + + class LLM(_TorchLLM): """LLM class is the main class for running an LLM model using AutoDeploy backend.""" args: LlmArgs + _factory: ModelFactory + + @property + def factory(self) -> ModelFactory: + return self._factory def __init__(self, *args, **kwargs): kwargs["backend"] = "_autodeploy" @@ -23,16 +87,14 @@ def _try_load_tokenizer(self) -> Optional[TokenizerBase]: if self.args.skip_tokenizer_init: return None - factory = self.args.create_factory() - return tokenizer_factory(factory.init_tokenizer()) + return tokenizer_factory(self._factory.init_tokenizer()) def _validate_args_for_torch_backend(self, kwargs: dict) -> None: """We don't need to validate args for AutoDeploy backend for now.""" pass - def _prefetch_model(self): - """Prefetch the model for the LLM.""" - self.args.create_factory().prefetch_checkpoint() + def _create_input_processor(self) -> ADInputProcessor: + return ADInputProcessor(self.tokenizer, self._factory.init_processor()) def _build_model(self): """Build the model for the LLM. @@ -40,13 +102,21 @@ def _build_model(self): This is a wrapper around the regular build model method that prefetches the model with the factory. """ + # create and store a factory + self._factory = self.args.create_factory() + # prefetch model with factory - self._prefetch_model() + self._factory.prefetch_checkpoint() # NOTE (lucaslie): do regular build model, we bypass the regular LLM CachedModelLoader in # _autodeploy backend. super()._build_model() + # now correct input processor + assert isinstance(self.input_processor, DefaultInputProcessor) + assert isinstance(self.tokenizer, TransformersTokenizer) + self.input_processor = self._create_input_processor() + class DemoLLM(LLM): """A simple LLM class to demo the LLM interface while debugging the e2e workflow. @@ -61,9 +131,10 @@ def __init__(self, **kwargs): self.runtime_context = None # prefetch model and load tokenizer - self._prefetch_model() + self._factory = self.args.create_factory() + self._factory.prefetch_checkpoint() self._tokenizer = self._try_load_tokenizer() - self.input_processor = create_input_processor(None, self.tokenizer) + self.input_processor = self._create_input_processor() # construct demo executor + engine self._executor = DemoGenerationExecutor( diff --git a/tensorrt_llm/_torch/auto_deploy/models/factory.py b/tensorrt_llm/_torch/auto_deploy/models/factory.py index 42a30402537..11c676bd935 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/factory.py +++ b/tensorrt_llm/_torch/auto_deploy/models/factory.py @@ -2,13 +2,13 @@ import copy from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, Optional, Tuple, Type import torch import torch.nn as nn from torch._prims_common import DeviceLikeType -from ..custom_ops.attention_interface import CacheConfig +from ..custom_ops.attention_interface import CacheConfig, DynamicShapeCallback from ..utils.logger import ad_logger @@ -113,6 +113,15 @@ def init_tokenizer(self) -> Optional[Any]: """ return None + def init_processor(self) -> Optional[Any]: + """Initialize the (multi-modal) processor for the model. + + Returns: + The initialized processor for the model. If the processor is not available, then this + method should return None. + """ + return None + def prefetch_checkpoint(self, force: bool = False): """Try or skip prefetching the checkpoint for the model and tokenizer. @@ -206,6 +215,33 @@ def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType): device: The device to load the model on. """ + def get_example_inputs(self) -> Dict[str, torch.Tensor]: + """Return a dictionary of example inputs for the model. + + This function can be overwritten by a factory when it requires a specific example input to + in order to run through export. + + Returns: + A dictionary of example inputs for the model where the key corresponds to the argument + name and the value corresponds to the example input. + """ + return {} + + def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, DynamicShapeCallback]]: + """Return a dictionary of extra inputs for the model. + + Returns: + A dictionary of extra inputs for the model where the key corresponds to the argument + name and the value corresponds to a tuple of (none_input, dynamic_shape_callback): + - `none_input`: The none input value of the extra input indicating the tensor + value corresponding to the equivalent of the None input. `None` is not supported + as we require the input to be a tensor. Hence, this none_input acts as a + placeholder for the None input. + - `dynamic_shape_callback`: A function that returns the dynamic shape of the extra + input. + """ + return {} + class ModelFactoryRegistry: _registry: Dict[str, Type[ModelFactory]] = {} diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index fc37c1e557a..937af7e2a61 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -4,7 +4,7 @@ import os import types from contextlib import contextmanager, nullcontext -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -12,11 +12,13 @@ from accelerate.utils import modeling from huggingface_hub import HfApi, snapshot_download from huggingface_hub.utils import HFValidationError, filter_repo_objects, validate_repo_id +from PIL import Image from torch._prims_common import DeviceLikeType from transformers import ( AutoConfig, AutoModelForCausalLM, AutoModelForImageTextToText, + AutoProcessor, AutoTokenizer, PretrainedConfig, ) @@ -27,7 +29,7 @@ WEIGHTS_NAME, ) -from ..custom_ops.attention_interface import CacheConfig +from ..custom_ops.attention_interface import CacheConfig, Dim, DynamicShapeCallback from ..utils._config import deep_merge_dicts from ..utils.logger import ad_logger from .factory import ModelFactory, ModelFactoryRegistry @@ -108,10 +110,6 @@ def __init__(self, *args, **kwargs): def autoconfig_from_pretrained(self): return AutoConfig.from_pretrained - @property - def autotokenizer_from_pretrained(self): - return AutoTokenizer.from_pretrained - # TODO (@lucaslie): Do we ever want to switch to from_pretrained? @property def automodel_from_config(self): @@ -200,7 +198,7 @@ def init_tokenizer(self) -> Optional[Any]: """Initialize the tokenizer—either a custom name or the model's default.""" if self.tokenizer is None: return None - return self.autotokenizer_from_pretrained(self.tokenizer, **self.tokenizer_kwargs) + return AutoTokenizer.from_pretrained(self.tokenizer, **self.tokenizer_kwargs) @staticmethod def _get_ignore_patterns(repo_id: str, skip_prefetch_weights: bool) -> List[str]: @@ -366,3 +364,109 @@ def _get_max_position_embeddings_config(self) -> Dict[str, Any]: @property def automodel_from_config(self): return AutoModelForImageTextToText.from_config + + def init_tokenizer(self) -> Optional[Any]: + """Initialize the tokenizer—either a custom name or the model's default.""" + processor = self.init_processor() + if processor is None: + return None + return processor.tokenizer + + def init_processor(self) -> Optional[Any]: + """Initialize the processor for the model.""" + if self.tokenizer is None: + return None + return AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs) + + @staticmethod + def _simple_forward( + model: nn.Module, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + pixel_values: torch.Tensor, + # How to get this programmatically? + image_sizes: torch.Tensor, + ): + """A simple forward pass for the model to functionalize the args. + + This follows the standard function signature as expected by factory.py. + """ + return type(model).forward( + model, + input_ids=input_ids, + position_ids=position_ids, + pixel_values=pixel_values, + image_sizes=image_sizes, + ) + + def get_example_inputs(self) -> Dict[str, torch.Tensor]: + """Return a dictionary of example inputs for the model.""" + + def _prep_seq(text, img1, img2): + return [ + { + "role": "user", + "content": [ + {"type": "image", "image": img1}, + {"type": "image", "image": img2}, + {"type": "text", "text": text}, + ], + } + ] + + # Create a batch of conversations (batch_size = 2) + batch_messages = [ + _prep_seq( + "Describe what you see in the two images and their differences.", + Image.new("RGB", (64, 64), color=(128, 128, 128)), + Image.new("RGB", (64, 64), color=(64, 64, 64)), + ), + _prep_seq( + "What are the main differences between these two images?", + Image.new("RGB", (64, 64), color=(255, 0, 0)), + Image.new("RGB", (64, 64), color=(0, 255, 0)), + ), + ] + + processor = AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs) + inputs = processor.apply_chat_template( + batch_messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=True, + return_attention_mask=False, + ) + + return { + # "input_ids": inputs["input_ids"], + # "pixel_values": inputs["pixel_values"], + **inputs + } + + def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, DynamicShapeCallback]]: + """Return a dictionary of extra inputs for the model. + + Returns: + A dictionary of extra inputs for the model where the key corresponds to the argument + name and the value corresponds to a tuple of (example_input, dynamic_shape_callback). + The dynamic shape callback is a function that returns the dynamic shape of the extra + input. + """ + + def _get_dynamic_shape(): + return { + # TODO (lucaslie): how to set default values for dynamic shapes? + 0: Dim("img_batch_size", max=10), + 2: Dim("img_height", min=32, max=2048), + 3: Dim("img_width", min=32, max=2048), + } + + none_pixel_values = torch.zeros(0, 3, 336, 336) + return { + "pixel_values": (none_pixel_values, _get_dynamic_shape), + # How to get this from the input processor? It seems there's no good way without + # running a dummy input through the processor. + "image_sizes": (torch.zeros(0, 2), None), + } diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py b/tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py new file mode 100644 index 00000000000..eeb931bec7c --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py @@ -0,0 +1,185 @@ +"""A patch to handle vision branch in Llama4ForConditionalGeneration.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers import Llama4ForConditionalGeneration +from transformers.models.llama4.modeling_llama4 import Llama4CausalLMOutputWithPast + +from ...export.interface import BaseExportPatch, ExportPatchRegistry + + +# Copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama4.py#L1651 +# With some modifications that won't affect current execution logic: +# 1. Vison branch managed by torch.cond to enable both text-only and text+image input during runtime. +# 2. Input arg `image_sizes` are set to none +# as the input to torch.cond true/false branch needs fixed argument type during export +# 3. Do not return `image_hidden_states` as it is calculated inside the vision branch +# and invisible to the function outside. +def _forward_with_cond( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + image_sizes: torch.Tensor = None, # image_sizes set as None + **lm_kwargs, +) -> Union[Tuple, Llama4CausalLMOutputWithPast]: + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer + if vision_feature_layer is not None + else self.config.vision_config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + def _vision_branch(inputs_embeds, pixel_values, input_ids): + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + image_sizes=None, + ) + original_inputs_embeds_shape = inputs_embeds.shape + + vision_flat = image_features.view(-1, image_features.size(-1)) + projected_vision_flat = self.multi_modal_projector(vision_flat) + + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + final_mask = special_image_mask.to(inputs_embeds.device) + inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) + + final_mask_1d = final_mask[..., 0].reshape(-1) + # num_tokens_to_fill = final_mask_1d.sum() + + # This condition statement breaks torch.export: + # TODO: sanity check on the inputs for this + # if num_tokens_to_fill != projected_vision_flat.size(0): + # raise ValueError( + # f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, " + # f"but multi_modal_projector returned {projected_vision_flat.size(0)}" + # ) + + expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1)) + inputs_embeds.masked_scatter_(expanded_mask, projected_vision_flat) + + return inputs_embeds.view(original_inputs_embeds_shape) + + def _no_vision_branch(inputs_embeds, pixel_values, input_ids): + return inputs_embeds + + # decide by whether there is any non-zero pixel_values + has_image: torch.Tensor = torch.any(pixel_values != 0) + + inputs_embeds = torch.cond( + has_image, + _vision_branch, + _no_vision_branch, + (inputs_embeds, pixel_values, input_ids), + ) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + if attention_mask is not None: + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) + shift_logits = logits[..., :-1, :][ + shift_attention_mask.to(logits.device) != 0 + ].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1).to(shift_logits.device), + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Llama4CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=None, # skip outputting this for simplicity + ) + + +@ExportPatchRegistry.register("hf_llama4_vision") +class Llama4VisionPatch(BaseExportPatch): + """Patch for Llama4ForConditionalGeneration to make it compatible with torch.export. + + This patch replaces the forward method of Llama4ForConditionalGeneration with + a version that uses the torch.cond to handle the optional vision branch. + """ + + def _apply_patch(self): + """Apply the Llama4 vision patch.""" + # Store original forward method + self.original_values["Llama4ForConditionalGeneration.forward"] = ( + Llama4ForConditionalGeneration.forward + ) + + # Apply patch by replacing the forward method + Llama4ForConditionalGeneration.forward = _forward_with_cond + + def _revert_patch(self): + """Revert the Llama4 vision patch.""" + # Restore original forward method + Llama4ForConditionalGeneration.forward = self.original_values[ + "Llama4ForConditionalGeneration.forward" + ] diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/pixtral.py b/tensorrt_llm/_torch/auto_deploy/models/patches/pixtral.py new file mode 100644 index 00000000000..486cb040753 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/pixtral.py @@ -0,0 +1,219 @@ +"""A patch for the PixtralVisionModel to make it compatible with torch.export.""" + +import torch +from transformers.models.mistral3.modeling_mistral3 import Mistral3PatchMerger +from transformers.models.pixtral.modeling_pixtral import ( + PixtralVisionModel, + generate_block_attention_mask, + position_ids_in_meshgrid, +) + +from ...export.interface import BaseExportPatch, ExportPatchRegistry + +# @williamz notes: +# 1. everything decorated by a `custom_op` must be type annotated. +# a. It must be one of the internally supported param types. As such, `self: PixtralVisionModel` +# is a no-go. +# As such, pretty much only free-standing functions with tensor inputs are supported - instance +# methods cannot be decorated. + + +@torch.library.custom_op("auto_deploy::process_pixtral_patch_embeds", mutates_args={}) +def _process_patch_embeds( + patch_embeds: torch.Tensor, + image_sizes: torch.Tensor, + patch_size: int, + hidden_size: int, + max_width: int, +) -> tuple[torch.Tensor, torch.Tensor]: + patch_embeds_list = [ + embed[..., : (size[0] // patch_size), : (size[1] // patch_size)] + for embed, size in zip(patch_embeds, image_sizes) + ] + + # flatten to a single sequence + patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0) + + position_ids = position_ids_in_meshgrid(patch_embeds_list, max_width=max_width) + + return patch_embeds, position_ids + + +@_process_patch_embeds.register_fake +def _process_patch_embeds_meta( + patch_embeds: torch.Tensor, + image_sizes: torch.Tensor, + patch_size: int, + hidden_size: int, + max_widht: int, +): + B = (image_sizes // patch_size).prod(dim=1).sum() + device = patch_embeds.device + return ( + # Leading 1 = `unsqueeze(0)`. + # The symbolic tracing will actually not complain if `1` is missing - I guess because + # the number of elements in the underlying tensor is the same? + torch.empty(1, B, hidden_size, device=device), + torch.empty(hidden_size, device=device, dtype=torch.int64), + ) + + +def _pixtral_forward( + self: PixtralVisionModel, + pixel_values: torch.Tensor, + image_sizes: torch.Tensor | None, + output_hidden_states: bool | None = None, + output_attentions: bool | None = None, + return_dict: bool | None = None, + *args, + **kwargs, +): + if image_sizes is None: + batch_size, _, height, width = pixel_values.shape + image_sizes = torch.tensor([(height, width)] * batch_size, device=pixel_values.device) + + # pass images through initial convolution independently + patch_embeds = self.patch_conv(pixel_values) + patch_embeds, position_ids = torch.ops.auto_deploy.process_pixtral_patch_embeds( + patch_embeds=patch_embeds, + image_sizes=image_sizes, + patch_size=self.patch_size, + hidden_size=self.config.hidden_size, + max_width=self.config.image_size // self.config.patch_size, + ) + + patch_embeds = self.ln_pre(patch_embeds) + + kwargs["position_ids"] = position_ids + + position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids) + + if self.config._attn_implementation == "flash_attention_2": + # We only rely on position_ids when using flash_attention_2 + attention_mask = None + else: + attention_mask = generate_block_attention_mask( + (image_sizes // self.config.patch_size).prod(dim=1), + patch_embeds, + ) + + return self.transformer( + patch_embeds, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=True, + **kwargs, + ) + + +def generate_block_attention_mask(num_ids_per_image, tensor): + dtype = tensor.dtype + device = tensor.device + seq_len = tensor.shape[1] + d_min = torch.finfo(dtype).min + + idx = torch.arange(seq_len, device=device) + block_end_idx = num_ids_per_image.cumsum(-1) + block_start_idx = torch.cat( + [ + num_ids_per_image.new_zeros((1,)), + num_ids_per_image[:-1], + ] + ).cumsum(-1) + + # Build a mask where positions outside each [start, end) block are 1, inside are 0. + mask = torch.ones((seq_len, seq_len), device=device, dtype=dtype) + for start, end in zip(block_start_idx, block_end_idx): + block = (idx >= start) & (idx < end) + mask[block.unsqueeze(0) & block.unsqueeze(1)] = 0 + + return mask + + +@torch.library.custom_op("auto_deploy::unfold_to_2d_grid", mutates_args={}) +def _unfold_to_2d_grid( + image_features: torch.Tensor, + image_sizes: torch.Tensor, + patch_size: int, + spatial_merge_size: int, +) -> torch.Tensor: + image_sizes = [ + (image_size[0] // patch_size, image_size[1] // patch_size) for image_size in image_sizes + ] + + tokens_per_image = [h * w for h, w in image_sizes] + d = image_features.shape[-1] + + permuted_tensor = [] + for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)): + # Reshape image_tokens into a 2D grid + h, w = image_sizes[image_index] + image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0) + grid = torch.nn.functional.unfold( + image_grid, kernel_size=spatial_merge_size, stride=spatial_merge_size + ) + grid = grid.view(d * spatial_merge_size**2, -1).t() + permuted_tensor.append(grid) + + image_features = torch.cat(permuted_tensor, dim=0) + + +@_unfold_to_2d_grid.register_fake +def _unfold_to_2d_grid_meta( + image_features: torch.Tensor, + image_sizes: torch.Tensor, + patch_size: int, + spatial_merge_size: int, +): + embedding_sizes = (image_sizes // patch_size).prod(dim=1) + spatial_factor = spatial_merge_size * spatial_merge_size + grid_sizes = embedding_sizes // spatial_factor + total_size = grid_sizes.sum() + + return image_features.new_empty(total_size, image_features.shape[-1] * spatial_factor) + + +def _patch_merger_forward( + self, image_features: torch.Tensor, image_sizes: torch.Tensor +) -> torch.Tensor: + unfolded_features = torch.ops.auto_deploy.unfold_to_2d_grid( + image_features=image_features, + image_sizes=image_sizes, + patch_size=self.patch_size, + spatial_merge_size=self.spatial_merge_size, + ) + image_features = self.merging_layer(unfolded_features) + return image_features + + +@ExportPatchRegistry.register("hf_pixtral_vit") +class PixtralVisionModelPatch(BaseExportPatch): + """Patch for `PixtralVisionModel`.""" + + def _apply_patch(self): + """Apply the PixtralVisionModel patch.""" + # Store original forward method + self.original_values["PixtralVisionModel.forward"] = PixtralVisionModel.forward + self.original_values["Mistral3PatchMerger.forward"] = Mistral3PatchMerger.forward + + # Apply patch by replacing the forward method + PixtralVisionModel._original_forward = PixtralVisionModel.forward # type: ignore + PixtralVisionModel.forward = _pixtral_forward # type: ignore + + Mistral3PatchMerger._original_forward = Mistral3PatchMerger.forward + Mistral3PatchMerger.forward = _patch_merger_forward + + def _revert_patch(self): + """Revert the PixtralVisionModel patch.""" + # Restore original forward method. + PixtralVisionModel.forward = self.original_values["PixtralVisionModel.forward"] # type: ignore + Mistral3PatchMerger.forward = self.original_values["Mistral3PatchMerger.forward"] + + # Clean up the temporary attribute. + if hasattr(PixtralVisionModel, "_original_forward"): + delattr(PixtralVisionModel, "_original_forward") + + if hasattr(Mistral3PatchMerger, "_original_forward"): + delattr(Mistral3PatchMerger, "_original_forward") diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 7f759d6796d..4598056c7bb 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -1,6 +1,6 @@ -from itertools import chain +from collections import defaultdict from types import SimpleNamespace -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch from torch._prims_common import DeviceLikeType @@ -102,16 +102,24 @@ def build_from_config(cls, ad_config: AutoDeployConfig): max_num_tokens=max_num_tokens, ) + # get factory + factory = ad_config.create_factory() + # update device to contain the current default device if it's in cuda device = torch.device(ad_config.device) if device.type == "cuda" and device.index is None: device = torch.device(f"cuda:{torch.cuda.current_device()}") device = str(device) + # pass in extra arguments defined by the model factory + for name, (none_input, dynamic_shape_callback) in factory.get_extra_inputs().items(): + seq_info.add_extra_arg(name, none_input, dynamic_shape_callback) + + # TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__, + # ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm. + # construct inference optimizer - build_and_optimize = InferenceOptimizer( - factory=ad_config.create_factory(), ad_config=ad_config - ) + build_and_optimize = InferenceOptimizer(factory=factory, ad_config=ad_config) # construct engine return cls(build_and_optimize, seq_info, device, max_beam_width) @@ -176,6 +184,7 @@ def _prepare_inputs( input_pos: List[int] = [] last_logit_only: List[bool] = [] page_assignments: List[List[int]] = [] + extra_args: Dict[str, List[torch.Tensor]] = defaultdict(list) # look at context requests first for request in context_requests: @@ -186,6 +195,15 @@ def _prepare_inputs( request.py_batch_idx = request.seq_slot last_logit_only.append(True) + # get cache indices + cache_indices = kv_cache_manager.get_cache_indices(request) + page_assignments.append(cache_indices) + + # store extra arguments + if request.py_multimodal_data is not None: + for k, v in request.py_multimodal_data.items(): + extra_args[k].append(v) + # look at generate requests next # TODO: we should also handle extend requests (for speculative decoding) here for request in gen_requests: @@ -202,17 +220,17 @@ def _prepare_inputs( # return all logits last_logit_only.append(False) - # extract cache information for all requests - for request in chain(context_requests, gen_requests): # get cache indices cache_indices = kv_cache_manager.get_cache_indices(request) page_assignments.append(cache_indices) # update the sequence info object now - si = self.cache_seq_interface.info - si.nest_sequences(input_ids) - si.update_pos(input_pos, reset=True) - si.assign_cache_loc(page_assignments) + self.cache_seq_interface.info.nest_sequences( + input_ids, + input_pos=input_pos, + page_assignments=page_assignments, + **extra_args, + ) return last_logit_only def _compute_logits(self) -> List[torch.Tensor]: diff --git a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py index c29cb5fbd7e..028d0ac5429 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py @@ -1,8 +1,9 @@ """A demo LLM api to for debugging and testing purposes of e2e workflows.""" import gc +from collections import defaultdict from queue import Empty -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.multiprocessing as mp @@ -10,6 +11,7 @@ from ....executor import GenerationExecutor from ....executor.request import GenerationRequest from ....executor.result import CompletionOutput, GenerationResult +from ....inputs.multimodal import MultimodalParams from ....sampling_params import SamplingParams from ...pyexecutor.sampler import greedy_search_sampling_batch, top_k_sampling_batch from ..distributed import common as dist_ad @@ -34,8 +36,11 @@ def __init__(self, *args, **kwargs) -> None: self.queue = mp.Queue() @torch.inference_mode() - def __call__(self, requests: GenerationRequest) -> mp.Queue: + def __call__( + self, requests: GenerationRequest, multimodal_params: Optional[MultimodalParams] + ) -> mp.Queue: """Generate tokens and put the results in a queue and return the queue.""" + requests.multimodal_params = multimodal_params output = self.generate_tokens_batched([requests])[0] self.queue.put(output) return self.queue @@ -45,7 +50,7 @@ def stop(self): self.queue.close() self.queue.join_thread() - def _assign_pages(self) -> List[List[int]]: + def _assign_pages(self, total_lens: List[int]) -> List[List[int]]: """A simple heuristic to assign pages based on current sequence info. In a nutshell, we will look at the following information to update the page assignments: @@ -67,7 +72,6 @@ def _assign_pages(self) -> List[List[int]]: unassigned page if needed. """ si = self.cache_seq_interface.info - total_lens = [s_l + i_p for s_l, i_p in zip(si.sequence_lengths, si.input_positions)] page_assignments = si.page_assignments free_pages = set(range(si.num_pages)) - {i for pages in page_assignments for i in pages} @@ -76,7 +80,7 @@ def _assign_pages(self) -> List[List[int]]: extra_tokens = t_l - len(pages) * si.page_size num_extra_pages = (extra_tokens // si.page_size) + (extra_tokens > 0) updated_assignments.append(pages + [free_pages.pop() for _ in range(num_extra_pages)]) - si.assign_cache_loc(updated_assignments) + return updated_assignments def generate_tokens_batched( self, requests: List[GenerationRequest] @@ -91,10 +95,26 @@ def generate_tokens_batched( ) assert sampling_params.best_of == 1, "Best-of is not supported." - # set up sequence info object + # set up sequence info object for decode phase sequence_info = self.cache_seq_interface.info sequence_info.reset() - sequence_info.nest_sequences([r.prompt_token_ids for r in requests]) + + input_ids = [] + total_lens = [] + extra_args: Dict[str, List[torch.Tensor]] = defaultdict(list) + + for request in requests: + total_lens.append(len(request.prompt_token_ids)) + input_ids.append(request.prompt_token_ids) + if request.multimodal_params is not None: + for k, v in request.multimodal_params.multimodal_data.items(): + extra_args[k].append(v) + + sequence_info.nest_sequences( + input_ids=input_ids, + page_assignments=self._assign_pages(total_lens), + **extra_args, + ) # setup objects we want to track for the output batch_size = sequence_info.num_sequences @@ -105,18 +125,21 @@ def generate_tokens_batched( context_logits: Optional[List[torch.Tensor]] = None def _generate_single_step(idx: int): - # assign pages - self._assign_pages() - - # get the logits and then last token logits in each sequence ([b, 1, vocab_size]) logits = self._compute_logits() logits_last = torch.stack([l_one_seq[-1] for l_one_seq in logits]).float().unsqueeze(1) token_ids, _ = self._decode_tokens(logits_last, sampling_params) # [b,1] - # update sequence info accordingly for next step - sequence_info.update_pos(sequence_info.sequence_lengths) - sequence_info.nest_sequences(token_ids) + # update sequence info accordingly for next step (generate phase) + input_pos_next = sequence_info.input_positions + seq_lens_current = sequence_info.sequence_lengths + input_pos_next = [ip + sl for ip, sl in zip(input_pos_next, seq_lens_current)] + total_lens_next = [ip + len(t_ids) for ip, t_ids in zip(input_pos_next, token_ids)] + sequence_info.nest_sequences( + token_ids, + input_pos=input_pos_next, + page_assignments=self._assign_pages(total_lens_next), + ) # nest new tokens and run stop check for b, (new_tokens_b, new_id) in enumerate(zip(new_tokens, token_ids)): @@ -255,6 +278,7 @@ def _run_engine( def _unpack(inputs) -> GenerationRequest: args, kwargs = inputs # unpack the inputs request: GenerationRequest = args[0] + request.multimodal_params: Optional[MultimodalParams] = args[1] return request engine = DemoEngine.build_from_config(**engine_kwargs) @@ -309,8 +333,11 @@ def submit(self, request: GenerationRequest) -> GenerationResult: request.set_id(client_id) # submit request to our demo engine and store results + # NOTE: when returning from this function, the reference request.multimodal_params will + # be cleared immediately. So we pass it in explicitly to maintain a reference even when + # requests get submitted asynchronously. result = GenerationResult(request) - result.queue = self.engine_executor(request) + result.queue = self.engine_executor(request, request.multimodal_params) return result diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py index bbe72650b4e..2868fbabe84 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py @@ -53,18 +53,24 @@ def _apply( model = gm.get_submodule("factory_model") # set the example sequence - cm.info.set_example_sequence() + cm.info.set_example_sequence(**factory.get_example_inputs()) # export the model to a graph module - gm = torch_export_to_gm( - model, - args=cm.args, - dynamic_shapes=cm.dynamic_shapes, - clone=self.config.clone_state_dict, - strict=self.config.strict, - patch_list=self.config.patch_list, - ) + # TODO: revert. this is just a hack to run export with debugger + # torch.export don't always work together nicely. But I can set a breakpoin here + # and then manually create the gm in the debugger console and then continue + if False: + gm = None + else: + gm = torch_export_to_gm( + model, + args=cm.args, + dynamic_shapes=cm.dynamic_shapes, + clone=self.config.clone_state_dict, + strict=self.config.strict, + patch_list=self.config.patch_list, + ) # this is a clean graph by definition since it was just exported info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py index 6d024aaadd1..cc998d2fbd0 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py @@ -10,7 +10,7 @@ from ...custom_ops.attention_interface import AttentionDescriptor from ...utils.logger import ad_logger from ...utils.node_utils import is_op -from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern +from ...utils.pattern_matcher import ADPatternMatcherPass, Match, register_ad_pattern from .._graph import canonicalize_graph, lift_to_meta @@ -24,14 +24,13 @@ def _apply_pattern( patterns = ADPatternMatcherPass() register_fn(patterns) num_matches = patterns.apply(gm.graph) - if num_matches > 0: with lift_to_meta(gm) if shape_prop else nullcontext(): canonicalize_graph(gm, shape_prop=shape_prop) ad_logger.info(f"Found and matched {num_matches} {pattern_name} pattern(s)") -def match_repeat_kv(gm: GraphModule) -> None: +def match_repeat_kv(gm: GraphModule, accept_match_fn: Callable[[Match], bool]) -> None: """ Match and replace the repeat_kv pattern with torch.ops.auto_deploy.torch_attention_repeat_kv. """ @@ -51,24 +50,25 @@ def register_repeat_kv(patterns: ADPatternMatcherPass): torch.ops.aten.expand.default: (int,), }, scalar_workaround={"n_rep": dummy_args[1]}, + extra_check=accept_match_fn, ) _apply_pattern(gm, "Repeat KV", register_repeat_kv, shape_prop=True) -def match_eager_attention(gm: GraphModule) -> None: +def match_eager_attention(gm: GraphModule, accept_match_fn: Callable[[Match], bool]) -> None: """ Match and replace the eager attention pattern with torch.ops.auto_deploy.torch_attention_sdpa. """ def register_eager_attention(patterns: ADPatternMatcherPass): for pattern_config in _get_sfdp_patterns(): - register_ad_pattern(**pattern_config, patterns=patterns) + register_ad_pattern(**pattern_config, patterns=patterns, extra_check=accept_match_fn) _apply_pattern(gm, "Eager Attention", register_eager_attention) -def match_grouped_attention(gm: GraphModule) -> None: +def match_grouped_attention(gm: GraphModule, accept_match_fn: Callable[[Match], bool]) -> None: """ Match and replace the grouped attention pattern with torch.ops.auto_deploy.torch_attention_grouped_sdpa. @@ -92,6 +92,7 @@ def register_grouped_attention(patterns: ADPatternMatcherPass): patterns=patterns, dummy_args=dummy_args_1, scalar_workaround={"scale": scale, "dropout_p": dropout, "n_rep": n_rep}, + extra_check=accept_match_fn, ) register_ad_pattern( search_fn=_grouped_attn_pattern_2, @@ -102,6 +103,7 @@ def register_grouped_attention(patterns: ADPatternMatcherPass): "scale": scale, "dropout_p": dropout, }, + extra_check=accept_match_fn, ) register_ad_pattern( search_fn=_grouped_attn_pattern_3, @@ -109,6 +111,7 @@ def register_grouped_attention(patterns: ADPatternMatcherPass): patterns=patterns, dummy_args=dummy_args_1, scalar_workaround={"scale": scale, "dropout_p": dropout, "n_rep": n_rep}, + extra_check=accept_match_fn, ) register_ad_pattern( search_fn=_grouped_attn_pattern_4, @@ -119,6 +122,7 @@ def register_grouped_attention(patterns: ADPatternMatcherPass): "scale": scale, "dropout_p": dropout, }, + extra_check=accept_match_fn, ) _apply_pattern(gm, "Grouped Attention", register_grouped_attention) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py index 618c8108f84..bf7adfe94e3 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py @@ -26,16 +26,14 @@ def update_in_out_nodes(egm: GraphModule, cm: CachedSequenceInterface) -> None: # loop through nodes to get input, output, and get_attr nodes input_nodes, output_nodes = get_all_input_output_nodes(egm.graph) - # we only expect one input node - assert len(input_nodes) == 2, "Expected exactly two input nodes (input_ids, position_ids)." - # NOTE: for now, we wanna make sure we *only* return the final output and no hidden states. # Later on, we can revisit how to support returning hidden states. - assert len(output_nodes) == 1, "Expected exactly one output node!" - assert len(output_nodes[0].all_input_nodes) == 1, "Expected to only return final tensor output!" - ad_logger.info(f"Found {len(input_nodes)} input nodes and {len(output_nodes)} output nodes") + assert len(output_nodes) == 1, "Expected exactly one output node!" + # the following assert feels too restrictive.. + # assert len(output_nodes[0].all_input_nodes) == 1, "Expected to only return final tensor output!" + # Activate and add extra argument nodes new_args = cm.info.switch_to_cached_attn_inputs() for name in new_args: @@ -73,16 +71,17 @@ def insert_cached_attention( # retrieve input nodes input_nodes, _ = get_all_input_output_nodes(egm.graph) + input_nodes_mapping = {n.target: n for n in input_nodes} + + # filtered and sorted for SequenceInfo arguments (input_ids, position_ids, etc.) + input_nodes_from_info = [input_nodes_mapping[k] for k in cm.info.named_standard_args.keys()] # insert metadata computation and extract each argument as a node get_metadata, num_metadata = attn_descriptor.get_prepare_metadata_op() with graph.inserting_before(input_nodes[-1].next): ret_node = graph.call_function( get_metadata, - args=( - *input_nodes, - cm.info.page_size, - ), + args=(*input_nodes_from_info, *cm.info.extra_args_for_prepare_metadata), ) metadata_nodes = [ graph.call_function(operator.getitem, args=(ret_node, idx)) @@ -162,7 +161,7 @@ def _get_mem_info_in_mb(): try: # Let's run a forward pass to get the memory usage - cm.info._set_max_num_tokens_sample() + cm.info.set_max_num_tokens_sample() free_mem_pre, _ = _get_mem_info_in_mb() ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}") diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py index 65e7f7f614c..ecaf84aae2e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py @@ -47,7 +47,7 @@ def apply_rotary_emb( import operator from collections import defaultdict -from typing import Any, DefaultDict, Dict, Optional, Sequence +from typing import Any, Callable, DefaultDict, Dict, Optional, Sequence import torch from torch.fx import GraphModule, Node @@ -119,7 +119,7 @@ def _explicit_not_interleaved(match: Match) -> bool: return not any(isinstance(n, Node) and _match_input_interleave_pattern(n) for n in (q, k)) -def match_rope_pattern(gm: GraphModule) -> int: +def match_rope_pattern(gm: GraphModule, accept_match_fn: Callable[[Match], bool]) -> int: graph = gm.graph patterns = ADPatternMatcherPass() @@ -154,7 +154,7 @@ def match_rope_pattern(gm: GraphModule) -> int: dummy_args=dummy_explicit, op_ignore_types={torch.ops.aten.slice.Tensor: (int,)}, scalar_workaround={"unsqueeze_dim": 1}, - extra_check=_explicit_not_interleaved, + extra_check=lambda match: _explicit_not_interleaved(match) and accept_match_fn(match), ) register_ad_pattern( search_fn=_interleaved_rope_pattern, @@ -167,6 +167,7 @@ def match_rope_pattern(gm: GraphModule) -> int: torch.ops.aten.view.default: (int,), }, scalar_workaround={"unsqueeze_dim": 1}, + extra_check=accept_match_fn, ) register_ad_pattern( search_fn=_complex_rope_pattern, @@ -177,6 +178,7 @@ def match_rope_pattern(gm: GraphModule) -> int: torch.ops.aten.reshape.default: (int,), }, scalar_workaround={"unsqueeze_dim": 1}, + extra_check=accept_match_fn, ) register_ad_pattern( search_fn=_complex_rope_pattern, @@ -187,6 +189,7 @@ def match_rope_pattern(gm: GraphModule) -> int: torch.ops.aten.reshape.default: (int,), }, scalar_workaround={"unsqueeze_dim": 1}, + extra_check=accept_match_fn, ) num_matches = patterns.apply(graph) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py index ed247753f83..a5dcdfb518c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py @@ -1,10 +1,14 @@ """High-level entrypoint to transform a model into an efficient inference model.""" import gc +from functools import partial import torch import torch.nn as nn +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_node_in_module +from tensorrt_llm._torch.auto_deploy.utils.pattern_matcher import Match + from ..compile import compile_and_capture from ..custom_ops.attention_interface import AttentionRegistry from ..distributed import common as dist_ad @@ -38,6 +42,17 @@ ) +def accept_match_fn(match: Match, name_to_skip: str) -> bool: + """ + Accept a match if it does not contain a node with the given name scope. + """ + for node in match.nodes: + if is_node_in_module(node, name_to_skip): + ad_logger.info(f"REJECTING MATCH: {match}") + return False + return True + + class InferenceOptimizer: def __init__(self, factory: ModelFactory, ad_config: AutoDeployConfig): self.factory = factory @@ -73,19 +88,19 @@ def __call__(self, cm: CachedSequenceInterface) -> nn.Module: match_moe_pattern(egm) # Match repeat_kv pattern - match_repeat_kv(egm) + match_repeat_kv(egm, partial(accept_match_fn, name_to_skip="vision_tower")) # Match eager attention pattern - match_eager_attention(egm) + match_eager_attention(egm, partial(accept_match_fn, name_to_skip="vision_tower")) # Match grouped attention pattern - match_grouped_attention(egm) + match_grouped_attention(egm, partial(accept_match_fn, name_to_skip="vision_tower")) # Match attention layout expected by our backend match_attention_layout(egm, AttentionRegistry.get(self.ad_config.attn_backend)) # Match rope - match_rope_pattern(egm) + match_rope_pattern(egm, partial(accept_match_fn, name_to_skip="vision_tower")) # Match RoPE layout expected by our backend match_rope_layout( diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 48f06c70e60..9586de5dae5 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -422,3 +422,12 @@ def _get(name): raise RuntimeError(f"Could not find a value for '{name}' on op {op}") return [_get(n) for n in arg_names] + + +def is_node_in_module(node: Node, module_name: str) -> bool: + """Check if the node is in the given module.""" + try: + nn_module_list = list(node.meta["nn_module_stack"].keys())[-1] + return module_name in nn_module_list + except: + return False diff --git a/tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py b/tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py index 00b535dec61..35a30d11ebb 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py @@ -32,6 +32,8 @@ from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +Match = torch._inductor.pattern_matcher.Match + @contextlib.contextmanager def _patch_unsupported_input_tensor(): @@ -83,6 +85,7 @@ def _trace_to_gm(fn: Callable, args: Sequence[torch.Tensor]) -> GraphModule: Exports a function or Module into a GraphModule via torch_export_to_gm. """ module = fn if isinstance(fn, torch.nn.Module) else _WrapperModule(fn) + return torch_export_to_gm(module, tuple(args)) diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index aa793d30ea6..31882b46491 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -80,7 +80,7 @@ def __init__( self) # TODO: make it weakref self._executor_config = executor_config self._is_pytorch_backend = getattr(self._executor_config, "backend", - None) == "pytorch" + None) in ["pytorch", "_autodeploy"] if global_mpi_size() > 1: logger.set_rank(self.global_rank) diff --git a/tests/unittest/_torch/auto_deploy/integration/test_llama4_vlm_export.py b/tests/unittest/_torch/auto_deploy/integration/test_llama4_vlm_export.py index 596b7ff50dc..aa290491c41 100644 --- a/tests/unittest/_torch/auto_deploy/integration/test_llama4_vlm_export.py +++ b/tests/unittest/_torch/auto_deploy/integration/test_llama4_vlm_export.py @@ -190,7 +190,10 @@ def test_build_run_llama4_vlm(): "content": [ {"type": "image", "image": img1}, {"type": "image", "image": img2}, - {"type": "text", "text": "What's the difference?"}, + { + "type": "text", + "text": "Describe what you see in the two images and their differences.", + }, ], }, ] diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py index e9d7acd7dc3..472bc71f1eb 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py @@ -71,7 +71,6 @@ def test_engine(engine_cls: Type[ADEngine], attn_backend: str, attn_page_size: i input_ids = [torch.tensor([0, 1, 2], device=device)] sequence_info.reset() sequence_info.nest_sequences(input_ids) - engine.cache_seq_interface.info.sync(sequence_info) logits = engine._compute_logits() logits = torch.stack(logits) assert logits is not None, "Logits are None" @@ -106,7 +105,6 @@ def test_demo_engine_sampling(attn_page_size: int): input_ids = [torch.tensor([1, 2, 3, 4], device=device)] sequence_info.reset() sequence_info.nest_sequences(input_ids) - engine.cache_seq_interface.info.sync(sequence_info) logits = engine._compute_logits() logits = torch.stack(logits) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py index 876eba196cc..4f1015a1268 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py @@ -38,17 +38,19 @@ def __init__( self.num_key_value_groups = None @torch.no_grad() - def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, input_ids: torch.Tensor, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: """ Forward pass with input tokens and optional position ids. position_ids parameter added to match expected interface in kvcache.py """ - b, s, _ = x.shape + b, s, _ = input_ids.shape # Project input to q, k, v representations - q = self.q_proj(x) # [b, s, n*h_d] - k = self.k_proj(x) # [b, s, n_kv*h_d] - v = self.v_proj(x) # [b, s, n_kv*h_d] + q = self.q_proj(input_ids) # [b, s, n*h_d] + k = self.k_proj(input_ids) # [b, s, n_kv*h_d] + v = self.v_proj(input_ids) # [b, s, n_kv*h_d] # Reshape to [b, s, n, h_d] q = q.view(b, s, self.num_heads, self.head_dim) @@ -185,9 +187,9 @@ def test_sdpa_with_kv_cache(dtype, attn_descriptor, gqa_config): cm.initialize_caches() # Helper function to call the model with proper sequence nesting - def _call_and_unnest(x): + def _call_and_unnest(x, input_pos): # Use nest_sequences to properly set input_ids and automatically update position_ids - cm.info.nest_sequences(x) + cm.info.nest_sequences(x, input_pos=input_pos) # Use the cm.args as is - it already contains the correct position_ids y = gm(*cm.args) @@ -197,31 +199,25 @@ def _call_and_unnest(x): # Test 1: Regular inference (all tokens at once) cm.info.reset() - y_no_cache = _call_and_unnest(x) + y_no_cache = _call_and_unnest(x, 0) assert all_close(y_model, y_no_cache, atol=atol, rtol=rtol) # Test 2: Autoregressive inference with KV cache cm.info.reset() y_with_cache = torch.empty_like(y_model) - for i in range(x.shape[1]): + for i_p in range(x.shape[1]): # Just pass the current token - y_with_cache[:, i : i + 1] = _call_and_unnest(x[:, i : i + 1]) - # Update position for next token - cm.info.update_pos(1) # This automatically updates position_ids too + y_with_cache[:, i_p : i_p + 1] = _call_and_unnest(x[:, i_p : i_p + 1], i_p) assert all_close(y_model, y_with_cache, atol=atol, rtol=rtol) # Test 3: Cache continuation after random tokens - cm.info.update_pos(-num_reset_steps) # Rewind position - for i in range(num_random_steps): - _call_and_unnest(torch.rand_like(x[:, :1])) - cm.info.update_pos(1) + for i_p in range(x.shape[1] - num_reset_steps, x.shape[1] - num_reset_steps + num_random_steps): + _call_and_unnest(torch.rand_like(x[:, :1]), i_p) # Continue inference from previous context cm.info.reset() - cm.info.update_pos(x.shape[1] - num_reset_steps) - for i in range(x.shape[1] - num_reset_steps, x.shape[1]): - y_with_cache[:, i : i + 1] = _call_and_unnest(x[:, i : i + 1]) - cm.info.update_pos(1) + for i_p in range(x.shape[1] - num_reset_steps, x.shape[1]): + y_with_cache[:, i_p : i_p + 1] = _call_and_unnest(x[:, i_p : i_p + 1], i_p) assert all_close(y_model, y_with_cache, atol=atol, rtol=rtol) # Test 4: Exportability of the transformed model