Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def forward(self, *args, **kwargs) -> Any:

# copy inputs to input buffers
for i, input_tensor in enumerate(args_batched):
self._input_buffers[i][: input_tensor.shape[0]] = input_tensor
self._input_buffers[i][: input_tensor.shape[0]].copy_(input_tensor, non_blocking=True)

# run forward pass via graph
self.graphs[combined_shape].replay()
Expand Down
184 changes: 129 additions & 55 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
"""

from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field, fields
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Type, Union

import torch
from torch._ops import OpOverloadPacket
from torch.export import Dim
from torch.fx import Node

from tensorrt_llm._utils import nvtx_range

@dataclass
class CacheConfig:
Expand Down Expand Up @@ -87,11 +88,13 @@ class SequenceInfo:
# Similarly, if a batch is composed of generate-only requests,
# then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens).
max_num_tokens: Optional[int] = None
# device is the device on which the sequence info is stored.
device: str = "cuda"

## [UPDATE WITH CARE] TENSOR FIELDS THAT WILL BE PASSED TO PREPARE_METADATA OP #################
# input_ids MUST ALWAYS BE THE FIRST FIELD
input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.int))
position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.long))
input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int))
position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.long))

seq_len: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int))
input_pos: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int))
Expand All @@ -104,30 +107,43 @@ class SequenceInfo:
_num_pages: int = 1

def __post_init__(self):
print("in __post_init__ device: ", self.device)
if self.page_size < 1:
self.page_size = self.max_seq_len

# NOTE (lucaslie): WAR to address issue when using flashinfer attention with
# (max_batch_size, max_seq_len) input in trtllm runtime.
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
max_seq_len_adjusted = self.max_seq_len + 1
self.max_seq_len_adjusted = self.max_seq_len + 1

if self.max_num_tokens is None or self.max_num_tokens < 1:
self.max_num_tokens = self.max_batch_size * max_seq_len_adjusted
self.max_num_tokens = self.max_batch_size * self.max_seq_len_adjusted
# if the provided max_num_tokens is less than the max_batch_size * max_seq_len,
# we use the provided max_num_tokens to calculate the number of pages
total_tokens = min(self.max_num_tokens, self.max_batch_size * max_seq_len_adjusted)
total_tokens = min(self.max_num_tokens, self.max_batch_size * self.max_seq_len_adjusted)
# Num pages can not be less than max_batch_size.
self._num_pages = max(
self.max_batch_size,
(total_tokens) // self.page_size + (total_tokens % self.page_size > 0),
)
self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int)
self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long)
self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int)
self.input_pos = torch.empty_like(self.seq_len)
self.cache_loc = torch.empty(self.num_pages, dtype=torch.int)
self.pages_per_seq = torch.empty_like(self.seq_len)
# Ensure that the device is set before initializing the tensors.
# Need to allocated input_ids and position_ids on the GPUs to avoid overheads of tensor creation in every forward pass.
self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device)

# Consumers of the sequence info args require input_ids and position_ids to be truncated.
# We maintain a full version of the input_ids and position_ids to avoid overheads of tensor creation in every forward pass.
self.input_ids_full = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
self.position_ids_full = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device)

self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int, device=self.device)
self.input_pos = torch.empty_like(self.seq_len, device=self.device)
self.cache_loc = torch.empty(self.num_pages, dtype=torch.int, device=self.device)
self.pages_per_seq = torch.empty_like(self.seq_len, device=self.device)

self.previous_batch_indices_cuda = torch.empty(self.max_num_tokens,
dtype=torch.long,
device=self.device)
assert self.num_pages >= self.max_batch_size, (
"num_pages must be greater than max_batch_size"
)
Expand All @@ -140,27 +156,34 @@ def __post_init__(self):
# indicator if extra args are activated that are needed for cached attention backends
self._is_cached_attn = False

# total number of tokens in the current batch
self.num_tokens : int = 0

# call reset once to initialize the tensors
self.reset()

@property
def device(self) -> torch.device:
return self.input_pos.device

@property
def args(self) -> Tuple[torch.Tensor, ...]:
args = []
for f in fields(self):
val = getattr(self, f.name)
if isinstance(val, torch.Tensor):
@nvtx_range("attention_interface_args")
def get_args():
args = []
for f in fields(self):
val = getattr(self, f.name)
if not isinstance(val, torch.Tensor):
continue
args.append(val)
if len(args) >= self._num_uncached_attn_args and not self._is_cached_attn:
break
return tuple(args)
if len(args) >= self._num_uncached_attn_args and not self._is_cached_attn:
break

return tuple(args)
return get_args()

@property
def _num_uncached_attn_args(self) -> int:
"""Return the number of original graph arguments expected by the model."""
"""Return the number of original graph arguments expected by the model.
This is 2 because we have input_ids and position_ids as the original graph arguments.
"""
return 2

@property
Expand All @@ -185,7 +208,7 @@ def dynamic_shapes(self) -> Tuple[Dict[str, Dim]]:
dynamic_shapes = ({}, {})
if self.max_batch_size > 1:
dynamic_shapes[0][0] = Dim("batch_size", max=self.max_batch_size)
dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len)
dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len_adjusted)
# set up shape for position_ids (same as input_ids)
dynamic_shapes[1].update(dynamic_shapes[0])
# set up shape for extra args
Expand Down Expand Up @@ -336,11 +359,15 @@ def reset(self) -> None:
self.input_pos.zero_()

# set a dummy sequence corresponding to a generate-only batch (will also reset position_ids)
self.nest_sequences(torch.zeros(self.max_batch_size, 1, dtype=torch.int))
self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True)

# reset cache information
self.cache_loc[:] = torch.arange(self.num_pages, dtype=torch.int, device=self.device)
self.pages_per_seq.fill_(1)

# let's also reset the input_ids and position_ids tensors to their max shapes (max_num_tokens)
self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device)

def set_example_sequence(self) -> None:
"""Set an example sequence useful for testing and export purposes."""
Expand All @@ -352,7 +379,7 @@ def set_example_sequence(self) -> None:
dtype=torch.int,
device=self.device,
)
self.nest_sequences(input_ids)
self.nest_sequences(input_ids, allow_realloc=True)

# unflatten if we are not yet using cached+flattened attention
if not self._is_cached_attn:
Expand All @@ -370,7 +397,7 @@ def _set_max_num_tokens_sample(self) -> None:
device=self.device,
)
self.pages_per_seq.fill_(seq_len // self.page_size)
self.nest_sequences(input_ids)
self.nest_sequences(input_ids, allow_realloc=True)

def set_generate_only_batch(self) -> None:
"""Set an example sequence for generate-only batch.
Expand All @@ -379,32 +406,78 @@ def set_generate_only_batch(self) -> None:
mode. So we don't need to do anything mode-specific here.
"""
self.reset()
self.nest_sequences([[1]] * self.max_batch_size)
self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True)

def _update_position_ids(self) -> None:
# set new position_ids as new tensor from input_pos and seq_len via torch.arange
def maybe_reshape_for_generate(self, tensor: torch.Tensor) -> torch.Tensor:
# use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len]
if self.is_generate:
return tensor.view(-1, 1, *tensor.shape[1:])
else:
return tensor.view(1, -1, *tensor.shape[1:])

@nvtx_range("ad_update_position_ids")
def _update_position_ids(self, allow_realloc: bool = False) -> None:
# set new position_ids from input_pos and seq_len
position_ids_list = [
num
for in_pos, seq_len in zip(self.input_positions, self.sequence_lengths)
for num in range(in_pos, in_pos + seq_len)
]
self.position_ids = torch.tensor(position_ids_list, dtype=torch.long).to(self.device)

# use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len]
if self.is_generate:
self.position_ids = self.position_ids.view(-1, 1)
position_ids_host = torch.tensor(position_ids_list, dtype=torch.long, pin_memory=True)
if allow_realloc:
# Create a new position_ids tensor on the device
self.position_ids = position_ids_host.to(self.device).clone()
else:
self.position_ids = self.position_ids.view(1, -1)
self.position_ids_full = self.position_ids_full.flatten()
self.position_ids_full[:len(position_ids_list)].copy_(position_ids_host, non_blocking=True)

self.position_ids = self.maybe_reshape_for_generate(self.position_ids if allow_realloc else self.position_ids_full[:self.num_tokens])

@nvtx_range("ad_update_sequence_lengths")
def _update_sequence_lengths(self, sequence_lengths: List[int]) -> None:
self._sequence_lengths = sequence_lengths
self.num_tokens = sum(self._sequence_lengths)
self.seq_len.zero_()
self.seq_len[: len(self._sequence_lengths)].copy_(torch.tensor(self._sequence_lengths), non_blocking=True)

def update_input_ids_with_new_tokens(self,
new_tokens: torch.Tensor,
previous_batch_indices: List[int]) -> None:
"""Update the input_ids with new tokens.

