Skip to content

Commit 071f352

Browse files
committed
Generation refactor
1 parent 93e175f commit 071f352

25 files changed

+735
-1555
lines changed

keras_nlp/src/layers/modeling/transformer_decoder.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,28 @@ def build(
251251
# Create layers based on input shape.
252252
self.built = True
253253

254+
def compute_self_attention_cache(
255+
self,
256+
decoder_sequence,
257+
):
258+
x = decoder_sequence
259+
if self.normalize_first:
260+
x = self._self_attention_layer_norm(x)
261+
key = self._self_attention_layer._key_dense(x)
262+
value = self._self_attention_layer._value_dense(x)
263+
return ops.stack((key, value), axis=1)
264+
265+
def compute_cross_attention_cache(
266+
self,
267+
encoder_sequence,
268+
):
269+
x = encoder_sequence
270+
if self.normalize_first:
271+
x = self._cross_attention_layer_norm(x)
272+
key = self._cross_attention_layer._key_dense(x)
273+
value = self._cross_attention_layer._value_dense(x)
274+
return ops.stack((key, value), axis=1)
275+
254276
def call(
255277
self,
256278
decoder_sequence,
@@ -314,7 +336,9 @@ def call(
314336
the layer has cross-attention.
315337
"""
316338

317-
has_encoder_sequence = encoder_sequence is not None
339+
has_encoder_sequence = (
340+
encoder_sequence is not None or cross_attention_cache is not None
341+
)
318342

319343
has_cross_attention = self._cross_attention_layer is not None
320344
if not has_cross_attention and has_encoder_sequence:

keras_nlp/src/layers/preprocessing/start_end_packer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def call(
193193
outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs
194194

195195
if self.return_padding_mask:
196-
mask = tf.ones_like(x, dtype="bool")
196+
mask = tf.ones_like(x, dtype="int32")
197197
mask = mask.to_tensor(shape=(batch_size, sequence_length))
198198
mask = tf.squeeze(mask, axis=0) if unbatched else mask
199199
return outputs, mask

keras_nlp/src/models/bart/bart_backbone.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,5 +257,4 @@ def get_config(self):
257257
"max_sequence_length": self.max_sequence_length,
258258
}
259259
)
260-
261260
return config

keras_nlp/src/models/bart/bart_seq_2_seq_lm.py

Lines changed: 43 additions & 267 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
BartSeq2SeqLMPreprocessor,
2222
)
2323
from keras_nlp.src.models.seq_2_seq_lm import Seq2SeqLM
24-
from keras_nlp.src.utils.tensor_utils import any_equal
2524

2625

2726
@keras_nlp_export("keras_nlp.models.BartSeq2SeqLM")
@@ -200,291 +199,68 @@ def __init__(
200199
**kwargs,
201200
)
202201

203-
def call_decoder_with_cache(
202+
def build_cache(self, batch_size, max_length):
203+
num_layers = self.backbone.num_layers
204+
num_heads = self.backbone.num_heads
205+
head_dim = self.backbone.hidden_dim // self.backbone.num_heads
206+
shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
207+
return ops.zeros(shape, dtype=self.compute_dtype)
208+
209+
def compute_cross_attention_cache(
210+
self, encoder_token_ids, encoder_padding_mask
211+
):
212+
"""Does a forward pass on the encoder and returns the encoder output."""
213+
# Embedding layers.
214+
tokens = self.backbone.token_embedding(encoder_token_ids)
215+
positions = self.backbone.encoder_position_embedding(tokens)
216+
# Sum, normalize and apply dropout to embeddings.
217+
x = self.backbone.encoder_embeddings_add((tokens, positions))
218+
x = self.backbone.encoder_embeddings_layer_norm(x)
219+
x = self.backbone.encoder_embeddings_dropout(x)
220+
# Transformer encoder layers.
221+
for layer in self.backbone.encoder_transformer_layers:
222+
x = layer(x, padding_mask=encoder_padding_mask)
223+
# Transformer encoder layers.
224+
caches = []
225+
for layer in self.backbone.decoder_transformer_layers:
226+
caches.append(layer.compute_cross_attention_cache(x))
227+
return ops.stack(caches, axis=1)
228+
229+
def call_with_cache(
204230
self,
205-
encoder_hidden_states,
231+
token_ids,
232+
cache,
233+
index,
234+
*,
206235
encoder_padding_mask,
207-
decoder_token_ids,
208-
self_attention_cache=None,
209-
self_attention_cache_update_index=None,
210-
cross_attention_cache=None,
211-
cross_attention_cache_update_index=None,
236+
cross_attention_cache,
212237
):
213-
"""Forward pass with a key/value caches for generative decoding..
214-
215-
`call_decoder_with_cache` adds an additional inference-time forward pass
216-
for the model for seq2seq text generation. Unlike calling the model
217-
directly, this method does two things to optimize text generation:
218-
219-
- Allows caching previous key/value tensors in the decoder's
220-
self-attention layer to avoid recomputing the outputs of seen tokens.
221-
- Allows caching key/value tensors in the decoder's cross-attention
222-
layer to avoid recomputing the encoder outputs.
223-
224-
Args:
225-
encoder_hidden_states: a dense float Tensor of shape
226-
`(batch_size, encoder_sequence_length, hidden_dim)`. The
227-
sequence of hidden states at the output of the encoder's last
228-
layer.
229-
encoder_padding_mask: a dense float Tensor of shape
230-
`(batch_size, encoder_sequence_length)`. The padding mask for
231-
the encoder input.
232-
decoder_token_ids: a dense int Tensor of shape
233-
`(batch_size, max_length)`. Input token ids to be fed to
234-
the decoder.
235-
self_attention_cache: a dense float Tensor of shape
236-
`(batch_size, num_layers, 2, max_length, num_heads, key_dims)`.
237-
The cached key/value tensors of previously seen tokens in the
238-
decoder's self-attention layer.
239-
self_attention_cache_update_index: an int or int Tensor, the index
240-
at which to update the `self_attention_cache`. Usually, this is
241-
the index of the current token being processed during decoding.
242-
cross_attention_cache: a dense float Tensor of shape
243-
`(batch_size, num_layers, 2, encoder_sequence_length, num_heads, key_dims)`.
244-
The cached key/value tensors of the encoder outputs in the
245-
decoder's cross-attention layer.
246-
cross_attention_cache_update_index: an int or int Tensor, the index
247-
at which to update the `cross_attention_cache`. Usually, this is
248-
either `0` (compute the entire `cross_attention_cache`), or
249-
`None` (reuse a previously computed `cross_attention_cache`).
250-
251-
Returns:
252-
A `(logits, hidden_states, self_attention_cache, cross_attention_cache)`
253-
tuple, where `logits` is the language model logits for the input
254-
`decoder_token_ids`, `hidden_states` is the final hidden
255-
representation of the input tokens, `self_attention_cache` is the
256-
key/value cache in the decoder's self-attention layer and
257-
`cross_attention_cache` is the key/value cache in the decoder's
258-
cross-attention layer.
259-
"""
260-
# Embedding layers.
261-
tokens = self.backbone.token_embedding(decoder_token_ids)
238+
tokens = self.backbone.token_embedding(token_ids)
262239
positions = self.backbone.decoder_position_embedding(
263-
tokens,
264-
start_index=self_attention_cache_update_index,
240+
tokens, start_index=index
265241
)
266242
# Sum, normalize and apply dropout to embeddings.
267243
x = self.backbone.decoder_embeddings_add((tokens, positions))
268244
x = self.backbone.decoder_embeddings_layer_norm(x)
269245
x = self.backbone.decoder_embeddings_dropout(x)
270-
271-
# Every decoder layer has a separate cache for the self-attention layer
272-
# and the cross-attention layer. We update all of them separately.
273-
self_attention_caches = []
274-
cross_attention_caches = []
246+
# Each decoder layer has a cache; we update them separately.
247+
caches = []
275248
for i, layer in enumerate(self.backbone.decoder_transformer_layers):
276-
current_self_attention_cache = self_attention_cache[:, i, ...]
249+
current_self_attention_cache = cache[:, i, ...]
277250
current_cross_attention_cache = cross_attention_cache[:, i, ...]
278-
(
279-
x,
280-
next_self_attention_cache,
281-
next_cross_attention_cache,
282-
) = layer(
251+
x, next_cache, _ = layer(
283252
decoder_sequence=x,
284-
encoder_sequence=encoder_hidden_states,
285253
encoder_padding_mask=encoder_padding_mask,
286254
self_attention_cache=current_self_attention_cache,
287-
self_attention_cache_update_index=self_attention_cache_update_index,
255+
self_attention_cache_update_index=index,
288256
cross_attention_cache=current_cross_attention_cache,
289-
cross_attention_cache_update_index=cross_attention_cache_update_index,
290257
)
291-
if self_attention_cache_update_index is not None:
292-
self_attention_caches.append(next_self_attention_cache)
293-
if cross_attention_cache_update_index is not None:
294-
cross_attention_caches.append(next_cross_attention_cache)
295-
296-
if self_attention_cache_update_index is not None:
297-
self_attention_cache = ops.stack(self_attention_caches, axis=1)
298-
if cross_attention_cache_update_index is not None:
299-
cross_attention_cache = ops.stack(cross_attention_caches, axis=1)
300-
258+
caches.append(next_cache)
259+
cache = ops.stack(caches, axis=1)
301260
hidden_states = x
302261
logits = self.backbone.token_embedding(hidden_states, reverse=True)
303262
return (
304263
logits,
305264
hidden_states,
306-
self_attention_cache,
307-
cross_attention_cache,
265+
cache,
308266
)
309-
310-
def call_encoder(self, token_ids, padding_mask):
311-
"""Does a forward pass on the encoder and returns the encoder output."""
312-
tokens = self.backbone.token_embedding(token_ids)
313-
positions = self.backbone.encoder_position_embedding(tokens)
314-
x = self.backbone.decoder_embeddings_add((tokens, positions))
315-
x = self.backbone.encoder_embeddings_layer_norm(x)
316-
x = self.backbone.encoder_embeddings_dropout(x)
317-
for transformer_layer in self.backbone.encoder_transformer_layers:
318-
x = transformer_layer(x, padding_mask=padding_mask)
319-
return x
320-
321-
def _initialize_cache(self, encoder_token_ids, decoder_token_ids):
322-
"""Initializes empty self-attention cache and cross-attention cache."""
323-
batch_size = ops.shape(encoder_token_ids)[0]
324-
encoder_max_length = ops.shape(encoder_token_ids)[1]
325-
decoder_max_length = ops.shape(decoder_token_ids)[1]
326-
327-
num_layers = self.backbone.num_layers
328-
num_heads = self.backbone.num_heads
329-
head_dim = self.backbone.hidden_dim // self.backbone.num_heads
330-
331-
shape = [
332-
batch_size,
333-
num_layers,
334-
2,
335-
decoder_max_length,
336-
num_heads,
337-
head_dim,
338-
]
339-
self_attention_cache = ops.zeros(shape, dtype=self.compute_dtype)
340-
341-
shape[3] = encoder_max_length
342-
cross_attention_cache = ops.zeros(shape, dtype=self.compute_dtype)
343-
344-
return (self_attention_cache, cross_attention_cache)
345-
346-
def _build_cache(
347-
self, encoder_token_ids, encoder_padding_mask, decoder_token_ids
348-
):
349-
"""Builds the self-attention cache and the cross-attention cache (key/value pairs)."""
350-
encoder_hidden_states = self.call_encoder(
351-
token_ids=encoder_token_ids, padding_mask=encoder_padding_mask
352-
)
353-
self_attention_cache, cross_attention_cache = self._initialize_cache(
354-
encoder_token_ids, decoder_token_ids
355-
)
356-
357-
# Seed the self-attention cache and the cross-attention cache.
358-
(
359-
_,
360-
hidden_states,
361-
self_attention_cache,
362-
cross_attention_cache,
363-
) = self.call_decoder_with_cache(
364-
encoder_hidden_states=encoder_hidden_states,
365-
encoder_padding_mask=encoder_padding_mask,
366-
decoder_token_ids=decoder_token_ids,
367-
self_attention_cache=self_attention_cache,
368-
self_attention_cache_update_index=0,
369-
cross_attention_cache=cross_attention_cache,
370-
cross_attention_cache_update_index=0,
371-
)
372-
return (
373-
hidden_states,
374-
encoder_hidden_states,
375-
self_attention_cache,
376-
cross_attention_cache,
377-
)
378-
379-
def generate_step(
380-
self,
381-
inputs,
382-
stop_token_ids=None,
383-
):
384-
"""A compilable generation function for a batch of inputs.
385-
386-
This function represents the inner, XLA-compilable, generation function
387-
for a single batch of inputs. Inputs should have the same structure as
388-
model inputs, a dictionary with keys `"encoder_token_ids"`,
389-
`"encoder_padding_mask"`, `"decoder_token_ids"` and
390-
`"decoder_padding_mask"`.
391-
392-
Args:
393-
inputs: A dictionary with four keys - `"encoder_token_ids"`,
394-
`"encoder_padding_mask"`, `"decoder_token_ids"` and
395-
`"decoder_padding_mask"`, with batched tensor values.
396-
stop_token_ids: Tuple of id's of end token's to stop on. If all
397-
sequences have produced a new stop token, generation
398-
will stop.
399-
"""
400-
(
401-
encoder_token_ids,
402-
encoder_padding_mask,
403-
decoder_token_ids,
404-
decoder_padding_mask,
405-
) = (
406-
inputs["encoder_token_ids"],
407-
inputs["encoder_padding_mask"],
408-
inputs["decoder_token_ids"],
409-
inputs["decoder_padding_mask"],
410-
)
411-
412-
batch_size = ops.shape(encoder_token_ids)[0]
413-
414-
# Create and seed cache with a single forward pass.
415-
(
416-
hidden_states,
417-
encoder_hidden_states,
418-
self_attention_cache,
419-
cross_attention_cache,
420-
) = self._build_cache(
421-
encoder_token_ids, encoder_padding_mask, decoder_token_ids
422-
)
423-
# Compute the lengths of all user inputted tokens ids.
424-
row_lengths = ops.sum(ops.cast(decoder_padding_mask, "int32"), axis=-1)
425-
# Start at the first index that has no user inputted id.
426-
index = ops.min(row_lengths)
427-
428-
def next(prompt, cache, index):
429-
# The cache index is the index of our previous token.
430-
cache_index = index - 1
431-
num_samples = ops.shape(prompt)[0]
432-
prompt = ops.slice(prompt, [0, cache_index], [num_samples, 1])
433-
434-
def repeat_tensor(x):
435-
"""Repeats tensors along batch axis to match dim for beam search."""
436-
if ops.shape(x)[0] == num_samples:
437-
return x
438-
return ops.repeat(x, repeats=num_samples // batch_size, axis=0)
439-
440-
logits, hidden_states, cache, _ = self.call_decoder_with_cache(
441-
encoder_hidden_states=repeat_tensor(encoder_hidden_states),
442-
encoder_padding_mask=repeat_tensor(encoder_padding_mask),
443-
decoder_token_ids=prompt,
444-
self_attention_cache=cache,
445-
self_attention_cache_update_index=cache_index,
446-
cross_attention_cache=repeat_tensor(cross_attention_cache),
447-
cross_attention_cache_update_index=None,
448-
)
449-
return (
450-
ops.squeeze(logits, axis=1),
451-
ops.squeeze(hidden_states, axis=1),
452-
cache,
453-
)
454-
455-
decoder_token_ids = self.sampler(
456-
next=next,
457-
prompt=decoder_token_ids,
458-
cache=self_attention_cache,
459-
index=index,
460-
mask=decoder_padding_mask,
461-
stop_token_ids=stop_token_ids,
462-
hidden_states=hidden_states,
463-
model=self,
464-
)
465-
466-
# Compute an output padding mask with the token ids we updated.
467-
if stop_token_ids is not None:
468-
# Build a mask of `stop_token_ids` locations not in the original
469-
# prompt (not in locations where `decoder_padding_mask` is True).
470-
end_locations = any_equal(
471-
decoder_token_ids,
472-
stop_token_ids,
473-
ops.logical_not(decoder_padding_mask),
474-
)
475-
end_locations = ops.cast(end_locations, "int32")
476-
# Use cumsum to get ones in all locations after `end_locations`.
477-
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
478-
overflow = cumsum - end_locations
479-
# Our padding mask is the inverse of these overflow locations.
480-
decoder_padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
481-
else:
482-
# Without early stopping, all locations will have been updated.
483-
decoder_padding_mask = ops.ones_like(
484-
decoder_token_ids, dtype="bool"
485-
)
486-
487-
return {
488-
"decoder_token_ids": decoder_token_ids,
489-
"decoder_padding_mask": decoder_padding_mask,
490-
}

0 commit comments

Comments
 (0)