@@ -629,45 +629,8 @@ def create_weights(self, module: torch.nn.Module):
629
629
630
630
def load_weights (self , module : torch .nn .Module , weights : List [Dict ],
631
631
weight_loading_mode : MoEWeightLoadingMode ):
632
-
633
- if get_sm_version () == 100 :
634
- expert_ids = set (module .initial_local_expert_ids )
635
- if self .need_load_shared_weights (module ):
636
- expert_ids .update (
637
- module .layer_load_balancer .get_load_expert_ids ())
638
- for name in list (weights .keys ()):
639
- if name .endswith ("weight_scale_inv" ):
640
- if int (name .split ("." )[0 ]) not in expert_ids :
641
- continue
642
- weight_name = name .replace ("weight_scale_inv" , "weight" )
643
- logger .debug (f"Resmoothing { weight_name } " )
644
- weight = weights [weight_name ][:]
645
- scale = weights [name ][:]
646
- weights [weight_name ], weights [name ] = resmooth_to_fp8_e8m0 (
647
- weight , scale )
648
632
super ().load_weights (module , weights , weight_loading_mode )
649
633
650
- if get_sm_version () == 100 :
651
- transfromed_w3_w1_scale = transform_sf_into_required_layout (
652
- module .quant_scales [0 ],
653
- mn = module .w3_w1_weight .shape [1 ],
654
- k = module .w3_w1_weight .shape [2 ],
655
- recipe = (1 , 128 , 128 ),
656
- num_groups = module .w3_w1_weight .shape [0 ],
657
- is_sfa = False )
658
- module .w3_w1_weight_scaling_factor = nn .Parameter (
659
- transfromed_w3_w1_scale , requires_grad = False )
660
- transfromed_w2_scale = transform_sf_into_required_layout (
661
- module .quant_scales [1 ],
662
- mn = module .w2_weight .shape [1 ],
663
- k = module .w2_weight .shape [2 ],
664
- recipe = (1 , 128 , 128 ),
665
- num_groups = module .w3_w1_weight .shape [0 ],
666
- is_sfa = False )
667
- module .w2_weight_scaling_factor = nn .Parameter (transfromed_w2_scale ,
668
- requires_grad = False )
669
- self .setup_quant_scales (module )
670
-
671
634
def setup_quant_scales (self , module : torch .nn .Module ):
672
635
module .quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales (
673
636
fc_weight_scales = module .w3_w1_weight_scaling_factor ,
@@ -765,6 +728,50 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
765
728
})
766
729
767
730
731
+ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm (
732
+ DeepSeekFP8BlockScalesFusedMoEMethod ):
733
+
734
+ def load_weights (self , module : torch .nn .Module , weights : List [Dict ],
735
+ weight_loading_mode : MoEWeightLoadingMode ):
736
+ if get_sm_version () == 100 :
737
+ expert_ids = set (module .initial_local_expert_ids )
738
+ if self .need_load_shared_weights (module ):
739
+ expert_ids .update (
740
+ module .layer_load_balancer .get_load_expert_ids ())
741
+ for name in list (weights .keys ()):
742
+ if name .endswith ("weight_scale_inv" ):
743
+ if int (name .split ("." )[0 ]) not in expert_ids :
744
+ continue
745
+ weight_name = name .replace ("weight_scale_inv" , "weight" )
746
+ logger .debug (f"Resmoothing { weight_name } " )
747
+ weight = weights [weight_name ][:]
748
+ scale = weights [name ][:]
749
+ weights [weight_name ], weights [name ] = resmooth_to_fp8_e8m0 (
750
+ weight , scale )
751
+ super ().load_weights (module , weights , weight_loading_mode )
752
+
753
+ if get_sm_version () == 100 :
754
+ transfromed_w3_w1_scale = transform_sf_into_required_layout (
755
+ module .quant_scales [0 ],
756
+ mn = module .w3_w1_weight .shape [1 ],
757
+ k = module .w3_w1_weight .shape [2 ],
758
+ recipe = (1 , 128 , 128 ),
759
+ num_groups = module .w3_w1_weight .shape [0 ],
760
+ is_sfa = False )
761
+ module .w3_w1_weight_scaling_factor = nn .Parameter (
762
+ transfromed_w3_w1_scale , requires_grad = False )
763
+ transfromed_w2_scale = transform_sf_into_required_layout (
764
+ module .quant_scales [1 ],
765
+ mn = module .w2_weight .shape [1 ],
766
+ k = module .w2_weight .shape [2 ],
767
+ recipe = (1 , 128 , 128 ),
768
+ num_groups = module .w3_w1_weight .shape [0 ],
769
+ is_sfa = False )
770
+ module .w2_weight_scaling_factor = nn .Parameter (transfromed_w2_scale ,
771
+ requires_grad = False )
772
+ self .setup_quant_scales (module )
773
+
774
+
768
775
class WInt4AFP8FusedMoEMethod (FusedMoEMethodBase ):
769
776
770
777
def create_weights (self , module : torch .nn .Module ):
0 commit comments