@@ -36,9 +36,43 @@ def __init__(self, *args, **kwargs):
36
36
self .generate_function = None
37
37
38
38
def build_cache (self , batch_size , max_length ):
39
+ """Builds an empty cache for use with `call_with_cache`.
40
+
41
+ Args:
42
+ batch_size: int. The size of the batch for generation.
43
+ max_length: int. The maximum sequence length for the cache.
44
+
45
+ Returns:
46
+ A cache Tensor, the exact shape will depend on the model.
47
+ """
39
48
raise NotImplementedError
40
49
41
50
def call_with_cache (self , token_ids , cache , index ):
51
+ """Forward pass with cache for generation.
52
+
53
+ `call_with_cache` adds an additional forward pass for the model for
54
+ autoregressive inference. Unlike calling the model directly, this method
55
+ allows caching previous key/value results in multi-head attention layer,
56
+ and avoids recomputing the outputs of seen tokens.
57
+
58
+ Args:
59
+ token_ids: a dense int Tensor with shape `(batch_size, n)`, where
60
+ `n` is some sequence length less than or equal to the max
61
+ length of the cache. Usually `n` is either the full cache
62
+ length, to "prefill" the prompt cache values, or `1`, to predict
63
+ single token id.
64
+ cache: a dense float Tensor. The cache of key and value projections
65
+ used in the attention layers of the model. The exact shape will
66
+ depend on the model.
67
+ index: int, or int Tensor. The index of the first token of
68
+ `token_ids` in the entire generated sequence.
69
+
70
+ Returns:
71
+ A `(logits, hidden_states, cache)` tuple. Where `logits` is the
72
+ language model logits for the input token_ids, `hidden_states` is
73
+ the final hidden representation of the input tokens, and `cache` is
74
+ the updated decoding cache.
75
+ """
42
76
raise NotImplementedError
43
77
44
78
def compile (
0 commit comments