Skip to content

Commit 3c5aec1

Browse files
galagamsuyoggupta
andauthored
[#5048][enhance] AutoDeploy: Optimize prepare_inputs (#6634)
Optimize prepare_inputs routine in AutoDeploy, as part of the effort to reduce the performance gap compared to the default backend. This PR includes two major fixes, and some other minor tweaks: 1. Avoid back and forth data copies 2. Optimize position ids update by separating the implementation for generation mode and context mode. Signed-off-by: Suyog Gupta <[email protected]> Signed-off-by: Gal Hubara Agam <[email protected]> Co-authored-by: Suyog Gupta <[email protected]>
1 parent ee19ca5 commit 3c5aec1

File tree

4 files changed

+175
-69
lines changed

4 files changed

+175
-69
lines changed

tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def forward(self, *args, **kwargs) -> Any:
162162

163163
# copy inputs to input buffers
164164
for i, input_tensor in enumerate(args_batched):
165-
self._input_buffers[i][: input_tensor.shape[0]] = input_tensor
165+
self._input_buffers[i][: input_tensor.shape[0]].copy_(input_tensor, non_blocking=True)
166166

167167
# run forward pass via graph
168168
self.graphs[combined_shape].replay()

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 154 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from torch.export import Dim
1919
from torch.fx import Node
2020

21+
from tensorrt_llm._utils import nvtx_range
22+
2123

2224
@dataclass
2325
class CacheConfig:
@@ -87,11 +89,13 @@ class SequenceInfo:
8789
# Similarly, if a batch is composed of generate-only requests,
8890
# then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens).
8991
max_num_tokens: Optional[int] = None
92+
# device is the device on which the sequence info is stored.
93+
device: str = "cuda"
9094

9195
## [UPDATE WITH CARE] TENSOR FIELDS THAT WILL BE PASSED TO PREPARE_METADATA OP #################
9296
# input_ids MUST ALWAYS BE THE FIRST FIELD
93-
input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.int))
94-
position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.long))
97+
input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int))
98+
position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.long))
9599

96100
seq_len: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int))
97101
input_pos: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int))
@@ -110,24 +114,44 @@ def __post_init__(self):
110114
# NOTE (lucaslie): WAR to address issue when using flashinfer attention with
111115
# (max_batch_size, max_seq_len) input in trtllm runtime.
112116
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
113-
max_seq_len_adjusted = self.max_seq_len + 1
117+
self.max_seq_len_adjusted = self.max_seq_len + 1
114118

115119
if self.max_num_tokens is None or self.max_num_tokens < 1:
116-
self.max_num_tokens = self.max_batch_size * max_seq_len_adjusted
120+
self.max_num_tokens = self.max_batch_size * self.max_seq_len_adjusted
117121
# if the provided max_num_tokens is less than the max_batch_size * max_seq_len,
118122
# we use the provided max_num_tokens to calculate the number of pages
119-
total_tokens = min(self.max_num_tokens, self.max_batch_size * max_seq_len_adjusted)
123+
total_tokens = min(self.max_num_tokens, self.max_batch_size * self.max_seq_len_adjusted)
120124
# Num pages can not be less than max_batch_size.
121125
self._num_pages = max(
122126
self.max_batch_size,
123127
(total_tokens) // self.page_size + (total_tokens % self.page_size > 0),
124128
)
125-
self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int)
126-
self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long)
127-
self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int)
128-
self.input_pos = torch.empty_like(self.seq_len)
129-
self.cache_loc = torch.empty(self.num_pages, dtype=torch.int)
130-
self.pages_per_seq = torch.empty_like(self.seq_len)
129+
# Ensure that the device is set before initializing the tensors.
130+
self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
131+
self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device)
132+
133+
# Consumers of the sequence info args require input_ids and position_ids to be truncated.
134+
# We maintain a full version of the input_ids and position_ids to avoid overheads of tensor
135+
# creation in every forward pass.
136+
self.input_ids_full = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
137+
self.position_ids_full = torch.zeros(
138+
self.max_num_tokens, dtype=torch.long, device=self.device
139+
)
140+
141+
self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int, device=self.device)
142+
self.input_pos = torch.empty_like(self.seq_len, device=self.device)
143+
144+
# Allocated host tensors for sequence lengths and input positions so that
145+
# position_ids calculation can be done on host.
146+
self.seq_len_host = torch.empty(self.max_batch_size, dtype=torch.int)
147+
self.input_pos_host = torch.empty_like(self.seq_len_host)
148+
149+
self.cache_loc = torch.empty(self.num_pages, dtype=torch.int, device=self.device)
150+
self.pages_per_seq = torch.empty_like(self.seq_len, device=self.device)
151+
152+
self.previous_batch_indices_cuda = torch.empty(
153+
self.max_num_tokens, dtype=torch.long, device=self.device
154+
)
131155
assert self.num_pages >= self.max_batch_size, (
132156
"num_pages must be greater than max_batch_size"
133157
)
@@ -140,13 +164,12 @@ def __post_init__(self):
140164
# indicator if extra args are activated that are needed for cached attention backends
141165
self._is_cached_attn = False
142166

