18
18
from torch .export import Dim
19
19
from torch .fx import Node
20
20
21
+ from tensorrt_llm ._utils import nvtx_range
22
+
21
23
22
24
@dataclass
23
25
class CacheConfig :
@@ -87,11 +89,13 @@ class SequenceInfo:
87
89
# Similarly, if a batch is composed of generate-only requests,
88
90
# then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens).
89
91
max_num_tokens : Optional [int ] = None
92
+ # device is the device on which the sequence info is stored.
93
+ device : str = "cuda"
90
94
91
95
## [UPDATE WITH CARE] TENSOR FIELDS THAT WILL BE PASSED TO PREPARE_METADATA OP #################
92
96
# input_ids MUST ALWAYS BE THE FIRST FIELD
93
- input_ids : torch .Tensor = field (default_factory = lambda : torch .zeros (1 , 1 , dtype = torch .int ))
94
- position_ids : torch .Tensor = field (default_factory = lambda : torch .zeros (1 , 1 , dtype = torch .long ))
97
+ input_ids : torch .Tensor = field (default_factory = lambda : torch .zeros (1 , dtype = torch .int ))
98
+ position_ids : torch .Tensor = field (default_factory = lambda : torch .zeros (1 , dtype = torch .long ))
95
99
96
100
seq_len : torch .Tensor = field (default_factory = lambda : torch .ones (1 , dtype = torch .int ))
97
101
input_pos : torch .Tensor = field (default_factory = lambda : torch .zeros (1 , dtype = torch .int ))
@@ -110,24 +114,44 @@ def __post_init__(self):
110
114
# NOTE (lucaslie): WAR to address issue when using flashinfer attention with
111
115
# (max_batch_size, max_seq_len) input in trtllm runtime.
112
116
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
113
- max_seq_len_adjusted = self .max_seq_len + 1
117
+ self . max_seq_len_adjusted = self .max_seq_len + 1
114
118
115
119
if self .max_num_tokens is None or self .max_num_tokens < 1 :
116
- self .max_num_tokens = self .max_batch_size * max_seq_len_adjusted
120
+ self .max_num_tokens = self .max_batch_size * self . max_seq_len_adjusted
117
121
# if the provided max_num_tokens is less than the max_batch_size * max_seq_len,
118
122
# we use the provided max_num_tokens to calculate the number of pages
119
- total_tokens = min (self .max_num_tokens , self .max_batch_size * max_seq_len_adjusted )
123
+ total_tokens = min (self .max_num_tokens , self .max_batch_size * self . max_seq_len_adjusted )
120
124
# Num pages can not be less than max_batch_size.
121
125
self ._num_pages = max (
122
126
self .max_batch_size ,
123
127
(total_tokens ) // self .page_size + (total_tokens % self .page_size > 0 ),
124
128
)
125
- self .input_ids = torch .ones (self .max_batch_size , 1 , dtype = torch .int )
126
- self .position_ids = torch .zeros (self .max_batch_size , 1 , dtype = torch .long )
127
- self .seq_len = torch .empty (self .max_batch_size , dtype = torch .int )
128
- self .input_pos = torch .empty_like (self .seq_len )
129
- self .cache_loc = torch .empty (self .num_pages , dtype = torch .int )
130
- self .pages_per_seq = torch .empty_like (self .seq_len )
129
+ # Ensure that the device is set before initializing the tensors.
130
+ self .input_ids = torch .ones (self .max_num_tokens , dtype = torch .int , device = self .device )
131
+ self .position_ids = torch .zeros (self .max_num_tokens , dtype = torch .long , device = self .device )
132
+
133
+ # Consumers of the sequence info args require input_ids and position_ids to be truncated.
134
+ # We maintain a full version of the input_ids and position_ids to avoid overheads of tensor
135
+ # creation in every forward pass.
136
+ self .input_ids_full = torch .ones (self .max_num_tokens , dtype = torch .int , device = self .device )
137
+ self .position_ids_full = torch .zeros (
138
+ self .max_num_tokens , dtype = torch .long , device = self .device
139
+ )
140
+
141
+ self .seq_len = torch .empty (self .max_batch_size , dtype = torch .int , device = self .device )
142
+ self .input_pos = torch .empty_like (self .seq_len , device = self .device )
143
+
144
+ # Allocated host tensors for sequence lengths and input positions so that
145
+ # position_ids calculation can be done on host.
146
+ self .seq_len_host = torch .empty (self .max_batch_size , dtype = torch .int )
147
+ self .input_pos_host = torch .empty_like (self .seq_len_host )
148
+
149
+ self .cache_loc = torch .empty (self .num_pages , dtype = torch .int , device = self .device )
150
+ self .pages_per_seq = torch .empty_like (self .seq_len , device = self .device )
151
+
152
+ self .previous_batch_indices_cuda = torch .empty (
153
+ self .max_num_tokens , dtype = torch .long , device = self .device
154
+ )
131
155
assert self .num_pages >= self .max_batch_size , (
132
156
"num_pages must be greater than max_batch_size"
133
157
)
@@ -140,13 +164,12 @@ def __post_init__(self):
140
164
# indicator if extra args are activated that are needed for cached attention backends
141
165
self ._is_cached_attn = False
142
166
167
+ # total number of tokens in the current batch
168
+ self .num_tokens : int = 0
169
+
143
170
# call reset once to initialize the tensors
144
171
self .reset ()
145
172
146
- @property
147
- def device (self ) -> torch .device :
148
- return self .input_pos .device
149
-
150
173
@property
151
174
def args (self ) -> Tuple [torch .Tensor , ...]:
152
175
args = []
@@ -156,11 +179,14 @@ def args(self) -> Tuple[torch.Tensor, ...]:
156
179
args .append (val )
157
180
if len (args ) >= self ._num_uncached_attn_args and not self ._is_cached_attn :
158
181
break
182
+
159
183
return tuple (args )
160
184
161
185
@property
162
186
def _num_uncached_attn_args (self ) -> int :
163
- """Return the number of original graph arguments expected by the model."""
187
+ """Return the number of original graph arguments expected by the model.
188
+ This is 2 because we have input_ids and position_ids as the original graph arguments.
189
+ """
164
190
return 2
165
191
166
192
@property
@@ -185,7 +211,7 @@ def dynamic_shapes(self) -> Tuple[Dict[str, Dim]]:
185
211
dynamic_shapes = ({}, {})
186
212
if self .max_batch_size > 1 :
187
213
dynamic_shapes [0 ][0 ] = Dim ("batch_size" , max = self .max_batch_size )
188
- dynamic_shapes [0 ][1 ] = Dim ("seq_len" , max = self .max_seq_len )
214
+ dynamic_shapes [0 ][1 ] = Dim ("seq_len" , max = self .max_seq_len_adjusted )
189
215
# set up shape for position_ids (same as input_ids)
190
216
dynamic_shapes [1 ].update (dynamic_shapes [0 ])
191
217
# set up shape for extra args
@@ -204,7 +230,7 @@ def sequence_lengths(self) -> List[int]:
204
230
205
231
@property
206
232
def input_positions (self ) -> List [int ]:
207
- return self .input_pos [: self .num_sequences ].tolist ()
233
+ return self .input_pos_host [: self .num_sequences ].tolist ()
208
234
209
235
@property
210
236
def is_generate (self ) -> bool :
@@ -334,14 +360,19 @@ def reset(self) -> None:
334
360
"""
335
361
# reset input_pos
336
362
self .input_pos .zero_ ()
363
+ self .input_pos_host .zero_ ()
337
364
338
365
# set a dummy sequence corresponding to a generate-only batch (will also reset position_ids)
339
- self .nest_sequences (torch . zeros ( self .max_batch_size , 1 , dtype = torch . int ) )
366
+ self .nest_sequences ([[ 1 ]] * self .max_batch_size , allow_realloc = True )
340
367
341
368
# reset cache information
342
369
self .cache_loc [:] = torch .arange (self .num_pages , dtype = torch .int , device = self .device )
343
370
self .pages_per_seq .fill_ (1 )
344
371
372
+ # let's also reset the input_ids and position_ids tensors to their max shapes (max_num_tokens)
373
+ self .input_ids = torch .ones (self .max_num_tokens , dtype = torch .int , device = self .device )
374
+ self .position_ids = torch .zeros (self .max_num_tokens , dtype = torch .long , device = self .device )
375
+
345
376
def set_example_sequence (self ) -> None :
346
377
"""Set an example sequence useful for testing and export purposes."""
347
378
self .reset ()
@@ -352,7 +383,7 @@ def set_example_sequence(self) -> None:
352
383
dtype = torch .int ,
353
384
device = self .device ,
354
385
)
355
- self .nest_sequences (input_ids )
386
+ self .nest_sequences (input_ids , allow_realloc = True )
356
387
357
388
# unflatten if we are not yet using cached+flattened attention
358
389
if not self ._is_cached_attn :
@@ -370,7 +401,7 @@ def _set_max_num_tokens_sample(self) -> None:
370
401
device = self .device ,
371
402
)
372
403
self .pages_per_seq .fill_ (seq_len // self .page_size )
373
- self .nest_sequences (input_ids )
404
+ self .nest_sequences (input_ids , allow_realloc = True )
374
405
375
406
def set_generate_only_batch (self ) -> None :
376
407
"""Set an example sequence for generate-only batch.
@@ -379,32 +410,96 @@ def set_generate_only_batch(self) -> None:
379
410
mode. So we don't need to do anything mode-specific here.
380
411
"""
381
412
self .reset ()
382
- self .nest_sequences ([[1 ]] * self .max_batch_size )
383
-
384
- def _update_position_ids (self ) -> None :
385
- # set new position_ids as new tensor from input_pos and seq_len via torch.arange
386
- position_ids_list = [
387
- num
388
- for in_pos , seq_len in zip (self .input_positions , self .sequence_lengths )
389
- for num in range (in_pos , in_pos + seq_len )
390
- ]
391
- self .position_ids = torch .tensor (position_ids_list , dtype = torch .long ).to (self .device )
413
+ self .nest_sequences ([[1 ]] * self .max_batch_size , allow_realloc = True )
392
414
415
+ def maybe_reshape_for_generate (self , tensor : torch .Tensor ) -> torch .Tensor :
393
416
# use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len]
394
417
if self .is_generate :
395
- self . position_ids = self . position_ids . view (- 1 , 1 )
418
+ return tensor . view (- 1 , 1 , * tensor . shape [ 1 :] )
396
419
else :
397
- self .position_ids = self .position_ids .view (1 , - 1 )
420
+ return tensor .view (1 , - 1 , * tensor .shape [1 :])
421
+
422
+ @nvtx_range ("ad_update_position_ids" )
423
+ def _update_position_ids (self , allow_realloc : bool = False ) -> None :
424
+ # set new position_ids from input_pos and seq_len
425
+ # Make sure this is done on host to avoid host-device copies.
426
+ with nvtx_range ("prepare_list" ):
427
+ # Optimize for the common case where all seq_len values are 1 (generation mode)
428
+ if torch .all (self .seq_len_host == 1 ):
429
+ # Fast path: when all seq_len are 1, position_ids is just input_pos_host
430
+ position_ids_host = (
431
+ self .input_pos_host [: self .num_tokens ].to (dtype = torch .long ).pin_memory ()
432
+ )
433
+ else :
434
+ # General case - can probably be optimized too, but overall impact will be minor.
435
+ position_ids_list = []
436
+ for in_pos , seq_len in zip (self .input_pos_host , self .seq_len_host ):
437
+ position_ids_list .extend (range (in_pos , in_pos + seq_len ))
438
+ position_ids_host = torch .tensor (
439
+ position_ids_list , dtype = torch .long , pin_memory = True
440
+ )
441
+ with nvtx_range ("copy_to_device" ):
442
+ if allow_realloc :
443
+ # Create a new position_ids tensor on the device
444
+ self .position_ids = position_ids_host .to (self .device ).clone ()
445
+ else :
446
+ self .position_ids_full = self .position_ids_full .flatten ()
447
+ self .position_ids_full [: self .num_tokens ].copy_ (
448
+ position_ids_host , non_blocking = True
449
+ )
450
+ with nvtx_range ("maybe_reshape" ):
451
+ self .position_ids = self .maybe_reshape_for_generate (
452
+ self .position_ids if allow_realloc else self .position_ids_full [: self .num_tokens ]
453
+ )
454
+
455
+ @nvtx_range ("ad_update_sequence_lengths" )
456
+ def _update_sequence_lengths (self , sequence_lengths : List [int ]) -> None :
457
+ self ._sequence_lengths = sequence_lengths
458
+ self .num_tokens = sum (self ._sequence_lengths )
459
+ self .seq_len .zero_ ()
460
+ self .seq_len_host = torch .tensor (self ._sequence_lengths , pin_memory = True )
461
+ self .seq_len [: len (self ._sequence_lengths )].copy_ (self .seq_len_host , non_blocking = True )
462
+
463
+ def update_input_ids_with_new_tokens (
464
+ self , new_tokens : torch .Tensor , previous_batch_indices : List [int ]
465
+ ) -> None :
466
+ """Update the input_ids with new tokens.
467
+
468
+ This function will update the input_ids with new tokens and previous batch indices.
469
+ """
470
+ # 1) flatten once
471
+ original_shape = self .input_ids .shape
472
+ flat = self .input_ids .flatten ()
473
+
474
+ # copy indices to the GPU
475
+ host_idx = torch .tensor (previous_batch_indices , dtype = torch .int , pin_memory = True )
476
+ idx = self .previous_batch_indices_cuda [: len (previous_batch_indices )]
477
+ idx .copy_ (host_idx , non_blocking = True )
478
+
479
+ # sort them so that masked_scatter_ lines up correctly
480
+ idx , _ = idx .sort ()
481
+
482
+ # gather the exact values you want to write
483
+ src = new_tokens [0 , idx , 0 ]
484
+
485
+ # in‐place fill every slot where flat == -1 with src, in order
486
+ flat .masked_scatter_ (flat == - 1 , src )
487
+
488
+ # 4) reshape back
489
+ self .input_ids = flat .view (original_shape )
398
490
399
- def nest_sequences (self , input_ids : Sequence [Sequence [int ]]) -> None :
491
+ @nvtx_range ("ad_nest_sequences" )
492
+ def nest_sequences (
493
+ self , input_ids : Sequence [Sequence [int ]], allow_realloc : bool = False
494
+ ) -> None :
400
495
"""Create and store a flattened list of input_ids from the provided list of sequences.
401
496
497
+ When allow_realloc is True, the input_ids will be reallocated on the device.
402
498
This i/f will also update any relevant sequence information.
403
499
"""
404
500
# set new sequence lengths
405
- seq_lens = [len (ids ) for ids in input_ids ]
406
- self .seq_len .zero_ ()
407
- self .seq_len [: len (seq_lens )].copy_ (torch .tensor (seq_lens ), non_blocking = True )
501
+ self ._update_sequence_lengths ([len (ids ) for ids in input_ids ])
502
+
408
503
# We'll preserve the dtype of the input_ids tensor if it is a tensor, otherwise we'll use int
409
504
dtype = input_ids .dtype if isinstance (input_ids , torch .Tensor ) else torch .int
410
505
# set new input_ids as new tensor from flattened input_ids
@@ -413,49 +508,57 @@ def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None:
413
508
for lst in input_ids
414
509
for val in (lst .detach ().tolist () if isinstance (lst , torch .Tensor ) else lst )
415
510
]
416
- self .input_ids = torch .tensor (ids_list , dtype = dtype ).to (self .device )
417
-
418
- # set derivative properties
419
- self ._sequence_lengths = seq_lens
511
+ input_ids_host = torch .tensor (ids_list , dtype = dtype , pin_memory = True )
420
512
421
- # use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len]
422
- if self .is_generate :
423
- self .input_ids = self .input_ids .view (- 1 , 1 , * self .input_ids .shape [1 :])
513
+ if allow_realloc :
514
+ self .input_ids = input_ids_host .to (self .device ).clone ()
424
515
else :
425
- self .input_ids = self .input_ids .view (1 , - 1 , * self .input_ids .shape [1 :])
516
+ self .input_ids_full = self .input_ids_full .flatten ()
517
+ self .input_ids_full [: self .num_tokens ].copy_ (input_ids_host , non_blocking = True )
426
518
519
+ self .input_ids = self .maybe_reshape_for_generate (
520
+ self .input_ids if allow_realloc else self .input_ids_full [: self .num_tokens ]
521
+ )
427
522
# update position_ids
428
- self ._update_position_ids ()
523
+ self ._update_position_ids (allow_realloc = allow_realloc )
429
524
525
+ @nvtx_range ("ad_unnest_sequences" )
430
526
def unnest_sequences (self , t_nested : torch .Tensor ) -> List [torch .Tensor ]:
431
527
t_squeezed = t_nested .squeeze (1 ) if self .is_generate else t_nested .squeeze (0 )
432
528
return list (torch .split (t_squeezed , self .sequence_lengths ))
433
529
530
+ @nvtx_range ("ad_update_pos" )
434
531
def update_pos (self , seq_len : Union [torch .Tensor , List [int ], int ], reset : bool = False ) -> None :
435
532
"""Update the starting position for each sequence in the cache.
436
533
437
534
If ``reset=True`, ``input_pos`` will be reset to zero before updating.
438
535
"""
439
536
if not isinstance (seq_len , torch .Tensor ):
440
- seq_len = torch .tensor (seq_len , dtype = torch .int )
537
+ seq_len = torch .tensor (seq_len , dtype = torch .int , pin_memory = True )
441
538
bs = len (seq_len ) if seq_len .dim () > 0 else self .max_batch_size
442
539
443
540
if reset :
444
- self .input_pos [:bs ] = seq_len . to ( self . device )
541
+ self .input_pos_host [:bs ]. copy_ ( seq_len , non_blocking = True )
445
542
else :
446
- self .input_pos [:bs ] += seq_len . to ( self . device )
543
+ self .input_pos_host [:bs ] += seq_len
447
544
448
545
# update position_ids
449
546
self ._update_position_ids ()
547
+ self .input_pos [:bs ].copy_ (self .input_pos_host [:bs ], non_blocking = True )
450
548
549
+ @nvtx_range ("ad_assign_cache_loc" )
451
550
def assign_cache_loc (self , page_assignments : Sequence [Sequence [int ]]) -> None :
452
551
"""Set the cache location and pages_per_seq tensors from page assignments."""
453
552
cache_loc_flat = torch .tensor (
454
- [p_idx for pages in page_assignments for p_idx in pages ], dtype = torch .int
553
+ [p_idx for pages in page_assignments for p_idx in pages ],
554
+ dtype = torch .int ,
555
+ pin_memory = True ,
455
556
)
456
557
self .cache_loc [: len (cache_loc_flat )].copy_ (cache_loc_flat , non_blocking = True )
457
558
458
- pages_per_seq = torch .tensor ([len (p ) for p in page_assignments ], dtype = torch .int )
559
+ pages_per_seq = torch .tensor (
560
+ [len (p ) for p in page_assignments ], dtype = torch .int , pin_memory = True
561
+ )
459
562
self .pages_per_seq [: len (pages_per_seq )].copy_ (pages_per_seq , non_blocking = True )
460
563
461
564
0 commit comments