Skip to content

Commit cb1c1f8

Browse files
Lookahead decoding eager implementation (#12491)
Summary: Implement reference lookahead decoding for CoreML implementation. Reviewed By: billmguo Test Plan: Imported from GitHub, without a `Test Plan:` line. Rollback Plan: Differential Revision: D78323399 Pulled By: viveknayakatmeta
1 parent ef48cc2 commit cb1c1f8

File tree

4 files changed

+520
-3
lines changed

4 files changed

+520
-3
lines changed

examples/apple/coreml/llama/llama_transformer.py

Lines changed: 222 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# Please refer to README.md in the same folder for more information.
88

9+
import logging
10+
from collections import defaultdict, deque
911
from dataclasses import dataclass
1012
from functools import partial
1113
from typing import Dict, List, Optional, Tuple
@@ -23,6 +25,8 @@
2325

2426
from torch import nn
2527

28+
logger = logging.getLogger(__name__)
29+
2630

2731
def find_multiple(n: int, k: int) -> int:
2832
if n % k == 0:
@@ -507,6 +511,24 @@ def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list):
507511

508512

509513
class InputManager:
514+
class NGramCache:
515+
def __init__(self, max_size: int):
516+
self.cache = deque()
517+
self.max_size = max_size
518+
519+
def add(self, ngram: List[int]):
520+
if ngram in self.cache:
521+
return
522+
if len(self.cache) == self.max_size:
523+
self.cache.popleft()
524+
self.cache.append(ngram)
525+
526+
def __iter__(self):
527+
return iter(self.cache)
528+
529+
def __str__(self):
530+
return str(self.cache)
531+
510532
def __init__(
511533
self,
512534
n_layers: int,
@@ -519,6 +541,7 @@ def __init__(
519541
dtype=torch.float16,
520542
minus_infinity=-torch.inf,
521543
cache_size=None,
544+
lookahead_enabled: bool = False,
522545
):
523546
if cache_size is None:
524547
cache_size = max_seq_length - seq_length
@@ -532,6 +555,8 @@ def __init__(
532555

533556
self.seq_length = seq_length
534557
self.use_cache_list = use_cache_list
558+
self.lookahead_enabled = lookahead_enabled
559+
self.minus_infinity = minus_infinity
535560

536561
if self.use_cache_list:
537562
self.k_caches = [
@@ -609,10 +634,10 @@ def _update_cache(self, start, length, new_k_caches, new_v_caches):
609634
if self.cache_pos == self.cache_size:
610635
self.cache_pos = 0
611636

612-
def update(self, input_length, new_k_caches, new_v_caches):
637+
def update(self, input_length, new_k_caches, new_v_caches, update_pos=0):
613638
# Copy as much new cache data into cache as possible without wrapping
614639
amount_to_copy = min(input_length, self.cache_size - self.cache_pos)
615-
self._update_cache(0, amount_to_copy, new_k_caches, new_v_caches)
640+
self._update_cache(update_pos, amount_to_copy, new_k_caches, new_v_caches)
616641
if self.input_pos <= self.cache_size:
617642
self.attn_mask[:, (self.input_pos) : (self.input_pos + amount_to_copy)] = (
618643
0.0
@@ -625,7 +650,10 @@ def update(self, input_length, new_k_caches, new_v_caches):
625650
)
626651
if remaining_to_copy > 0:
627652
self._update_cache(
628-
amount_to_copy, remaining_to_copy, new_k_caches, new_v_caches
653+
update_pos + amount_to_copy,
654+
remaining_to_copy,
655+
new_k_caches,
656+
new_v_caches
629657
)
630658

631659
self.input_pos += input_length
@@ -661,3 +689,194 @@ def get_inputs_and_remaining_tokens(self, tokens: List[int]):
661689
self.get_inputs(tokens[0:processed_tokens]),
662690
tokens[processed_tokens:],
663691
)
692+
693+
def _get_lookahead_decoding_mask(
694+
self, ngram_size: int, window_size: int, n_verifications: int
695+
) -> torch.Tensor:
696+
mask = torch.full((self.seq_length, self.seq_length), self.minus_infinity)
697+
mask[0][0] = 0.0
698+
699+
lookahead_submask = torch.triu(
700+
torch.full((window_size, window_size), self.minus_infinity),
701+
diagonal=1,
702+
)
703+
for i in range(ngram_size - 1):
704+
offset = window_size * i
705+
mask[offset : offset + window_size, :window_size] = lookahead_submask
706+
for j in range(1, i + 1):
707+
mask[
708+
offset : offset + window_size,
709+
window_size * j : window_size * (j + 1),
710+
].fill_diagonal_(0.0)
711+
712+
verification_offset = max(window_size * (ngram_size - 1), 1)
713+
verification_submask = torch.triu(
714+
torch.full((ngram_size - 1, ngram_size - 1), self.minus_infinity),
715+
diagonal=1,
716+
)
717+
for i in range(n_verifications):
718+
mask[
719+
verification_offset
720+
+ i * (ngram_size - 1) : verification_offset
721+
+ (i + 1) * (ngram_size - 1),
722+
verification_offset
723+
+ i * (ngram_size - 1) : verification_offset
724+
+ (i + 1) * (ngram_size - 1),
725+
] = verification_submask
726+
mask[verification_offset:, :1] = 0.0
727+
728+
return mask
729+
730+
def _get_lookahead_position_offsets(
731+
self, ngram_size: int, window_size: int, n_verifications: int
732+
) -> torch.Tensor:
733+
pos_offsets = torch.zeros(self.seq_length, dtype=torch.int32)
734+
idx = 0
735+
if window_size > 0:
736+
for i in range(ngram_size - 1):
737+
for j in range(window_size):
738+
pos_offsets[idx] = i + j
739+
idx += 1
740+
else:
741+
pos_offsets[0] = 0
742+
idx += 1
743+
744+
# Verification branches: [1, 2, ..., ngram_size - 1].
745+
for _ in range(n_verifications):
746+
for j in range(1, ngram_size):
747+
pos_offsets[idx] = j
748+
idx += 1
749+
750+
return pos_offsets
751+
752+
def lookahead_decode(
753+
self,
754+
model,
755+
init_token: int,
756+
n: int,
757+
ngram_size: int,
758+
window_size: int,
759+
n_verifications: int,
760+
stop_tokens: Optional[List[int]] = None,
761+
ngram_caches: Optional[Dict[int, "InputManager.NGramCache"]] = None,
762+
) -> List[int]:
763+
if not self.lookahead_enabled:
764+
raise RuntimeError("Lookahead decoding is not enabled")
765+
766+
if (ngram_size - 1) * (window_size + n_verifications) > self.seq_length:
767+
raise RuntimeError(
768+
f"Lookahead decoding configuration not compatible with seq_length {self.seq_length}. "
769+
f"Required: {(ngram_size - 1) * (window_size + n_verifications)}"
770+
)
771+
772+
self.attn_mask[:, self.cache_size :] = self._get_lookahead_decoding_mask(
773+
ngram_size, window_size, n_verifications
774+
)
775+
logger.debug("Lookahead decoding mask: ")
776+
for i in range(self.seq_length):
777+
logger.debug(
778+
" ".join(
779+
("X" if x == 0.0 else " ")
780+
for x in self.attn_mask[i][self.cache_size :]
781+
)
782+
)
783+
784+
offsets = self._get_lookahead_position_offsets(
785+
ngram_size, window_size, n_verifications
786+
)
787+
788+
stop_tokens = stop_tokens or []
789+
verification_offset = window_size * (ngram_size - 1)
790+
791+
if ngram_caches is None:
792+
ngram_caches = defaultdict(lambda: InputManager.NGramCache(n_verifications))
793+
new_tokens = [init_token]
794+
x = [init_token] * self.seq_length
795+
inference_count = 0
796+
797+
while len(new_tokens) < n + 1:
798+
cache = ngram_caches[x[0]]
799+
for i, ngram in enumerate(cache):
800+
for j, token in enumerate(ngram):
801+
x[verification_offset + i * (ngram_size - 1) + j] = token
802+
803+
logits, new_k, new_v = model(
804+
tokens=torch.tensor([x], dtype=torch.int64),
805+
input_pos=torch.tensor([self.input_pos], dtype=torch.long),
806+
k_caches=self.k_caches,
807+
v_caches=self.v_caches,
808+
attn_mask=self.attn_mask,
809+
input_len=torch.tensor([len(x)], dtype=torch.long),
810+
rope_indices=self.input_pos + offsets,
811+
)
812+
inference_count += 1
813+
814+
# Greedy only
815+
y = logits[0].argmax(dim=-1).tolist()
816+
new_tokens.append(y[0])
817+
logger.debug(f"{self.input_pos}: x = {x[0]}, y = {y[0]}")
818+
if new_tokens[-1] in stop_tokens:
819+
break
820+
821+
# Collect new n-grams.
822+
for i in range(window_size):
823+
key = x[i]
824+
suffix = []
825+
for j in range(1, ngram_size - 1):
826+
suffix.append(x[i + j * window_size])
827+
suffix.append(y[i + window_size * (ngram_size - 2)])
828+
ngram_caches[key].add(suffix)
829+
830+
# Verification.
831+
longest_match = []
832+
matched_branch = None
833+
for i in range(n_verifications):
834+
match = [y[0]]
835+
j = 0
836+
# for j in range(ngram_size - 1):
837+
while (
838+
j < ngram_size - 1
839+
and x[verification_offset + (ngram_size - 1) * i + j] == match[-1]
840+
):
841+
match.append(y[verification_offset + (ngram_size - 1) * i + j])
842+
j += 1
843+
if len(match) - 1 > len(longest_match):
844+
longest_match = match[1:]
845+
matched_branch = i
846+
847+
if matched_branch is not None:
848+
logger.debug(
849+
f"Matched {len(longest_match)} additional tokens from n-grams: {longest_match}"
850+
)
851+
for stop in stop_tokens:
852+
if stop in longest_match:
853+
longest_match = longest_match[: longest_match.index(stop) + 1]
854+
855+
new_tokens.extend(longest_match)
856+
branch_offset = verification_offset + (ngram_size - 1) * matched_branch
857+
self.update(
858+
input_length=len(longest_match),
859+
new_k_caches=new_k,
860+
new_v_caches=new_v,
861+
update_pos=branch_offset,
862+
)
863+
else:
864+
self.update(input_length=1, new_k_caches=new_k, new_v_caches=new_v)
865+
866+
# Update lookahead branch.
867+
for i in range(ngram_size - 2):
868+
for j in range(window_size):
869+
x[window_size * i + j] = x[window_size * (i + 1) + j]
870+
for j in range(window_size):
871+
x[window_size * (ngram_size - 2) + j] = y[
872+
window_size * (ngram_size - 2) + j
873+
]
874+
875+
x[0] = new_tokens[-1]
876+
if new_tokens[-1] in stop_tokens:
877+
break
878+
879+
logger.info(
880+
f"Generated {len(new_tokens) - 1} tokens with {inference_count} inference(s)."
881+
)
882+
return new_tokens

0 commit comments

Comments
 (0)