@@ -685,7 +685,7 @@ def backward(ctx, do3):
685
685
686
686
class FP8MlpFunction (paddle .autograd .PyLayer ):
687
687
@staticmethod
688
- def forward (ctx , x , w1 , w2 ):
688
+ def forward (ctx , x , w1 , w2 , recompute_fwd_gate_up ):
689
689
# ===== reshape for deep_gemm, since deep_gemm only support 2D =====
690
690
x_orig_shape = x .shape
691
691
x = x .reshape ([- 1 , x_orig_shape [- 1 ]])
@@ -697,6 +697,7 @@ def forward(ctx, x, w1, w2):
697
697
o3 = o3 .reshape ([x_orig_shape [0 ], - 1 , o3 .shape [- 1 ]])
698
698
699
699
# ===== save for backward =====
700
+ o1 = None if recompute_fwd_gate_up else o1
700
701
ctx .save_for_backward (
701
702
o1 ,
702
703
x_fp8 ,
@@ -729,9 +730,14 @@ def backward(ctx, do3):
729
730
)
730
731
731
732
# ===== call func common_fp8_mlp_bwd =====
732
- dx = FP8LinearFunctionBase .common_fp8_mlp_bwd (
733
- do3 , x_t_fp8 , x_t_scale , w1 , w2 , o1 = o1 , x_fp8 = None , x_scale = None , apply_backward_hook = True
734
- )
733
+ if o1 is None :
734
+ dx = FP8LinearFunctionBase .common_fp8_mlp_bwd (
735
+ do3 , x_t_fp8 , x_t_scale , w1 , w2 , o1 = None , x_fp8 = x_fp8 , x_scale = x_scale , apply_backward_hook = True
736
+ )
737
+ else :
738
+ dx = FP8LinearFunctionBase .common_fp8_mlp_bwd (
739
+ do3 , x_t_fp8 , x_t_scale , w1 , w2 , o1 = o1 , x_fp8 = None , x_scale = None , apply_backward_hook = True
740
+ )
735
741
# ===== reshape to origin shape =====
736
742
if len (x_orig_shape ) > 2 :
737
743
dx = dx .reshape ([x_orig_shape [0 ], - 1 , dx .shape [- 1 ]])
@@ -749,6 +755,7 @@ def __init__(
749
755
using_post_norm_recompute = False ,
750
756
norm_weight = None ,
751
757
norm_eps = None ,
758
+ recompute_fwd_gate_up = False ,
752
759
):
753
760
super ().__init__ ()
754
761
self .config = config
@@ -761,6 +768,8 @@ def __init__(
761
768
self .hidden_size = config .hidden_size if hidden_size is None else hidden_size
762
769
self .intermediate_size = config .intermediate_size if intermediate_size is None else intermediate_size
763
770
771
+ self .recompute_fwd_gate_up = recompute_fwd_gate_up
772
+
764
773
self .w1 = self .create_parameter (
765
774
shape = [self .hidden_size , self .intermediate_size * 2 ],
766
775
dtype = "bfloat16" ,
@@ -780,7 +789,7 @@ def forward(self, x):
780
789
if self .using_post_norm_recompute :
781
790
return FusedNormFP8MLPFunction .apply (x , self .norm_weight , self .w1 , self .w2 , self .norm_eps )
782
791
else :
783
- return FP8MlpFunction .apply (x , self .w1 , self .w2 )
792
+ return FP8MlpFunction .apply (x , self .w1 , self .w2 , self . recompute_fwd_gate_up )
784
793
785
794
786
795
def split_group_gemm (x_fp8 , x_scale , w_fp8 , w_scale , tokens_per_expert , gemm_out ):
0 commit comments