|
20 | 20 | BartSeq2SeqLMPreprocessor,
|
21 | 21 | )
|
22 | 22 | from keras_nlp.models.seq_2_seq_lm import Seq2SeqLM
|
23 |
| -from keras_nlp.utils.tensor_utils import any_equal |
24 | 23 |
|
25 | 24 |
|
26 | 25 | @keras_nlp_export("keras_nlp.models.BartSeq2SeqLM")
|
@@ -199,291 +198,68 @@ def __init__(
|
199 | 198 | **kwargs,
|
200 | 199 | )
|
201 | 200 |
|
202 |
| - def call_decoder_with_cache( |
| 201 | + def build_cache(self, batch_size, max_length): |
| 202 | + num_layers = self.backbone.num_layers |
| 203 | + num_heads = self.backbone.num_heads |
| 204 | + head_dim = self.backbone.hidden_dim // self.backbone.num_heads |
| 205 | + shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] |
| 206 | + return ops.zeros(shape, dtype=self.compute_dtype) |
| 207 | + |
| 208 | + def compute_cross_attention_cache( |
| 209 | + self, encoder_token_ids, encoder_padding_mask |
| 210 | + ): |
| 211 | + """Does a forward pass on the encoder and returns the encoder output.""" |
| 212 | + # Embedding layers. |
| 213 | + tokens = self.backbone.token_embedding(encoder_token_ids) |
| 214 | + positions = self.backbone.encoder_position_embedding(tokens) |
| 215 | + # Sum, normalize and apply dropout to embeddings. |
| 216 | + x = self.backbone.encoder_embeddings_add((tokens, positions)) |
| 217 | + x = self.backbone.encoder_embeddings_layer_norm(x) |
| 218 | + x = self.backbone.encoder_embeddings_dropout(x) |
| 219 | + # Transformer encoder layers. |
| 220 | + for layer in self.backbone.encoder_transformer_layers: |
| 221 | + x = layer(x, padding_mask=encoder_padding_mask) |
| 222 | + # Transformer encoder layers. |
| 223 | + caches = [] |
| 224 | + for layer in self.backbone.decoder_transformer_layers: |
| 225 | + caches.append(layer.compute_cross_attention_cache(x)) |
| 226 | + return ops.stack(caches, axis=1) |
| 227 | + |
| 228 | + def call_with_cache( |
203 | 229 | self,
|
204 |
| - encoder_hidden_states, |
| 230 | + token_ids, |
| 231 | + cache, |
| 232 | + index, |
| 233 | + *, |
205 | 234 | encoder_padding_mask,
|
206 |
| - decoder_token_ids, |
207 |
| - self_attention_cache=None, |
208 |
| - self_attention_cache_update_index=None, |
209 |
| - cross_attention_cache=None, |
210 |
| - cross_attention_cache_update_index=None, |
| 235 | + cross_attention_cache, |
211 | 236 | ):
|
212 |
| - """Forward pass with a key/value caches for generative decoding.. |
213 |
| -
|
214 |
| - `call_decoder_with_cache` adds an additional inference-time forward pass |
215 |
| - for the model for seq2seq text generation. Unlike calling the model |
216 |
| - directly, this method does two things to optimize text generation: |
217 |
| -
|
218 |
| - - Allows caching previous key/value tensors in the decoder's |
219 |
| - self-attention layer to avoid recomputing the outputs of seen tokens. |
220 |
| - - Allows caching key/value tensors in the decoder's cross-attention |
221 |
| - layer to avoid recomputing the encoder outputs. |
222 |
| -
|
223 |
| - Args: |
224 |
| - encoder_hidden_states: a dense float Tensor of shape |
225 |
| - `(batch_size, encoder_sequence_length, hidden_dim)`. The |
226 |
| - sequence of hidden states at the output of the encoder's last |
227 |
| - layer. |
228 |
| - encoder_padding_mask: a dense float Tensor of shape |
229 |
| - `(batch_size, encoder_sequence_length)`. The padding mask for |
230 |
| - the encoder input. |
231 |
| - decoder_token_ids: a dense int Tensor of shape |
232 |
| - `(batch_size, max_length)`. Input token ids to be fed to |
233 |
| - the decoder. |
234 |
| - self_attention_cache: a dense float Tensor of shape |
235 |
| - `(batch_size, num_layers, 2, max_length, num_heads, key_dims)`. |
236 |
| - The cached key/value tensors of previously seen tokens in the |
237 |
| - decoder's self-attention layer. |
238 |
| - self_attention_cache_update_index: an int or int Tensor, the index |
239 |
| - at which to update the `self_attention_cache`. Usually, this is |
240 |
| - the index of the current token being processed during decoding. |
241 |
| - cross_attention_cache: a dense float Tensor of shape |
242 |
| - `(batch_size, num_layers, 2, encoder_sequence_length, num_heads, key_dims)`. |
243 |
| - The cached key/value tensors of the encoder outputs in the |
244 |
| - decoder's cross-attention layer. |
245 |
| - cross_attention_cache_update_index: an int or int Tensor, the index |
246 |
| - at which to update the `cross_attention_cache`. Usually, this is |
247 |
| - either `0` (compute the entire `cross_attention_cache`), or |
248 |
| - `None` (reuse a previously computed `cross_attention_cache`). |
249 |
| -
|
250 |
| - Returns: |
251 |
| - A `(logits, hidden_states, self_attention_cache, cross_attention_cache)` |
252 |
| - tuple, where `logits` is the language model logits for the input |
253 |
| - `decoder_token_ids`, `hidden_states` is the final hidden |
254 |
| - representation of the input tokens, `self_attention_cache` is the |
255 |
| - key/value cache in the decoder's self-attention layer and |
256 |
| - `cross_attention_cache` is the key/value cache in the decoder's |
257 |
| - cross-attention layer. |
258 |
| - """ |
259 |
| - # Embedding layers. |
260 |
| - tokens = self.backbone.token_embedding(decoder_token_ids) |
| 237 | + tokens = self.backbone.token_embedding(token_ids) |
261 | 238 | positions = self.backbone.decoder_position_embedding(
|
262 |
| - tokens, |
263 |
| - start_index=self_attention_cache_update_index, |
| 239 | + tokens, start_index=index |
264 | 240 | )
|
265 | 241 | # Sum, normalize and apply dropout to embeddings.
|
266 | 242 | x = self.backbone.decoder_embeddings_add((tokens, positions))
|
267 | 243 | x = self.backbone.decoder_embeddings_layer_norm(x)
|
268 | 244 | x = self.backbone.decoder_embeddings_dropout(x)
|
269 |
| - |
270 |
| - # Every decoder layer has a separate cache for the self-attention layer |
271 |
| - # and the cross-attention layer. We update all of them separately. |
272 |
| - self_attention_caches = [] |
273 |
| - cross_attention_caches = [] |
| 245 | + # Each decoder layer has a cache; we update them separately. |
| 246 | + caches = [] |
274 | 247 | for i, layer in enumerate(self.backbone.decoder_transformer_layers):
|
275 |
| - current_self_attention_cache = self_attention_cache[:, i, ...] |
| 248 | + current_self_attention_cache = cache[:, i, ...] |
276 | 249 | current_cross_attention_cache = cross_attention_cache[:, i, ...]
|
277 |
| - ( |
278 |
| - x, |
279 |
| - next_self_attention_cache, |
280 |
| - next_cross_attention_cache, |
281 |
| - ) = layer( |
| 250 | + x, next_cache, _ = layer( |
282 | 251 | decoder_sequence=x,
|
283 |
| - encoder_sequence=encoder_hidden_states, |
284 | 252 | encoder_padding_mask=encoder_padding_mask,
|
285 | 253 | self_attention_cache=current_self_attention_cache,
|
286 |
| - self_attention_cache_update_index=self_attention_cache_update_index, |
| 254 | + self_attention_cache_update_index=index, |
287 | 255 | cross_attention_cache=current_cross_attention_cache,
|
288 |
| - cross_attention_cache_update_index=cross_attention_cache_update_index, |
289 | 256 | )
|
290 |
| - if self_attention_cache_update_index is not None: |
291 |
| - self_attention_caches.append(next_self_attention_cache) |
292 |
| - if cross_attention_cache_update_index is not None: |
293 |
| - cross_attention_caches.append(next_cross_attention_cache) |
294 |
| - |
295 |
| - if self_attention_cache_update_index is not None: |
296 |
| - self_attention_cache = ops.stack(self_attention_caches, axis=1) |
297 |
| - if cross_attention_cache_update_index is not None: |
298 |
| - cross_attention_cache = ops.stack(cross_attention_caches, axis=1) |
299 |
| - |
| 257 | + caches.append(next_cache) |
| 258 | + cache = ops.stack(caches, axis=1) |
300 | 259 | hidden_states = x
|
301 | 260 | logits = self.backbone.token_embedding(hidden_states, reverse=True)
|
302 | 261 | return (
|
303 | 262 | logits,
|
304 | 263 | hidden_states,
|
305 |
| - self_attention_cache, |
306 |
| - cross_attention_cache, |
| 264 | + cache, |
307 | 265 | )
|
308 |
| - |
309 |
| - def call_encoder(self, token_ids, padding_mask): |
310 |
| - """Does a forward pass on the encoder and returns the encoder output.""" |
311 |
| - tokens = self.backbone.token_embedding(token_ids) |
312 |
| - positions = self.backbone.encoder_position_embedding(tokens) |
313 |
| - x = self.backbone.decoder_embeddings_add((tokens, positions)) |
314 |
| - x = self.backbone.encoder_embeddings_layer_norm(x) |
315 |
| - x = self.backbone.encoder_embeddings_dropout(x) |
316 |
| - for transformer_layer in self.backbone.encoder_transformer_layers: |
317 |
| - x = transformer_layer(x, padding_mask=padding_mask) |
318 |
| - return x |
319 |
| - |
320 |
| - def _initialize_cache(self, encoder_token_ids, decoder_token_ids): |
321 |
| - """Initializes empty self-attention cache and cross-attention cache.""" |
322 |
| - batch_size = ops.shape(encoder_token_ids)[0] |
323 |
| - encoder_max_length = ops.shape(encoder_token_ids)[1] |
324 |
| - decoder_max_length = ops.shape(decoder_token_ids)[1] |
325 |
| - |
326 |
| - num_layers = self.backbone.num_layers |
327 |
| - num_heads = self.backbone.num_heads |
328 |
| - head_dim = self.backbone.hidden_dim // self.backbone.num_heads |
329 |
| - |
330 |
| - shape = [ |
331 |
| - batch_size, |
332 |
| - num_layers, |
333 |
| - 2, |
334 |
| - decoder_max_length, |
335 |
| - num_heads, |
336 |
| - head_dim, |
337 |
| - ] |
338 |
| - self_attention_cache = ops.zeros(shape, dtype=self.compute_dtype) |
339 |
| - |
340 |
| - shape[3] = encoder_max_length |
341 |
| - cross_attention_cache = ops.zeros(shape, dtype=self.compute_dtype) |
342 |
| - |
343 |
| - return (self_attention_cache, cross_attention_cache) |
344 |
| - |
345 |
| - def _build_cache( |
346 |
| - self, encoder_token_ids, encoder_padding_mask, decoder_token_ids |
347 |
| - ): |
348 |
| - """Builds the self-attention cache and the cross-attention cache (key/value pairs).""" |
349 |
| - encoder_hidden_states = self.call_encoder( |
350 |
| - token_ids=encoder_token_ids, padding_mask=encoder_padding_mask |
351 |
| - ) |
352 |
| - self_attention_cache, cross_attention_cache = self._initialize_cache( |
353 |
| - encoder_token_ids, decoder_token_ids |
354 |
| - ) |
355 |
| - |
356 |
| - # Seed the self-attention cache and the cross-attention cache. |
357 |
| - ( |
358 |
| - _, |
359 |
| - hidden_states, |
360 |
| - self_attention_cache, |
361 |
| - cross_attention_cache, |
362 |
| - ) = self.call_decoder_with_cache( |
363 |
| - encoder_hidden_states=encoder_hidden_states, |
364 |
| - encoder_padding_mask=encoder_padding_mask, |
365 |
| - decoder_token_ids=decoder_token_ids, |
366 |
| - self_attention_cache=self_attention_cache, |
367 |
| - self_attention_cache_update_index=0, |
368 |
| - cross_attention_cache=cross_attention_cache, |
369 |
| - cross_attention_cache_update_index=0, |
370 |
| - ) |
371 |
| - return ( |
372 |
| - hidden_states, |
373 |
| - encoder_hidden_states, |
374 |
| - self_attention_cache, |
375 |
| - cross_attention_cache, |
376 |
| - ) |
377 |
| - |
378 |
| - def generate_step( |
379 |
| - self, |
380 |
| - inputs, |
381 |
| - stop_token_ids=None, |
382 |
| - ): |
383 |
| - """A compilable generation function for a batch of inputs. |
384 |
| -
|
385 |
| - This function represents the inner, XLA-compilable, generation function |
386 |
| - for a single batch of inputs. Inputs should have the same structure as |
387 |
| - model inputs, a dictionary with keys `"encoder_token_ids"`, |
388 |
| - `"encoder_padding_mask"`, `"decoder_token_ids"` and |
389 |
| - `"decoder_padding_mask"`. |
390 |
| -
|
391 |
| - Args: |
392 |
| - inputs: A dictionary with four keys - `"encoder_token_ids"`, |
393 |
| - `"encoder_padding_mask"`, `"decoder_token_ids"` and |
394 |
| - `"decoder_padding_mask"`, with batched tensor values. |
395 |
| - stop_token_ids: Tuple of id's of end token's to stop on. If all |
396 |
| - sequences have produced a new stop token, generation |
397 |
| - will stop. |
398 |
| - """ |
399 |
| - ( |
400 |
| - encoder_token_ids, |
401 |
| - encoder_padding_mask, |
402 |
| - decoder_token_ids, |
403 |
| - decoder_padding_mask, |
404 |
| - ) = ( |
405 |
| - inputs["encoder_token_ids"], |
406 |
| - inputs["encoder_padding_mask"], |
407 |
| - inputs["decoder_token_ids"], |
408 |
| - inputs["decoder_padding_mask"], |
409 |
| - ) |
410 |
| - |
411 |
| - batch_size = ops.shape(encoder_token_ids)[0] |
412 |
| - |
413 |
| - # Create and seed cache with a single forward pass. |
414 |
| - ( |
415 |
| - hidden_states, |
416 |
| - encoder_hidden_states, |
417 |
| - self_attention_cache, |
418 |
| - cross_attention_cache, |
419 |
| - ) = self._build_cache( |
420 |
| - encoder_token_ids, encoder_padding_mask, decoder_token_ids |
421 |
| - ) |
422 |
| - # Compute the lengths of all user inputted tokens ids. |
423 |
| - row_lengths = ops.sum(ops.cast(decoder_padding_mask, "int32"), axis=-1) |
424 |
| - # Start at the first index that has no user inputted id. |
425 |
| - index = ops.min(row_lengths) |
426 |
| - |
427 |
| - def next(prompt, cache, index): |
428 |
| - # The cache index is the index of our previous token. |
429 |
| - cache_index = index - 1 |
430 |
| - num_samples = ops.shape(prompt)[0] |
431 |
| - prompt = ops.slice(prompt, [0, cache_index], [num_samples, 1]) |
432 |
| - |
433 |
| - def repeat_tensor(x): |
434 |
| - """Repeats tensors along batch axis to match dim for beam search.""" |
435 |
| - if ops.shape(x)[0] == num_samples: |
436 |
| - return x |
437 |
| - return ops.repeat(x, repeats=num_samples // batch_size, axis=0) |
438 |
| - |
439 |
| - logits, hidden_states, cache, _ = self.call_decoder_with_cache( |
440 |
| - encoder_hidden_states=repeat_tensor(encoder_hidden_states), |
441 |
| - encoder_padding_mask=repeat_tensor(encoder_padding_mask), |
442 |
| - decoder_token_ids=prompt, |
443 |
| - self_attention_cache=cache, |
444 |
| - self_attention_cache_update_index=cache_index, |
445 |
| - cross_attention_cache=repeat_tensor(cross_attention_cache), |
446 |
| - cross_attention_cache_update_index=None, |
447 |
| - ) |
448 |
| - return ( |
449 |
| - ops.squeeze(logits, axis=1), |
450 |
| - ops.squeeze(hidden_states, axis=1), |
451 |
| - cache, |
452 |
| - ) |
453 |
| - |
454 |
| - decoder_token_ids = self.sampler( |
455 |
| - next=next, |
456 |
| - prompt=decoder_token_ids, |
457 |
| - cache=self_attention_cache, |
458 |
| - index=index, |
459 |
| - mask=decoder_padding_mask, |
460 |
| - stop_token_ids=stop_token_ids, |
461 |
| - hidden_states=hidden_states, |
462 |
| - model=self, |
463 |
| - ) |
464 |
| - |
465 |
| - # Compute an output padding mask with the token ids we updated. |
466 |
| - if stop_token_ids is not None: |
467 |
| - # Build a mask of `stop_token_ids` locations not in the original |
468 |
| - # prompt (not in locations where `decoder_padding_mask` is True). |
469 |
| - end_locations = any_equal( |
470 |
| - decoder_token_ids, |
471 |
| - stop_token_ids, |
472 |
| - ops.logical_not(decoder_padding_mask), |
473 |
| - ) |
474 |
| - end_locations = ops.cast(end_locations, "int32") |
475 |
| - # Use cumsum to get ones in all locations after `end_locations`. |
476 |
| - cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") |
477 |
| - overflow = cumsum - end_locations |
478 |
| - # Our padding mask is the inverse of these overflow locations. |
479 |
| - decoder_padding_mask = ops.logical_not(ops.cast(overflow, "bool")) |
480 |
| - else: |
481 |
| - # Without early stopping, all locations will have been updated. |
482 |
| - decoder_padding_mask = ops.ones_like( |
483 |
| - decoder_token_ids, dtype="bool" |
484 |
| - ) |
485 |
| - |
486 |
| - return { |
487 |
| - "decoder_token_ids": decoder_token_ids, |
488 |
| - "decoder_padding_mask": decoder_padding_mask, |
489 |
| - } |
0 commit comments