Skip to content

Commit 69b1582

Browse files
authored
Refactor embedding input/output getter/setter (#39339)
* simplify common get/set * remove some noise * change some 5 years old modeling utils * update examples * fix copies * revert some changes * fixes, gah * format * move to Mixin * remove smolvlm specific require grad * skip * force defaults * remodularise some stuff * remodularise more stuff * add safety for audio models * style * have a correct fallback, you daft donkey * remove this argh * change heuristic for audio models * fixup * revert * this works * revert again * 🧠 * aaah ESM has two modelings aaah * add informative but short comment * add `input_embed_layer` mixin attribute * style * walrus has low precedence * modular fix * this was breaking parser
1 parent 2da97f0 commit 69b1582

File tree

163 files changed

+235
-2388
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

163 files changed

+235
-2388
lines changed

examples/modular-transformers/modeling_my_new_model2.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -333,12 +333,6 @@ def __init__(self, config: MyNewModel2Config):
333333
# Initialize weights and apply final processing
334334
self.post_init()
335335

336-
def get_input_embeddings(self):
337-
return self.embed_tokens
338-
339-
def set_input_embeddings(self, value):
340-
self.embed_tokens = value
341-
342336
@check_model_inputs
343337
@auto_docstring
344338
def forward(
@@ -433,12 +427,6 @@ def __init__(self, config):
433427
# Initialize weights and apply final processing
434428
self.post_init()
435429

436-
def get_input_embeddings(self):
437-
return self.model.embed_tokens
438-
439-
def set_input_embeddings(self, value):
440-
self.model.embed_tokens = value
441-
442430
@can_return_tuple
443431
@auto_docstring
444432
def forward(

examples/modular-transformers/modeling_new_task_model.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -389,12 +389,6 @@ def get_input_embeddings(self):
389389
def set_input_embeddings(self, value):
390390
self.model.set_input_embeddings(value)
391391

392-
def get_output_embeddings(self):
393-
return self.lm_head
394-
395-
def set_output_embeddings(self, new_embeddings):
396-
self.lm_head = new_embeddings
397-
398392
def set_decoder(self, decoder):
399393
self.model.set_decoder(decoder)
400394

examples/modular-transformers/modeling_super.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -332,12 +332,6 @@ def __init__(self, config: SuperConfig):
332332
# Initialize weights and apply final processing
333333
self.post_init()
334334

335-
def get_input_embeddings(self):
336-
return self.embed_tokens
337-
338-
def set_input_embeddings(self, value):
339-
self.embed_tokens = value
340-
341335
@check_model_inputs
342336
@auto_docstring
343337
def forward(

src/transformers/modeling_utils.py

Lines changed: 188 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1902,7 +1902,97 @@ def floating_point_ops(
19021902
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
19031903

19041904

1905-
class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
1905+
class EmbeddingAccessMixin:
1906+
"""
1907+
Base utilities to regroup getters and setters for embeddings.
1908+
Introduces the `input_layer_embed` attribute, which indicates
1909+
where the input embeddings come from and where they
1910+
should be set.
1911+
"""
1912+
1913+
_input_embed_layer = "embed_tokens" # default layer that holds input embeddings.
1914+
1915+
def get_input_embeddings(self) -> nn.Module:
1916+
"""
1917+
Returns the model's input embeddings.
1918+
1919+
Returns:
1920+
`nn.Module`: A torch module mapping vocabulary to hidden states.
1921+
"""
1922+
1923+
# 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer
1924+
# for most NLP models), and if so, return it.
1925+
1926+
name = getattr(self, "_input_embed_layer", "embed_tokens")
1927+
1928+
if (default_embedding := getattr(self, name, None)) is not None:
1929+
return default_embedding
1930+
# 2) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
1931+
1932+
if hasattr(self, "model") and hasattr(self.model, "embed_tokens"):
1933+
return self.model.embed_tokens
1934+
1935+
# 3) vanilla decoder‑only architectures
1936+
elif hasattr(self, "embed_tokens"):
1937+
return self.embed_tokens
1938+
else:
1939+
base_model = getattr(self, "base_model_prefix", None)
1940+
if base_model is not None:
1941+
base_model = getattr(self, base_model, None)
1942+
if base_model is not None and base_model is not self:
1943+
return base_model.get_input_embeddings()
1944+
raise NotImplementedError(
1945+
f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; "
1946+
"please override in the subclass."
1947+
)
1948+
1949+
def set_input_embeddings(self, value: nn.Module):
1950+
"""Fallback setter that handles **~70 %** of models in the code‑base.
1951+
1952+
Order of attempts:
1953+
1. `self.model.embed_tokens`
1954+
2. `self.embed_tokens`
1955+
3. delegate to the *base model* if one exists
1956+
4. otherwise raise `NotImplementedError` so subclasses still can (and
1957+
should) override for exotic layouts.
1958+
"""
1959+
1960+
# 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
1961+
name = getattr(self, "_input_embed_layer", "embed_tokens")
1962+
if hasattr(self, "model") and hasattr(self.model, name):
1963+
setattr(self.model, name, value)
1964+
# 2) as well as vanilla decoder‑only architectures
1965+
elif hasattr(self, name):
1966+
setattr(self, name, value)
1967+
# 3) recurse once into the registered *base* model (e.g. for encoder/decoder)
1968+
elif getattr(self, self.base_model_prefix, self) is not self:
1969+
base_model = getattr(self, self.base_model_prefix, self)
1970+
base_model.set_input_embeddings(value)
1971+
else:
1972+
raise NotImplementedError(
1973+
f"`set_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
1974+
)
1975+
1976+
def get_output_embeddings(self):
1977+
if not hasattr(self, "lm_head"):
1978+
return None
1979+
try:
1980+
# Speech / vision backbones raise here, so we return None.
1981+
# Legit use of get_input_embs?
1982+
self.get_input_embeddings()
1983+
except NotImplementedError:
1984+
return None
1985+
return self.lm_head
1986+
1987+
def set_output_embeddings(self, new_embeddings):
1988+
"""
1989+
Sets the model's output embedding, defaulting to setting new_embeddings to lm_head.
1990+
"""
1991+
if getattr(self, "lm_head"):
1992+
self.lm_head = new_embeddings
1993+
1994+
1995+
class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
19061996
r"""
19071997
Base class for all models.
19081998
@@ -2004,6 +2094,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
20042094
_supports_attention_backend = False
20052095
_can_record_outputs = None
20062096

2097+
# This attribute sets the default parameter to be
2098+
20072099
@property
20082100
@torch._dynamo.allow_in_graph
20092101
def can_record_outputs(self) -> dict[str, OutputRecorder]:
@@ -2267,6 +2359,101 @@ def _from_config(cls, config, **kwargs):
22672359

22682360
return model
22692361

2362+
@classmethod
2363+
def _check_attn_implementation(cls, attn_implementation: Union[str, dict]) -> Union[str, dict]:
2364+
"""
2365+
Checks that the requested attention implementation exists and tries to get the kernel from hub
2366+
if `attn_implementation` matches hf kernels pattern.
2367+
"""
2368+
if isinstance(attn_implementation, str) and re.match(r"^[^/:]+/[^/:]+:[^/:]+$", attn_implementation):
2369+
if not is_kernels_available():
2370+
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
2371+
2372+
# Extract repo_id and kernel_name from the string
2373+
repo_id, kernel_name = attn_implementation.split(":")
2374+
kernel_name = kernel_name.strip()
2375+
repo_id = repo_id.strip()
2376+
2377+
try:
2378+
kernel = get_kernel(repo_id)
2379+
ALL_ATTENTION_FUNCTIONS.register(f"kernel_{repo_id.replace('/', '_')}", getattr(kernel, kernel_name))
2380+
attn_implementation = f"kernel_{repo_id.replace('/', '_')}"
2381+
except FileNotFoundError as e:
2382+
logger.warning(
2383+
f"Could not find a kernel repository '{repo_id}' compatible with your devicein the hub: {e}. Using eager attention implementation instead."
2384+
)
2385+
attn_implementation = None # try to dispatch SDPA and fallback eager if not available
2386+
except AttributeError:
2387+
raise ValueError(
2388+
"the kernel function name or class specified in the attn_implementation argument is not valid. \
2389+
Please check the documentation for the correct format, \
2390+
and check that the kernel exports the class and the function correctly."
2391+
)
2392+
if (
2393+
not isinstance(attn_implementation, dict)
2394+
and attn_implementation not in ["eager", None] + ALL_ATTENTION_FUNCTIONS.valid_keys()
2395+
):
2396+
message = f'Specified `attn_implementation="{attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
2397+
# check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
2398+
if cls._supports_flash_attn or getattr(cls, "_supports_flash_attn_2", False):
2399+
message += (
2400+
', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
2401+
', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
2402+
)
2403+
if cls._supports_sdpa:
2404+
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
2405+
if cls._supports_flex_attn:
2406+
message += ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
2407+
raise ValueError(message + ".")
2408+
2409+
return attn_implementation
2410+
2411+
def set_attention_implementation(self, attn_implementation: Union[str, dict]):
2412+
"""
2413+
Checks and dispatches to the requested attention implementation.
2414+
"""
2415+
requested_attn_implementation = self._check_attn_implementation(attn_implementation)
2416+
2417+
# Composite models consisting of several PretrainedModels can specify attention implementation as a dict where
2418+
# keys are sub-config names. But most people will specify one `str` which means that should dispatch it for all sub-models.
2419+
# See https://github.com/huggingface/transformers/pull/32238
2420+
for key in self.config.sub_configs.keys():
2421+
sub_config = getattr(self.config, key)
2422+
curr_attn_implementation = (
2423+
requested_attn_implementation
2424+
if not isinstance(requested_attn_implementation, dict)
2425+
else requested_attn_implementation.get(key, None)
2426+
)
2427+
# For models with backbone sub-config might be not initialized. Set the requested att
2428+
# if the config hasn't got any attn pre-set and the requested attn in not `None` (i.e not the default attn)
2429+
if (
2430+
sub_config is not None
2431+
and sub_config._attn_implementation_internal is None
2432+
and curr_attn_implementation is not None
2433+
):
2434+
sub_config._attn_implementation_internal = curr_attn_implementation
2435+
2436+
if requested_attn_implementation == "flash_attention_3" and self._flash_attn_3_can_dispatch():
2437+
self.config._attn_implementation = "flash_attention_3"
2438+
if requested_attn_implementation == "flash_attention_2" and self._flash_attn_2_can_dispatch():
2439+
self.config._attn_implementation = "flash_attention_2"
2440+
elif requested_attn_implementation == "flex_attention" and self._flex_attn_can_dispatch():
2441+
self.config._attn_implementation = "flex_attention"
2442+
elif (
2443+
requested_attn_implementation in [None, "sdpa"]
2444+
and not is_torch_xla_available()
2445+
and self._sdpa_can_dispatch(hard_check_only=requested_attn_implementation is not None)
2446+
):
2447+
self.config._attn_implementation = "sdpa"
2448+
elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.valid_keys():
2449+
self.config._attn_implementation = requested_attn_implementation
2450+
elif isinstance(requested_attn_implementation, dict):
2451+
self.config._attn_implementation = requested_attn_implementation.get("", None)
2452+
else:
2453+
self.config._attn_implementation = "eager"
2454+
2455+
self.config._attn_implementation_autoset = True
2456+
22702457
@classmethod
22712458
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
22722459
"""
@@ -2769,41 +2956,6 @@ def disable_input_require_grads(self):
27692956
"""
27702957
self._require_grads_hook.remove()
27712958

2772-
def get_input_embeddings(self) -> nn.Module:
2773-
"""
2774-
Returns the model's input embeddings.
2775-
2776-
Returns:
2777-
`nn.Module`: A torch module mapping vocabulary to hidden states.
2778-
"""
2779-
base_model = getattr(self, self.base_model_prefix, self)
2780-
if base_model is not self:
2781-
return base_model.get_input_embeddings()
2782-
else:
2783-
raise NotImplementedError
2784-
2785-
def set_input_embeddings(self, value: nn.Module):
2786-
"""
2787-
Set model's input embeddings.
2788-
2789-
Args:
2790-
value (`nn.Module`): A module mapping vocabulary to hidden states.
2791-
"""
2792-
base_model = getattr(self, self.base_model_prefix, self)
2793-
if base_model is not self:
2794-
base_model.set_input_embeddings(value)
2795-
else:
2796-
raise NotImplementedError
2797-
2798-
def get_output_embeddings(self) -> nn.Module:
2799-
"""
2800-
Returns the model's output embeddings.
2801-
2802-
Returns:
2803-
`nn.Module`: A torch module mapping hidden states to vocabulary.
2804-
"""
2805-
return None # Overwrite for models with output embeddings
2806-
28072959
def _init_weights(self, module):
28082960
"""
28092961
Initialize the weights. This method should be overridden by derived class and is

src/transformers/models/arcee/modeling_arcee.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -356,12 +356,6 @@ def __init__(self, config: ArceeConfig):
356356
# Initialize weights and apply final processing
357357
self.post_init()
358358

359-
def get_input_embeddings(self):
360-
return self.embed_tokens
361-
362-
def set_input_embeddings(self, value):
363-
self.embed_tokens = value
364-
365359
@check_model_inputs
366360
@auto_docstring
367361
def forward(
@@ -438,18 +432,6 @@ def __init__(self, config):
438432
# Initialize weights and apply final processing
439433
self.post_init()
440434

441-
def get_input_embeddings(self):
442-
return self.model.embed_tokens
443-
444-
def set_input_embeddings(self, value):
445-
self.model.embed_tokens = value
446-
447-
def get_output_embeddings(self):
448-
return self.lm_head
449-
450-
def set_output_embeddings(self, new_embeddings):
451-
self.lm_head = new_embeddings
452-
453435
def set_decoder(self, decoder):
454436
self.model = decoder
455437

@@ -533,12 +515,6 @@ def __init__(self, config):
533515
# Initialize weights and apply final processing
534516
self.post_init()
535517

536-
def get_input_embeddings(self):
537-
return self.model.embed_tokens
538-
539-
def set_input_embeddings(self, value):
540-
self.model.embed_tokens = value
541-
542518
@can_return_tuple
543519
@auto_docstring
544520
def forward(
@@ -685,12 +661,6 @@ def __init__(self, config):
685661
# Initialize weights and apply final processing
686662
self.post_init()
687663

688-
def get_input_embeddings(self):
689-
return self.model.embed_tokens
690-
691-
def set_input_embeddings(self, value):
692-
self.model.embed_tokens = value
693-
694664
@can_return_tuple
695665
@auto_docstring
696666
def forward(

0 commit comments

Comments
 (0)