@@ -37,7 +37,7 @@ class SequenceInfo:
3737 between arguments that are originally part of the model/graph and arguments that are needed for
3838 the attention operator when we switch to cached+flattened attention.
3939
40- # ORIGINAL MODEL ARGUMENTS ## ###################################################################
40+ ### ORIGINAL MODEL ARGUMENTS ###################################################################
4141 - input_ids: [id_0, ..., id_{s_total-1}]
4242 flattened sequence of [b, 1] or [1, s_total]. We use [b, 1] to denote generate-only batches.
4343 - position_ids: [pos_0, ..., pos_{s_total-1}]
@@ -47,7 +47,18 @@ class SequenceInfo:
4747 NOTE: ``input_ids`` and ``position_ids`` are initially expected to be of shape [b, seq_len]
4848 before we switch to cached+flattened attention.
4949
50- # EXTRA ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ##############
50+ ### EXTRA ARGUMENTS PROVIDED TO THE INTERFACE ##################################################
51+ Those are extra arguments that can be provided to the interface and they are stored as follows:
52+ - _extra_args: dictionary of extra arguments with currently active values.
53+ - _extra_example_inputs: dictionary of example inputs to the extra arguments.
54+ - _extra_none_inputs: dictionary of none inputs to the extra arguments.
55+ NOTE: we assume that extra arguments are *optional* arguments to the model. However, we
56+ cannot represent them via `None` since fx graphs require a fixed input type. Instead,
57+ we require a special placeholder tensor to represent the `None` input.
58+ - _extra_dynamic_shapes_callbacks: dictionary of callbacks to initialize the dynamic shapes of
59+ the extra arguments.
60+
61+ ### CACHE ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ############
5162 - seq_len: [s_0, s_1, ..., s_{b-1}] such that s_total = sum(s_i)
5263 Describes how long each sequence is. For example,
5364 input_ids[:s_0] will correspond to sequence 0 in the batch and input_ids[s_0:s_1] will
@@ -128,26 +139,24 @@ def __post_init__(self):
128139 self .input_ids = torch .ones (self .max_batch_size , 1 , dtype = torch .int )
129140 self .position_ids = torch .zeros (self .max_batch_size , 1 , dtype = torch .long )
130141 self ._uncached_arg_names = ["input_ids" , "position_ids" ]
142+ self ._uncached_dynamic_shapes : Optional [Dict [str , DynamicShape ]] = None
143+
144+ # EXTRA TENSOR FIELDS
145+ self ._extra_args : Dict [str , torch .Tensor ] = {}
146+ self ._extra_example_inputs : Dict [str , torch .Tensor ] = {}
147+ self ._extra_none_inputs : Dict [str , torch .Tensor ] = {}
148+ self ._extra_dynamic_shapes : Optional [Dict [str , DynamicShape ]] = None
149+ self ._extra_dynamic_shapes_callbacks : Dict [str , DynamicShapeCallback ] = {}
131150
132151 # CACHED TENSOR FIELDS (for cached attention backends)
133152 self .seq_len = torch .empty (self .max_batch_size , dtype = torch .int )
134153 self .input_pos = torch .empty_like (self .seq_len )
135154 self .cache_loc = torch .empty (self .num_pages , dtype = torch .int )
136155 self .pages_per_seq = torch .empty_like (self .seq_len )
137156 self ._cached_arg_names = ["seq_len" , "input_pos" , "cache_loc" , "pages_per_seq" ]
138-
139- # DYNAMIC SHAPES
140- # --> initialized lazily since Dim is not picklable for multi-processing
141- self ._uncached_dynamic_shapes : Optional [Dict [str , DynamicShape ]] = None
142157 self ._cached_dynamic_shapes : Optional [Dict [str , DynamicShape ]] = None
143158 ############################################################################################
144159
145- ### EXTRA ARGS #############################################################################
146- self ._extra_args : Dict [str , torch .Tensor ] = {}
147- self ._extra_dynamic_shapes : Optional [Dict [str , DynamicShape ]] = None
148- self ._extra_dynamic_shapes_callbacks : Dict [str , DynamicShapeCallback ] = {}
149- ############################################################################################
150-
151160 # call reset once to initialize the tensors
152161 self .reset ()
153162
@@ -345,9 +354,13 @@ def to(self, *args, **kwargs) -> None:
345354 for k in self ._uncached_arg_names + self ._cached_arg_names :
346355 setattr (self , k , getattr (self , k ).to (* args , ** kwargs ))
347356
348- for k , v in self ._extra_args .items ():
349- if isinstance (v , torch .Tensor ):
350- self ._extra_args [k ] = v .to (* args , ** kwargs )
357+ def _move_dict (d : Dict [str , torch .Tensor ]) -> None :
358+ for k , v in d .items ():
359+ d [k ] = v .to (* args , ** kwargs )
360+
361+ _move_dict (self ._extra_args )
362+ _move_dict (self ._extra_example_inputs )
363+ _move_dict (self ._extra_none_inputs )
351364
352365 def reset (self ) -> None :
353366 """Reset the sequence information.
@@ -369,16 +382,63 @@ def set_example_sequence(self) -> None:
369382 """Set an example sequence useful for testing and export purposes."""
370383 self .reset ()
371384 bs , seq_len = min (2 , self .max_batch_size ), min (4 , self .max_seq_len )
372- input_ids = torch .ones (
385+ input_ids = torch .ones ( # noqa
373386 bs ,
374387 seq_len ,
375388 dtype = torch .int ,
376389 device = self .device ,
377390 )
378- self .nest_sequences (input_ids )
391+
392+ # TODO (lucaslie): seems we have hit a road block using generic example inputs for export
393+ # with VLMs. We need to probably switch to having the factory provide an example input that
394+ # is then being tokenized inside the factory.
395+ # WHY: for VLMs we need to hit these special tokens representing images. No way we can do
396+ # that with a generic example input.
397+ # fmt: off
398+ input_ids2 = [[
399+ 200000 , 200005 , 1556 , 200006 , 368 , 200080 , 200090 , 200092 , 200092 ,
400+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
401+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
402+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
403+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
404+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
405+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
406+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
407+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
408+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
409+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
410+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
411+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
412+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
413+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
414+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
415+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200081 , 200080 ,
416+ 200090 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
417+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
418+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
419+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
420+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
421+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
422+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
423+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
424+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
425+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
426+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
427+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
428+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
429+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
430+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
431+ 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
432+ 200092 , 200081 , 74777 , 290 , 5326 , 43 , 200008 , 200005 , 140680 ,
433+ 200006 , 368
434+ ] for _ in range (2 )]
435+ # fmt: on
436+
437+ self .nest_sequences (input_ids2 , ** self ._extra_example_inputs )
379438
380439 def set_max_num_tokens_sample (self ) -> None :
381440 """Set an example sequence with max_num_tokens."""
441+ # TODO: understand what this implies for extra arguments
382442 self .reset ()
383443 seq_len = self .max_num_tokens // self .max_batch_size
384444 input_ids = torch .ones (
@@ -480,6 +540,7 @@ def nest_sequences(
480540 position_ids : Optional [Sequence [Sequence [int ]]] = None ,
481541 input_pos : Optional [Union [torch .Tensor , Sequence [int ], int ]] = None ,
482542 page_assignments : Optional [Sequence [Sequence [int ]]] = None ,
543+ ** extra_args : Dict [str , Union [torch .Tensor , Sequence [torch .Tensor ]]],
483544 ) -> None :
484545 """Create and store a flattened list of input_ids from the provided list of sequences.
485546
@@ -488,6 +549,7 @@ def nest_sequences(
488549 position_ids: List of sequences of position_ids for each token.
489550 input_pos: Absolute starting position in the cache for each sequence.
490551 page_assignments: List of sequences of page assignments for each sequence.
552+ extra_args: Extra arguments to be stored in the interface.
491553
492554 This i/f will ensure that all sequence info args are updated accordingly.
493555 """
@@ -542,36 +604,53 @@ def nest_sequences(
542604 if page_assignments is not None :
543605 self ._assign_pages_per_seq (page_assignments )
544606
607+ # go through all extra arguments and update them
608+ for name , none_input in self ._extra_none_inputs .items ():
609+ if name in extra_args :
610+ arg = extra_args .pop (name )
611+ if not isinstance (arg , torch .Tensor ):
612+ if len (arg ) > 1 :
613+ arg = torch .cat (arg )
614+ else :
615+ arg = arg [0 ]
616+ self ._extra_args [name ] = arg .to (self .device )
617+ else :
618+ self ._extra_args [name ] = none_input
619+
620+ assert not extra_args , f"Extra arguments { extra_args .keys ()} not found"
621+
545622 def unnest_sequences (self , t_nested : torch .Tensor ) -> List [torch .Tensor ]:
546623 t_squeezed = t_nested .squeeze (1 ) if self .is_generate else t_nested .squeeze (0 )
547624 return list (torch .split (t_squeezed , self .sequence_lengths ))
548625
549626 def add_extra_arg (
550627 self ,
551628 name : str ,
552- value : torch .Tensor ,
629+ example_input : torch .Tensor ,
630+ none_input : torch .Tensor ,
553631 dynamic_shape_callback : Optional [DynamicShapeCallback ] = None ,
554632 ) -> None :
555633 """Add an extra argument to the sequence info object.
556634
557635 Args:
558636 name: The name of the extra argument.
559- value: Example input value of the extra argument.
637+ example_input: Example input value of the extra argument.
638+ none_input: None input value of the extra argument.
560639 dynamic_shape_callback: The callback to get the dynamic shape of the extra argument.
561640
562641 Note that the extra argument is expected to be a tensor.
563642 """
564- self ._extra_args [name ] = value .to (self .device )
643+ assert name not in self ._named_args ().keys (), f"Extra argument { name } already exists"
644+
645+ self ._extra_args [name ] = example_input .to (self .device )
646+ self ._extra_example_inputs [name ] = example_input .to (self .device )
647+ self ._extra_none_inputs [name ] = none_input .to (self .device )
648+
565649 if dynamic_shape_callback is None :
566650 self ._extra_dynamic_shapes_callbacks [name ] = lambda : {}
567651 else :
568652 self ._extra_dynamic_shapes_callbacks [name ] = dynamic_shape_callback
569653
570- def set_extra_arg (self , name : str , value : torch .Tensor ) -> None :
571- """Set an extra argument to the sequence info."""
572- # TODO (lucaslie): assume fixed shape for now
573- self ._extra_args [name ].copy_ (value .to (self .device ), non_blocking = True )
574-
575654
576655Constant = Union [int , float , str , None ]
577656
0 commit comments