|
6 | 6 | from typing import Any, Dict, Callable, List, Optional, Tuple, Union
|
7 | 7 | import torch
|
8 | 8 | import torch.nn.functional as F
|
| 9 | +import copy |
9 | 10 |
|
10 | 11 | trans_version = transformers.__version__
|
11 | 12 |
|
@@ -634,3 +635,85 @@ def whisper_generate(
|
634 | 635 | outputs["segments"] = final_segments
|
635 | 636 |
|
636 | 637 | return outputs
|
| 638 | + |
| 639 | + |
| 640 | +def is_torchdynamo_compiling(): |
| 641 | + try: |
| 642 | + import torch |
| 643 | + |
| 644 | + return torch.compiler.is_compiling() |
| 645 | + except Exception: |
| 646 | + try: |
| 647 | + import torch._dynamo as dynamo # noqa: F401 |
| 648 | + |
| 649 | + return dynamo.is_compiling() |
| 650 | + except Exception: |
| 651 | + return False |
| 652 | + |
| 653 | + |
| 654 | +def _prepare_generation_config( |
| 655 | + self, generation_config, use_model_defaults=None, **kwargs: Dict |
| 656 | +): |
| 657 | + """ |
| 658 | + Prepares the base generation config, then applies any generation configuration options from kwargs. This |
| 659 | + function handles retrocompatibility with respect to configuration files. |
| 660 | + """ |
| 661 | + # TODO joao: when we can detect `fullgraph=True` in `torch.compile` (https://github.com/pytorch/pytorch/pull/120400) |
| 662 | + # replace `is_torchdynamo_compiling` by the corresponding check. As it is, we are being too restrictive with |
| 663 | + # the parameterization in `fullgraph=False` so as to enable `fullgraph=True`. |
| 664 | + |
| 665 | + # priority: `generation_config` argument > `model.generation_config` (the default generation config) |
| 666 | + using_model_generation_config = False |
| 667 | + if generation_config is None: |
| 668 | + # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, |
| 669 | + # the following conditions must be met |
| 670 | + # 1) the generation config must have been created from the model config (`_from_model_config` field); |
| 671 | + # 2) the generation config must have seen no modification since its creation (the hash is the same); |
| 672 | + # 3) there are non-default generation parameters in the model config. |
| 673 | + # 4) the user must have set new generation parameters in the model config. |
| 674 | + # NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation. |
| 675 | + if ( |
| 676 | + not is_torchdynamo_compiling() |
| 677 | + and self.generation_config._from_model_config # 1) |
| 678 | + and self.generation_config._original_object_hash |
| 679 | + == hash(self.generation_config) # 2) |
| 680 | + and len(self.config._get_non_default_generation_parameters()) > 0 # 3) |
| 681 | + ): |
| 682 | + new_generation_config = transformers.generation.configuration_utils.GenerationConfig.from_model_config( |
| 683 | + self.config |
| 684 | + ) |
| 685 | + if new_generation_config != self.generation_config: # 4) |
| 686 | + warnings.warn( |
| 687 | + "You have modified the pretrained model configuration to control generation. This is a" |
| 688 | + " deprecated strategy to control generation and will be removed in v5." |
| 689 | + " Please use and modify the model generation configuration (see" |
| 690 | + " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )", |
| 691 | + UserWarning, |
| 692 | + ) |
| 693 | + self.generation_config = new_generation_config |
| 694 | + |
| 695 | + generation_config = self.generation_config |
| 696 | + using_model_generation_config = True |
| 697 | + |
| 698 | + # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` |
| 699 | + # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an |
| 700 | + # exception will be raised in `_validate_model_kwargs` |
| 701 | + if not is_torchdynamo_compiling(): |
| 702 | + generation_config = copy.deepcopy(generation_config) |
| 703 | + model_kwargs = generation_config.update(**kwargs) |
| 704 | + # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model |
| 705 | + if not using_model_generation_config: |
| 706 | + if generation_config.bos_token_id is None: |
| 707 | + generation_config.bos_token_id = self.generation_config.bos_token_id |
| 708 | + if generation_config.eos_token_id is None: |
| 709 | + generation_config.eos_token_id = self.generation_config.eos_token_id |
| 710 | + if generation_config.pad_token_id is None: |
| 711 | + generation_config.pad_token_id = self.generation_config.pad_token_id |
| 712 | + if generation_config.decoder_start_token_id is None: |
| 713 | + generation_config.decoder_start_token_id = ( |
| 714 | + self.generation_config.decoder_start_token_id |
| 715 | + ) |
| 716 | + else: |
| 717 | + model_kwargs = kwargs |
| 718 | + |
| 719 | + return generation_config, model_kwargs |
0 commit comments