@@ -604,9 +604,12 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
604
604
load_weights_vanilla_helper (module , weights )
605
605
606
606
scale_name = self ._get_scale_name (weights )
607
- weight_scale = load_weight_shard (weights [0 ][scale_name ], module .tp_size ,
608
- module .tp_rank ,
609
- module .tp_mode ).squeeze ()
607
+ full_weight_scale = weights [0 ][scale_name ]
608
+ # modelopt fp8_pb_wo can have 2 extra singleton dimensions
609
+ if full_weight_scale .dim () == 4 :
610
+ full_weight_scale = full_weight_scale .squeeze (1 ).squeeze (- 1 )
611
+ weight_scale = load_weight_shard (full_weight_scale , module .tp_size ,
612
+ module .tp_rank , module .tp_mode )
610
613
copy_weight (module .weight_scale , weight_scale )
611
614
if "input_scale" in weights [0 ]:
612
615
copy_weight (module .input_scale , weights [0 ]["input_scale" ])
@@ -619,13 +622,23 @@ def load_weights_fused_qkv_linear(self, module: Linear,
619
622
fused_weight = torch .cat ((q_weight , k_weight , v_weight ))
620
623
621
624
scale_name = self ._get_scale_name (weights )
622
- q_scale = load_weight_shard (weights [0 ][scale_name ], module .tp_size ,
625
+ full_q_scale = weights [0 ][scale_name ]
626
+ full_k_scale = weights [1 ][scale_name ]
627
+ full_v_scale = weights [2 ][scale_name ]
628
+ # modelopt fp8_pb_wo can have 2 extra singleton dimensions
629
+ if full_q_scale .dim () == 4 :
630
+ full_q_scale = full_q_scale .squeeze (1 ).squeeze (- 1 )
631
+ if full_k_scale .dim () == 4 :
632
+ full_k_scale = full_k_scale .squeeze (1 ).squeeze (- 1 )
633
+ if full_v_scale .dim () == 4 :
634
+ full_v_scale = full_v_scale .squeeze (1 ).squeeze (- 1 )
635
+ q_scale = load_weight_shard (full_q_scale , module .tp_size ,
623
636
module .tp_rank , module .tp_mode )
624
- k_scale = load_weight_shard (weights [ 1 ][ scale_name ] , module .tp_size ,
637
+ k_scale = load_weight_shard (full_k_scale , module .tp_size ,
625
638
module .tp_rank , module .tp_mode )
626
- v_scale = load_weight_shard (weights [ 2 ][ scale_name ] , module .tp_size ,
639
+ v_scale = load_weight_shard (full_v_scale , module .tp_size ,
627
640
module .tp_rank , module .tp_mode )
628
- fused_fp8_block_scale = torch .cat ((q_scale , k_scale , v_scale )). squeeze ()
641
+ fused_fp8_block_scale = torch .cat ((q_scale , k_scale , v_scale ))
629
642
630
643
copy_weight (module .weight , fused_weight )
631
644
copy_weight (module .weight_scale , fused_fp8_block_scale )
@@ -637,11 +650,18 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
637
650
fused_weight = torch .cat ((gate_weight , up_weight ))
638
651
639
652
scale_name = self ._get_scale_name (weights )
640
- left_scale = load_weight_shard (weights [0 ][scale_name ], module .tp_size ,
653
+ full_left_scale = weights [0 ][scale_name ]
654
+ full_right_scale = weights [1 ][scale_name ]
655
+ # modelopt fp8_pb_wo can have 2 extra singleton dimensions
656
+ if full_left_scale .dim () == 4 :
657
+ full_left_scale = full_left_scale .squeeze (1 ).squeeze (- 1 )
658
+ if full_right_scale .dim () == 4 :
659
+ full_right_scale = full_right_scale .squeeze (1 ).squeeze (- 1 )
660
+ left_scale = load_weight_shard (full_left_scale , module .tp_size ,
641
661
module .tp_rank , module .tp_mode )
642
- right_scale = load_weight_shard (weights [ 1 ][ scale_name ] , module .tp_size ,
662
+ right_scale = load_weight_shard (full_right_scale , module .tp_size ,
643
663
module .tp_rank , module .tp_mode )
644
- fused_scale = torch .cat ([left_scale , right_scale ], dim = 0 ). squeeze ()
664
+ fused_scale = torch .cat ([left_scale , right_scale ], dim = 0 )
645
665
copy_weight (module .weight , fused_weight )
646
666
copy_weight (module .weight_scale , fused_scale )
647
667
0 commit comments