|
21 | 21 | BartSeq2SeqLMPreprocessor,
|
22 | 22 | )
|
23 | 23 | from keras_nlp.src.models.seq_2_seq_lm import Seq2SeqLM
|
24 |
| -from keras_nlp.src.utils.tensor_utils import any_equal |
25 | 24 |
|
26 | 25 |
|
27 | 26 | @keras_nlp_export("keras_nlp.models.BartSeq2SeqLM")
|
@@ -200,291 +199,68 @@ def __init__(
|
200 | 199 | **kwargs,
|
201 | 200 | )
|
202 | 201 |
|
203 |
| - def call_decoder_with_cache( |
| 202 | + def build_cache(self, batch_size, max_length): |
| 203 | + num_layers = self.backbone.num_layers |
| 204 | + num_heads = self.backbone.num_heads |
| 205 | + head_dim = self.backbone.hidden_dim // self.backbone.num_heads |
| 206 | + shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] |
| 207 | + return ops.zeros(shape, dtype=self.compute_dtype) |
| 208 | + |
| 209 | + def compute_cross_attention_cache( |
| 210 | + self, encoder_token_ids, encoder_padding_mask |
| 211 | + ): |
| 212 | + """Does a forward pass on the encoder and returns the encoder output.""" |
| 213 | + # Embedding layers. |
| 214 | + tokens = self.backbone.token_embedding(encoder_token_ids) |
| 215 | + positions = self.backbone.encoder_position_embedding(tokens) |
| 216 | + # Sum, normalize and apply dropout to embeddings. |
| 217 | + x = self.backbone.encoder_embeddings_add((tokens, positions)) |
| 218 | + x = self.backbone.encoder_embeddings_layer_norm(x) |
| 219 | + x = self.backbone.encoder_embeddings_dropout(x) |
| 220 | + # Transformer encoder layers. |
| 221 | + for layer in self.backbone.encoder_transformer_layers: |
| 222 | + x = layer(x, padding_mask=encoder_padding_mask) |
| 223 | + # Transformer encoder layers. |
| 224 | + caches = [] |
| 225 | + for layer in self.backbone.decoder_transformer_layers: |
| 226 | + caches.append(layer.compute_cross_attention_cache(x)) |
| 227 | + return ops.stack(caches, axis=1) |
| 228 | + |
| 229 | + def call_with_cache( |
204 | 230 | self,
|
205 |
| - encoder_hidden_states, |
| 231 | + token_ids, |
| 232 | + cache, |
| 233 | + index, |
| 234 | + *, |
206 | 235 | encoder_padding_mask,
|
207 |
| - decoder_token_ids, |
208 |
| - self_attention_cache=None, |
209 |
| - self_attention_cache_update_index=None, |
210 |
| - cross_attention_cache=None, |
211 |
| - cross_attention_cache_update_index=None, |
| 236 | + cross_attention_cache, |
212 | 237 | ):
|
213 |
| - """Forward pass with a key/value caches for generative decoding.. |
214 |
| -
|
215 |
| - `call_decoder_with_cache` adds an additional inference-time forward pass |
216 |
| - for the model for seq2seq text generation. Unlike calling the model |
217 |
| - directly, this method does two things to optimize text generation: |
218 |
| -
|
219 |
| - - Allows caching previous key/value tensors in the decoder's |
220 |
| - self-attention layer to avoid recomputing the outputs of seen tokens. |
221 |
| - - Allows caching key/value tensors in the decoder's cross-attention |
222 |
| - layer to avoid recomputing the encoder outputs. |
223 |
| -
|
224 |
| - Args: |
225 |
| - encoder_hidden_states: a dense float Tensor of shape |
226 |
| - `(batch_size, encoder_sequence_length, hidden_dim)`. The |
227 |
| - sequence of hidden states at the output of the encoder's last |
228 |
| - layer. |
229 |
| - encoder_padding_mask: a dense float Tensor of shape |
230 |
| - `(batch_size, encoder_sequence_length)`. The padding mask for |
231 |
| - the encoder input. |
232 |
| - decoder_token_ids: a dense int Tensor of shape |
233 |
| - `(batch_size, max_length)`. Input token ids to be fed to |
234 |
| - the decoder. |
235 |
| - self_attention_cache: a dense float Tensor of shape |
236 |
| - `(batch_size, num_layers, 2, max_length, num_heads, key_dims)`. |
237 |
| - The cached key/value tensors of previously seen tokens in the |
238 |
| - decoder's self-attention layer. |
239 |
| - self_attention_cache_update_index: an int or int Tensor, the index |
240 |
| - at which to update the `self_attention_cache`. Usually, this is |
241 |
| - the index of the current token being processed during decoding. |
242 |
| - cross_attention_cache: a dense float Tensor of shape |
243 |
| - `(batch_size, num_layers, 2, encoder_sequence_length, num_heads, key_dims)`. |
244 |
| - The cached key/value tensors of the encoder outputs in the |
245 |
| - decoder's cross-attention layer. |
246 |
| - cross_attention_cache_update_index: an int or int Tensor, the index |
247 |
| - at which to update the `cross_attention_cache`. Usually, this is |
248 |
| - either `0` (compute the entire `cross_attention_cache`), or |
249 |
| - `None` (reuse a previously computed `cross_attention_cache`). |
250 |
| -
|
251 |
| - Returns: |
252 |
| - A `(logits, hidden_states, self_attention_cache, cross_attention_cache)` |
253 |
| - tuple, where `logits` is the language model logits for the input |
254 |
| - `decoder_token_ids`, `hidden_states` is the final hidden |
255 |
| - representation of the input tokens, `self_attention_cache` is the |
256 |
| - key/value cache in the decoder's self-attention layer and |
257 |
| - `cross_attention_cache` is the key/value cache in the decoder's |
258 |
| - cross-attention layer. |
259 |
| - """ |
260 |
| - # Embedding layers. |
261 |
| - tokens = self.backbone.token_embedding(decoder_token_ids) |
| 238 | + tokens = self.backbone.token_embedding(token_ids) |
262 | 239 | positions = self.backbone.decoder_position_embedding(
|
263 |
| - tokens, |
264 |
| - start_index=self_attention_cache_update_index, |
| 240 | + tokens, start_index=index |
265 | 241 | )
|
266 | 242 | # Sum, normalize and apply dropout to embeddings.
|
267 | 243 | x = self.backbone.decoder_embeddings_add((tokens, positions))
|
268 | 244 | x = self.backbone.decoder_embeddings_layer_norm(x)
|
269 | 245 | x = self.backbone.decoder_embeddings_dropout(x)
|
270 |
| - |
271 |
| - # Every decoder layer has a separate cache for the self-attention layer |
272 |
| - # and the cross-attention layer. We update all of them separately. |
273 |
| - self_attention_caches = [] |
274 |
| - cross_attention_caches = [] |
| 246 | + # Each decoder layer has a cache; we update them separately. |
| 247 | + caches = [] |
275 | 248 | for i, layer in enumerate(self.backbone.decoder_transformer_layers):
|
276 |
| - current_self_attention_cache = self_attention_cache[:, i, ...] |
| 249 | + current_self_attention_cache = cache[:, i, ...] |
277 | 250 | current_cross_attention_cache = cross_attention_cache[:, i, ...]
|
278 |
| - ( |
279 |
| - x, |
280 |
| - next_self_attention_cache, |
281 |
| - next_cross_attention_cache, |
282 |
| - ) = layer( |
| 251 | + x, next_cache, _ = layer( |
283 | 252 | decoder_sequence=x,
|
284 |
| - encoder_sequence=encoder_hidden_states, |
285 | 253 | encoder_padding_mask=encoder_padding_mask,
|
286 | 254 | self_attention_cache=current_self_attention_cache,
|
287 |
| - self_attention_cache_update_index=self_attention_cache_update_index, |
| 255 | + self_attention_cache_update_index=index, |
288 | 256 | cross_attention_cache=current_cross_attention_cache,
|
289 |
| - cross_attention_cache_update_index=cross_attention_cache_update_index, |
290 | 257 | )
|
291 |
| - if self_attention_cache_update_index is not None: |
292 |
| - self_attention_caches.append(next_self_attention_cache) |
293 |
| - if cross_attention_cache_update_index is not None: |
294 |
| - cross_attention_caches.append(next_cross_attention_cache) |
295 |
| - |
296 |
| - if self_attention_cache_update_index is not None: |
297 |
| - self_attention_cache = ops.stack(self_attention_caches, axis=1) |
298 |
| - if cross_attention_cache_update_index is not None: |
299 |
| - cross_attention_cache = ops.stack(cross_attention_caches, axis=1) |
300 |
| - |
| 258 | + caches.append(next_cache) |
| 259 | + cache = ops.stack(caches, axis=1) |
301 | 260 | hidden_states = x
|
302 | 261 | logits = self.backbone.token_embedding(hidden_states, reverse=True)
|
303 | 262 | return (
|
304 | 263 | logits,
|
305 | 264 | hidden_states,
|
306 |
| - self_attention_cache, |
307 |
| - cross_attention_cache, |
| 265 | + cache, |
308 | 266 | )
|
309 |
| - |
310 |
| - def call_encoder(self, token_ids, padding_mask): |
311 |
| - """Does a forward pass on the encoder and returns the encoder output.""" |
312 |
| - tokens = self.backbone.token_embedding(token_ids) |
313 |
| - positions = self.backbone.encoder_position_embedding(tokens) |
314 |
| - x = self.backbone.decoder_embeddings_add((tokens, positions)) |
315 |
| - x = self.backbone.encoder_embeddings_layer_norm(x) |
316 |
| - x = self.backbone.encoder_embeddings_dropout(x) |
317 |
| - for transformer_layer in self.backbone.encoder_transformer_layers: |
318 |
| - x = transformer_layer(x, padding_mask=padding_mask) |
319 |
| - return x |
320 |
| - |
321 |
| - def _initialize_cache(self, encoder_token_ids, decoder_token_ids): |
322 |
| - """Initializes empty self-attention cache and cross-attention cache.""" |
323 |
| - batch_size = ops.shape(encoder_token_ids)[0] |
324 |
| - encoder_max_length = ops.shape(encoder_token_ids)[1] |
325 |
| - decoder_max_length = ops.shape(decoder_token_ids)[1] |
326 |
| - |
327 |
| - num_layers = self.backbone.num_layers |
328 |
| - num_heads = self.backbone.num_heads |
329 |
| - head_dim = self.backbone.hidden_dim // self.backbone.num_heads |
330 |
| - |
331 |
| - shape = [ |
332 |
| - batch_size, |
333 |
| - num_layers, |
334 |
| - 2, |
335 |
| - decoder_max_length, |
336 |
| - num_heads, |
337 |
| - head_dim, |
338 |
| - ] |
339 |
| - self_attention_cache = ops.zeros(shape, dtype=self.compute_dtype) |
340 |
| - |
341 |
| - shape[3] = encoder_max_length |
342 |
| - cross_attention_cache = ops.zeros(shape, dtype=self.compute_dtype) |
343 |
| - |
344 |
| - return (self_attention_cache, cross_attention_cache) |
345 |
| - |
346 |
| - def _build_cache( |
347 |
| - self, encoder_token_ids, encoder_padding_mask, decoder_token_ids |
348 |
| - ): |
349 |
| - """Builds the self-attention cache and the cross-attention cache (key/value pairs).""" |
350 |
| - encoder_hidden_states = self.call_encoder( |
351 |
| - token_ids=encoder_token_ids, padding_mask=encoder_padding_mask |
352 |
| - ) |
353 |
| - self_attention_cache, cross_attention_cache = self._initialize_cache( |
354 |
| - encoder_token_ids, decoder_token_ids |
355 |
| - ) |
356 |
| - |
357 |
| - # Seed the self-attention cache and the cross-attention cache. |
358 |
| - ( |
359 |
| - _, |
360 |
| - hidden_states, |
361 |
| - self_attention_cache, |
362 |
| - cross_attention_cache, |
363 |
| - ) = self.call_decoder_with_cache( |
364 |
| - encoder_hidden_states=encoder_hidden_states, |
365 |
| - encoder_padding_mask=encoder_padding_mask, |
366 |
| - decoder_token_ids=decoder_token_ids, |
367 |
| - self_attention_cache=self_attention_cache, |
368 |
| - self_attention_cache_update_index=0, |
369 |
| - cross_attention_cache=cross_attention_cache, |
370 |
| - cross_attention_cache_update_index=0, |
371 |
| - ) |
372 |
| - return ( |
373 |
| - hidden_states, |
374 |
| - encoder_hidden_states, |
375 |
| - self_attention_cache, |
376 |
| - cross_attention_cache, |
377 |
| - ) |
378 |
| - |
379 |
| - def generate_step( |
380 |
| - self, |
381 |
| - inputs, |
382 |
| - stop_token_ids=None, |
383 |
| - ): |
384 |
| - """A compilable generation function for a batch of inputs. |
385 |
| -
|
386 |
| - This function represents the inner, XLA-compilable, generation function |
387 |
| - for a single batch of inputs. Inputs should have the same structure as |
388 |
| - model inputs, a dictionary with keys `"encoder_token_ids"`, |
389 |
| - `"encoder_padding_mask"`, `"decoder_token_ids"` and |
390 |
| - `"decoder_padding_mask"`. |
391 |
| -
|
392 |
| - Args: |
393 |
| - inputs: A dictionary with four keys - `"encoder_token_ids"`, |
394 |
| - `"encoder_padding_mask"`, `"decoder_token_ids"` and |
395 |
| - `"decoder_padding_mask"`, with batched tensor values. |
396 |
| - stop_token_ids: Tuple of id's of end token's to stop on. If all |
397 |
| - sequences have produced a new stop token, generation |
398 |
| - will stop. |
399 |
| - """ |
400 |
| - ( |
401 |
| - encoder_token_ids, |
402 |
| - encoder_padding_mask, |
403 |
| - decoder_token_ids, |
404 |
| - decoder_padding_mask, |
405 |
| - ) = ( |
406 |
| - inputs["encoder_token_ids"], |
407 |
| - inputs["encoder_padding_mask"], |
408 |
| - inputs["decoder_token_ids"], |
409 |
| - inputs["decoder_padding_mask"], |
410 |
| - ) |
411 |
| - |
412 |
| - batch_size = ops.shape(encoder_token_ids)[0] |
413 |
| - |
414 |
| - # Create and seed cache with a single forward pass. |
415 |
| - ( |
416 |
| - hidden_states, |
417 |
| - encoder_hidden_states, |
418 |
| - self_attention_cache, |
419 |
| - cross_attention_cache, |
420 |
| - ) = self._build_cache( |
421 |
| - encoder_token_ids, encoder_padding_mask, decoder_token_ids |
422 |
| - ) |
423 |
| - # Compute the lengths of all user inputted tokens ids. |
424 |
| - row_lengths = ops.sum(ops.cast(decoder_padding_mask, "int32"), axis=-1) |
425 |
| - # Start at the first index that has no user inputted id. |
426 |
| - index = ops.min(row_lengths) |
427 |
| - |
428 |
| - def next(prompt, cache, index): |
429 |
| - # The cache index is the index of our previous token. |
430 |
| - cache_index = index - 1 |
431 |
| - num_samples = ops.shape(prompt)[0] |
432 |
| - prompt = ops.slice(prompt, [0, cache_index], [num_samples, 1]) |
433 |
| - |
434 |
| - def repeat_tensor(x): |
435 |
| - """Repeats tensors along batch axis to match dim for beam search.""" |
436 |
| - if ops.shape(x)[0] == num_samples: |
437 |
| - return x |
438 |
| - return ops.repeat(x, repeats=num_samples // batch_size, axis=0) |
439 |
| - |
440 |
| - logits, hidden_states, cache, _ = self.call_decoder_with_cache( |
441 |
| - encoder_hidden_states=repeat_tensor(encoder_hidden_states), |
442 |
| - encoder_padding_mask=repeat_tensor(encoder_padding_mask), |
443 |
| - decoder_token_ids=prompt, |
444 |
| - self_attention_cache=cache, |
445 |
| - self_attention_cache_update_index=cache_index, |
446 |
| - cross_attention_cache=repeat_tensor(cross_attention_cache), |
447 |
| - cross_attention_cache_update_index=None, |
448 |
| - ) |
449 |
| - return ( |
450 |
| - ops.squeeze(logits, axis=1), |
451 |
| - ops.squeeze(hidden_states, axis=1), |
452 |
| - cache, |
453 |
| - ) |
454 |
| - |
455 |
| - decoder_token_ids = self.sampler( |
456 |
| - next=next, |
457 |
| - prompt=decoder_token_ids, |
458 |
| - cache=self_attention_cache, |
459 |
| - index=index, |
460 |
| - mask=decoder_padding_mask, |
461 |
| - stop_token_ids=stop_token_ids, |
462 |
| - hidden_states=hidden_states, |
463 |
| - model=self, |
464 |
| - ) |
465 |
| - |
466 |
| - # Compute an output padding mask with the token ids we updated. |
467 |
| - if stop_token_ids is not None: |
468 |
| - # Build a mask of `stop_token_ids` locations not in the original |
469 |
| - # prompt (not in locations where `decoder_padding_mask` is True). |
470 |
| - end_locations = any_equal( |
471 |
| - decoder_token_ids, |
472 |
| - stop_token_ids, |
473 |
| - ops.logical_not(decoder_padding_mask), |
474 |
| - ) |
475 |
| - end_locations = ops.cast(end_locations, "int32") |
476 |
| - # Use cumsum to get ones in all locations after `end_locations`. |
477 |
| - cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") |
478 |
| - overflow = cumsum - end_locations |
479 |
| - # Our padding mask is the inverse of these overflow locations. |
480 |
| - decoder_padding_mask = ops.logical_not(ops.cast(overflow, "bool")) |
481 |
| - else: |
482 |
| - # Without early stopping, all locations will have been updated. |
483 |
| - decoder_padding_mask = ops.ones_like( |
484 |
| - decoder_token_ids, dtype="bool" |
485 |
| - ) |
486 |
| - |
487 |
| - return { |
488 |
| - "decoder_token_ids": decoder_token_ids, |
489 |
| - "decoder_padding_mask": decoder_padding_mask, |
490 |
| - } |
0 commit comments