18
18
19
19
python src/transformers/models/gemma3n/convert_gemma3n_weights.py \
20
20
--variant='gemma3n_e4b' \
21
- --tokenizer_path="$HOME/nano3/checkpoints/tokenizer /gemma-3n-tokenizer.model" \
22
- --checkpoint_path="$HOME/nano3/ checkpoints/g251_orbax /" \
23
- --output_path="$HOME/nano3/ checkpoints/g251_vision_encoder /"
21
+ --tokenizer_path="$HOME/tokenizers /gemma-3n-tokenizer.model" \
22
+ --checkpoint_path="$HOME/checkpoints/gemma-3n-orbax /" \
23
+ --output_path="$HOME/checkpoints/gemma-3n-safetensors /"
24
24
"""
25
25
26
26
import json
@@ -552,8 +552,9 @@ def generate_base_path(path: str, block_type: str) -> tuple[str, tuple[int, int]
552
552
converted_weight = weights
553
553
elif _MOBILE_NET_CONV in path :
554
554
if "Conv_0" in path :
555
- converted_path = "conv_stem.conv.weight"
556
- converted_weight = weights .transpose (3 , 2 , 1 , 0 )
555
+ converted_path = ("conv_stem.conv.weight" , "conv_stem.conv.bias" )
556
+ converted_weight = weights .transpose (3 , 2 , 0 , 1 )
557
+ converted_weight = (converted_weight , np .zeros (converted_weight .shape [0 ]))
557
558
elif "Normalize_0" in path :
558
559
converted_path = "conv_stem.bn.weight"
559
560
converted_weight = weights
@@ -567,7 +568,7 @@ def generate_base_path(path: str, block_type: str) -> tuple[str, tuple[int, int]
567
568
converted_weight = weights
568
569
elif "expand_conv" in path :
569
570
converted_path += ".conv_exp.weight"
570
- converted_weight = weights .transpose (3 , 2 , 1 , 0 )
571
+ converted_weight = weights .transpose (3 , 2 , 0 , 1 )
571
572
else :
572
573
converted_path += ".conv_pwl.weight"
573
574
converted_weight = weights .transpose ()[:, :, None , None ]
@@ -588,7 +589,7 @@ def generate_base_path(path: str, block_type: str) -> tuple[str, tuple[int, int]
588
589
converted_weight = weights
589
590
elif "key_dwconv" in path :
590
591
converted_path += ".attn.key.down_conv.weight"
591
- converted_weight = weights .transpose ()
592
+ converted_weight = weights .transpose (3 , 2 , 0 , 1 )
592
593
elif "key_proj" in path :
593
594
converted_path += ".attn.key.proj.weight"
594
595
converted_weight = weights .transpose ()[:, :, None , None ]
@@ -600,7 +601,7 @@ def generate_base_path(path: str, block_type: str) -> tuple[str, tuple[int, int]
600
601
converted_weight = weights .transpose ()[:, :, None , None ]
601
602
elif "value_dwconv" in path :
602
603
converted_path += ".attn.value.down_conv.weight"
603
- converted_weight = weights .transpose ()
604
+ converted_weight = weights .transpose (3 , 2 , 0 , 1 )
604
605
elif "value_proj" in path :
605
606
converted_path += ".attn.value.proj.weight"
606
607
converted_weight = weights .transpose ()[:, :, None , None ]
@@ -630,15 +631,18 @@ def generate_base_path(path: str, block_type: str) -> tuple[str, tuple[int, int]
630
631
converted_weight = weights .transpose ()[:, :, None , None ]
631
632
elif "middle_dwconv" in path :
632
633
converted_path += ".dw_mid.conv.weight"
633
- converted_weight = weights .transpose (3 , 2 , 1 , 0 )
634
+ converted_weight = weights .transpose (3 , 2 , 0 , 1 )
634
635
elif "project" in path :
635
636
converted_path += ".pw_proj.conv.weight"
636
637
converted_weight = weights .transpose ()[:, :, None , None ]
637
638
elif "start_dwconv" in path :
638
639
converted_path += ".dw_start.conv.weight"
639
- converted_weight = weights .transpose (3 , 2 , 1 , 0 )
640
+ converted_weight = weights .transpose (3 , 2 , 0 , 1 )
640
641
641
- return [(converted_path , converted_weight )]
642
+ if isinstance (converted_path , (tuple , list )):
643
+ return zip (converted_path , converted_weight )
644
+ else :
645
+ return [(converted_path , converted_weight )]
642
646
643
647
644
648
def convert (checkpoint_path : str , config : Gemma3nConfig ) -> dict [str , torch .Tensor ]:
0 commit comments