This function will update the input_ids with new tokens and previous batch indices.
"""
# 1) flatten once
original_shape = self.input_ids.shape
flat = self.input_ids.flatten()

def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None:
"""Create and store a flattened list of input_ids from the provided list of sequences.
# copy indices to the GPU
host_idx = torch.tensor(previous_batch_indices, dtype=torch.int, pin_memory=True)
idx = self.previous_batch_indices_cuda[:len(previous_batch_indices)]
idx.copy_(host_idx, non_blocking=True)

# sort them so that masked_scatter_ lines up correctly
idx, _ = idx.sort()

# gather the exact values you want to write
src = new_tokens[0, idx, 0]

# in‐place fill every slot where flat == -1 with src, in order
flat.masked_scatter_(flat == -1, src)

# 4) reshape back
self.input_ids = flat.view(original_shape)

@nvtx_range("ad_nest_sequences")
def nest_sequences(self, input_ids: Sequence[Sequence[int]], allow_realloc: bool = False) -> None:
"""Create and store a flattened list of input_ids from the provided list of sequences.

When allow_realloc is True, the input_ids will be reallocated on the device.
This i/f will also update any relevant sequence information.
"""
# set new sequence lengths
seq_lens = [len(ids) for ids in input_ids]
self.seq_len.zero_()
self.seq_len[: len(seq_lens)].copy_(torch.tensor(seq_lens), non_blocking=True)
self._update_sequence_lengths([len(ids) for ids in input_ids])

