@@ -1902,7 +1902,97 @@ def floating_point_ops(
1902
1902
return 6 * self .estimate_tokens (input_dict ) * self .num_parameters (exclude_embeddings = exclude_embeddings )
1903
1903
1904
1904
1905
- class PreTrainedModel (nn .Module , ModuleUtilsMixin , PushToHubMixin , PeftAdapterMixin ):
1905
+ class EmbeddingAccessMixin :
1906
+ """
1907
+ Base utilities to regroup getters and setters for embeddings.
1908
+ Introduces the `input_layer_embed` attribute, which indicates
1909
+ where the input embeddings come from and where they
1910
+ should be set.
1911
+ """
1912
+
1913
+ _input_embed_layer = "embed_tokens" # default layer that holds input embeddings.
1914
+
1915
+ def get_input_embeddings (self ) -> nn .Module :
1916
+ """
1917
+ Returns the model's input embeddings.
1918
+
1919
+ Returns:
1920
+ `nn.Module`: A torch module mapping vocabulary to hidden states.
1921
+ """
1922
+
1923
+ # 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer
1924
+ # for most NLP models), and if so, return it.
1925
+
1926
+ name = getattr (self , "_input_embed_layer" , "embed_tokens" )
1927
+
1928
+ if (default_embedding := getattr (self , name , None )) is not None :
1929
+ return default_embedding
1930
+ # 2) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
1931
+
1932
+ if hasattr (self , "model" ) and hasattr (self .model , "embed_tokens" ):
1933
+ return self .model .embed_tokens
1934
+
1935
+ # 3) vanilla decoder‑only architectures
1936
+ elif hasattr (self , "embed_tokens" ):
1937
+ return self .embed_tokens
1938
+ else :
1939
+ base_model = getattr (self , "base_model_prefix" , None )
1940
+ if base_model is not None :
1941
+ base_model = getattr (self , base_model , None )
1942
+ if base_model is not None and base_model is not self :
1943
+ return base_model .get_input_embeddings ()
1944
+ raise NotImplementedError (
1945
+ f"`get_input_embeddings` not auto‑handled for { self .__class__ .__name__ } ; "
1946
+ "please override in the subclass."
1947
+ )
1948
+
1949
+ def set_input_embeddings (self , value : nn .Module ):
1950
+ """Fallback setter that handles **~70 %** of models in the code‑base.
1951
+
1952
+ Order of attempts:
1953
+ 1. `self.model.embed_tokens`
1954
+ 2. `self.embed_tokens`
1955
+ 3. delegate to the *base model* if one exists
1956
+ 4. otherwise raise `NotImplementedError` so subclasses still can (and
1957
+ should) override for exotic layouts.
1958
+ """
1959
+
1960
+ # 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
1961
+ name = getattr (self , "_input_embed_layer" , "embed_tokens" )
1962
+ if hasattr (self , "model" ) and hasattr (self .model , name ):
1963
+ setattr (self .model , name , value )
1964
+ # 2) as well as vanilla decoder‑only architectures
1965
+ elif hasattr (self , name ):
1966
+ setattr (self , name , value )
1967
+ # 3) recurse once into the registered *base* model (e.g. for encoder/decoder)
1968
+ elif getattr (self , self .base_model_prefix , self ) is not self :
1969
+ base_model = getattr (self , self .base_model_prefix , self )
1970
+ base_model .set_input_embeddings (value )
1971
+ else :
1972
+ raise NotImplementedError (
1973
+ f"`set_input_embeddings` not auto‑handled for { self .__class__ .__name__ } ; please override in the subclass."
1974
+ )
1975
+
1976
+ def get_output_embeddings (self ):
1977
+ if not hasattr (self , "lm_head" ):
1978
+ return None
1979
+ try :
1980
+ # Speech / vision backbones raise here, so we return None.
1981
+ # Legit use of get_input_embs?
1982
+ self .get_input_embeddings ()
1983
+ except NotImplementedError :
1984
+ return None
1985
+ return self .lm_head
1986
+
1987
+ def set_output_embeddings (self , new_embeddings ):
1988
+ """
1989
+ Sets the model's output embedding, defaulting to setting new_embeddings to lm_head.
1990
+ """
1991
+ if getattr (self , "lm_head" ):
1992
+ self .lm_head = new_embeddings
1993
+
1994
+
1995
+ class PreTrainedModel (nn .Module , EmbeddingAccessMixin , ModuleUtilsMixin , PushToHubMixin , PeftAdapterMixin ):
1906
1996
r"""
1907
1997
Base class for all models.
1908
1998
@@ -2004,6 +2094,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
2004
2094
_supports_attention_backend = False
2005
2095
_can_record_outputs = None
2006
2096
2097
+ # This attribute sets the default parameter to be
2098
+
2007
2099
@property
2008
2100
@torch ._dynamo .allow_in_graph
2009
2101
def can_record_outputs (self ) -> dict [str , OutputRecorder ]:
@@ -2267,6 +2359,101 @@ def _from_config(cls, config, **kwargs):
2267
2359
2268
2360
return model
2269
2361
2362
+ @classmethod
2363
+ def _check_attn_implementation (cls , attn_implementation : Union [str , dict ]) -> Union [str , dict ]:
2364
+ """
2365
+ Checks that the requested attention implementation exists and tries to get the kernel from hub
2366
+ if `attn_implementation` matches hf kernels pattern.
2367
+ """
2368
+ if isinstance (attn_implementation , str ) and re .match (r"^[^/:]+/[^/:]+:[^/:]+$" , attn_implementation ):
2369
+ if not is_kernels_available ():
2370
+ raise ValueError ("kernels is not installed. Please install it with `pip install kernels`." )
2371
+
2372
+ # Extract repo_id and kernel_name from the string
2373
+ repo_id , kernel_name = attn_implementation .split (":" )
2374
+ kernel_name = kernel_name .strip ()
2375
+ repo_id = repo_id .strip ()
2376
+
2377
+ try :
2378
+ kernel = get_kernel (repo_id )
2379
+ ALL_ATTENTION_FUNCTIONS .register (f"kernel_{ repo_id .replace ('/' , '_' )} " , getattr (kernel , kernel_name ))
2380
+ attn_implementation = f"kernel_{ repo_id .replace ('/' , '_' )} "
2381
+ except FileNotFoundError as e :
2382
+ logger .warning (
2383
+ f"Could not find a kernel repository '{ repo_id } ' compatible with your devicein the hub: { e } . Using eager attention implementation instead."
2384
+ )
2385
+ attn_implementation = None # try to dispatch SDPA and fallback eager if not available
2386
+ except AttributeError :
2387
+ raise ValueError (
2388
+ "the kernel function name or class specified in the attn_implementation argument is not valid. \
2389
+ Please check the documentation for the correct format, \
2390
+ and check that the kernel exports the class and the function correctly."
2391
+ )
2392
+ if (
2393
+ not isinstance (attn_implementation , dict )
2394
+ and attn_implementation not in ["eager" , None ] + ALL_ATTENTION_FUNCTIONS .valid_keys ()
2395
+ ):
2396
+ message = f'Specified `attn_implementation="{ attn_implementation } "` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
2397
+ # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
2398
+ if cls ._supports_flash_attn or getattr (cls , "_supports_flash_attn_2" , False ):
2399
+ message += (
2400
+ ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
2401
+ ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
2402
+ )
2403
+ if cls ._supports_sdpa :
2404
+ message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
2405
+ if cls ._supports_flex_attn :
2406
+ message += ', `"attn_implementation=flex_attention"` (implementation using torch\' s flex_attention)'
2407
+ raise ValueError (message + "." )
2408
+
2409
+ return attn_implementation
2410
+
2411
+ def set_attention_implementation (self , attn_implementation : Union [str , dict ]):
2412
+ """
2413
+ Checks and dispatches to the requested attention implementation.
2414
+ """
2415
+ requested_attn_implementation = self ._check_attn_implementation (attn_implementation )
2416
+
2417
+ # Composite models consisting of several PretrainedModels can specify attention implementation as a dict where
2418
+ # keys are sub-config names. But most people will specify one `str` which means that should dispatch it for all sub-models.
2419
+ # See https://github.com/huggingface/transformers/pull/32238
2420
+ for key in self .config .sub_configs .keys ():
2421
+ sub_config = getattr (self .config , key )
2422
+ curr_attn_implementation = (
2423
+ requested_attn_implementation
2424
+ if not isinstance (requested_attn_implementation , dict )
2425
+ else requested_attn_implementation .get (key , None )
2426
+ )
2427
+ # For models with backbone sub-config might be not initialized. Set the requested att
2428
+ # if the config hasn't got any attn pre-set and the requested attn in not `None` (i.e not the default attn)
2429
+ if (
2430
+ sub_config is not None
2431
+ and sub_config ._attn_implementation_internal is None
2432
+ and curr_attn_implementation is not None
2433
+ ):
2434
+ sub_config ._attn_implementation_internal = curr_attn_implementation
2435
+
2436
+ if requested_attn_implementation == "flash_attention_3" and self ._flash_attn_3_can_dispatch ():
2437
+ self .config ._attn_implementation = "flash_attention_3"
2438
+ if requested_attn_implementation == "flash_attention_2" and self ._flash_attn_2_can_dispatch ():
2439
+ self .config ._attn_implementation = "flash_attention_2"
2440
+ elif requested_attn_implementation == "flex_attention" and self ._flex_attn_can_dispatch ():
2441
+ self .config ._attn_implementation = "flex_attention"
2442
+ elif (
2443
+ requested_attn_implementation in [None , "sdpa" ]
2444
+ and not is_torch_xla_available ()
2445
+ and self ._sdpa_can_dispatch (hard_check_only = requested_attn_implementation is not None )
2446
+ ):
2447
+ self .config ._attn_implementation = "sdpa"
2448
+ elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS .valid_keys ():
2449
+ self .config ._attn_implementation = requested_attn_implementation
2450
+ elif isinstance (requested_attn_implementation , dict ):
2451
+ self .config ._attn_implementation = requested_attn_implementation .get ("" , None )
2452
+ else :
2453
+ self .config ._attn_implementation = "eager"
2454
+
2455
+ self .config ._attn_implementation_autoset = True
2456
+
2270
2457
@classmethod
2271
2458
def _set_default_torch_dtype (cls , dtype : torch .dtype ) -> torch .dtype :
2272
2459
"""
@@ -2769,41 +2956,6 @@ def disable_input_require_grads(self):
2769
2956
"""
2770
2957
self ._require_grads_hook .remove ()
2771
2958
2772
- def get_input_embeddings (self ) -> nn .Module :
2773
- """
2774
- Returns the model's input embeddings.
2775
-
2776
- Returns:
2777
- `nn.Module`: A torch module mapping vocabulary to hidden states.
2778
- """
2779
- base_model = getattr (self , self .base_model_prefix , self )
2780
- if base_model is not self :
2781
- return base_model .get_input_embeddings ()
2782
- else :
2783
- raise NotImplementedError
2784
-
2785
- def set_input_embeddings (self , value : nn .Module ):
2786
- """
2787
- Set model's input embeddings.
2788
-
2789
- Args:
2790
- value (`nn.Module`): A module mapping vocabulary to hidden states.
2791
- """
2792
- base_model = getattr (self , self .base_model_prefix , self )
2793
- if base_model is not self :
2794
- base_model .set_input_embeddings (value )
2795
- else :
2796
- raise NotImplementedError
2797
-
2798
- def get_output_embeddings (self ) -> nn .Module :
2799
- """
2800
- Returns the model's output embeddings.
2801
-
2802
- Returns:
2803
- `nn.Module`: A torch module mapping hidden states to vocabulary.
2804
- """
2805
- return None # Overwrite for models with output embeddings
2806
-
2807
2959
def _init_weights (self , module ):
2808
2960
"""
2809
2961
Initialize the weights. This method should be overridden by derived class and is
0 commit comments