@@ -63,13 +63,32 @@ def _get_dtype(self):
6363 return torch .float16
6464
6565 def get_input_iter (self ):
66- def _get_scale_per_tensor (x : torch .Tensor , custom_scale : float = None ) -> torch .Tensor :
66+ def _get_scale_per_tensor (
67+ x : torch .Tensor , custom_scale : float = None
68+ ) -> torch .Tensor :
6769 # For tensor-wise scaling, kernel requires a float32 scale tensor
6870 if custom_scale :
6971 return torch .tensor (custom_scale , dtype = torch .float32 , device = x .device )
7072 scale = torch .finfo (torch .float8_e4m3fn ).max / x .abs ().max ()
7173 return scale .to (torch .float32 )
7274
75+ def _get_scale_per_row (
76+ x : torch .Tensor , transpose : bool = False
77+ ) -> torch .Tensor :
78+ if transpose : # scale_b.shape should be [1, N]
79+ scale = (
80+ torch .finfo (torch .float8_e4m3fn ).max
81+ / x .abs ().max (dim = 0 , keepdim = True ).values
82+ )
83+ else : # scale_a.shape should be [M, 1]
84+ scale = (
85+ torch .finfo (torch .float8_e4m3fn ).max
86+ / x .abs ().max (dim = 1 , keepdim = True ).values
87+ )
88+ return scale .to (
89+ torch .float32
90+ ) # For row-wise scaling, kernel requires a float32 scale tensor
91+
7392 def args (m , n , k ):
7493 a = torch .randn (m , k , device = self .device ).to (self ._get_dtype ())
7594 b = (
@@ -80,26 +99,33 @@ def args(m, n, k):
8099 )
81100
82101 if self .extra_args .scaling_rowwise :
83- M , N = a .shape [0 ], b .shape [1 ]
84- scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
85- scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
102+ scale_a = _get_scale_per_row (a )
103+ scale_b = _get_scale_per_row (b , transpose = True )
86104 else :
87- scale_a = _get_scale_per_tensor (a , custom_scale = self .extra_args .per_tensor_scale_a )
88- scale_b = _get_scale_per_tensor (b , custom_scale = self .extra_args .per_tensor_scale_b )
105+ scale_a = _get_scale_per_tensor (
106+ a , custom_scale = self .extra_args .per_tensor_scale_a
107+ )
108+ scale_b = _get_scale_per_tensor (
109+ b , custom_scale = self .extra_args .per_tensor_scale_b
110+ )
89111
90112 # Kernels expect dtype=float8_e4m3fn
91113 a = a .to (torch .float8_e4m3fn )
92114 b = b .to (torch .float8_e4m3fn )
93115
94116 return (a , b , scale_a , scale_b )
95117
96- if hasattr (self , 'external_shapes' ) and self .external_shapes : # Check for external shapes loaded from input-loader
118+ if (
119+ hasattr (self , "external_shapes" ) and self .external_shapes
120+ ): # Check for external shapes loaded from input-loader
97121 for shape in self .external_shapes :
98122 if len (shape ) == 3 :
99123 m , n , k = shape
100124 yield args (m , n , k )
101125 else :
102- logger .warning (f"Skipping invalid shape: { shape } , expected [M, N, K]" )
126+ logger .warning (
127+ f"Skipping invalid shape: { shape } , expected [M, N, K]"
128+ )
103129 elif self .extra_args .llama :
104130 for m , n , k , _bias in llama_shapes ():
105131 yield args (m , n , k )
0 commit comments