diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py index a553fcc0d8b..ae98c327b45 100644 --- a/examples/apple/coreml/llama/llama_transformer.py +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -6,6 +6,8 @@ # Please refer to README.md in the same folder for more information. +import logging +from collections import defaultdict, deque from dataclasses import dataclass from functools import partial from typing import Dict, List, Optional, Tuple @@ -23,6 +25,8 @@ from torch import nn +logger = logging.getLogger(__name__) + def find_multiple(n: int, k: int) -> int: if n % k == 0: @@ -507,6 +511,24 @@ def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list): class InputManager: + class NGramCache: + def __init__(self, max_size: int): + self.cache = deque() + self.max_size = max_size + + def add(self, ngram: List[int]): + if ngram in self.cache: + return + if len(self.cache) == self.max_size: + self.cache.popleft() + self.cache.append(ngram) + + def __iter__(self): + return iter(self.cache) + + def __str__(self): + return str(self.cache) + def __init__( self, n_layers: int, @@ -519,6 +541,7 @@ def __init__( dtype=torch.float16, minus_infinity=-torch.inf, cache_size=None, + lookahead_enabled: bool = False, ): if cache_size is None: cache_size = max_seq_length - seq_length @@ -532,6 +555,8 @@ def __init__( self.seq_length = seq_length self.use_cache_list = use_cache_list + self.lookahead_enabled = lookahead_enabled + self.minus_infinity = minus_infinity if self.use_cache_list: self.k_caches = [ @@ -609,10 +634,10 @@ def _update_cache(self, start, length, new_k_caches, new_v_caches): if self.cache_pos == self.cache_size: self.cache_pos = 0 - def update(self, input_length, new_k_caches, new_v_caches): + def update(self, input_length, new_k_caches, new_v_caches, update_pos=0): # Copy as much new cache data into cache as possible without wrapping amount_to_copy = min(input_length, self.cache_size - self.cache_pos) - self._update_cache(0, amount_to_copy, new_k_caches, new_v_caches) + self._update_cache(update_pos, amount_to_copy, new_k_caches, new_v_caches) if self.input_pos <= self.cache_size: self.attn_mask[:, (self.input_pos) : (self.input_pos + amount_to_copy)] = ( 0.0 @@ -625,7 +650,10 @@ def update(self, input_length, new_k_caches, new_v_caches): ) if remaining_to_copy > 0: self._update_cache( - amount_to_copy, remaining_to_copy, new_k_caches, new_v_caches + update_pos + amount_to_copy, + remaining_to_copy, + new_k_caches, + new_v_caches, ) self.input_pos += input_length @@ -661,3 +689,270 @@ def get_inputs_and_remaining_tokens(self, tokens: List[int]): self.get_inputs(tokens[0:processed_tokens]), tokens[processed_tokens:], ) + + def _get_lookahead_decoding_mask( + self, ngram_size: int, window_size: int, n_verifications: int + ) -> torch.Tensor: + mask = torch.full((self.seq_length, self.seq_length), self.minus_infinity) + mask[0][0] = 0.0 + + lookahead_submask = torch.triu( + torch.full((window_size, window_size), self.minus_infinity), + diagonal=1, + ) + for i in range(ngram_size - 1): + offset = window_size * i + mask[offset : offset + window_size, :window_size] = lookahead_submask + for j in range(1, i + 1): + mask[ + offset : offset + window_size, + window_size * j : window_size * (j + 1), + ].fill_diagonal_(0.0) + + verification_offset = max(window_size * (ngram_size - 1), 1) + verification_submask = torch.triu( + torch.full((ngram_size - 1, ngram_size - 1), self.minus_infinity), + diagonal=1, + ) + for i in range(n_verifications): + mask[ + verification_offset + + i * (ngram_size - 1) : verification_offset + + (i + 1) * (ngram_size - 1), + verification_offset + + i * (ngram_size - 1) : verification_offset + + (i + 1) * (ngram_size - 1), + ] = verification_submask + mask[verification_offset:, :1] = 0.0 + + return mask + + def _get_lookahead_position_offsets( + self, ngram_size: int, window_size: int, n_verifications: int + ) -> torch.Tensor: + pos_offsets = torch.zeros(self.seq_length, dtype=torch.int32) + idx = 0 + if window_size > 0: + for i in range(ngram_size - 1): + for j in range(window_size): + pos_offsets[idx] = i + j + idx += 1 + else: + pos_offsets[0] = 0 + idx += 1 + + # Verification branches: [1, 2, ..., ngram_size - 1]. + for _ in range(n_verifications): + for j in range(1, ngram_size): + pos_offsets[idx] = j + idx += 1 + + return pos_offsets + + def _validate_lookahead_config( + self, ngram_size: int, window_size: int, n_verifications: int + ) -> None: + """ + Validate the lookahead decoding configuration. + """ + if not self.lookahead_enabled: + raise RuntimeError("Lookahead decoding is not enabled") + + if (ngram_size - 1) * (window_size + n_verifications) > self.seq_length: + raise RuntimeError( + f"Lookahead decoding configuration not compatible with seq_length {self.seq_length}. " + f"Required: {(ngram_size - 1) * (window_size + n_verifications)}" + ) + + def _setup_lookahead_mask( + self, ngram_size: int, window_size: int, n_verifications: int + ) -> None: + """ + Set up the attention mask for lookahead decoding and log debug information. + """ + self.attn_mask[:, self.cache_size :] = self._get_lookahead_decoding_mask( + ngram_size, window_size, n_verifications + ) + logger.debug("Lookahead decoding mask: ") + for i in range(self.seq_length): + logger.debug( + " ".join( + ("X" if x == 0.0 else " ") + for x in self.attn_mask[i][self.cache_size :] + ) + ) + + def _populate_verification_branches( + self, x: List[int], cache, verification_offset: int, ngram_size: int + ) -> None: + """ + Populate verification branches with tokens from the n-gram cache. + """ + for i, ngram in enumerate(cache): + for j, token in enumerate(ngram): + x[verification_offset + i * (ngram_size - 1) + j] = token + + def _collect_ngrams( + self, + x: List[int], + y: List[int], + ngram_caches: Dict[int, "InputManager.NGramCache"], + window_size: int, + ngram_size: int, + ) -> None: + """ + Collect new n-grams from the current state and predictions. + """ + for i in range(window_size): + key = x[i] + suffix = [] + for j in range(1, ngram_size - 1): + suffix.append(x[i + j * window_size]) + suffix.append(y[i + window_size * (ngram_size - 2)]) + ngram_caches[key].add(suffix) + + def _find_longest_match( + self, + x: List[int], + y: List[int], + verification_offset: int, + n_verifications: int, + ngram_size: int, + ) -> Tuple[List[int], Optional[int]]: + """ + Find the longest matching sequence from verification branches. + Returns the matched tokens and the branch index. + """ + longest_match = [] + matched_branch = None + + for i in range(n_verifications): + match = [y[0]] + j = 0 + while ( + j < ngram_size - 1 + and x[verification_offset + (ngram_size - 1) * i + j] == match[-1] + ): + match.append(y[verification_offset + (ngram_size - 1) * i + j]) + j += 1 + if len(match) - 1 > len(longest_match): + longest_match = match[1:] + matched_branch = i + + return longest_match, matched_branch + + def _update_lookahead_branches( + self, x: List[int], y: List[int], ngram_size: int, window_size: int + ) -> None: + """ + Update the lookahead branches with new predictions. + """ + # Shift window contents up + for i in range(ngram_size - 2): + for j in range(window_size): + x[window_size * i + j] = x[window_size * (i + 1) + j] + + # Fill the last window with new predictions + for j in range(window_size): + x[window_size * (ngram_size - 2) + j] = y[ + window_size * (ngram_size - 2) + j + ] + + def lookahead_decode( + self, + model, + init_token: int, + n: int, + ngram_size: int, + window_size: int, + n_verifications: int, + stop_tokens: Optional[List[int]] = None, + ngram_caches: Optional[Dict[int, "InputManager.NGramCache"]] = None, + ) -> List[int]: + # Validate configuration + self._validate_lookahead_config(ngram_size, window_size, n_verifications) + + # Setup attention mask and position offsets + self._setup_lookahead_mask(ngram_size, window_size, n_verifications) + offsets = self._get_lookahead_position_offsets( + ngram_size, window_size, n_verifications + ) + + # Initialize state + stop_tokens = stop_tokens or [] + verification_offset = window_size * (ngram_size - 1) + if ngram_caches is None: + ngram_caches = defaultdict(lambda: InputManager.NGramCache(n_verifications)) + + new_tokens = [init_token] + x = [init_token] * self.seq_length + inference_count = 0 + + # Main decoding loop + while len(new_tokens) < n + 1: + # Populate verification branches + cache = ngram_caches[x[0]] + self._populate_verification_branches( + x, cache, verification_offset, ngram_size + ) + + # Run model inference + logits, new_k, new_v = model( + tokens=torch.tensor([x], dtype=torch.int64), + input_pos=torch.tensor([self.input_pos], dtype=torch.long), + k_caches=self.k_caches, + v_caches=self.v_caches, + attn_mask=self.attn_mask, + input_len=torch.tensor([len(x)], dtype=torch.long), + rope_indices=self.input_pos + offsets, + ) + inference_count += 1 + + # Process model output (greedy selection) + y = logits[0].argmax(dim=-1).tolist() + new_tokens.append(y[0]) + logger.debug(f"{self.input_pos}: x = {x[0]}, y = {y[0]}") + if new_tokens[-1] in stop_tokens: + break + + # Collect new n-grams + self._collect_ngrams(x, y, ngram_caches, window_size, ngram_size) + + # Find longest match from verification branches + longest_match, matched_branch = self._find_longest_match( + x, y, verification_offset, n_verifications, ngram_size + ) + + # Process match results + if matched_branch is not None: + logger.debug( + f"Matched {len(longest_match)} additional tokens from n-grams: {longest_match}" + ) + # Truncate at stop token if present + for stop in stop_tokens: + if stop in longest_match: + longest_match = longest_match[: longest_match.index(stop) + 1] + + new_tokens.extend(longest_match) + branch_offset = verification_offset + (ngram_size - 1) * matched_branch + self.update( + input_length=len(longest_match), + new_k_caches=new_k, + new_v_caches=new_v, + update_pos=branch_offset, + ) + else: + self.update(input_length=1, new_k_caches=new_k, new_v_caches=new_v) + + # Update lookahead branches + self._update_lookahead_branches(x, y, ngram_size, window_size) + + # Update first token and check for stop condition + x[0] = new_tokens[-1] + if new_tokens[-1] in stop_tokens: + break + + logger.info( + f"Generated {len(new_tokens) - 1} tokens with {inference_count} inference(s)." + ) + return new_tokens diff --git a/examples/apple/coreml/llama/run_lookahead.py b/examples/apple/coreml/llama/run_lookahead.py new file mode 100644 index 00000000000..1d48c2b07e8 --- /dev/null +++ b/examples/apple/coreml/llama/run_lookahead.py @@ -0,0 +1,284 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from collections import defaultdict + +import sentencepiece as spm + +import torch +from executorch.examples.apple.coreml.llama.llama_transformer import ( + InputManager, + load_model, +) + +from executorch.examples.models.llama.runner.generation import next_token +from executorch.examples.models.llama.tokenizer import tiktoken + +from executorch.runtime import Runtime + + +class Tokenizer: + def __init__(self, model_path: str): + # Try sentence piece + try: + print("Trying to load sentencepiece") + sp = spm.SentencePieceProcessor() + sp.load(model_path) + self.tokenizer = sp + except: + print("Trying to load tiktoken") + self.tokenizer = tiktoken.Tokenizer(model_path) + + def encode(self, text, bos, eos): + if isinstance(self.tokenizer, spm.SentencePieceProcessor): + bos_string = "" if bos else "" + eos_string = "" if eos else "" + return self.tokenizer.encode(f"{bos_string}{text}{eos_string}") + return self.tokenizer.encode(text, bos=bos, eos=eos) + + def decode(self, tokens): + if isinstance(self.tokenizer, spm.SentencePieceProcessor): + return self.tokenizer.decode(tokens) + return self.tokenizer.decode(tokens) + + def decode_token(self, token): + if isinstance(self.tokenizer, spm.SentencePieceProcessor): + return f"{self.tokenizer.decode([token])} " + return self.tokenizer.decode_token(token) + + def stop_tokens(self): + if isinstance(self.tokenizer, spm.SentencePieceProcessor): + return [self.tokenizer.eos_id()] + return self.tokenizer.stop_tokens + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + help="model.pte", + ) + parser.add_argument( + "-t", + "--tokenizer", + help="tokenizer.model path", + ) + parser.add_argument( + "--prompt", + type=str, + default="Once upon a time,", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.6, + ) + parser.add_argument( + "--top_p", + type=float, + default=0.9, + ) + parser.add_argument( + "--use_eager", + action="store_true", + ) + parser.add_argument( + "-p", + "--params", + type=str, + default=None, + ) + parser.add_argument( + "-c", + "--checkpoint", + type=str, + default=None, + ) + parser.add_argument("--dtype", type=str, choices=["fp16", "fp32"], default=None) + parser.add_argument( + "--seq_length", + type=int, + default=None, + ) + parser.add_argument( + "--max_seq_length", + type=int, + default=None, + ) + parser.add_argument( + "--cache_size", + type=int, + default=None, + ) + # Lookahead decoding parameters + parser.add_argument( + "--ngram_size", + type=int, + default=3, + help="Size of ngrams for lookahead decoding", + ) + parser.add_argument( + "--window_size", + type=int, + default=4, + help="Window size for lookahead decoding", + ) + parser.add_argument( + "--n_verifications", + type=int, + default=4, + help="Number of verifications for lookahead decoding", + ) + parser.add_argument( + "--ngrams_seed", + type=str, + default=None, + help="Seed for ngrams cache in lookahead decoding", + ) + parser.add_argument( + "--max_tokens", + type=int, + default=32, + help="Maximum number of tokens to generate", + ) + + args = parser.parse_args() + + tokenizer = Tokenizer(args.tokenizer) + + runtime = Runtime.get() + if args.use_eager: + assert args.params is not None + assert args.checkpoint is not None + assert args.dtype is not None + assert args.max_seq_length is not None + assert args.seq_length is not None + + max_seq_length = args.max_seq_length + seq_length = args.seq_length + model = load_model( + args.checkpoint, + args.params, + max_seq_length=max_seq_length, + use_cache_list=True, + ) + n_layers = model.params.n_layers + max_batch_size = model.params.max_batch_size + n_kv_heads = model.params.n_kv_heads + head_dim = model.params.head_dim + cache_size = args.cache_size + + float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype] + model.eval() + model.to(float_dtype) + else: + program = runtime.load_program(args.model) + method = program.load_method("forward") + + metadata = method.metadata + print("Method metadata: ", metadata, "\n\n") + + assert ( + metadata.num_inputs() == 6 + ), "Do not export with --use_cache_list for use in pybindings" + n_layers, max_batch_size, n_kv_heads, cache_size, head_dim = ( + metadata.input_tensor_meta(3).sizes() + ) + float_dtype = {5: torch.float16, 6: torch.float32}[ + metadata.input_tensor_meta(3).dtype() + ] + + seq_length, max_seq_length = metadata.input_tensor_meta(5).sizes() + + input_manager = InputManager( + n_layers=n_layers, + max_batch_size=max_batch_size, + n_kv_heads=n_kv_heads, + max_seq_length=max_seq_length, + head_dim=head_dim, + use_cache_list=True, + seq_length=seq_length, + dtype=float_dtype, + minus_infinity=-30000.0, + cache_size=cache_size, + lookahead_enabled=True, + ) + + print(f"Prompt: {args.prompt}") + tokens = tokenizer.encode(args.prompt, bos=True, eos=False) + logits = None + + while len(tokens) > 0 and (input_manager.input_pos + seq_length < max_seq_length): + inputs, remaining_tokens = input_manager.get_inputs_and_remaining_tokens(tokens) + processed_tokens = len(tokens) - len(remaining_tokens) + + if args.use_eager: + model_inputs = ( + inputs[0], # tokens + inputs[1], # input_pos + inputs[3], # k_caches + inputs[4], # v_caches + inputs[5], # attn_mask + inputs[2], # input_length + ) + logits, k, v = model(*model_inputs) + else: + logits, k, v = method.execute(inputs) + + input_manager.update( + input_length=processed_tokens, new_k_caches=k, new_v_caches=v + ) + tokens = remaining_tokens + + ngram_caches = None + if args.ngrams_seed is not None: + ngram_caches = defaultdict( + lambda: InputManager.NGramCache(args.n_verifications) + ) + seed_tokens = tokenizer.encode(args.ngrams_seed, bos=False, eos=False) + for i in range(len(seed_tokens) - args.ngram_size + 1): + key = seed_tokens[i] + suffix = seed_tokens[i + 1 : i + args.ngram_size] + ngram_caches[key].add(suffix) + + if input_manager.input_pos < max_seq_length and logits is not None: + last_token_logits = logits[0, processed_tokens - 1, :] + init_token = next_token(last_token_logits.unsqueeze(0), 0, 0) + + print("\nGenerating with lookahead decoding...") + if args.use_eager: + new_tokens = input_manager.lookahead_decode( + model=model, + init_token=init_token, + n=args.max_tokens, + ngram_size=args.ngram_size, + window_size=args.window_size, + n_verifications=args.n_verifications, + stop_tokens=tokenizer.stop_tokens(), + ngram_caches=ngram_caches, + ) + else: + new_tokens = input_manager.lookahead_decode( + model=lambda *inputs: method.execute(inputs), + init_token=init_token, + n=args.max_tokens, + ngram_size=args.ngram_size, + window_size=args.window_size, + n_verifications=args.n_verifications, + stop_tokens=tokenizer.stop_tokens(), + ngram_caches=ngram_caches, + ) + + print("\nGenerated text:") + print(tokenizer.decode(new_tokens)) + else: + print("Failed to generate text") + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index ad69f159e7c..f788b8f5032 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -306,3 +306,15 @@ def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): freqs_cos = self.freqs_cos[:seq_len] freqs_sin = self.freqs_sin[:seq_len] return freqs_cos, freqs_sin + + def get_freqs_using_indices(self, indices: torch.Tensor): + """ + Get the precomputed frequencies for given input indices. + + Args: + indices (torch.Tensor): The input indices tensor. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The precomputed frequencies for given input indices. + """ + return self.freqs_cos[indices], self.freqs_sin[indices] diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 03a9289924e..43f1f2de374 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -418,6 +418,8 @@ def lookahead_decode( # noqa: C901 ] x[0] = new_tokens[-1] + if new_tokens[-1] in stop_tokens: + break logger.info( f"Generated {len(new_tokens) - 1} tokens with {inference_cnt} inference(s)."