Skip to content

Commit e355c0a

Browse files
Fix missing initializations for models created in 2024 (#38987)
* fix GroundingDino * fix SuperGlue * fix GroundingDino * fix MambaModel * fix OmDetTurbo * fix SegGpt * fix Qwen2Audio * fix Mamba2 * fix DabDetr * fix Dac * fix FalconMamba * skip timm initialization * fix Encodec and MusicgenMelody * fix Musicgen * skip timm initialization test * fix OmDetTurbo * clean the code Co-authored-by: Cyril Vallez <[email protected]> * add reviewed changes * add back timm * style * better check for parametrizations --------- Co-authored-by: Cyril Vallez <[email protected]>
1 parent 1125513 commit e355c0a

21 files changed

+229
-98
lines changed

src/transformers/models/dab_detr/modeling_dab_detr.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,9 @@ def _init_weights(self, module):
829829
module.weight.data.normal_(mean=0.0, std=std)
830830
if module.bias is not None:
831831
module.bias.data.zero_()
832+
elif isinstance(module, nn.LayerNorm):
833+
module.weight.data.fill_(1.0)
834+
module.bias.data.zero_()
832835
elif isinstance(module, nn.Embedding):
833836
module.weight.data.normal_(mean=0.0, std=std)
834837
if module.padding_idx is not None:
@@ -841,6 +844,8 @@ def _init_weights(self, module):
841844
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
842845
bias_value = -math.log((1 - prior_prob) / prior_prob)
843846
module.class_embed.bias.data.fill_(bias_value)
847+
elif isinstance(module, nn.PReLU):
848+
module.reset_parameters()
844849

845850

846851
# Modified from transformers.models.detr.modeling_detr.DetrEncoder with Detr->DabDetr,DETR->ConditionalDETR

src/transformers/models/dac/modeling_dac.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,12 @@ def _init_weights(self, module):
480480
if isinstance(module, nn.Conv1d):
481481
nn.init.trunc_normal_(module.weight, std=0.02)
482482
nn.init.constant_(module.bias, 0)
483+
elif isinstance(module, Snake1d):
484+
module.alpha.data.fill_(1.0)
485+
elif isinstance(module, nn.ConvTranspose1d):
486+
module.reset_parameters()
487+
elif isinstance(module, nn.Embedding):
488+
module.weight.data.normal_(mean=0.0, std=0.02)
483489

484490
def apply_weight_norm(self):
485491
weight_norm = nn.utils.weight_norm

src/transformers/models/encodec/modeling_encodec.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ class EncodecLSTM(nn.Module):
235235
LSTM without worrying about the hidden state, nor the layout of the data. Expects input as convolutional layout.
236236
"""
237237

238-
def __init__(self, config, dimension):
238+
def __init__(self, config: EncodecConfig, dimension: int):
239239
super().__init__()
240240
self.lstm = nn.LSTM(dimension, dimension, config.num_lstm_layers)
241241

@@ -452,22 +452,16 @@ class EncodecPreTrainedModel(PreTrainedModel):
452452

453453
def _init_weights(self, module):
454454
"""Initialize the weights"""
455-
if isinstance(module, nn.Linear):
456-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
457-
if module.bias is not None:
458-
module.bias.data.zero_()
459-
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
455+
if isinstance(module, nn.GroupNorm):
460456
module.bias.data.zero_()
461457
module.weight.data.fill_(1.0)
462458
elif isinstance(module, nn.Conv1d):
463459
nn.init.kaiming_normal_(module.weight)
464460
if module.bias is not None:
465461
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
466462
nn.init.uniform_(module.bias, a=-k, b=k)
467-
elif isinstance(module, nn.Embedding):
468-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
469-
if module.padding_idx is not None:
470-
module.weight.data[module.padding_idx].zero_()
463+
elif isinstance(module, nn.ConvTranspose1d):
464+
module.reset_parameters()
471465
elif isinstance(module, nn.LSTM):
472466
for name, param in module.named_parameters():
473467
if "weight" in name:
@@ -659,7 +653,7 @@ def _decode_frame(self, codes: torch.Tensor, scale: Optional[torch.Tensor] = Non
659653

660654
def decode(
661655
self,
662-
audio_codes: torch.Tensor,
656+
audio_codes: torch.LongTensor,
663657
audio_scales: torch.Tensor,
664658
padding_mask: Optional[torch.Tensor] = None,
665659
return_dict: Optional[bool] = None,
@@ -708,10 +702,10 @@ def decode(
708702
@auto_docstring
709703
def forward(
710704
self,
711-
input_values: torch.Tensor,
712-
padding_mask: Optional[torch.Tensor] = None,
705+
input_values: torch.FloatTensor,
706+
padding_mask: Optional[torch.BoolTensor] = None,
713707
bandwidth: Optional[float] = None,
714-
audio_codes: Optional[torch.Tensor] = None,
708+
audio_codes: Optional[torch.LongTensor] = None,
715709
audio_scales: Optional[torch.Tensor] = None,
716710
return_dict: Optional[bool] = None,
717711
) -> Union[tuple[torch.Tensor, torch.Tensor], EncodecOutput]:

src/transformers/models/falcon_mamba/modeling_falcon_mamba.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,16 @@ class FalconMambaPreTrainedModel(PreTrainedModel):
445445

446446
def _init_weights(self, module):
447447
"""Initialize the weights."""
448+
std = self.config.initializer_range
448449
if isinstance(module, FalconMambaMixer):
450+
# S4D real initialization. These are not discretized!
451+
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
452+
A = torch.arange(1, module.ssm_state_size + 1, dtype=torch.float32)[None, :]
453+
A = A.expand(module.intermediate_size, -1).contiguous()
454+
module.A_log.copy_(torch.log(A))
449455
module.A_log._no_weight_decay = True
450456
module.D._no_weight_decay = True
457+
module.D.data.fill_(1.0)
451458

452459
dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
453460
if self.config.time_step_init_scheme == "constant":
@@ -462,33 +469,39 @@ def _init_weights(self, module):
462469
).clamp(min=self.config.time_step_floor)
463470
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
464471
inv_dt = dt + torch.log(-torch.expm1(-dt))
465-
with torch.no_grad():
466-
module.dt_proj.bias.copy_(inv_dt)
472+
module.dt_proj.bias.copy_(inv_dt)
467473
module.dt_proj.bias._no_reinit = True
468474

475+
nn.init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5))
476+
if module.conv1d.bias is not None:
477+
if not getattr(module.conv1d.bias, "_no_reinit", False):
478+
nn.init.zeros_(module.conv1d.bias)
479+
nn.init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5))
480+
481+
if self.config.rescale_prenorm_residual:
482+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
483+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
484+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
485+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
486+
#
487+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
488+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
489+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
490+
# We need to reinit p since this code could be called multiple times
491+
# Having just p *= scale would repeatedly scale it down
492+
p = module.out_proj.weight
493+
p /= math.sqrt(self.config.num_hidden_layers)
494+
469495
if isinstance(module, nn.Linear):
496+
if not getattr(module.weight, "_no_reinit", False):
497+
nn.init.normal_(module.weight, std=std)
470498
if module.bias is not None:
471499
if not getattr(module.bias, "_no_reinit", False):
472500
nn.init.zeros_(module.bias)
501+
elif isinstance(module, FalconMambaRMSNorm):
502+
module.weight.data.fill_(1.0)
473503
elif isinstance(module, nn.Embedding):
474-
nn.init.normal_(module.weight, std=self.config.initializer_range)
475-
476-
if self.config.rescale_prenorm_residual:
477-
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
478-
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
479-
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
480-
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
481-
#
482-
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
483-
for name, p in module.named_parameters():
484-
if name in ["out_proj.weight"]:
485-
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
486-
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
487-
# We need to reinit p since this code could be called multiple times
488-
# Having just p *= scale would repeatedly scale it down
489-
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
490-
with torch.no_grad():
491-
p /= math.sqrt(self.config.num_hidden_layers)
504+
nn.init.normal_(module.weight, std=std)
492505

493506

494507
@dataclass

src/transformers/models/grounding_dino/modeling_grounding_dino.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,16 +1414,18 @@ def _init_weights(self, module):
14141414
module.out_vision_proj.bias.data.fill_(0)
14151415
nn.init.xavier_uniform_(module.out_text_proj.weight)
14161416
module.out_text_proj.bias.data.fill_(0)
1417-
elif isinstance(module, (GroundingDinoEncoderLayer, GroundingDinoDecoderLayer)):
1418-
for p in module.parameters():
1419-
if p.dim() > 1:
1420-
nn.init.normal_(p, mean=0.0, std=std)
1417+
elif isinstance(module, GroundingDinoFusionLayer):
1418+
module.vision_param.data.fill_(1e-4)
1419+
module.text_param.data.fill_(1e-4)
14211420
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
14221421
# Slightly different from the TF version which uses truncated_normal for initialization
14231422
# cf https://github.com/pytorch/pytorch/pull/5617
14241423
module.weight.data.normal_(mean=0.0, std=std)
14251424
if module.bias is not None:
14261425
module.bias.data.zero_()
1426+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
1427+
module.weight.data.fill_(1.0)
1428+
module.bias.data.zero_()
14271429
elif isinstance(module, nn.Embedding):
14281430
module.weight.data.normal_(mean=0.0, std=std)
14291431
if module.padding_idx is not None:

src/transformers/models/mamba/modeling_mamba.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,16 @@ class MambaPreTrainedModel(PreTrainedModel):
382382

383383
def _init_weights(self, module):
384384
"""Initialize the weights."""
385+
std = self.config.initializer_range
385386
if isinstance(module, MambaMixer):
387+
# S4D real initialization. These are not discretized!
388+
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
389+
A = torch.arange(1, module.ssm_state_size + 1, dtype=torch.float32)[None, :]
390+
A = A.expand(module.intermediate_size, -1).contiguous()
391+
module.A_log.copy_(torch.log(A))
386392
module.A_log._no_weight_decay = True
387393
module.D._no_weight_decay = True
394+
module.D.data.fill_(1.0)
388395

389396
dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
390397
if self.config.time_step_init_scheme == "constant":
@@ -399,33 +406,39 @@ def _init_weights(self, module):
399406
).clamp(min=self.config.time_step_floor)
400407
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
401408
inv_dt = dt + torch.log(-torch.expm1(-dt))
402-
with torch.no_grad():
403-
module.dt_proj.bias.copy_(inv_dt)
409+
module.dt_proj.bias.copy_(inv_dt)
404410
module.dt_proj.bias._no_reinit = True
405411

412+
nn.init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5))
413+
if module.conv1d.bias is not None:
414+
if not getattr(module.conv1d.bias, "_no_reinit", False):
415+
nn.init.zeros_(module.conv1d.bias)
416+
nn.init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5))
417+
418+
if self.config.rescale_prenorm_residual:
419+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
420+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
421+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
422+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
423+
#
424+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
425+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
426+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
427+
# We need to reinit p since this code could be called multiple times
428+
# Having just p *= scale would repeatedly scale it down
429+
p = module.out_proj.weight
430+
p /= math.sqrt(self.config.num_hidden_layers)
431+
406432
if isinstance(module, nn.Linear):
433+
if not getattr(module.weight, "_no_reinit", False):
434+
nn.init.normal_(module.weight, std=std)
407435
if module.bias is not None:
408436
if not getattr(module.bias, "_no_reinit", False):
409437
nn.init.zeros_(module.bias)
438+
elif isinstance(module, MambaRMSNorm):
439+
module.weight.data.fill_(1.0)
410440
elif isinstance(module, nn.Embedding):
411-
nn.init.normal_(module.weight, std=self.config.initializer_range)
412-
413-
if self.config.rescale_prenorm_residual:
414-
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
415-
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
416-
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
417-
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
418-
#
419-
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
420-
for name, p in module.named_parameters():
421-
if name in ["out_proj.weight"]:
422-
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
423-
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
424-
# We need to reinit p since this code could be called multiple times
425-
# Having just p *= scale would repeatedly scale it down
426-
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
427-
with torch.no_grad():
428-
p /= math.sqrt(self.config.num_hidden_layers)
441+
nn.init.normal_(module.weight, std=std)
429442

430443

431444
@dataclass

src/transformers/models/mamba2/modeling_mamba2.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -721,9 +721,15 @@ class Mamba2PreTrainedModel(PreTrainedModel):
721721

722722
def _init_weights(self, module):
723723
"""Initialize the weights."""
724+
std = self.config.initializer_range
724725
if isinstance(module, Mamba2Mixer):
726+
# S4D real initialization. These are not discretized!
727+
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
728+
A = torch.arange(1, self.config.num_heads + 1)
729+
module.A_log.copy_(torch.log(A))
725730
module.A_log._no_weight_decay = True
726731
module.D._no_weight_decay = True
732+
module.D.data.fill_(1.0)
727733

728734
dt = torch.exp(
729735
torch.rand(self.config.num_heads)
@@ -733,33 +739,39 @@ def _init_weights(self, module):
733739

734740
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
735741
inv_dt = dt + torch.log(-torch.expm1(-dt))
736-
with torch.no_grad():
737-
module.dt_bias.copy_(inv_dt)
742+
module.dt_bias.copy_(inv_dt)
738743
module.dt_bias._no_reinit = True
739744

745+
nn.init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5))
746+
if module.conv1d.bias is not None:
747+
if not getattr(module.conv1d.bias, "_no_reinit", False):
748+
nn.init.zeros_(module.conv1d.bias)
749+
nn.init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5))
750+
751+
if self.config.rescale_prenorm_residual:
752+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
753+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
754+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
755+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
756+
#
757+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
758+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
759+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
760+
# We need to reinit p since this code could be called multiple times
761+
# Having just p *= scale would repeatedly scale it down
762+
p = module.out_proj.weight
763+
p /= math.sqrt(self.config.num_hidden_layers)
764+
740765
if isinstance(module, nn.Linear):
766+
if not getattr(module.weight, "_no_reinit", False):
767+
nn.init.normal_(module.weight, std=std)
741768
if module.bias is not None:
742769
if not getattr(module.bias, "_no_reinit", False):
743770
nn.init.zeros_(module.bias)
771+
elif isinstance(module, (Mamba2RMSNorm, MambaRMSNormGated)):
772+
module.weight.data.fill_(1.0)
744773
elif isinstance(module, nn.Embedding):
745-
nn.init.normal_(module.weight, std=self.config.initializer_range)
746-
747-
if self.config.rescale_prenorm_residual:
748-
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
749-
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
750-
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
751-
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
752-
#
753-
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
754-
for name, p in module.named_parameters():
755-
if name in ["out_proj.weight"]:
756-
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
757-
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
758-
# We need to reinit p since this code could be called multiple times
759-
# Having just p *= scale would repeatedly scale it down
760-
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
761-
with torch.no_grad():
762-
p /= math.sqrt(self.config.num_hidden_layers)
774+
nn.init.normal_(module.weight, std=std)
763775

764776

765777
@dataclass

src/transformers/models/musicgen/modeling_musicgen.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,10 +440,13 @@ class MusicgenPreTrainedModel(PreTrainedModel):
440440

441441
def _init_weights(self, module):
442442
std = self.config.initializer_factor
443-
if isinstance(module, (nn.Linear, nn.Conv1d)):
443+
if isinstance(module, nn.Linear):
444444
module.weight.data.normal_(mean=0.0, std=std)
445445
if module.bias is not None:
446446
module.bias.data.zero_()
447+
elif isinstance(module, nn.LayerNorm):
448+
module.weight.data.fill_(1.0)
449+
module.bias.data.zero_()
447450
elif isinstance(module, nn.Embedding):
448451
module.weight.data.normal_(mean=0.0, std=std)
449452
if module.padding_idx is not None:

0 commit comments

Comments
 (0)