167+
# total number of tokens in the current batch
168+
self.num_tokens: int = 0
169+
143170
# call reset once to initialize the tensors
144171
self.reset()
145172

146-
@property
147-
def device(self) -> torch.device:
148-
return self.input_pos.device
149-
150173
@property
151174
def args(self) -> Tuple[torch.Tensor, ...]:
152175
args = []
@@ -156,11 +179,14 @@ def args(self) -> Tuple[torch.Tensor, ...]:
156179
args.append(val)
157180
if len(args) >= self._num_uncached_attn_args and not self._is_cached_attn:
158181
break
182+
159183
return tuple(args)
160184

161185
@property
162186
def _num_uncached_attn_args(self) -> int:
163-
"""Return the number of original graph arguments expected by the model."""
187+
"""Return the number of original graph arguments expected by the model.
188+
This is 2 because we have input_ids and position_ids as the original graph arguments.
189+
"""
164190
return 2
165191

166192
@property
@@ -185,7 +211,7 @@ def dynamic_shapes(self) -> Tuple[Dict[str, Dim]]:
185211
dynamic_shapes = ({}, {})
186212
if self.max_batch_size > 1:
187213
dynamic_shapes[0][0] = Dim("batch_size", max=self.max_batch_size)
188-
dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len)
214+
dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len_adjusted)
189215
# set up shape for position_ids (same as input_ids)
190216
dynamic_shapes[1].update(dynamic_shapes[0])
191217
# set up shape for extra args
@@ -204,7 +230,7 @@ def sequence_lengths(self) -> List[int]:
204230

205231
@property
206232
def input_positions(self) -> List[int]:
207-
return self.input_pos[: self.num_sequences].tolist()
233+
return self.input_pos_host[: self.num_sequences].tolist()
208234

