1212from ...activations import ACT2FN
1313from ...cache_utils import Cache , DynamicCache
1414from ...masking_utils import create_causal_mask
15- from ...modeling_flash_attention_utils import FlashAttentionKwargs
1615from ...modeling_layers import GradientCheckpointingLayer
1716from ...modeling_outputs import BaseModelOutputWithPast , SequenceClassifierOutputWithPast
1817from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS , dynamic_rope_update
1918from ...modeling_utils import ALL_ATTENTION_FUNCTIONS , PreTrainedModel
2019from ...processing_utils import Unpack
21- from ...utils import auto_docstring , can_return_tuple , logging
20+ from ...utils import TransformersKwargs , auto_docstring , can_return_tuple , logging
21+ from ...utils .generic import check_model_inputs
2222from .configuration_my_new_model2 import MyNewModel2Config
2323
2424
@@ -149,7 +149,7 @@ def eager_attention_forward(
149149 attention_mask : Optional [torch .Tensor ],
150150 scaling : float ,
151151 dropout : float = 0.0 ,
152- ** kwargs ,
152+ ** kwargs : Unpack [ TransformersKwargs ] ,
153153):
154154 key_states = repeat_kv (key , module .num_key_value_groups )
155155 value_states = repeat_kv (value , module .num_key_value_groups )
@@ -200,8 +200,8 @@ def forward(
200200 attention_mask : Optional [torch .Tensor ],
201201 past_key_value : Optional [Cache ] = None ,
202202 cache_position : Optional [torch .LongTensor ] = None ,
203- ** kwargs : Unpack [FlashAttentionKwargs ],
204- ) -> tuple [torch .Tensor , Optional [ torch .Tensor ], Optional [ tuple [ torch . Tensor ]] ]:
203+ ** kwargs : Unpack [TransformersKwargs ],
204+ ) -> tuple [torch .Tensor , torch .Tensor ]:
205205 input_shape = hidden_states .shape [:- 1 ]
206206 hidden_shape = (* input_shape , - 1 , self .head_dim )
207207
@@ -254,22 +254,19 @@ def forward(
254254 attention_mask : Optional [torch .Tensor ] = None ,
255255 position_ids : Optional [torch .LongTensor ] = None ,
256256 past_key_value : Optional [Cache ] = None ,
257- output_attentions : Optional [bool ] = False ,
258257 use_cache : Optional [bool ] = False ,
259258 cache_position : Optional [torch .LongTensor ] = None ,
260259 position_embeddings : Optional [tuple [torch .Tensor , torch .Tensor ]] = None , # necessary, but kept here for BC
261- ** kwargs : Unpack [FlashAttentionKwargs ],
262- ) -> tuple [torch .FloatTensor , Optional [ tuple [ torch . FloatTensor , torch . FloatTensor ]] ]:
260+ ** kwargs : Unpack [TransformersKwargs ],
261+ ) -> tuple [torch .Tensor ]:
263262 residual = hidden_states
264263 hidden_states = self .input_layernorm (hidden_states )
265-
266264 # Self Attention
267- hidden_states , self_attn_weights = self .self_attn (
265+ hidden_states , _ = self .self_attn (
268266 hidden_states = hidden_states ,
269267 attention_mask = attention_mask ,
270268 position_ids = position_ids ,
271269 past_key_value = past_key_value ,
272- output_attentions = output_attentions ,
273270 use_cache = use_cache ,
274271 cache_position = cache_position ,
275272 position_embeddings = position_embeddings ,
@@ -282,12 +279,7 @@ def forward(
282279 hidden_states = self .post_attention_layernorm (hidden_states )
283280 hidden_states = self .mlp (hidden_states )
284281 hidden_states = residual + hidden_states
285-
286- outputs = (hidden_states ,)
287- if output_attentions :
288- outputs += (self_attn_weights ,)
289-
290- return outputs
282+ return hidden_states
291283
292284
293285@auto_docstring
@@ -304,6 +296,10 @@ class MyNewModel2PreTrainedModel(PreTrainedModel):
304296 _supports_quantized_cache = True
305297 _supports_static_cache = True
306298 _supports_attention_backend = True
299+ _can_record_outputs = {
300+ "hidden_states" : MyNewModel2DecoderLayer ,
301+ "attentions" : MyNewModel2Attention ,
302+ }
307303
308304 def _init_weights (self , module ):
309305 std = self .config .initializer_range
@@ -343,7 +339,7 @@ def get_input_embeddings(self):
343339 def set_input_embeddings (self , value ):
344340 self .embed_tokens = value
345341
346- @can_return_tuple
342+ @check_model_inputs
347343 @auto_docstring
348344 def forward (
349345 self ,
@@ -353,26 +349,12 @@ def forward(
353349 past_key_values : Optional [Cache ] = None ,
354350 inputs_embeds : Optional [torch .FloatTensor ] = None ,
355351 use_cache : Optional [bool ] = None ,
356- output_attentions : Optional [bool ] = None ,
357- output_hidden_states : Optional [bool ] = None ,
358352 cache_position : Optional [torch .LongTensor ] = None ,
359- ** kwargs : Unpack [FlashAttentionKwargs ],
353+ ** kwargs : Unpack [TransformersKwargs ],
360354 ) -> BaseModelOutputWithPast :
361- output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
362- output_hidden_states = (
363- output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
364- )
365- use_cache = use_cache if use_cache is not None else self .config .use_cache
366-
367355 if (input_ids is None ) ^ (inputs_embeds is not None ):
368356 raise ValueError ("You must specify exactly one of input_ids or inputs_embeds" )
369357
370- if self .gradient_checkpointing and self .training and use_cache :
371- logger .warning_once (
372- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
373- )
374- use_cache = False
375-
376358 if inputs_embeds is None :
377359 inputs_embeds = self .embed_tokens (input_ids )
378360
@@ -394,6 +376,7 @@ def forward(
394376 attention_mask = attention_mask ,
395377 cache_position = cache_position ,
396378 past_key_values = past_key_values ,
379+ position_ids = position_ids ,
397380 )
398381
399382 # embed positions
@@ -408,42 +391,21 @@ def forward(
408391 normalizer = torch .tensor (self .config .hidden_size ** 0.5 , dtype = hidden_states .dtype )
409392 hidden_states = hidden_states * normalizer
410393
411- # decoder layers
412- all_hidden_states = () if output_hidden_states else None
413- all_self_attns = () if output_attentions else None
414-
415394 for decoder_layer in self .layers [: self .config .num_hidden_layers ]:
416- if output_hidden_states :
417- all_hidden_states += (hidden_states ,)
418-
419- layer_outputs = decoder_layer (
395+ hidden_states = decoder_layer (
420396 hidden_states ,
421397 attention_mask = causal_mask ,
422398 position_ids = position_ids ,
423399 past_key_value = past_key_values ,
424- output_attentions = output_attentions ,
425400 use_cache = use_cache ,
426401 cache_position = cache_position ,
427402 position_embeddings = position_embeddings ,
428403 ** kwargs ,
429404 )
430-
431- hidden_states = layer_outputs [0 ]
432-
433- if output_attentions :
434- all_self_attns += (layer_outputs [1 ],)
435-
436405 hidden_states = self .norm (hidden_states )
437-
438- # add hidden states from the last decoder layer
439- if output_hidden_states :
440- all_hidden_states += (hidden_states ,)
441-
442406 return BaseModelOutputWithPast (
443407 last_hidden_state = hidden_states ,
444408 past_key_values = past_key_values if use_cache else None ,
445- hidden_states = all_hidden_states ,
446- attentions = all_self_attns ,
447409 )
448410
449411
@@ -488,8 +450,7 @@ def forward(
488450 inputs_embeds : Optional [torch .FloatTensor ] = None ,
489451 labels : Optional [torch .LongTensor ] = None ,
490452 use_cache : Optional [bool ] = None ,
491- output_attentions : Optional [bool ] = None ,
492- output_hidden_states : Optional [bool ] = None ,
453+ ** kwargs : Unpack [TransformersKwargs ],
493454 ) -> SequenceClassifierOutputWithPast :
494455 r"""
495456 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -505,8 +466,7 @@ def forward(
505466 past_key_values = past_key_values ,
506467 inputs_embeds = inputs_embeds ,
507468 use_cache = use_cache ,
508- output_attentions = output_attentions ,
509- output_hidden_states = output_hidden_states ,
469+ ** kwargs ,
510470 )
511471 hidden_states = transformer_outputs .last_hidden_state
512472 logits = self .score (hidden_states )
0 commit comments