Skip to content

Lookahead decoding eager implementation #12491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 298 additions & 3 deletions examples/apple/coreml/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

# Please refer to README.md in the same folder for more information.

import logging
from collections import defaultdict, deque
from dataclasses import dataclass
from functools import partial
from typing import Dict, List, Optional, Tuple
Expand All @@ -23,6 +25,8 @@

from torch import nn

logger = logging.getLogger(__name__)


def find_multiple(n: int, k: int) -> int:
if n % k == 0:
Expand Down Expand Up @@ -507,6 +511,24 @@ def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list):


class InputManager:
class NGramCache:
def __init__(self, max_size: int):
self.cache = deque()
self.max_size = max_size

def add(self, ngram: List[int]):
if ngram in self.cache:
return
if len(self.cache) == self.max_size:
self.cache.popleft()
self.cache.append(ngram)

def __iter__(self):
return iter(self.cache)

def __str__(self):
return str(self.cache)

def __init__(
self,
n_layers: int,
Expand All @@ -519,6 +541,7 @@ def __init__(
dtype=torch.float16,
minus_infinity=-torch.inf,
cache_size=None,
lookahead_enabled: bool = False,
):
if cache_size is None:
cache_size = max_seq_length - seq_length
Expand All @@ -532,6 +555,8 @@ def __init__(

self.seq_length = seq_length
self.use_cache_list = use_cache_list
self.lookahead_enabled = lookahead_enabled
self.minus_infinity = minus_infinity

if self.use_cache_list:
self.k_caches = [
Expand Down Expand Up @@ -609,10 +634,10 @@ def _update_cache(self, start, length, new_k_caches, new_v_caches):
if self.cache_pos == self.cache_size:
self.cache_pos = 0

def update(self, input_length, new_k_caches, new_v_caches):
def update(self, input_length, new_k_caches, new_v_caches, update_pos=0):
# Copy as much new cache data into cache as possible without wrapping
amount_to_copy = min(input_length, self.cache_size - self.cache_pos)
self._update_cache(0, amount_to_copy, new_k_caches, new_v_caches)
self._update_cache(update_pos, amount_to_copy, new_k_caches, new_v_caches)
if self.input_pos <= self.cache_size:
self.attn_mask[:, (self.input_pos) : (self.input_pos + amount_to_copy)] = (
0.0
Expand All @@ -625,7 +650,10 @@ def update(self, input_length, new_k_caches, new_v_caches):
)
if remaining_to_copy > 0:
self._update_cache(
amount_to_copy, remaining_to_copy, new_k_caches, new_v_caches
update_pos + amount_to_copy,
remaining_to_copy,
new_k_caches,
new_v_caches,
)

self.input_pos += input_length
Expand Down Expand Up @@ -661,3 +689,270 @@ def get_inputs_and_remaining_tokens(self, tokens: List[int]):
self.get_inputs(tokens[0:processed_tokens]),
tokens[processed_tokens:],
)

def _get_lookahead_decoding_mask(
self, ngram_size: int, window_size: int, n_verifications: int
) -> torch.Tensor:
mask = torch.full((self.seq_length, self.seq_length), self.minus_infinity)
mask[0][0] = 0.0

lookahead_submask = torch.triu(
torch.full((window_size, window_size), self.minus_infinity),
diagonal=1,
)
for i in range(ngram_size - 1):
offset = window_size * i
mask[offset : offset + window_size, :window_size] = lookahead_submask
for j in range(1, i + 1):
mask[
offset : offset + window_size,
window_size * j : window_size * (j + 1),
].fill_diagonal_(0.0)

verification_offset = max(window_size * (ngram_size - 1), 1)
verification_submask = torch.triu(
torch.full((ngram_size - 1, ngram_size - 1), self.minus_infinity),
diagonal=1,
)
for i in range(n_verifications):
mask[
verification_offset
+ i * (ngram_size - 1) : verification_offset
+ (i + 1) * (ngram_size - 1),
verification_offset
+ i * (ngram_size - 1) : verification_offset
+ (i + 1) * (ngram_size - 1),
] = verification_submask
mask[verification_offset:, :1] = 0.0

return mask

def _get_lookahead_position_offsets(
self, ngram_size: int, window_size: int, n_verifications: int
) -> torch.Tensor:
pos_offsets = torch.zeros(self.seq_length, dtype=torch.int32)
idx = 0
if window_size > 0:
for i in range(ngram_size - 1):
for j in range(window_size):
pos_offsets[idx] = i + j
idx += 1
else:
pos_offsets[0] = 0
idx += 1

# Verification branches: [1, 2, ..., ngram_size - 1].
for _ in range(n_verifications):
for j in range(1, ngram_size):
pos_offsets[idx] = j
idx += 1

return pos_offsets

def _validate_lookahead_config(
self, ngram_size: int, window_size: int, n_verifications: int
) -> None:
"""
Validate the lookahead decoding configuration.
"""
if not self.lookahead_enabled:
raise RuntimeError("Lookahead decoding is not enabled")

if (ngram_size - 1) * (window_size + n_verifications) > self.seq_length:
raise RuntimeError(
f"Lookahead decoding configuration not compatible with seq_length {self.seq_length}. "
f"Required: {(ngram_size - 1) * (window_size + n_verifications)}"
)

def _setup_lookahead_mask(
self, ngram_size: int, window_size: int, n_verifications: int
) -> None:
"""
Set up the attention mask for lookahead decoding and log debug information.
"""
self.attn_mask[:, self.cache_size :] = self._get_lookahead_decoding_mask(
ngram_size, window_size, n_verifications
)
logger.debug("Lookahead decoding mask: ")
for i in range(self.seq_length):
logger.debug(
" ".join(
("X" if x == 0.0 else " ")
for x in self.attn_mask[i][self.cache_size :]
)
)

def _populate_verification_branches(
self, x: List[int], cache, verification_offset: int, ngram_size: int
) -> None:
"""
Populate verification branches with tokens from the n-gram cache.
"""
for i, ngram in enumerate(cache):
for j, token in enumerate(ngram):
x[verification_offset + i * (ngram_size - 1) + j] = token

def _collect_ngrams(
self,
x: List[int],
y: List[int],
ngram_caches: Dict[int, "InputManager.NGramCache"],
window_size: int,
ngram_size: int,
) -> None:
"""
Collect new n-grams from the current state and predictions.
"""
for i in range(window_size):
key = x[i]
suffix = []
for j in range(1, ngram_size - 1):
suffix.append(x[i + j * window_size])
suffix.append(y[i + window_size * (ngram_size - 2)])
ngram_caches[key].add(suffix)

def _find_longest_match(
self,
x: List[int],
y: List[int],
verification_offset: int,
n_verifications: int,
ngram_size: int,
) -> Tuple[List[int], Optional[int]]:
"""
Find the longest matching sequence from verification branches.
Returns the matched tokens and the branch index.
"""
longest_match = []
matched_branch = None

for i in range(n_verifications):
match = [y[0]]
j = 0
while (
j < ngram_size - 1
and x[verification_offset + (ngram_size - 1) * i + j] == match[-1]
):
match.append(y[verification_offset + (ngram_size - 1) * i + j])
j += 1
if len(match) - 1 > len(longest_match):
longest_match = match[1:]
matched_branch = i

return longest_match, matched_branch

def _update_lookahead_branches(
self, x: List[int], y: List[int], ngram_size: int, window_size: int
) -> None:
"""
Update the lookahead branches with new predictions.
"""
# Shift window contents up
for i in range(ngram_size - 2):
for j in range(window_size):
x[window_size * i + j] = x[window_size * (i + 1) + j]

# Fill the last window with new predictions
for j in range(window_size):
x[window_size * (ngram_size - 2) + j] = y[
window_size * (ngram_size - 2) + j
]

def lookahead_decode(
self,
model,
init_token: int,
n: int,
ngram_size: int,
window_size: int,
n_verifications: int,
stop_tokens: Optional[List[int]] = None,
ngram_caches: Optional[Dict[int, "InputManager.NGramCache"]] = None,
) -> List[int]:
# Validate configuration
self._validate_lookahead_config(ngram_size, window_size, n_verifications)

# Setup attention mask and position offsets
self._setup_lookahead_mask(ngram_size, window_size, n_verifications)
offsets = self._get_lookahead_position_offsets(
ngram_size, window_size, n_verifications
)

# Initialize state
stop_tokens = stop_tokens or []
verification_offset = window_size * (ngram_size - 1)
if ngram_caches is None:
ngram_caches = defaultdict(lambda: InputManager.NGramCache(n_verifications))

new_tokens = [init_token]
x = [init_token] * self.seq_length
inference_count = 0

# Main decoding loop
while len(new_tokens) < n + 1:
# Populate verification branches
cache = ngram_caches[x[0]]
self._populate_verification_branches(
x, cache, verification_offset, ngram_size
)

# Run model inference
logits, new_k, new_v = model(
tokens=torch.tensor([x], dtype=torch.int64),
input_pos=torch.tensor([self.input_pos], dtype=torch.long),
k_caches=self.k_caches,
v_caches=self.v_caches,
attn_mask=self.attn_mask,
input_len=torch.tensor([len(x)], dtype=torch.long),
rope_indices=self.input_pos + offsets,
)
inference_count += 1

# Process model output (greedy selection)
y = logits[0].argmax(dim=-1).tolist()
new_tokens.append(y[0])
logger.debug(f"{self.input_pos}: x = {x[0]}, y = {y[0]}")
if new_tokens[-1] in stop_tokens:
break

# Collect new n-grams
self._collect_ngrams(x, y, ngram_caches, window_size, ngram_size)

# Find longest match from verification branches
longest_match, matched_branch = self._find_longest_match(
x, y, verification_offset, n_verifications, ngram_size
)

# Process match results
if matched_branch is not None:
logger.debug(
f"Matched {len(longest_match)} additional tokens from n-grams: {longest_match}"
)
# Truncate at stop token if present
for stop in stop_tokens:
if stop in longest_match:
longest_match = longest_match[: longest_match.index(stop) + 1]

new_tokens.extend(longest_match)
branch_offset = verification_offset + (ngram_size - 1) * matched_branch
self.update(
input_length=len(longest_match),
new_k_caches=new_k,
new_v_caches=new_v,
update_pos=branch_offset,
)
else:
self.update(input_length=1, new_k_caches=new_k, new_v_caches=new_v)

# Update lookahead branches
self._update_lookahead_branches(x, y, ngram_size, window_size)

# Update first token and check for stop condition
x[0] = new_tokens[-1]
if new_tokens[-1] in stop_tokens:
break

logger.info(
f"Generated {len(new_tokens) - 1} tokens with {inference_count} inference(s)."
)
return new_tokens
Loading
Loading