# We'll preserve the dtype of the input_ids tensor if it is a tensor, otherwise we'll use int
dtype = input_ids.dtype if isinstance(input_ids, torch.Tensor) else torch.int
# set new input_ids as new tensor from flattened input_ids
Expand All @@ -413,49 +486,50 @@ def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None:
for lst in input_ids
for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst)
]
self.input_ids = torch.tensor(ids_list, dtype=dtype).to(self.device)

# set derivative properties
self._sequence_lengths = seq_lens
input_ids_host = torch.tensor(ids_list, dtype=dtype, pin_memory=True)

# use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len]
if self.is_generate:
self.input_ids = self.input_ids.view(-1, 1, *self.input_ids.shape[1:])
if allow_realloc:
self.input_ids = input_ids_host.to(self.device).clone()
else:
self.input_ids = self.input_ids.view(1, -1, *self.input_ids.shape[1:])

self.input_ids_full = self.input_ids_full.flatten()
self.input_ids_full[:self.num_tokens].copy_(input_ids_host, non_blocking=True)

self.input_ids = self.maybe_reshape_for_generate(self.input_ids if allow_realloc else self.input_ids_full[:self.num_tokens])
# update position_ids
self._update_position_ids()
self._update_position_ids(allow_realloc=allow_realloc)

@nvtx_range("ad_unnest_sequences")
def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0)
return list(torch.split(t_squeezed, self.sequence_lengths))

@nvtx_range("ad_update_pos")
def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool = False) -> None:
"""Update the starting position for each sequence in the cache.

If ``reset=True`, ``input_pos`` will be reset to zero before updating.
"""
if not isinstance(seq_len, torch.Tensor):
seq_len = torch.tensor(seq_len, dtype=torch.int)
seq_len = torch.tensor(seq_len, dtype=torch.int, pin_memory=True)
bs = len(seq_len) if seq_len.dim() > 0 else self.max_batch_size

if reset:
self.input_pos[:bs] = seq_len.to(self.device)
self.input_pos[:bs].copy_(seq_len, non_blocking=True)
else:
self.input_pos[:bs] += seq_len.to(self.device)

# update position_ids
self._update_position_ids()

@nvtx_range("ad_assign_cache_loc")
def assign_cache_loc(self, page_assignments: Sequence[Sequence[int]]) -> None:
"""Set the cache location and pages_per_seq tensors from page assignments."""
cache_loc_flat = torch.tensor(
[p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int
[p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int, pin_memory=True
)
self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True)

pages_per_seq = torch.tensor([len(p) for p in page_assignments], dtype=torch.int)
pages_per_seq = torch.tensor([len(p) for p in page_assignments], dtype=torch.int, pin_memory=True)
self.pages_per_seq[: len(pages_per_seq)].copy_(pages_per_seq, non_blocking=True)


Expand Down
Loading
Loading