|
17 | 17 | from auto_round.data_type.register import register_dtype |
18 | 18 | from auto_round.data_type.utils import float8_e4m3fn_ste, get_gaudi_fp8_ste_func |
19 | 19 |
|
20 | | -# @register_dtype("fp8_gaudi3_to_int_sym") |
21 | | -# def progressive_quant_fp8_int4_gaudi3( |
22 | | -# tensor, |
23 | | -# bits=4, |
24 | | -# group_size=-1, |
25 | | -# v=0, |
26 | | -# min_scale=1.0, |
27 | | -# max_scale=1.0, |
28 | | -# q_scale_thresh=1e-5, |
29 | | -# weight_fp8_max_scale=1.0, |
30 | | -# **kwargs |
31 | | -# ): |
32 | | -# """Two-stage quantization: quantize tensor to fp8 by per tensor, then quantize fp8 to w4g128 |
33 | | -# |
34 | | -# This method first quantizes the input tensor into float8 format and then performs |
35 | | -# a secondary quantization to int4 with grouping. |
36 | | -# |
37 | | -# Args: |
38 | | -# tensor (torch.Tensor): Input tensor to quantize. |
39 | | -# bits (int, optional): Bit precision for secondary quantization. Defaults to 4. |
40 | | -# group_size (int, optional): Group size for int4 quantization. Defaults to -1 (no grouping). |
41 | | -# v (float, optional): Optional parameter for variance tuning. Defaults to 0. |
42 | | -# min_scale (float, optional): Minimum scaling factor for int4 quantization. Defaults to 1.0. |
43 | | -# max_scale (float, optional): Maximum scaling factor for int4 quantization. Defaults to 1.0. |
44 | | -# q_scale_thresh (float, optional): Threshold for scaling. Defaults to 1e-5. |
45 | | -# weight_fp8_max_scale (float, optional): Maximum scaling factor for float8 quantization. Defaults to 1.0. |
46 | | -# **kwargs: Additional arguments for compatibility. |
47 | | -# |
48 | | -# Returns: |
49 | | -# tuple: |
50 | | -# - Quantized and dequantized tensor (torch.Tensor). |
51 | | -# - Combined scaling factor (torch.Tensor). |
52 | | -# - Placeholder for zp (None). |
53 | | -# """ |
54 | | -# fp8_max = torch.finfo(torch.float8_e4m3fn).max |
55 | | -# tensor_max = ( |
56 | | -# torch.max(torch.abs(tensor)).to(torch.float32) * weight_fp8_max_scale |
57 | | -# ) ## better train a ratio |
58 | | -# scale = tensor_max.to(torch.float32) / fp8_max |
59 | | -# min_scaling_factor = 1.0 / (fp8_max * 512.0) ##copy from vllm |
60 | | -# scale_bf16_to_fp8 = torch.clip(scale, min=min_scaling_factor) |
61 | | -# fp8_res = tensor / scale_bf16_to_fp8 |
62 | | -# fp8_res = torch.clip(fp8_res, -fp8_max, fp8_max) |
63 | | -# float8_e4m3fn_ste_gaudi = get_gaudi_fp8_ste_func() |
64 | | -# fp8_res = float8_e4m3fn_ste_gaudi(fp8_res) |
65 | | -# |
66 | | -# # convert to bf16 |
67 | | -# fp8_res_using_16bit = fp8_res.to(tensor.dtype) |
68 | | -# # convert to int4 |
69 | | -# from auto_round.data_type.int import quant_tensor_sym |
70 | | -# |
71 | | -# qdq_int4_tensor, scale_fp8_to_int4, zp_fp8_to_int4 = quant_tensor_sym( |
72 | | -# fp8_res_using_16bit, |
73 | | -# bits=bits, |
74 | | -# group_size=group_size, |
75 | | -# v=v, |
76 | | -# min_scale=min_scale, |
77 | | -# max_scale=max_scale, |
78 | | -# scale_dtype=torch.bfloat16, |
79 | | -# q_scale_thresh=q_scale_thresh, |
80 | | -# ) |
81 | | -# qdq_tensor = qdq_int4_tensor * scale_bf16_to_fp8 |
82 | | -# scale_bf16_to_int4 = scale_fp8_to_int4 * scale_bf16_to_fp8 |
83 | | -# return qdq_tensor, (scale_bf16_to_int4, scale_bf16_to_fp8), zp_fp8_to_int4 |
84 | | - |
85 | | - |
86 | | -# @register_dtype("fp8_gaudi3_to_int_sym_pc") |
87 | | -# def progressive_quant_fp8_int4_per_channel( |
88 | | -# tensor, |
89 | | -# bits=4, |
90 | | -# group_size=-1, |
91 | | -# v=0, |
92 | | -# min_scale=1.0, |
93 | | -# max_scale=1.0, |
94 | | -# q_scale_thresh=1e-5, |
95 | | -# weight_fp8_max_scale=1.0, |
96 | | -# **kwargs |
97 | | -# ): |
98 | | -# """The per-channel version of progressive quantization from float8 to int4.""" |
99 | | -# # tensor: [out_feats, in_feats] |
100 | | -# # scale_bf16_to_fp8: [out_feats, 1] |
101 | | -# out_feats, in_feats = tensor.shape |
102 | | -# fp8_max = torch.finfo(torch.float8_e4m3fn).max |
103 | | -# dim = 1 |
104 | | -# tensor_max = ( |
105 | | -# torch.max(torch.abs(tensor), dim=dim, keepdim=True)[0].to(torch.float32) |
106 | | -# * weight_fp8_max_scale |
107 | | -# ) ## better train a ratio |
108 | | -# scale = tensor_max.to(torch.float32) / fp8_max |
109 | | -# min_scaling_factor = 1.0 / (fp8_max * 512.0) ##copy from vllm |
110 | | -# scale_bf16_to_fp8 = torch.clip(scale, min=min_scaling_factor) |
111 | | -# fp8_res = tensor / scale_bf16_to_fp8 |
112 | | -# fp8_res = torch.clip(fp8_res, -fp8_max, fp8_max) |
113 | | -# float8_e4m3fn_ste_gaudi = get_gaudi_fp8_ste_func() |
114 | | -# fp8_res = float8_e4m3fn_ste_gaudi(fp8_res) |
115 | | -# |
116 | | -# ##convert to bf16 |
117 | | -# fp8_res_using_16bit = fp8_res.to(tensor.dtype) |
118 | | -# ##convert to int4 |
119 | | -# from auto_round.data_type.int import quant_tensor_sym |
120 | | -# |
121 | | -# qdq_int4_tensor, scale_fp8_to_int4, zp_fp8_to_int4 = quant_tensor_sym( |
122 | | -# fp8_res_using_16bit, |
123 | | -# bits=bits, |
124 | | -# group_size=group_size, |
125 | | -# v=v, |
126 | | -# min_scale=min_scale, |
127 | | -# max_scale=max_scale, |
128 | | -# scale_dtype=torch.bfloat16, |
129 | | -# q_scale_thresh=q_scale_thresh, |
130 | | -# ) |
131 | | -# qdq_tensor = qdq_int4_tensor * scale_bf16_to_fp8 |
132 | | -# scale_fp8_to_int4_with_group = scale_fp8_to_int4 |
133 | | -# scale_fp8_to_int4_with_group_reshape_back = scale_fp8_to_int4_with_group.reshape( |
134 | | -# out_feats, -1 |
135 | | -# ) |
136 | | -# scale_bf16_to_int4 = scale_fp8_to_int4_with_group_reshape_back * scale_bf16_to_fp8 |
137 | | -# scale_bf16_to_int4_with_group = scale_bf16_to_int4.reshape(-1, 1) |
138 | | -# return ( |
139 | | -# qdq_tensor, |
140 | | -# (scale_bf16_to_int4_with_group, scale_bf16_to_fp8), |
141 | | -# zp_fp8_to_int4, |
142 | | -# ) |
143 | | - |
144 | | - |
145 | | -# @register_dtype("fp8_gaudi3_to_int_sym_v2") |
146 | | -# def progressive_quant_fp8_int4_v2( |
147 | | -# tensor, |
148 | | -# bits=4, |
149 | | -# group_size=-1, |
150 | | -# v=0, |
151 | | -# min_scale=1.0, |
152 | | -# max_scale=1.0, |
153 | | -# q_scale_thresh=1e-5, |
154 | | -# weight_fp8_max_scale=1.0, |
155 | | -# **kwargs |
156 | | -# ): |
157 | | -# """The variant of progressive quantization from float8 to int4. |
158 | | -# |
159 | | -# The variant quantizes the tensor to int4 first and then quantizes the qdq tensor to fp8. |
160 | | -# """ |
161 | | -# # convert to int4 first |
162 | | -# from auto_round.data_type.int import quant_tensor_sym |
163 | | -# |
164 | | -# qdq_int4_tensor, scale_bf16_to_int4, zp_fp8_to_int4 = quant_tensor_sym( |
165 | | -# tensor, |
166 | | -# bits=bits, |
167 | | -# group_size=group_size, |
168 | | -# v=v, |
169 | | -# min_scale=min_scale, |
170 | | -# max_scale=max_scale, |
171 | | -# scale_dtype=torch.bfloat16, |
172 | | -# q_scale_thresh=q_scale_thresh, |
173 | | -# ) |
174 | | -# # FIXME(Yi): some fuse error here |
175 | | -# torch._dynamo.graph_break() |
176 | | -# fp8_max = torch.finfo(torch.float8_e4m3fn).max |
177 | | -# tensor_max = ( |
178 | | -# torch.max(torch.abs(qdq_int4_tensor)).to(torch.float32) * weight_fp8_max_scale |
179 | | -# ) ## better train a ratio |
180 | | -# scale = tensor_max.to(torch.float32) / fp8_max |
181 | | -# min_scaling_factor = 1.0 / (fp8_max * 512.0) ##copy from vllm |
182 | | -# scale_bf16_to_fp8 = torch.clip(scale, min=min_scaling_factor) |
183 | | -# fp8_res = qdq_int4_tensor / scale_bf16_to_fp8 |
184 | | -# fp8_res = torch.clip(fp8_res, -fp8_max, fp8_max) |
185 | | -# float8_e4m3fn_ste_gaudi = get_gaudi_fp8_ste_func() |
186 | | -# fp8_res = float8_e4m3fn_ste_gaudi(fp8_res) |
187 | | -# |
188 | | -# # convert to bf16 |
189 | | -# fp8_res_using_16bit = fp8_res.to(tensor.dtype) |
190 | | -# |
191 | | -# qdq_tensor = fp8_res_using_16bit * scale_bf16_to_fp8 |
192 | | -# |
193 | | -# return qdq_tensor, (scale_bf16_to_int4, scale_bf16_to_fp8), zp_fp8_to_int4 |
194 | | - |
195 | 20 |
|
196 | 21 | @register_dtype("fp8_to_int_sym") |
197 | 22 | def progressive_quant_fp8_int4( |
|
0 commit comments