Skip to content

Commit d913b39

Browse files
fix: HWIO to OIHW (#39200)
* fix: HWIO to OIHW * Bug in attention type * Conversion script docstring * style --------- Co-authored-by: Arthur <[email protected]> Co-authored-by: Arthur <[email protected]>
1 parent a26f0fa commit d913b39

File tree

3 files changed

+17
-13
lines changed

3 files changed

+17
-13
lines changed

src/transformers/models/gemma3n/configuration_gemma3n.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def __init__(
271271

272272
if layer_types is None:
273273
self.layer_types = [
274-
"full_attention" if i % 5 == 0 else "sliding_attention" for i in range(self.num_hidden_layers)
274+
"full_attention" if (i + 1) % 5 == 0 else "sliding_attention" for i in range(self.num_hidden_layers)
275275
]
276276
else:
277277
self.layer_types = layer_types

src/transformers/models/gemma3n/convert_gemma3n_weights.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
1919
python src/transformers/models/gemma3n/convert_gemma3n_weights.py \
2020
--variant='gemma3n_e4b' \
21-
--tokenizer_path="$HOME/nano3/checkpoints/tokenizer/gemma-3n-tokenizer.model" \
22-
--checkpoint_path="$HOME/nano3/checkpoints/g251_orbax/" \
23-
--output_path="$HOME/nano3/checkpoints/g251_vision_encoder/"
21+
--tokenizer_path="$HOME/tokenizers/gemma-3n-tokenizer.model" \
22+
--checkpoint_path="$HOME/checkpoints/gemma-3n-orbax/" \
23+
--output_path="$HOME/checkpoints/gemma-3n-safetensors/"
2424
"""
2525

2626
import json
@@ -552,8 +552,9 @@ def generate_base_path(path: str, block_type: str) -> tuple[str, tuple[int, int]
552552
converted_weight = weights
553553
elif _MOBILE_NET_CONV in path:
554554
if "Conv_0" in path:
555-
converted_path = "conv_stem.conv.weight"
556-
converted_weight = weights.transpose(3, 2, 1, 0)
555+
converted_path = ("conv_stem.conv.weight", "conv_stem.conv.bias")
556+
converted_weight = weights.transpose(3, 2, 0, 1)
557+
converted_weight = (converted_weight, np.zeros(converted_weight.shape[0]))
557558
elif "Normalize_0" in path:
558559
converted_path = "conv_stem.bn.weight"
559560
converted_weight = weights
@@ -567,7 +568,7 @@ def generate_base_path(path: str, block_type: str) -> tuple[str, tuple[int, int]
567568
converted_weight = weights
568569
elif "expand_conv" in path:
569570
converted_path += ".conv_exp.weight"
570-
converted_weight = weights.transpose(3, 2, 1, 0)
571+
converted_weight = weights.transpose(3, 2, 0, 1)
571572
else:
572573
converted_path += ".conv_pwl.weight"
573574
converted_weight = weights.transpose()[:, :, None, None]
@@ -588,7 +589,7 @@ def generate_base_path(path: str, block_type: str) -> tuple[str, tuple[int, int]
588589
converted_weight = weights
589590
elif "key_dwconv" in path:
590591
converted_path += ".attn.key.down_conv.weight"
591-
converted_weight = weights.transpose()
592+
converted_weight = weights.transpose(3, 2, 0, 1)
592593
elif "key_proj" in path:
593594
converted_path += ".attn.key.proj.weight"
594595
converted_weight = weights.transpose()[:, :, None, None]
@@ -600,7 +601,7 @@ def generate_base_path(path: str, block_type: str) -> tuple[str, tuple[int, int]
600601
converted_weight = weights.transpose()[:, :, None, None]
601602
elif "value_dwconv" in path:
602603
converted_path += ".attn.value.down_conv.weight"
603-
converted_weight = weights.transpose()
604+
converted_weight = weights.transpose(3, 2, 0, 1)
604605
elif "value_proj" in path:
605606
converted_path += ".attn.value.proj.weight"
606607
converted_weight = weights.transpose()[:, :, None, None]
@@ -630,15 +631,18 @@ def generate_base_path(path: str, block_type: str) -> tuple[str, tuple[int, int]
630631
converted_weight = weights.transpose()[:, :, None, None]
631632
elif "middle_dwconv" in path:
632633
converted_path += ".dw_mid.conv.weight"
633-
converted_weight = weights.transpose(3, 2, 1, 0)
634+
converted_weight = weights.transpose(3, 2, 0, 1)
634635
elif "project" in path:
635636
converted_path += ".pw_proj.conv.weight"
636637
converted_weight = weights.transpose()[:, :, None, None]
637638
elif "start_dwconv" in path:
638639
converted_path += ".dw_start.conv.weight"
639-
converted_weight = weights.transpose(3, 2, 1, 0)
640+
converted_weight = weights.transpose(3, 2, 0, 1)
640641

641-
return [(converted_path, converted_weight)]
642+
if isinstance(converted_path, (tuple, list)):
643+
return zip(converted_path, converted_weight)
644+
else:
645+
return [(converted_path, converted_weight)]
642646

643647

644648
def convert(checkpoint_path: str, config: Gemma3nConfig) -> dict[str, torch.Tensor]:

src/transformers/models/gemma3n/modular_gemma3n.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def __init__(
283283

284284
if layer_types is None:
285285
self.layer_types = [
286-
"full_attention" if i % 5 == 0 else "sliding_attention" for i in range(self.num_hidden_layers)
286+
"full_attention" if (i + 1) % 5 == 0 else "sliding_attention" for i in range(self.num_hidden_layers)
287287
]
288288
else:
289289
self.layer_types = layer_types

0 commit comments

Comments
 (0)