diff --git a/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py b/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py index 0b309ae2bf8..c2081e00df8 100644 --- a/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py +++ b/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py @@ -162,7 +162,7 @@ def forward(self, *args, **kwargs) -> Any: # copy inputs to input buffers for i, input_tensor in enumerate(args_batched): - self._input_buffers[i][: input_tensor.shape[0]] = input_tensor + self._input_buffers[i][: input_tensor.shape[0]].copy_(input_tensor, non_blocking=True) # run forward pass via graph self.graphs[combined_shape].replay() diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 13c91652bff..9387dee31cc 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -10,6 +10,7 @@ """ from abc import ABC, abstractmethod +from contextlib import contextmanager from dataclasses import dataclass, field, fields from typing import Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Type, Union @@ -17,7 +18,7 @@ from torch._ops import OpOverloadPacket from torch.export import Dim from torch.fx import Node - +from tensorrt_llm._utils import nvtx_range @dataclass class CacheConfig: @@ -87,11 +88,13 @@ class SequenceInfo: # Similarly, if a batch is composed of generate-only requests, # then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens). max_num_tokens: Optional[int] = None + # device is the device on which the sequence info is stored. + device: str = "cuda" ## [UPDATE WITH CARE] TENSOR FIELDS THAT WILL BE PASSED TO PREPARE_METADATA OP ################# # input_ids MUST ALWAYS BE THE FIRST FIELD - input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.int)) - position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.long)) + input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int)) + position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.long)) seq_len: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int)) input_pos: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int)) @@ -104,30 +107,43 @@ class SequenceInfo: _num_pages: int = 1 def __post_init__(self): + print("in __post_init__ device: ", self.device) if self.page_size < 1: self.page_size = self.max_seq_len # NOTE (lucaslie): WAR to address issue when using flashinfer attention with # (max_batch_size, max_seq_len) input in trtllm runtime. # see https://github.com/NVIDIA/TensorRT-LLM/issues/4504 - max_seq_len_adjusted = self.max_seq_len + 1 + self.max_seq_len_adjusted = self.max_seq_len + 1 if self.max_num_tokens is None or self.max_num_tokens < 1: - self.max_num_tokens = self.max_batch_size * max_seq_len_adjusted + self.max_num_tokens = self.max_batch_size * self.max_seq_len_adjusted # if the provided max_num_tokens is less than the max_batch_size * max_seq_len, # we use the provided max_num_tokens to calculate the number of pages - total_tokens = min(self.max_num_tokens, self.max_batch_size * max_seq_len_adjusted) + total_tokens = min(self.max_num_tokens, self.max_batch_size * self.max_seq_len_adjusted) # Num pages can not be less than max_batch_size. self._num_pages = max( self.max_batch_size, (total_tokens) // self.page_size + (total_tokens % self.page_size > 0), ) - self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int) - self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long) - self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int) - self.input_pos = torch.empty_like(self.seq_len) - self.cache_loc = torch.empty(self.num_pages, dtype=torch.int) - self.pages_per_seq = torch.empty_like(self.seq_len) + # Ensure that the device is set before initializing the tensors. + # Need to allocated input_ids and position_ids on the GPUs to avoid overheads of tensor creation in every forward pass. + self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device) + self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device) + + # Consumers of the sequence info args require input_ids and position_ids to be truncated. + # We maintain a full version of the input_ids and position_ids to avoid overheads of tensor creation in every forward pass. + self.input_ids_full = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device) + self.position_ids_full = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device) + + self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int, device=self.device) + self.input_pos = torch.empty_like(self.seq_len, device=self.device) + self.cache_loc = torch.empty(self.num_pages, dtype=torch.int, device=self.device) + self.pages_per_seq = torch.empty_like(self.seq_len, device=self.device) + + self.previous_batch_indices_cuda = torch.empty(self.max_num_tokens, + dtype=torch.long, + device=self.device) assert self.num_pages >= self.max_batch_size, ( "num_pages must be greater than max_batch_size" ) @@ -140,27 +156,34 @@ def __post_init__(self): # indicator if extra args are activated that are needed for cached attention backends self._is_cached_attn = False + # total number of tokens in the current batch + self.num_tokens : int = 0 + # call reset once to initialize the tensors self.reset() - @property - def device(self) -> torch.device: - return self.input_pos.device @property def args(self) -> Tuple[torch.Tensor, ...]: - args = [] - for f in fields(self): - val = getattr(self, f.name) - if isinstance(val, torch.Tensor): + @nvtx_range("attention_interface_args") + def get_args(): + args = [] + for f in fields(self): + val = getattr(self, f.name) + if not isinstance(val, torch.Tensor): + continue args.append(val) - if len(args) >= self._num_uncached_attn_args and not self._is_cached_attn: - break - return tuple(args) + if len(args) >= self._num_uncached_attn_args and not self._is_cached_attn: + break + + return tuple(args) + return get_args() @property def _num_uncached_attn_args(self) -> int: - """Return the number of original graph arguments expected by the model.""" + """Return the number of original graph arguments expected by the model. + This is 2 because we have input_ids and position_ids as the original graph arguments. + """ return 2 @property @@ -185,7 +208,7 @@ def dynamic_shapes(self) -> Tuple[Dict[str, Dim]]: dynamic_shapes = ({}, {}) if self.max_batch_size > 1: dynamic_shapes[0][0] = Dim("batch_size", max=self.max_batch_size) - dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len) + dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len_adjusted) # set up shape for position_ids (same as input_ids) dynamic_shapes[1].update(dynamic_shapes[0]) # set up shape for extra args @@ -336,11 +359,15 @@ def reset(self) -> None: self.input_pos.zero_() # set a dummy sequence corresponding to a generate-only batch (will also reset position_ids) - self.nest_sequences(torch.zeros(self.max_batch_size, 1, dtype=torch.int)) + self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True) # reset cache information self.cache_loc[:] = torch.arange(self.num_pages, dtype=torch.int, device=self.device) self.pages_per_seq.fill_(1) + + # let's also reset the input_ids and position_ids tensors to their max shapes (max_num_tokens) + self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device) + self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device) def set_example_sequence(self) -> None: """Set an example sequence useful for testing and export purposes.""" @@ -352,7 +379,7 @@ def set_example_sequence(self) -> None: dtype=torch.int, device=self.device, ) - self.nest_sequences(input_ids) + self.nest_sequences(input_ids, allow_realloc=True) # unflatten if we are not yet using cached+flattened attention if not self._is_cached_attn: @@ -370,7 +397,7 @@ def _set_max_num_tokens_sample(self) -> None: device=self.device, ) self.pages_per_seq.fill_(seq_len // self.page_size) - self.nest_sequences(input_ids) + self.nest_sequences(input_ids, allow_realloc=True) def set_generate_only_batch(self) -> None: """Set an example sequence for generate-only batch. @@ -379,32 +406,78 @@ def set_generate_only_batch(self) -> None: mode. So we don't need to do anything mode-specific here. """ self.reset() - self.nest_sequences([[1]] * self.max_batch_size) + self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True) - def _update_position_ids(self) -> None: - # set new position_ids as new tensor from input_pos and seq_len via torch.arange + def maybe_reshape_for_generate(self, tensor: torch.Tensor) -> torch.Tensor: + # use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len] + if self.is_generate: + return tensor.view(-1, 1, *tensor.shape[1:]) + else: + return tensor.view(1, -1, *tensor.shape[1:]) + + @nvtx_range("ad_update_position_ids") + def _update_position_ids(self, allow_realloc: bool = False) -> None: + # set new position_ids from input_pos and seq_len position_ids_list = [ num for in_pos, seq_len in zip(self.input_positions, self.sequence_lengths) for num in range(in_pos, in_pos + seq_len) ] - self.position_ids = torch.tensor(position_ids_list, dtype=torch.long).to(self.device) - - # use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len] - if self.is_generate: - self.position_ids = self.position_ids.view(-1, 1) + position_ids_host = torch.tensor(position_ids_list, dtype=torch.long, pin_memory=True) + if allow_realloc: + # Create a new position_ids tensor on the device + self.position_ids = position_ids_host.to(self.device).clone() else: - self.position_ids = self.position_ids.view(1, -1) + self.position_ids_full = self.position_ids_full.flatten() + self.position_ids_full[:len(position_ids_list)].copy_(position_ids_host, non_blocking=True) + + self.position_ids = self.maybe_reshape_for_generate(self.position_ids if allow_realloc else self.position_ids_full[:self.num_tokens]) + + @nvtx_range("ad_update_sequence_lengths") + def _update_sequence_lengths(self, sequence_lengths: List[int]) -> None: + self._sequence_lengths = sequence_lengths + self.num_tokens = sum(self._sequence_lengths) + self.seq_len.zero_() + self.seq_len[: len(self._sequence_lengths)].copy_(torch.tensor(self._sequence_lengths), non_blocking=True) + + def update_input_ids_with_new_tokens(self, + new_tokens: torch.Tensor, + previous_batch_indices: List[int]) -> None: + """Update the input_ids with new tokens. + + This function will update the input_ids with new tokens and previous batch indices. + """ + # 1) flatten once + original_shape = self.input_ids.shape + flat = self.input_ids.flatten() - def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None: - """Create and store a flattened list of input_ids from the provided list of sequences. + # copy indices to the GPU + host_idx = torch.tensor(previous_batch_indices, dtype=torch.int, pin_memory=True) + idx = self.previous_batch_indices_cuda[:len(previous_batch_indices)] + idx.copy_(host_idx, non_blocking=True) + + # sort them so that masked_scatter_ lines up correctly + idx, _ = idx.sort() + # gather the exact values you want to write + src = new_tokens[0, idx, 0] + + # in‐place fill every slot where flat == -1 with src, in order + flat.masked_scatter_(flat == -1, src) + + # 4) reshape back + self.input_ids = flat.view(original_shape) + + @nvtx_range("ad_nest_sequences") + def nest_sequences(self, input_ids: Sequence[Sequence[int]], allow_realloc: bool = False) -> None: + """Create and store a flattened list of input_ids from the provided list of sequences. + + When allow_realloc is True, the input_ids will be reallocated on the device. This i/f will also update any relevant sequence information. """ # set new sequence lengths - seq_lens = [len(ids) for ids in input_ids] - self.seq_len.zero_() - self.seq_len[: len(seq_lens)].copy_(torch.tensor(seq_lens), non_blocking=True) + self._update_sequence_lengths([len(ids) for ids in input_ids]) + # We'll preserve the dtype of the input_ids tensor if it is a tensor, otherwise we'll use int dtype = input_ids.dtype if isinstance(input_ids, torch.Tensor) else torch.int # set new input_ids as new tensor from flattened input_ids @@ -413,49 +486,50 @@ def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None: for lst in input_ids for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst) ] - self.input_ids = torch.tensor(ids_list, dtype=dtype).to(self.device) - - # set derivative properties - self._sequence_lengths = seq_lens + input_ids_host = torch.tensor(ids_list, dtype=dtype, pin_memory=True) - # use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len] - if self.is_generate: - self.input_ids = self.input_ids.view(-1, 1, *self.input_ids.shape[1:]) + if allow_realloc: + self.input_ids = input_ids_host.to(self.device).clone() else: - self.input_ids = self.input_ids.view(1, -1, *self.input_ids.shape[1:]) - + self.input_ids_full = self.input_ids_full.flatten() + self.input_ids_full[:self.num_tokens].copy_(input_ids_host, non_blocking=True) + + self.input_ids = self.maybe_reshape_for_generate(self.input_ids if allow_realloc else self.input_ids_full[:self.num_tokens]) # update position_ids - self._update_position_ids() + self._update_position_ids(allow_realloc=allow_realloc) + @nvtx_range("ad_unnest_sequences") def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]: t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0) return list(torch.split(t_squeezed, self.sequence_lengths)) + @nvtx_range("ad_update_pos") def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool = False) -> None: """Update the starting position for each sequence in the cache. If ``reset=True`, ``input_pos`` will be reset to zero before updating. """ if not isinstance(seq_len, torch.Tensor): - seq_len = torch.tensor(seq_len, dtype=torch.int) + seq_len = torch.tensor(seq_len, dtype=torch.int, pin_memory=True) bs = len(seq_len) if seq_len.dim() > 0 else self.max_batch_size if reset: - self.input_pos[:bs] = seq_len.to(self.device) + self.input_pos[:bs].copy_(seq_len, non_blocking=True) else: self.input_pos[:bs] += seq_len.to(self.device) # update position_ids self._update_position_ids() + @nvtx_range("ad_assign_cache_loc") def assign_cache_loc(self, page_assignments: Sequence[Sequence[int]]) -> None: """Set the cache location and pages_per_seq tensors from page assignments.""" cache_loc_flat = torch.tensor( - [p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int + [p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int, pin_memory=True ) self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True) - pages_per_seq = torch.tensor([len(p) for p in page_assignments], dtype=torch.int) + pages_per_seq = torch.tensor([len(p) for p in page_assignments], dtype=torch.int, pin_memory=True) self.pages_per_seq[: len(pages_per_seq)].copy_(pages_per_seq, non_blocking=True) diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 7f759d6796d..1a60c8ad4fe 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -94,19 +94,22 @@ def build_from_config(cls, ad_config: AutoDeployConfig): f"{max_seq_len=}, {max_batch_size=}, {attn_page_size=}, {max_num_tokens=}, {max_beam_width=}" ) + # update device to contain the current default device if it's in cuda + device = torch.device(ad_config.device) + if device.type == "cuda" and device.index is None: + device = torch.device(f"cuda:{torch.cuda.current_device()}") + device = str(device) + # initialize seq info object seq_info = SequenceInfo( max_seq_len=max_seq_len, max_batch_size=max_batch_size, page_size=attn_page_size, max_num_tokens=max_num_tokens, + device=device, ) + print(" in seq_info for device: ", torch.cuda.current_device()) - # update device to contain the current default device if it's in cuda - device = torch.device(ad_config.device) - if device.type == "cuda" and device.index is None: - device = torch.device(f"cuda:{torch.cuda.current_device()}") - device = str(device) # construct inference optimizer build_and_optimize = InferenceOptimizer( @@ -167,16 +170,12 @@ def _prepare_inputs( context_requests = scheduled_requests.context_requests gen_requests = [r for r in scheduled_requests.generation_requests if not r.draft_tokens] - # new_tokens is a tensor on the device, we need to convert it to a list of lists. - # can we avoid this additional gpu->cpu transfer? - new_tokens_list = new_tokens.flatten().cpu().tolist() if new_tokens is not None else None - # info to be extracted input_ids: List[List[int]] = [] input_pos: List[int] = [] last_logit_only: List[bool] = [] page_assignments: List[List[int]] = [] - + previous_batch_indices: List[int] = [] # look at context requests first for request in context_requests: # store input ids and pos of first token in sequence @@ -190,11 +189,13 @@ def _prepare_inputs( # TODO: we should also handle extend requests (for speculative decoding) here for request in gen_requests: # new_tokens are provided when the overlap scheduler is enabled. - if new_tokens_list is None or request.is_dummy or request.py_batch_idx is None: + if new_tokens is None or request.is_dummy or request.py_batch_idx is None: input_ids.append([request.get_token(0, request.get_num_tokens(0) - 1)]) input_pos.append(request.max_beam_num_tokens - 1) else: - input_ids.append([new_tokens_list[request.py_batch_idx]]) + # insert a dummy token to indicate the new tokens + input_ids.append([-1]) + previous_batch_indices.append(request.py_batch_idx) input_pos.append(request.max_beam_num_tokens) request.py_batch_idx = request.seq_slot @@ -207,14 +208,17 @@ def _prepare_inputs( # get cache indices cache_indices = kv_cache_manager.get_cache_indices(request) page_assignments.append(cache_indices) - + # update the sequence info object now si = self.cache_seq_interface.info si.nest_sequences(input_ids) si.update_pos(input_pos, reset=True) si.assign_cache_loc(page_assignments) + if new_tokens is not None: + si.update_input_ids_with_new_tokens(new_tokens, previous_batch_indices) return last_logit_only + @nvtx_range("ad_compute_logits") def _compute_logits(self) -> List[torch.Tensor]: # run the model logits: torch.Tensor = self.model(*self.cache_seq_interface.args)[0] @@ -231,13 +235,13 @@ def forward( self, scheduled_requests: ScheduledRequests, resource_manager: ResourceManager, - new_tokens_device: Optional[torch.Tensor] = None, + new_tensors_device: Optional[torch.Tensor] = None, gather_context_logits: bool = False, cache_indirection_buffer: Optional[torch.Tensor] = None, ): """Run forward from scheduled requests; main entrypoint that gets called by the executor.""" # convert requests and store in sequence info object - new_tokens = getattr(new_tokens_device, "new_tokens", None) + new_tokens = getattr(new_tensors_device, "new_tokens", None) last_logit_only = self._prepare_inputs(scheduled_requests, resource_manager, new_tokens) # compute all logits diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py index bbe72650b4e..fec32a6706e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py @@ -52,7 +52,7 @@ def _apply( # retrieve the actual model from the dummy graph module model = gm.get_submodule("factory_model") - # set the example sequence + # set an example sequence context cm.info.set_example_sequence() # export the model to a graph module diff --git a/tensorrt_llm/_torch/pyexecutor/config.py b/tensorrt_llm/_torch/pyexecutor/config.py index 181f2b0bdc0..b1e7b87f471 100644 --- a/tensorrt_llm/_torch/pyexecutor/config.py +++ b/tensorrt_llm/_torch/pyexecutor/config.py @@ -46,7 +46,7 @@ class PyTorchConfig: moe_max_num_tokens: Optional[int] = None moe_load_balancer: Optional[Union[MoeLoadBalancerConfig, dict, str]] = None - attn_backend: str = 'TRTLLM' + attn_backend: str = 'FLASHINFER' moe_backend: str = 'CUTLASS' enable_mixed_sampler: bool = False diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index e5b302310fc..5ae759da34e 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1683,7 +1683,7 @@ def _forward_step(self, new_tensors_device: Optional[SampleStateTensors] = None): @nvtx_range( - f"[Executor] _forward_step {self.model_engine.iter_counter}: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs" + f"[Executor PP] _forward_step {self.model_engine.iter_counter}: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs" ) def forward(scheduled_requests, resource_manager, new_tensors_device, gather_context_logits, cache_indirection_buffer):