7
7
import torch
8
8
from torch import nn
9
9
10
+ from torchao .core .config import AOBaseConfig
11
+ from torchao .prototype .blockwise_fp8 .deep_gemm_utils import (
12
+ scaled_mm_deep_gemm_128_1_128_1 ,
13
+ scaled_mm_deep_gemm_128_1_128_128 ,
14
+ )
10
15
from torchao .prototype .blockwise_fp8 .kernels import (
11
- blockwise_fp8_gemm ,
12
16
fp8_blockwise_act_quant ,
17
+ triton_quantize_fp8_block ,
18
+ )
19
+ from torchao .quantization .transform_module import (
20
+ register_quantize_module_handler ,
13
21
)
14
22
15
23
16
- class BlockwiseQuantLinear (nn .Module ):
24
+ class fp8_blockwise_mm (torch .autograd .Function ):
25
+ @staticmethod
26
+ def forward (ctx , x , weight , block_size ):
27
+ assert block_size == 128 , "Only support block_size=128"
28
+
29
+ # Temporarily reshape x to 2D tensor
30
+ x_orig_shape = x .shape
31
+ x = x .reshape (- 1 , x_orig_shape [- 1 ])
32
+
33
+ # Triton kernel from DeepGEMM currently has the fastest activation quantization (1 x block_size)
34
+ x_fp8 , x_scale = fp8_blockwise_act_quant (x , block_size )
35
+
36
+ # fbgemm currently has the fastest weight quantization (block_size x block_size)
37
+ weight_t_fp8 , weight_t_scale = triton_quantize_fp8_block (
38
+ weight ,
39
+ block_m = block_size ,
40
+ block_k = block_size ,
41
+ k_major = True , # For [M,K] -> [K,M] in column-major
42
+ )
43
+
44
+ # DeepGEMM for blockwise GEMM where activation has (1 x block_size) scaling granularity
45
+ # and weight has (block_size x block_size) scaling granularity.
46
+ out = scaled_mm_deep_gemm_128_1_128_128 (
47
+ x_fp8 ,
48
+ x_scale ,
49
+ weight_t_fp8 ,
50
+ weight_t_scale ,
51
+ )
52
+ ctx .save_for_backward (x , weight )
53
+ ctx .block_size = block_size
54
+ return out
55
+
56
+ @staticmethod
57
+ def backward (ctx , grad_output ):
58
+ x , weight = ctx .saved_tensors
59
+ block_size = ctx .block_size
60
+
61
+ # left operand must be row-major
62
+ grad_output_fp8 , grad_output_scale = fp8_blockwise_act_quant (
63
+ grad_output ,
64
+ block_size ,
65
+ )
66
+
67
+ # right operand must be column-major
68
+ weight_t_fp8 , weight_t_scale = triton_quantize_fp8_block (
69
+ weight ,
70
+ block_m = block_size ,
71
+ block_k = block_size ,
72
+ k_major = False , # For [M,K] -> [K,M] in row-major
73
+ )
74
+ weight_t_fp8 = weight_t_fp8 .t ().contiguous ().t () # To col-major
75
+
76
+ # DeepGEMM for blockwise GEMM where left operand has (1 x block_size) scaling granularity
77
+ # and right operand has (block_size x block_size) scaling granularity.
78
+ # grad_x = grad_output @ weight.T
79
+ grad_x = scaled_mm_deep_gemm_128_1_128_128 (
80
+ grad_output_fp8 ,
81
+ weight_t_fp8 ,
82
+ 1.0 / grad_output_scale ,
83
+ 1.0 / weight_t_scale ,
84
+ )
85
+
86
+ # left operand must be row-major
87
+ grad_output_t_fp8 , grad_output_t_scale = fp8_blockwise_act_quant (
88
+ grad_output .t ().contiguous (),
89
+ block_size ,
90
+ )
91
+
92
+ # right operand must be column-major
93
+ x_fp8 , x_scale = fp8_blockwise_act_quant (
94
+ x ,
95
+ block_size ,
96
+ )
97
+ x_fp8 = x_fp8 .t ().contiguous ().t () # To col-major
98
+
99
+ # DeepGEMM for blockwise GEMM where both operands have (1 x block_size) scaling granularity.
100
+ # grad_weight = grad_output.T @ x
101
+ grad_weight = scaled_mm_deep_gemm_128_1_128_1 (
102
+ grad_output_t_fp8 ,
103
+ x_fp8 ,
104
+ 1.0 / grad_output_t_scale ,
105
+ 1.0 / x_scale ,
106
+ )
107
+ return grad_x , grad_weight , None , None
108
+
109
+
110
+ class Float8BlockwiseLinear (nn .Linear ):
17
111
"""
18
112
Custom linear layer with support for quantized weights and optional bias.
19
113
@@ -25,53 +119,60 @@ class BlockwiseQuantLinear(nn.Module):
25
119
dtype (torch.dtype): Data type for the weights. Defaults to torch.float8_e4m3fn.
26
120
"""
27
121
28
- dtype = torch .bfloat16
122
+ supported_dtypes = [
123
+ torch .bfloat16 ,
124
+ ]
29
125
30
126
def __init__ (
31
127
self ,
32
- in_features : int ,
33
- out_features : int ,
34
- bias : bool = False ,
128
+ * args ,
35
129
block_size : int = 128 ,
36
- dtype : torch .dtype = torch .float8_e4m3fn ,
130
+ dtype = torch .bfloat16 ,
131
+ ** kwargs ,
37
132
):
38
- super ().__init__ ()
39
- supported_dtypes = [
40
- torch .float8_e4m3fn ,
41
- torch .float8_e5m2 ,
42
- ]
43
- assert dtype in supported_dtypes , (
44
- f"Unsupported dtype: { dtype } . Supported dtypes: { supported_dtypes } "
45
- )
46
- scale_in_features = (in_features + block_size - 1 ) // block_size
47
- scale_out_features = (out_features + block_size - 1 ) // block_size
48
- self .weight = nn .Parameter (torch .empty (out_features , in_features , dtype = dtype ))
49
- self .weight .scale = self .scale = nn .Parameter (
50
- torch .empty (scale_out_features , scale_in_features , dtype = torch .float32 )
133
+ super ().__init__ (* args , ** kwargs )
134
+
135
+ assert dtype in self .supported_dtypes , (
136
+ f"Unsupported dtype: { dtype } . Supported dtypes: { self .supported_dtypes } "
51
137
)
52
138
self .block_size = block_size
53
- self .dtype
54
-
55
- if bias :
56
- self .bias = nn .Parameter (torch .empty (out_features ))
57
- else :
58
- self .register_parameter ("bias" , None )
139
+ self .dtype = dtype
59
140
60
141
def forward (self , x : torch .Tensor ) -> torch .Tensor :
61
142
"""
62
143
Forward pass for the custom linear layer.
63
144
64
145
Args:
65
- x (torch.Tensor): Input tensor.
146
+ x (torch.Tensor): input tensor.
66
147
67
148
Returns:
68
149
torch.Tensor: Transformed tensor after linear computation.
69
150
"""
70
- x , scale = fp8_blockwise_act_quant (x , self .block_size , self .dtype )
71
- y = blockwise_fp8_gemm (
72
- x , scale , self .weight , self .weight .scale , self .block_size
73
- )
151
+ return fp8_blockwise_mm .apply (x , self .weight , self .block_size )
152
+
153
+ @classmethod
154
+ def from_float (
155
+ cls ,
156
+ mod ,
157
+ ):
158
+ assert mod .bias is None , "unsupported"
159
+ assert mod .in_features % 128 == 0 , "unsupported"
160
+ assert mod .out_features % 128 == 0 , "unsupported"
161
+ with torch .device ("meta" ):
162
+ new_mod = cls (
163
+ mod .in_features ,
164
+ mod .out_features ,
165
+ bias = False ,
166
+ )
167
+ new_mod .weight = mod .weight
168
+ new_mod .bias = mod .bias
169
+ return new_mod
170
+
171
+
172
+ class Float8BlockwiseLinearConfig (AOBaseConfig ):
173
+ pass
174
+
74
175
75
- if self . bias is not None :
76
- y += self . bias
77
- return y
176
+ @ register_quantize_module_handler ( Float8BlockwiseLinearConfig )
177
+ def _deep_gemm_float8_inference_linear_transform ( module , config ):
178
+ return Float8BlockwiseLinear . from_float ( module )
0 commit comments