209235
@property
210236
def is_generate(self) -> bool:
@@ -334,14 +360,19 @@ def reset(self) -> None:
334360
"""
335361
# reset input_pos
336362
self.input_pos.zero_()
363+
self.input_pos_host.zero_()
337364

338365
# set a dummy sequence corresponding to a generate-only batch (will also reset position_ids)
339-
self.nest_sequences(torch.zeros(self.max_batch_size, 1, dtype=torch.int))
366+
self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True)
340367

341368
# reset cache information
342369
self.cache_loc[:] = torch.arange(self.num_pages, dtype=torch.int, device=self.device)
343370
self.pages_per_seq.fill_(1)
344371

372+
# let's also reset the input_ids and position_ids tensors to their max shapes (max_num_tokens)
373+
self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
374+
self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device)
375+
345376
def set_example_sequence(self) -> None:
346377
"""Set an example sequence useful for testing and export purposes."""
347378
self.reset()
@@ -352,7 +383,7 @@ def set_example_sequence(self) -> None:
352383
dtype=torch.int,
353384
device=self.device,
354385
)
355-
self.nest_sequences(input_ids)
386+
self.nest_sequences(input_ids, allow_realloc=True)
356387

357388
# unflatten if we are not yet using cached+flattened attention
358389
if not self._is_cached_attn:
@@ -370,7 +401,7 @@ def _set_max_num_tokens_sample(self) -> None:
370401
device=self.device,
371402
)
372403
self.pages_per_seq.fill_(seq_len // self.page_size)
373-
self.nest_sequences(input_ids)
404+
self.nest_sequences(input_ids, allow_realloc=True)
374405

375406
def set_generate_only_batch(self) -> None:
376407
"""Set an example sequence for generate-only batch.
@@ -379,32 +410,96 @@ def set_generate_only_batch(self) -> None:
379410
mode. So we don't need to do anything mode-specific here.
380411
"""
381412
self.reset()
382-
self.nest_sequences([[1]] * self.max_batch_size)
383-
384-
def _update_position_ids(self) -> None:
385-
# set new position_ids as new tensor from input_pos and seq_len via torch.arange
386-
position_ids_list = [
387-
num
388-
for in_pos, seq_len in zip(self.input_positions, self.sequence_lengths)
389-
for num in range(in_pos, in_pos + seq_len)
390-
]
391-
self.position_ids = torch.tensor(position_ids_list, dtype=torch.long).to(self.device)
413+
self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True)
392414

415+
def maybe_reshape_for_generate(self, tensor: torch.Tensor) -> torch.Tensor:
393416
# use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len]
394417
if self.is_generate:
395-
self.position_ids = self.position_ids.view(-1, 1)
418+
return tensor.view(-1, 1, *tensor.shape[1:])
396419
else:
397-
self.position_ids = self.position_ids.view(1, -1)
420+
return tensor.view(1, -1, *tensor.shape[1:])
421+
422+
@nvtx_range("ad_update_position_ids")
423+
def _update_position_ids(self, allow_realloc: bool = False) -> None:
424+
# set new position_ids from input_pos and seq_len
425+
# Make sure this is done on host to avoid host-device copies.
426+
with nvtx_range("prepare_list"):
427+
# Optimize for the common case where all seq_len values are 1 (generation mode)
428+
if torch.all(self.seq_len_host == 1):
429+
# Fast path: when all seq_len are 1, position_ids is just input_pos_host
430+
position_ids_host = (
431+
self.input_pos_host[: self.num_tokens].to(dtype=torch.long).pin_memory()
432+
)
433+
else:
434+
# General case - can probably be optimized too, but overall impact will be minor.
435+
position_ids_list = []
436+
for in_pos, seq_len in zip(self.input_pos_host, self.seq_len_host):
437+
position_ids_list.extend(range(in_pos, in_pos + seq_len))
438+
position_ids_host = torch.tensor(
439+
position_ids_list, dtype=torch.long, pin_memory=True
440+
)
441+
with nvtx_range("copy_to_device"):
442+
if allow_realloc:
443+
# Create a new position_ids tensor on the device
444+
self.position_ids = position_ids_host.to(self.device).clone()
445+
else:
446+
self.position_ids_full = self.position_ids_full.flatten()
447+
self.position_ids_full[: self.num_tokens].copy_(
448+
position_ids_host, non_blocking=True
449+
)
450+
with nvtx_range("maybe_reshape"):
451+
self.position_ids = self.maybe_reshape_for_generate(
452+
self.position_ids if allow_realloc else self.position_ids_full[: self.num_tokens]
453+
)
454+
455+
@nvtx_range("ad_update_sequence_lengths")
456+
def _update_sequence_lengths(self, sequence_lengths: List[int]) -> None:
457+
self._sequence_lengths = sequence_lengths
458+
self.num_tokens = sum(self._sequence_lengths)
459+
self.seq_len.zero_()
460+
self.seq_len_host = torch.tensor(self._sequence_lengths, pin_memory=True)
461+
self.seq_len[: len(self._sequence_lengths)].copy_(self.seq_len_host, non_blocking=True)
462+
463+
def update_input_ids_with_new_tokens(
464+
self, new_tokens: torch.Tensor, previous_batch_indices: List[int]
465+
) -> None:
466+
"""Update the input_ids with new tokens.
467+
468+
This function will update the input_ids with new tokens and previous batch indices.
469+
"""
470+
# 1) flatten once
471+
original_shape = self.input_ids.shape
472+
flat = self.input_ids.flatten()
473+
474+
# copy indices to the GPU
475+
host_idx = torch.tensor(previous_batch_indices, dtype=torch.int, pin_memory=True)
476+
idx = self.previous_batch_indices_cuda[: len(previous_batch_indices)]
477+
idx.copy_(host_idx, non_blocking=True)
478+
479+
# sort them so that masked_scatter_ lines up correctly
480+
idx, _ = idx.sort()
481+
482+
# gather the exact values you want to write
483+
src = new_tokens[0, idx, 0]
484+
485+
# in‐place fill every slot where flat == -1 with src, in order
486+
flat.masked_scatter_(flat == -1, src)
487+
488+
# 4) reshape back
489+
self.input_ids = flat.view(original_shape)
398490

399-
def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None:
491+
@nvtx_range("ad_nest_sequences")
492+
def nest_sequences(
493+
self, input_ids: Sequence[Sequence[int]], allow_realloc: bool = False
494+
) -> None:
400495
"""Create and store a flattened list of input_ids from the provided list of sequences.
401496
497+
When allow_realloc is True, the input_ids will be reallocated on the device.
402498
This i/f will also update any relevant sequence information.
403499
"""
404500
# set new sequence lengths
405-
seq_lens = [len(ids) for ids in input_ids]
406-
self.seq_len.zero_()
407-
self.seq_len[: len(seq_lens)].copy_(torch.tensor(seq_lens), non_blocking=True)
501+
self._update_sequence_lengths([len(ids) for ids in input_ids])
502+
408503
# We'll preserve the dtype of the input_ids tensor if it is a tensor, otherwise we'll use int
409504
dtype = input_ids.dtype if isinstance(input_ids, torch.Tensor) else torch.int
410505
# set new input_ids as new tensor from flattened input_ids
@@ -413,49 +508,57 @@ def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None:
413508
for lst in input_ids
414509
for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst)
415510
]
416-
self.input_ids = torch.tensor(ids_list, dtype=dtype).to(self.device)
417-
418-
# set derivative properties
419-
self._sequence_lengths = seq_lens
511+
input_ids_host = torch.tensor(ids_list, dtype=dtype, pin_memory=True)
420512

421-
# use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len]
422-
if self.is_generate:
423-
self.input_ids = self.input_ids.view(-1, 1, *self.input_ids.shape[1:])
513+
if allow_realloc:
514+
self.input_ids = input_ids_host.to(self.device).clone()
424515
else:
425-
self.input_ids = self.input_ids.view(1, -1, *self.input_ids.shape[1:])
516+
self.input_ids_full = self.input_ids_full.flatten()
517+
self.input_ids_full[: self.num_tokens].copy_(input_ids_host, non_blocking=True)
426518

519+
self.input_ids = self.maybe_reshape_for_generate(
520+
self.input_ids if allow_realloc else self.input_ids_full[: self.num_tokens]
521+
)
427522
# update position_ids
428-
self._update_position_ids()
523+
self._update_position_ids(allow_realloc=allow_realloc)
429524

525+
@nvtx_range("ad_unnest_sequences")
430526
def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
431527
t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0)
432528
return list(torch.split(t_squeezed, self.sequence_lengths))
433529

530+
@nvtx_range("ad_update_pos")
434531
def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool = False) -> None:
435532
"""Update the starting position for each sequence in the cache.
436533
437534
If ``reset=True`, ``input_pos`` will be reset to zero before updating.
438535
"""
439536
if not isinstance(seq_len, torch.Tensor):
440-
seq_len = torch.tensor(seq_len, dtype=torch.int)
537+
seq_len = torch.tensor(seq_len, dtype=torch.int, pin_memory=True)
441538
bs = len(seq_len) if seq_len.dim() > 0 else self.max_batch_size
442539

443540
if reset:
444-
self.input_pos[:bs] = seq_len.to(self.device)
541+
self.input_pos_host[:bs].copy_(seq_len, non_blocking=True)
445542
else:
446-
self.input_pos[:bs] += seq_len.to(self.device)
543+
self.input_pos_host[:bs] += seq_len
447544

448545
# update position_ids
449546
self._update_position_ids()
547+
self.input_pos[:bs].copy_(self.input_pos_host[:bs], non_blocking=True)
450548

549+
@nvtx_range("ad_assign_cache_loc")
451550
def assign_cache_loc(self, page_assignments: Sequence[Sequence[int]]) -> None:
452551
"""Set the cache location and pages_per_seq tensors from page assignments."""
453552
cache_loc_flat = torch.tensor(
454-
[p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int
553+
[p_idx for pages in page_assignments for p_idx in pages],
554+
dtype=torch.int,
555+
pin_memory=True,
455556
)
456557
self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True)
457558

458-
pages_per_seq = torch.tensor([len(p) for p in page_assignments], dtype=torch.int)
559+
pages_per_seq = torch.tensor(
560+
[len(p) for p in page_assignments], dtype=torch.int, pin_memory=True
561+
)
459562
self.pages_per_seq[: len(pages_per_seq)].copy_(pages_per_seq, non_blocking=True)
460563

461564

0 commit comments

Comments
 (0)