Skip to content

Commit f96b5f2

Browse files
committed
docs update
1 parent d19adad commit f96b5f2

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

keras_nlp/models/causal_lm.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,43 @@ def __init__(self, *args, **kwargs):
3636
self.generate_function = None
3737

3838
def build_cache(self, batch_size, max_length):
39+
"""Builds an empty cache for use with `call_with_cache`.
40+
41+
Args:
42+
batch_size: int. The size of the batch for generation.
43+
max_length: int. The maximum sequence length for the cache.
44+
45+
Returns:
46+
A cache Tensor, the exact shape will depend on the model.
47+
"""
3948
raise NotImplementedError
4049

4150
def call_with_cache(self, token_ids, cache, index):
51+
"""Forward pass with cache for generation.
52+
53+
`call_with_cache` adds an additional forward pass for the model for
54+
autoregressive inference. Unlike calling the model directly, this method
55+
allows caching previous key/value results in multi-head attention layer,
56+
and avoids recomputing the outputs of seen tokens.
57+
58+
Args:
59+
token_ids: a dense int Tensor with shape `(batch_size, n)`, where
60+
`n` is some sequence length less than or equal to the max
61+
length of the cache. Usually `n` is either the full cache
62+
length, to "prefill" the prompt cache values, or `1`, to predict
63+
single token id.
64+
cache: a dense float Tensor. The cache of key and value projections
65+
used in the attention layers of the model. The exact shape will
66+
depend on the model.
67+
index: int, or int Tensor. The index of the first token of
68+
`token_ids` in the entire generated sequence.
69+
70+
Returns:
71+
A `(logits, hidden_states, cache)` tuple. Where `logits` is the
72+
language model logits for the input token_ids, `hidden_states` is
73+
the final hidden representation of the input tokens, and `cache` is
74+
the updated decoding cache.
75+
"""
4276
raise NotImplementedError
4377

4478
def compile(

0 commit comments

Comments
 (0)