Skip to content

Commit bba28a0

Browse files
authored
[rls2.8] fix whisper failure with tp (#3780)
1 parent ea628af commit bba28a0

File tree

3 files changed

+88
-1
lines changed

3 files changed

+88
-1
lines changed

intel_extension_for_pytorch/transformers/generation/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
from .greedy_search import _greedy_search
33
from .sample import _sample
44
from .beam_sample import _beam_sample, _beam_sample_legacy
5-
from .utils import whisper_generate
5+
from .utils import whisper_generate, _prepare_generation_config

intel_extension_for_pytorch/transformers/generation/utils.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, Dict, Callable, List, Optional, Tuple, Union
77
import torch
88
import torch.nn.functional as F
9+
import copy
910

1011
trans_version = transformers.__version__
1112

@@ -634,3 +635,85 @@ def whisper_generate(
634635
outputs["segments"] = final_segments
635636

636637
return outputs
638+
639+
640+
def is_torchdynamo_compiling():
641+
try:
642+
import torch
643+
644+
return torch.compiler.is_compiling()
645+
except Exception:
646+
try:
647+
import torch._dynamo as dynamo # noqa: F401
648+
649+
return dynamo.is_compiling()
650+
except Exception:
651+
return False
652+
653+
654+
def _prepare_generation_config(
655+
self, generation_config, use_model_defaults=None, **kwargs: Dict
656+
):
657+
"""
658+
Prepares the base generation config, then applies any generation configuration options from kwargs. This
659+
function handles retrocompatibility with respect to configuration files.
660+
"""
661+
# TODO joao: when we can detect `fullgraph=True` in `torch.compile` (https://github.com/pytorch/pytorch/pull/120400)
662+
# replace `is_torchdynamo_compiling` by the corresponding check. As it is, we are being too restrictive with
663+
# the parameterization in `fullgraph=False` so as to enable `fullgraph=True`.
664+
665+
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
666+
using_model_generation_config = False
667+
if generation_config is None:
668+
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
669+
# the following conditions must be met
670+
# 1) the generation config must have been created from the model config (`_from_model_config` field);
671+
# 2) the generation config must have seen no modification since its creation (the hash is the same);
672+
# 3) there are non-default generation parameters in the model config.
673+
# 4) the user must have set new generation parameters in the model config.
674+
# NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation.
675+
if (
676+
not is_torchdynamo_compiling()
677+
and self.generation_config._from_model_config # 1)
678+
and self.generation_config._original_object_hash
679+
== hash(self.generation_config) # 2)
680+
and len(self.config._get_non_default_generation_parameters()) > 0 # 3)
681+
):
682+
new_generation_config = transformers.generation.configuration_utils.GenerationConfig.from_model_config(
683+
self.config
684+
)
685+
if new_generation_config != self.generation_config: # 4)
686+
warnings.warn(
687+
"You have modified the pretrained model configuration to control generation. This is a"
688+
" deprecated strategy to control generation and will be removed in v5."
689+
" Please use and modify the model generation configuration (see"
690+
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )",
691+
UserWarning,
692+
)
693+
self.generation_config = new_generation_config
694+
695+
generation_config = self.generation_config
696+
using_model_generation_config = True
697+
698+
# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
699+
# will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an
700+
# exception will be raised in `_validate_model_kwargs`
701+
if not is_torchdynamo_compiling():
702+
generation_config = copy.deepcopy(generation_config)
703+
model_kwargs = generation_config.update(**kwargs)
704+
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
705+
if not using_model_generation_config:
706+
if generation_config.bos_token_id is None:
707+
generation_config.bos_token_id = self.generation_config.bos_token_id
708+
if generation_config.eos_token_id is None:
709+
generation_config.eos_token_id = self.generation_config.eos_token_id
710+
if generation_config.pad_token_id is None:
711+
generation_config.pad_token_id = self.generation_config.pad_token_id
712+
if generation_config.decoder_start_token_id is None:
713+
generation_config.decoder_start_token_id = (
714+
self.generation_config.decoder_start_token_id
715+
)
716+
else:
717+
model_kwargs = kwargs
718+
719+
return generation_config, model_kwargs

intel_extension_for_pytorch/transformers/optimize.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def model_convert_reference(_model):
127127
_beam_sample,
128128
_beam_sample_legacy,
129129
whisper_generate,
130+
_prepare_generation_config,
130131
)
131132

132133
# model wise optimization for MHA module
@@ -1068,6 +1069,9 @@ def model_convert_reference(_model):
10681069
_model, "_postprocess_outputs", _postprocess_outputs_whisper
10691070
)
10701071
convert_function(_model, "generate", whisper_generate)
1072+
convert_function(
1073+
_model, "_prepare_generation_config", _prepare_generation_config
1074+
)
10711075
convert_function(_model, "forward", WhisperForConditionalGeneration_forward)
10721076
convert_function(_model.model, "forward", WhisperModel_forward)
10731077
convert_function(_model.model.decoder, "forward", WhisperDecoderLayer_forward)

0 commit comments

Comments
 (0)