66
77# Please refer to README.md in the same folder for more information.
88
9+ import logging
10+ from collections import defaultdict , deque
911from dataclasses import dataclass
1012from functools import partial
1113from typing import Dict , List , Optional , Tuple
2325
2426from torch import nn
2527
28+ logger = logging .getLogger (__name__ )
29+
2630
2731def find_multiple (n : int , k : int ) -> int :
2832 if n % k == 0 :
@@ -507,6 +511,24 @@ def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list):
507511
508512
509513class InputManager :
514+ class NGramCache :
515+ def __init__ (self , max_size : int ):
516+ self .cache = deque ()
517+ self .max_size = max_size
518+
519+ def add (self , ngram : List [int ]):
520+ if ngram in self .cache :
521+ return
522+ if len (self .cache ) == self .max_size :
523+ self .cache .popleft ()
524+ self .cache .append (ngram )
525+
526+ def __iter__ (self ):
527+ return iter (self .cache )
528+
529+ def __str__ (self ):
530+ return str (self .cache )
531+
510532 def __init__ (
511533 self ,
512534 n_layers : int ,
@@ -519,6 +541,7 @@ def __init__(
519541 dtype = torch .float16 ,
520542 minus_infinity = - torch .inf ,
521543 cache_size = None ,
544+ lookahead_enabled : bool = False ,
522545 ):
523546 if cache_size is None :
524547 cache_size = max_seq_length - seq_length
@@ -532,6 +555,8 @@ def __init__(
532555
533556 self .seq_length = seq_length
534557 self .use_cache_list = use_cache_list
558+ self .lookahead_enabled = lookahead_enabled
559+ self .minus_infinity = minus_infinity
535560
536561 if self .use_cache_list :
537562 self .k_caches = [
@@ -609,10 +634,10 @@ def _update_cache(self, start, length, new_k_caches, new_v_caches):
609634 if self .cache_pos == self .cache_size :
610635 self .cache_pos = 0
611636
612- def update (self , input_length , new_k_caches , new_v_caches ):
637+ def update (self , input_length , new_k_caches , new_v_caches , update_pos = 0 ):
613638 # Copy as much new cache data into cache as possible without wrapping
614639 amount_to_copy = min (input_length , self .cache_size - self .cache_pos )
615- self ._update_cache (0 , amount_to_copy , new_k_caches , new_v_caches )
640+ self ._update_cache (update_pos , amount_to_copy , new_k_caches , new_v_caches )
616641 if self .input_pos <= self .cache_size :
617642 self .attn_mask [:, (self .input_pos ) : (self .input_pos + amount_to_copy )] = (
618643 0.0
@@ -625,7 +650,10 @@ def update(self, input_length, new_k_caches, new_v_caches):
625650 )
626651 if remaining_to_copy > 0 :
627652 self ._update_cache (
628- amount_to_copy , remaining_to_copy , new_k_caches , new_v_caches
653+ update_pos + amount_to_copy ,
654+ remaining_to_copy ,
655+ new_k_caches ,
656+ new_v_caches
629657 )
630658
631659 self .input_pos += input_length
@@ -661,3 +689,194 @@ def get_inputs_and_remaining_tokens(self, tokens: List[int]):
661689 self .get_inputs (tokens [0 :processed_tokens ]),
662690 tokens [processed_tokens :],
663691 )
692+
693+ def _get_lookahead_decoding_mask (
694+ self , ngram_size : int , window_size : int , n_verifications : int
695+ ) -> torch .Tensor :
696+ mask = torch .full ((self .seq_length , self .seq_length ), self .minus_infinity )
697+ mask [0 ][0 ] = 0.0
698+
699+ lookahead_submask = torch .triu (
700+ torch .full ((window_size , window_size ), self .minus_infinity ),
701+ diagonal = 1 ,
702+ )
703+ for i in range (ngram_size - 1 ):
704+ offset = window_size * i
705+ mask [offset : offset + window_size , :window_size ] = lookahead_submask
706+ for j in range (1 , i + 1 ):
707+ mask [
708+ offset : offset + window_size ,
709+ window_size * j : window_size * (j + 1 ),
710+ ].fill_diagonal_ (0.0 )
711+
712+ verification_offset = max (window_size * (ngram_size - 1 ), 1 )
713+ verification_submask = torch .triu (
714+ torch .full ((ngram_size - 1 , ngram_size - 1 ), self .minus_infinity ),
715+ diagonal = 1 ,
716+ )
717+ for i in range (n_verifications ):
718+ mask [
719+ verification_offset
720+ + i * (ngram_size - 1 ) : verification_offset
721+ + (i + 1 ) * (ngram_size - 1 ),
722+ verification_offset
723+ + i * (ngram_size - 1 ) : verification_offset
724+ + (i + 1 ) * (ngram_size - 1 ),
725+ ] = verification_submask
726+ mask [verification_offset :, :1 ] = 0.0
727+
728+ return mask
729+
730+ def _get_lookahead_position_offsets (
731+ self , ngram_size : int , window_size : int , n_verifications : int
732+ ) -> torch .Tensor :
733+ pos_offsets = torch .zeros (self .seq_length , dtype = torch .int32 )
734+ idx = 0
735+ if window_size > 0 :
736+ for i in range (ngram_size - 1 ):
737+ for j in range (window_size ):
738+ pos_offsets [idx ] = i + j
739+ idx += 1
740+ else :
741+ pos_offsets [0 ] = 0
742+ idx += 1
743+
744+ # Verification branches: [1, 2, ..., ngram_size - 1].
745+ for _ in range (n_verifications ):
746+ for j in range (1 , ngram_size ):
747+ pos_offsets [idx ] = j
748+ idx += 1
749+
750+ return pos_offsets
751+
752+ def lookahead_decode (
753+ self ,
754+ model ,
755+ init_token : int ,
756+ n : int ,
757+ ngram_size : int ,
758+ window_size : int ,
759+ n_verifications : int ,
760+ stop_tokens : Optional [List [int ]] = None ,
761+ ngram_caches : Optional [Dict [int , "InputManager.NGramCache" ]] = None ,
762+ ) -> List [int ]:
763+ if not self .lookahead_enabled :
764+ raise RuntimeError ("Lookahead decoding is not enabled" )
765+
766+ if (ngram_size - 1 ) * (window_size + n_verifications ) > self .seq_length :
767+ raise RuntimeError (
768+ f"Lookahead decoding configuration not compatible with seq_length { self .seq_length } . "
769+ f"Required: { (ngram_size - 1 ) * (window_size + n_verifications )} "
770+ )
771+
772+ self .attn_mask [:, self .cache_size :] = self ._get_lookahead_decoding_mask (
773+ ngram_size , window_size , n_verifications
774+ )
775+ logger .debug ("Lookahead decoding mask: " )
776+ for i in range (self .seq_length ):
777+ logger .debug (
778+ " " .join (
779+ ("X" if x == 0.0 else " " )
780+ for x in self .attn_mask [i ][self .cache_size :]
781+ )
782+ )
783+
784+ offsets = self ._get_lookahead_position_offsets (
785+ ngram_size , window_size , n_verifications
786+ )
787+
788+ stop_tokens = stop_tokens or []
789+ verification_offset = window_size * (ngram_size - 1 )
790+
791+ if ngram_caches is None :
792+ ngram_caches = defaultdict (lambda : InputManager .NGramCache (n_verifications ))
793+ new_tokens = [init_token ]
794+ x = [init_token ] * self .seq_length
795+ inference_count = 0
796+
797+ while len (new_tokens ) < n + 1 :
798+ cache = ngram_caches [x [0 ]]
799+ for i , ngram in enumerate (cache ):
800+ for j , token in enumerate (ngram ):
801+ x [verification_offset + i * (ngram_size - 1 ) + j ] = token
802+
803+ logits , new_k , new_v = model (
804+ tokens = torch .tensor ([x ], dtype = torch .int64 ),
805+ input_pos = torch .tensor ([self .input_pos ], dtype = torch .long ),
806+ k_caches = self .k_caches ,
807+ v_caches = self .v_caches ,
808+ attn_mask = self .attn_mask ,
809+ input_len = torch .tensor ([len (x )], dtype = torch .long ),
810+ rope_indices = self .input_pos + offsets ,
811+ )
812+ inference_count += 1
813+
814+ # Greedy only
815+ y = logits [0 ].argmax (dim = - 1 ).tolist ()
816+ new_tokens .append (y [0 ])
817+ logger .debug (f"{ self .input_pos } : x = { x [0 ]} , y = { y [0 ]} " )
818+ if new_tokens [- 1 ] in stop_tokens :
819+ break
820+
821+ # Collect new n-grams.
822+ for i in range (window_size ):
823+ key = x [i ]
824+ suffix = []
825+ for j in range (1 , ngram_size - 1 ):
826+ suffix .append (x [i + j * window_size ])
827+ suffix .append (y [i + window_size * (ngram_size - 2 )])
828+ ngram_caches [key ].add (suffix )
829+
830+ # Verification.
831+ longest_match = []
832+ matched_branch = None
833+ for i in range (n_verifications ):
834+ match = [y [0 ]]
835+ j = 0
836+ # for j in range(ngram_size - 1):
837+ while (
838+ j < ngram_size - 1
839+ and x [verification_offset + (ngram_size - 1 ) * i + j ] == match [- 1 ]
840+ ):
841+ match .append (y [verification_offset + (ngram_size - 1 ) * i + j ])
842+ j += 1
843+ if len (match ) - 1 > len (longest_match ):
844+ longest_match = match [1 :]
845+ matched_branch = i
846+
847+ if matched_branch is not None :
848+ logger .debug (
849+ f"Matched { len (longest_match )} additional tokens from n-grams: { longest_match } "
850+ )
851+ for stop in stop_tokens :
852+ if stop in longest_match :
853+ longest_match = longest_match [: longest_match .index (stop ) + 1 ]
854+
855+ new_tokens .extend (longest_match )
856+ branch_offset = verification_offset + (ngram_size - 1 ) * matched_branch
857+ self .update (
858+ input_length = len (longest_match ),
859+ new_k_caches = new_k ,
860+ new_v_caches = new_v ,
861+ update_pos = branch_offset ,
862+ )
863+ else :
864+ self .update (input_length = 1 , new_k_caches = new_k , new_v_caches = new_v )
865+
866+ # Update lookahead branch.
867+ for i in range (ngram_size - 2 ):
868+ for j in range (window_size ):
869+ x [window_size * i + j ] = x [window_size * (i + 1 ) + j ]
870+ for j in range (window_size ):
871+ x [window_size * (ngram_size - 2 ) + j ] = y [
872+ window_size * (ngram_size - 2 ) + j
873+ ]
874+
875+ x [0 ] = new_tokens [- 1 ]
876+ if new_tokens [- 1 ] in stop_tokens :
877+ break
878+
879+ logger .info (
880+ f"Generated { len (new_tokens ) - 1 } tokens with { inference_count } inference(s)."
881+ )
882+ return new_tokens
0 commit comments