66
77# Please refer to README.md in the same folder for more information.
88
9+ import logging
10+ from collections import defaultdict , deque
911from dataclasses import dataclass
1012from functools import partial
1113from typing import Dict , List , Optional , Tuple
2325
2426from torch import nn
2527
28+ logger = logging .getLogger (__name__ )
29+
2630
2731def find_multiple (n : int , k : int ) -> int :
2832 if n % k == 0 :
@@ -507,6 +511,24 @@ def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list):
507511
508512
509513class InputManager :
514+ class NGramCache :
515+ def __init__ (self , max_size : int ):
516+ self .cache = deque ()
517+ self .max_size = max_size
518+
519+ def add (self , ngram : List [int ]):
520+ if ngram in self .cache :
521+ return
522+ if len (self .cache ) == self .max_size :
523+ self .cache .popleft ()
524+ self .cache .append (ngram )
525+
526+ def __iter__ (self ):
527+ return iter (self .cache )
528+
529+ def __str__ (self ):
530+ return str (self .cache )
531+
510532 def __init__ (
511533 self ,
512534 n_layers : int ,
@@ -519,6 +541,7 @@ def __init__(
519541 dtype = torch .float16 ,
520542 minus_infinity = - torch .inf ,
521543 cache_size = None ,
544+ lookahead_enabled : bool = False ,
522545 ):
523546 if cache_size is None :
524547 cache_size = max_seq_length - seq_length
@@ -532,6 +555,8 @@ def __init__(
532555
533556 self .seq_length = seq_length
534557 self .use_cache_list = use_cache_list
558+ self .lookahead_enabled = lookahead_enabled
559+ self .minus_infinity = minus_infinity
535560
536561 if self .use_cache_list :
537562 self .k_caches = [
@@ -609,10 +634,10 @@ def _update_cache(self, start, length, new_k_caches, new_v_caches):
609634 if self .cache_pos == self .cache_size :
610635 self .cache_pos = 0
611636
612- def update (self , input_length , new_k_caches , new_v_caches ):
637+ def update (self , input_length , new_k_caches , new_v_caches , update_pos = 0 ):
613638 # Copy as much new cache data into cache as possible without wrapping
614639 amount_to_copy = min (input_length , self .cache_size - self .cache_pos )
615- self ._update_cache (0 , amount_to_copy , new_k_caches , new_v_caches )
640+ self ._update_cache (update_pos , amount_to_copy , new_k_caches , new_v_caches )
616641 if self .input_pos <= self .cache_size :
617642 self .attn_mask [:, (self .input_pos ) : (self .input_pos + amount_to_copy )] = (
618643 0.0
@@ -625,7 +650,10 @@ def update(self, input_length, new_k_caches, new_v_caches):
625650 )
626651 if remaining_to_copy > 0 :
627652 self ._update_cache (
628- amount_to_copy , remaining_to_copy , new_k_caches , new_v_caches
653+ update_pos + amount_to_copy ,
654+ remaining_to_copy ,
655+ new_k_caches ,
656+ new_v_caches ,
629657 )
630658
631659 self .input_pos += input_length
@@ -661,3 +689,270 @@ def get_inputs_and_remaining_tokens(self, tokens: List[int]):
661689 self .get_inputs (tokens [0 :processed_tokens ]),
662690 tokens [processed_tokens :],
663691 )
692+
693+ def _get_lookahead_decoding_mask (
694+ self , ngram_size : int , window_size : int , n_verifications : int
695+ ) -> torch .Tensor :
696+ mask = torch .full ((self .seq_length , self .seq_length ), self .minus_infinity )
697+ mask [0 ][0 ] = 0.0
698+
699+ lookahead_submask = torch .triu (
700+ torch .full ((window_size , window_size ), self .minus_infinity ),
701+ diagonal = 1 ,
702+ )
703+ for i in range (ngram_size - 1 ):
704+ offset = window_size * i
705+ mask [offset : offset + window_size , :window_size ] = lookahead_submask
706+ for j in range (1 , i + 1 ):
707+ mask [
708+ offset : offset + window_size ,
709+ window_size * j : window_size * (j + 1 ),
710+ ].fill_diagonal_ (0.0 )
711+
712+ verification_offset = max (window_size * (ngram_size - 1 ), 1 )
713+ verification_submask = torch .triu (
714+ torch .full ((ngram_size - 1 , ngram_size - 1 ), self .minus_infinity ),
715+ diagonal = 1 ,
716+ )
717+ for i in range (n_verifications ):
718+ mask [
719+ verification_offset
720+ + i * (ngram_size - 1 ) : verification_offset
721+ + (i + 1 ) * (ngram_size - 1 ),
722+ verification_offset
723+ + i * (ngram_size - 1 ) : verification_offset
724+ + (i + 1 ) * (ngram_size - 1 ),
725+ ] = verification_submask
726+ mask [verification_offset :, :1 ] = 0.0
727+
728+ return mask
729+
730+ def _get_lookahead_position_offsets (
731+ self , ngram_size : int , window_size : int , n_verifications : int
732+ ) -> torch .Tensor :
733+ pos_offsets = torch .zeros (self .seq_length , dtype = torch .int32 )
734+ idx = 0
735+ if window_size > 0 :
736+ for i in range (ngram_size - 1 ):
737+ for j in range (window_size ):
738+ pos_offsets [idx ] = i + j
739+ idx += 1
740+ else :
741+ pos_offsets [0 ] = 0
742+ idx += 1
743+
744+ # Verification branches: [1, 2, ..., ngram_size - 1].
745+ for _ in range (n_verifications ):
746+ for j in range (1 , ngram_size ):
747+ pos_offsets [idx ] = j
748+ idx += 1
749+
750+ return pos_offsets
751+
752+ def _validate_lookahead_config (
753+ self , ngram_size : int , window_size : int , n_verifications : int
754+ ) -> None :
755+ """
756+ Validate the lookahead decoding configuration.
757+ """
758+ if not self .lookahead_enabled :
759+ raise RuntimeError ("Lookahead decoding is not enabled" )
760+
761+ if (ngram_size - 1 ) * (window_size + n_verifications ) > self .seq_length :
762+ raise RuntimeError (
763+ f"Lookahead decoding configuration not compatible with seq_length { self .seq_length } . "
764+ f"Required: { (ngram_size - 1 ) * (window_size + n_verifications )} "
765+ )
766+
767+ def _setup_lookahead_mask (
768+ self , ngram_size : int , window_size : int , n_verifications : int
769+ ) -> None :
770+ """
771+ Set up the attention mask for lookahead decoding and log debug information.
772+ """
773+ self .attn_mask [:, self .cache_size :] = self ._get_lookahead_decoding_mask (
774+ ngram_size , window_size , n_verifications
775+ )
776+ logger .debug ("Lookahead decoding mask: " )
777+ for i in range (self .seq_length ):
778+ logger .debug (
779+ " " .join (
780+ ("X" if x == 0.0 else " " )
781+ for x in self .attn_mask [i ][self .cache_size :]
782+ )
783+ )
784+
785+ def _populate_verification_branches (
786+ self , x : List [int ], cache , verification_offset : int , ngram_size : int
787+ ) -> None :
788+ """
789+ Populate verification branches with tokens from the n-gram cache.
790+ """
791+ for i , ngram in enumerate (cache ):
792+ for j , token in enumerate (ngram ):
793+ x [verification_offset + i * (ngram_size - 1 ) + j ] = token
794+
795+ def _collect_ngrams (
796+ self ,
797+ x : List [int ],
798+ y : List [int ],
799+ ngram_caches : Dict [int , "InputManager.NGramCache" ],
800+ window_size : int ,
801+ ngram_size : int ,
802+ ) -> None :
803+ """
804+ Collect new n-grams from the current state and predictions.
805+ """
806+ for i in range (window_size ):
807+ key = x [i ]
808+ suffix = []
809+ for j in range (1 , ngram_size - 1 ):
810+ suffix .append (x [i + j * window_size ])
811+ suffix .append (y [i + window_size * (ngram_size - 2 )])
812+ ngram_caches [key ].add (suffix )
813+
814+ def _find_longest_match (
815+ self ,
816+ x : List [int ],
817+ y : List [int ],
818+ verification_offset : int ,
819+ n_verifications : int ,
820+ ngram_size : int ,
821+ ) -> Tuple [List [int ], Optional [int ]]:
822+ """
823+ Find the longest matching sequence from verification branches.
824+ Returns the matched tokens and the branch index.
825+ """
826+ longest_match = []
827+ matched_branch = None
828+
829+ for i in range (n_verifications ):
830+ match = [y [0 ]]
831+ j = 0
832+ while (
833+ j < ngram_size - 1
834+ and x [verification_offset + (ngram_size - 1 ) * i + j ] == match [- 1 ]
835+ ):
836+ match .append (y [verification_offset + (ngram_size - 1 ) * i + j ])
837+ j += 1
838+ if len (match ) - 1 > len (longest_match ):
839+ longest_match = match [1 :]
840+ matched_branch = i
841+
842+ return longest_match , matched_branch
843+
844+ def _update_lookahead_branches (
845+ self , x : List [int ], y : List [int ], ngram_size : int , window_size : int
846+ ) -> None :
847+ """
848+ Update the lookahead branches with new predictions.
849+ """
850+ # Shift window contents up
851+ for i in range (ngram_size - 2 ):
852+ for j in range (window_size ):
853+ x [window_size * i + j ] = x [window_size * (i + 1 ) + j ]
854+
855+ # Fill the last window with new predictions
856+ for j in range (window_size ):
857+ x [window_size * (ngram_size - 2 ) + j ] = y [
858+ window_size * (ngram_size - 2 ) + j
859+ ]
860+
861+ def lookahead_decode (
862+ self ,
863+ model ,
864+ init_token : int ,
865+ n : int ,
866+ ngram_size : int ,
867+ window_size : int ,
868+ n_verifications : int ,
869+ stop_tokens : Optional [List [int ]] = None ,
870+ ngram_caches : Optional [Dict [int , "InputManager.NGramCache" ]] = None ,
871+ ) -> List [int ]:
872+ # Validate configuration
873+ self ._validate_lookahead_config (ngram_size , window_size , n_verifications )
874+
875+ # Setup attention mask and position offsets
876+ self ._setup_lookahead_mask (ngram_size , window_size , n_verifications )
877+ offsets = self ._get_lookahead_position_offsets (
878+ ngram_size , window_size , n_verifications
879+ )
880+
881+ # Initialize state
882+ stop_tokens = stop_tokens or []
883+ verification_offset = window_size * (ngram_size - 1 )
884+ if ngram_caches is None :
885+ ngram_caches = defaultdict (lambda : InputManager .NGramCache (n_verifications ))
886+
887+ new_tokens = [init_token ]
888+ x = [init_token ] * self .seq_length
889+ inference_count = 0
890+
891+ # Main decoding loop
892+ while len (new_tokens ) < n + 1 :
893+ # Populate verification branches
894+ cache = ngram_caches [x [0 ]]
895+ self ._populate_verification_branches (
896+ x , cache , verification_offset , ngram_size
897+ )
898+
899+ # Run model inference
900+ logits , new_k , new_v = model (
901+ tokens = torch .tensor ([x ], dtype = torch .int64 ),
902+ input_pos = torch .tensor ([self .input_pos ], dtype = torch .long ),
903+ k_caches = self .k_caches ,
904+ v_caches = self .v_caches ,
905+ attn_mask = self .attn_mask ,
906+ input_len = torch .tensor ([len (x )], dtype = torch .long ),
907+ rope_indices = self .input_pos + offsets ,
908+ )
909+ inference_count += 1
910+
911+ # Process model output (greedy selection)
912+ y = logits [0 ].argmax (dim = - 1 ).tolist ()
913+ new_tokens .append (y [0 ])
914+ logger .debug (f"{ self .input_pos } : x = { x [0 ]} , y = { y [0 ]} " )
915+ if new_tokens [- 1 ] in stop_tokens :
916+ break
917+
918+ # Collect new n-grams
919+ self ._collect_ngrams (x , y , ngram_caches , window_size , ngram_size )
920+
921+ # Find longest match from verification branches
922+ longest_match , matched_branch = self ._find_longest_match (
923+ x , y , verification_offset , n_verifications , ngram_size
924+ )
925+
926+ # Process match results
927+ if matched_branch is not None :
928+ logger .debug (
929+ f"Matched { len (longest_match )} additional tokens from n-grams: { longest_match } "
930+ )
931+ # Truncate at stop token if present
932+ for stop in stop_tokens :
933+ if stop in longest_match :
934+ longest_match = longest_match [: longest_match .index (stop ) + 1 ]
935+
936+ new_tokens .extend (longest_match )
937+ branch_offset = verification_offset + (ngram_size - 1 ) * matched_branch
938+ self .update (
939+ input_length = len (longest_match ),
940+ new_k_caches = new_k ,
941+ new_v_caches = new_v ,
942+ update_pos = branch_offset ,
943+ )
944+ else :
945+ self .update (input_length = 1 , new_k_caches = new_k , new_v_caches = new_v )
946+
947+ # Update lookahead branches
948+ self ._update_lookahead_branches (x , y , ngram_size , window_size )
949+
950+ # Update first token and check for stop condition
951+ x [0 ] = new_tokens [- 1 ]
952+ if new_tokens [- 1 ] in stop_tokens :
953+ break
954+
955+ logger .info (
956+ f"Generated { len (new_tokens ) - 1 } tokens with { inference_count } inference(s)."
957+ )
958+ return new_tokens
0 commit comments