Skip to content

Commit 52b3d77

Browse files
committed
Generation refactor
1 parent a2a9602 commit 52b3d77

28 files changed

+739
-1567
lines changed

keras_nlp/layers/modeling/transformer_decoder.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,28 @@ def build(
250250
# Create layers based on input shape.
251251
self.built = True
252252

253+
def compute_self_attention_cache(
254+
self,
255+
decoder_sequence,
256+
):
257+
x = decoder_sequence
258+
if self.normalize_first:
259+
x = self._self_attention_layer_norm(x)
260+
key = self._self_attention_layer._key_dense(x)
261+
value = self._self_attention_layer._value_dense(x)
262+
return ops.stack((key, value), axis=1)
263+
264+
def compute_cross_attention_cache(
265+
self,
266+
encoder_sequence,
267+
):
268+
x = encoder_sequence
269+
if self.normalize_first:
270+
x = self._cross_attention_layer_norm(x)
271+
key = self._cross_attention_layer._key_dense(x)
272+
value = self._cross_attention_layer._value_dense(x)
273+
return ops.stack((key, value), axis=1)
274+
253275
def __call__(
254276
self,
255277
decoder_sequence,
@@ -325,7 +347,9 @@ def call(
325347
the layer has cross-attention.
326348
"""
327349

328-
has_encoder_sequence = encoder_sequence is not None
350+
has_encoder_sequence = (
351+
encoder_sequence is not None or cross_attention_cache is not None
352+
)
329353

330354
has_cross_attention = self._cross_attention_layer is not None
331355
if not has_cross_attention and has_encoder_sequence:

keras_nlp/layers/preprocessing/start_end_packer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def call(
189189
outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs
190190

191191
if self.return_padding_mask:
192-
mask = tf.ones_like(x, dtype="bool")
192+
mask = tf.ones_like(x, dtype="int32")
193193
mask = mask.to_tensor(shape=(batch_size, sequence_length))
194194
mask = tf.squeeze(mask, axis=0) if unbatched else mask
195195
return outputs, mask

keras_nlp/models/bart/bart_backbone.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,5 +254,4 @@ def get_config(self):
254254
"max_sequence_length": self.max_sequence_length,
255255
}
256256
)
257-
258257
return config

keras_nlp/models/bart/bart_seq_2_seq_lm.py

Lines changed: 43 additions & 267 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
BartSeq2SeqLMPreprocessor,
2121
)
2222
from keras_nlp.models.seq_2_seq_lm import Seq2SeqLM
23-
from keras_nlp.utils.tensor_utils import any_equal
2423

2524

2625
@keras_nlp_export("keras_nlp.models.BartSeq2SeqLM")
@@ -199,291 +198,68 @@ def __init__(
199198
**kwargs,
200199
)
201200

202-
def call_decoder_with_cache(
201+
def build_cache(self, batch_size, max_length):
202+
num_layers = self.backbone.num_layers
203+
num_heads = self.backbone.num_heads
204+
head_dim = self.backbone.hidden_dim // self.backbone.num_heads
205+
shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
206+
return ops.zeros(shape, dtype=self.compute_dtype)
207+
208+
def compute_cross_attention_cache(
209+
self, encoder_token_ids, encoder_padding_mask
210+
):
211+
"""Does a forward pass on the encoder and returns the encoder output."""
212+
# Embedding layers.
213+
tokens = self.backbone.token_embedding(encoder_token_ids)
214+
positions = self.backbone.encoder_position_embedding(tokens)
215+
# Sum, normalize and apply dropout to embeddings.
216+
x = self.backbone.encoder_embeddings_add((tokens, positions))
217+
x = self.backbone.encoder_embeddings_layer_norm(x)
218+
x = self.backbone.encoder_embeddings_dropout(x)
219+
# Transformer encoder layers.
220+
for layer in self.backbone.encoder_transformer_layers:
221+
x = layer(x, padding_mask=encoder_padding_mask)
222+
# Transformer encoder layers.
223+
caches = []
224+
for layer in self.backbone.decoder_transformer_layers:
225+
caches.append(layer.compute_cross_attention_cache(x))
226+
return ops.stack(caches, axis=1)
227+
228+
def call_with_cache(
203229
self,
204-
encoder_hidden_states,
230+
token_ids,
231+
cache,
232+
index,
233+
*,
205234
encoder_padding_mask,
206-
decoder_token_ids,
207-
self_attention_cache=None,
208-
self_attention_cache_update_index=None,
209-
cross_attention_cache=None,
210-
cross_attention_cache_update_index=None,
235+
cross_attention_cache,
211236
):
212-
"""Forward pass with a key/value caches for generative decoding..
213-
214-
`call_decoder_with_cache` adds an additional inference-time forward pass
215-
for the model for seq2seq text generation. Unlike calling the model
216-
directly, this method does two things to optimize text generation:
217-
218-
- Allows caching previous key/value tensors in the decoder's
219-
self-attention layer to avoid recomputing the outputs of seen tokens.
220-
- Allows caching key/value tensors in the decoder's cross-attention
221-
layer to avoid recomputing the encoder outputs.
222-
223-
Args:
224-
encoder_hidden_states: a dense float Tensor of shape
225-
`(batch_size, encoder_sequence_length, hidden_dim)`. The
226-
sequence of hidden states at the output of the encoder's last
227-
layer.
228-
encoder_padding_mask: a dense float Tensor of shape
229-
`(batch_size, encoder_sequence_length)`. The padding mask for
230-
the encoder input.
231-
decoder_token_ids: a dense int Tensor of shape
232-
`(batch_size, max_length)`. Input token ids to be fed to
233-
the decoder.
234-
self_attention_cache: a dense float Tensor of shape
235-
`(batch_size, num_layers, 2, max_length, num_heads, key_dims)`.
236-
The cached key/value tensors of previously seen tokens in the
237-
decoder's self-attention layer.
238-
self_attention_cache_update_index: an int or int Tensor, the index
239-
at which to update the `self_attention_cache`. Usually, this is
240-
the index of the current token being processed during decoding.
241-
cross_attention_cache: a dense float Tensor of shape
242-
`(batch_size, num_layers, 2, encoder_sequence_length, num_heads, key_dims)`.
243-
The cached key/value tensors of the encoder outputs in the
244-
decoder's cross-attention layer.
245-
cross_attention_cache_update_index: an int or int Tensor, the index
246-
at which to update the `cross_attention_cache`. Usually, this is
247-
either `0` (compute the entire `cross_attention_cache`), or
248-
`None` (reuse a previously computed `cross_attention_cache`).
249-
250-
Returns:
251-
A `(logits, hidden_states, self_attention_cache, cross_attention_cache)`
252-
tuple, where `logits` is the language model logits for the input
253-
`decoder_token_ids`, `hidden_states` is the final hidden
254-
representation of the input tokens, `self_attention_cache` is the
255-
key/value cache in the decoder's self-attention layer and
256-
`cross_attention_cache` is the key/value cache in the decoder's
257-
cross-attention layer.
258-
"""
259-
# Embedding layers.
260-
tokens = self.backbone.token_embedding(decoder_token_ids)
237+
tokens = self.backbone.token_embedding(token_ids)
261238
positions = self.backbone.decoder_position_embedding(
262-
tokens,
263-
start_index=self_attention_cache_update_index,
239+
tokens, start_index=index
264240
)
265241
# Sum, normalize and apply dropout to embeddings.
266242
x = self.backbone.decoder_embeddings_add((tokens, positions))
267243
x = self.backbone.decoder_embeddings_layer_norm(x)
268244
x = self.backbone.decoder_embeddings_dropout(x)
269-
270-
# Every decoder layer has a separate cache for the self-attention layer
271-
# and the cross-attention layer. We update all of them separately.
272-
self_attention_caches = []
273-
cross_attention_caches = []
245+
# Each decoder layer has a cache; we update them separately.
246+
caches = []
274247
for i, layer in enumerate(self.backbone.decoder_transformer_layers):
275-
current_self_attention_cache = self_attention_cache[:, i, ...]
248+
current_self_attention_cache = cache[:, i, ...]
276249
current_cross_attention_cache = cross_attention_cache[:, i, ...]
277-
(
278-
x,
279-
next_self_attention_cache,
280-
next_cross_attention_cache,
281-
) = layer(
250+
x, next_cache, _ = layer(
282251
decoder_sequence=x,
283-
encoder_sequence=encoder_hidden_states,
284252
encoder_padding_mask=encoder_padding_mask,
285253
self_attention_cache=current_self_attention_cache,
286-
self_attention_cache_update_index=self_attention_cache_update_index,
254+
self_attention_cache_update_index=index,
287255
cross_attention_cache=current_cross_attention_cache,
288-
cross_attention_cache_update_index=cross_attention_cache_update_index,
289256
)
290-
if self_attention_cache_update_index is not None:
291-
self_attention_caches.append(next_self_attention_cache)
292-
if cross_attention_cache_update_index is not None:
293-
cross_attention_caches.append(next_cross_attention_cache)
294-
295-
if self_attention_cache_update_index is not None:
296-
self_attention_cache = ops.stack(self_attention_caches, axis=1)
297-
if cross_attention_cache_update_index is not None:
298-
cross_attention_cache = ops.stack(cross_attention_caches, axis=1)
299-
257+
caches.append(next_cache)
258+
cache = ops.stack(caches, axis=1)
300259
hidden_states = x
301260
logits = self.backbone.token_embedding(hidden_states, reverse=True)
302261
return (
303262
logits,
304263
hidden_states,
305-
self_attention_cache,
306-
cross_attention_cache,
264+
cache,
307265
)
308-
309-
def call_encoder(self, token_ids, padding_mask):
310-
"""Does a forward pass on the encoder and returns the encoder output."""
311-
tokens = self.backbone.token_embedding(token_ids)
312-
positions = self.backbone.encoder_position_embedding(tokens)
313-
x = self.backbone.decoder_embeddings_add((tokens, positions))
314-
x = self.backbone.encoder_embeddings_layer_norm(x)
315-
x = self.backbone.encoder_embeddings_dropout(x)
316-
for transformer_layer in self.backbone.encoder_transformer_layers:
317-
x = transformer_layer(x, padding_mask=padding_mask)
318-
return x
319-
320-
def _initialize_cache(self, encoder_token_ids, decoder_token_ids):
321-
"""Initializes empty self-attention cache and cross-attention cache."""
322-
batch_size = ops.shape(encoder_token_ids)[0]
323-
encoder_max_length = ops.shape(encoder_token_ids)[1]
324-
decoder_max_length = ops.shape(decoder_token_ids)[1]
325-
326-
num_layers = self.backbone.num_layers
327-
num_heads = self.backbone.num_heads
328-
head_dim = self.backbone.hidden_dim // self.backbone.num_heads
329-
330-
shape = [
331-
batch_size,
332-
num_layers,
333-
2,
334-
decoder_max_length,
335-
num_heads,
336-
head_dim,
337-
]
338-
self_attention_cache = ops.zeros(shape, dtype=self.compute_dtype)
339-
340-
shape[3] = encoder_max_length
341-
cross_attention_cache = ops.zeros(shape, dtype=self.compute_dtype)
342-
343-
return (self_attention_cache, cross_attention_cache)
344-
345-
def _build_cache(
346-
self, encoder_token_ids, encoder_padding_mask, decoder_token_ids
347-
):
348-
"""Builds the self-attention cache and the cross-attention cache (key/value pairs)."""
349-
encoder_hidden_states = self.call_encoder(
350-
token_ids=encoder_token_ids, padding_mask=encoder_padding_mask
351-
)
352-
self_attention_cache, cross_attention_cache = self._initialize_cache(
353-
encoder_token_ids, decoder_token_ids
354-
)
355-
356-
# Seed the self-attention cache and the cross-attention cache.
357-
(
358-
_,
359-
hidden_states,
360-
self_attention_cache,
361-
cross_attention_cache,
362-
) = self.call_decoder_with_cache(
363-
encoder_hidden_states=encoder_hidden_states,
364-
encoder_padding_mask=encoder_padding_mask,
365-
decoder_token_ids=decoder_token_ids,
366-
self_attention_cache=self_attention_cache,
367-
self_attention_cache_update_index=0,
368-
cross_attention_cache=cross_attention_cache,
369-
cross_attention_cache_update_index=0,
370-
)
371-
return (
372-
hidden_states,
373-
encoder_hidden_states,
374-
self_attention_cache,
375-
cross_attention_cache,
376-
)
377-
378-
def generate_step(
379-
self,
380-
inputs,
381-
stop_token_ids=None,
382-
):
383-
"""A compilable generation function for a batch of inputs.
384-
385-
This function represents the inner, XLA-compilable, generation function
386-
for a single batch of inputs. Inputs should have the same structure as
387-
model inputs, a dictionary with keys `"encoder_token_ids"`,
388-
`"encoder_padding_mask"`, `"decoder_token_ids"` and
389-
`"decoder_padding_mask"`.
390-
391-
Args:
392-
inputs: A dictionary with four keys - `"encoder_token_ids"`,
393-
`"encoder_padding_mask"`, `"decoder_token_ids"` and
394-
`"decoder_padding_mask"`, with batched tensor values.
395-
stop_token_ids: Tuple of id's of end token's to stop on. If all
396-
sequences have produced a new stop token, generation
397-
will stop.
398-
"""
399-
(
400-
encoder_token_ids,
401-
encoder_padding_mask,
402-
decoder_token_ids,
403-
decoder_padding_mask,
404-
) = (
405-
inputs["encoder_token_ids"],
406-
inputs["encoder_padding_mask"],
407-
inputs["decoder_token_ids"],
408-
inputs["decoder_padding_mask"],
409-
)
410-
411-
batch_size = ops.shape(encoder_token_ids)[0]
412-
413-
# Create and seed cache with a single forward pass.
414-
(
415-
hidden_states,
416-
encoder_hidden_states,
417-
self_attention_cache,
418-
cross_attention_cache,
419-
) = self._build_cache(
420-
encoder_token_ids, encoder_padding_mask, decoder_token_ids
421-
)
422-
# Compute the lengths of all user inputted tokens ids.
423-
row_lengths = ops.sum(ops.cast(decoder_padding_mask, "int32"), axis=-1)
424-
# Start at the first index that has no user inputted id.
425-
index = ops.min(row_lengths)
426-
427-
def next(prompt, cache, index):
428-
# The cache index is the index of our previous token.
429-
cache_index = index - 1
430-
num_samples = ops.shape(prompt)[0]
431-
prompt = ops.slice(prompt, [0, cache_index], [num_samples, 1])
432-
433-
def repeat_tensor(x):
434-
"""Repeats tensors along batch axis to match dim for beam search."""
435-
if ops.shape(x)[0] == num_samples:
436-
return x
437-
return ops.repeat(x, repeats=num_samples // batch_size, axis=0)
438-
439-
logits, hidden_states, cache, _ = self.call_decoder_with_cache(
440-
encoder_hidden_states=repeat_tensor(encoder_hidden_states),
441-
encoder_padding_mask=repeat_tensor(encoder_padding_mask),
442-
decoder_token_ids=prompt,
443-
self_attention_cache=cache,
444-
self_attention_cache_update_index=cache_index,
445-
cross_attention_cache=repeat_tensor(cross_attention_cache),
446-
cross_attention_cache_update_index=None,
447-
)
448-
return (
449-
ops.squeeze(logits, axis=1),
450-
ops.squeeze(hidden_states, axis=1),
451-
cache,
452-
)
453-
454-
decoder_token_ids = self.sampler(
455-
next=next,
456-
prompt=decoder_token_ids,
457-
cache=self_attention_cache,
458-
index=index,
459-
mask=decoder_padding_mask,
460-
stop_token_ids=stop_token_ids,
461-
hidden_states=hidden_states,
462-
model=self,
463-
)
464-
465-
# Compute an output padding mask with the token ids we updated.
466-
if stop_token_ids is not None:
467-
# Build a mask of `stop_token_ids` locations not in the original
468-
# prompt (not in locations where `decoder_padding_mask` is True).
469-
end_locations = any_equal(
470-
decoder_token_ids,
471-
stop_token_ids,
472-
ops.logical_not(decoder_padding_mask),
473-
)
474-
end_locations = ops.cast(end_locations, "int32")
475-
# Use cumsum to get ones in all locations after `end_locations`.
476-
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
477-
overflow = cumsum - end_locations
478-
# Our padding mask is the inverse of these overflow locations.
479-
decoder_padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
480-
else:
481-
# Without early stopping, all locations will have been updated.
482-
decoder_padding_mask = ops.ones_like(
483-
decoder_token_ids, dtype="bool"
484-
)
485-
486-
return {
487-
"decoder_token_ids": decoder_token_ids,
488-
"decoder_padding_mask": decoder_padding_mask,
489-
}

0 commit comments

Comments
 (0)