|
12 | 12 | matmul_mxf4_bf16_tn, |
13 | 13 | matmul_nvf4_bf16_tn, |
14 | 14 | matmul_mxf8_bf16_tn, |
15 | | - matmul_mxf8_bf16_tt, |
16 | | - matmul_mxf8_bf16_nt, |
17 | 15 | matmul_mxf8_bf16_nn, |
18 | 16 | # Backward quantization |
19 | 17 | backward_t_bf16, |
@@ -192,50 +190,6 @@ def _(x, w, xs, ws, alpha): |
192 | 190 | return x.new_empty(x.shape[0], w.shape[0], dtype=torch.bfloat16) |
193 | 191 |
|
194 | 192 |
|
195 | | -@torch.library.custom_op("fp_quant::matmul_mxf8_bf16_nt_op", mutates_args=()) |
196 | | -def matmul_mxf8_bf16_nt_op( |
197 | | - x: torch.Tensor, |
198 | | - w: torch.Tensor, |
199 | | - xs: torch.Tensor, |
200 | | - ws: torch.Tensor, |
201 | | - alpha: torch.Tensor, |
202 | | -) -> torch.Tensor: |
203 | | - return matmul_mxf8_bf16_nt( |
204 | | - x, |
205 | | - w, |
206 | | - to_blocked_qutlass(xs, use_triton_kernel=True), |
207 | | - to_blocked_qutlass(ws, use_triton_kernel=True).view(torch.float8_e8m0fnu), |
208 | | - alpha, |
209 | | - ) |
210 | | - |
211 | | - |
212 | | -@matmul_mxf8_bf16_nt_op.register_fake |
213 | | -def _(x, w, xs, ws, alpha): |
214 | | - return x.new_empty(x.shape[1], w.shape[1], dtype=torch.bfloat16) |
215 | | - |
216 | | - |
217 | | -@torch.library.custom_op("fp_quant::matmul_mxf8_bf16_tt_op", mutates_args=()) |
218 | | -def matmul_mxf8_bf16_tt_op( |
219 | | - x: torch.Tensor, |
220 | | - w: torch.Tensor, |
221 | | - xs: torch.Tensor, |
222 | | - ws: torch.Tensor, |
223 | | - alpha: torch.Tensor, |
224 | | -) -> torch.Tensor: |
225 | | - return matmul_mxf8_bf16_tt( |
226 | | - x, |
227 | | - w, |
228 | | - to_blocked_qutlass(xs, use_triton_kernel=True), |
229 | | - to_blocked_qutlass(ws, use_triton_kernel=True).view(torch.float8_e8m0fnu), |
230 | | - alpha, |
231 | | - ) |
232 | | - |
233 | | - |
234 | | -@matmul_mxf8_bf16_tt_op.register_fake |
235 | | -def _(x, w, xs, ws, alpha): |
236 | | - return x.new_empty(x.shape[0], w.shape[1], dtype=torch.bfloat16) |
237 | | - |
238 | | - |
239 | 193 | @torch.library.custom_op("fp_quant::matmul_mxf8_bf16_nn_op", mutates_args=()) |
240 | 194 | def matmul_mxf8_bf16_nn_op( |
241 | 195 | x: torch.Tensor, |
|
0 commit comments