|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import torch |
| 8 | +from torch import nn |
| 9 | + |
| 10 | +from torchao.core.config import AOBaseConfig |
| 11 | +from torchao.prototype.blockwise_fp8_training.kernels import ( |
| 12 | + blockwise_fp8_gemm_1x128_128x1, |
| 13 | + blockwise_fp8_gemm_1x128_128x128, |
| 14 | + fp8_blockwise_act_quant_lhs, |
| 15 | + fp8_blockwise_act_quant_rhs, |
| 16 | + fp8_blockwise_act_quant_transposed_lhs, |
| 17 | + fp8_blockwise_weight_quant_rhs, |
| 18 | + fp8_blockwise_weight_quant_transposed_rhs, |
| 19 | +) |
| 20 | +from torchao.quantization.transform_module import ( |
| 21 | + register_quantize_module_handler, |
| 22 | +) |
| 23 | +from torchao.utils import is_sm_at_least_90 |
| 24 | + |
| 25 | + |
| 26 | +class fp8_blockwise_mm(torch.autograd.Function): |
| 27 | + @staticmethod |
| 28 | + def forward(ctx, x, weight, block_size): |
| 29 | + assert block_size == 128, "Only support block_size=128" |
| 30 | + |
| 31 | + # Temporarily reshape x to 2D tensor |
| 32 | + x_orig_shape = x.shape |
| 33 | + x = x.reshape(-1, x_orig_shape[-1]) |
| 34 | + |
| 35 | + # Cast inputs to fp8 blockwise using (1, block_size) scaling granularity in row major format. |
| 36 | + x_fp8, x_scale = fp8_blockwise_act_quant_lhs(x, block_size) |
| 37 | + |
| 38 | + # Cast weight to fp8 blockwise using (block_size, block_size) scaling granularity, with transposed dims in column major format. |
| 39 | + weight_t_fp8, weight_t_scale = fp8_blockwise_weight_quant_transposed_rhs( |
| 40 | + weight, |
| 41 | + block_size=block_size, |
| 42 | + ) |
| 43 | + |
| 44 | + # out = input @ weight.T |
| 45 | + out = blockwise_fp8_gemm_1x128_128x128( |
| 46 | + x_fp8, |
| 47 | + 1.0 / x_scale, |
| 48 | + weight_t_fp8, |
| 49 | + 1.0 / weight_t_scale, |
| 50 | + ) |
| 51 | + out = out.reshape(*x_orig_shape[:-1], out.shape[-1]) |
| 52 | + ctx.save_for_backward(x, weight) |
| 53 | + ctx.block_size = block_size |
| 54 | + return out |
| 55 | + |
| 56 | + @staticmethod |
| 57 | + def backward(ctx, grad_output): |
| 58 | + x, weight = ctx.saved_tensors |
| 59 | + block_size = ctx.block_size |
| 60 | + |
| 61 | + # Reshape input to 2D |
| 62 | + x_orig_shape = x.shape |
| 63 | + x = x.reshape(-1, x_orig_shape[-1]) |
| 64 | + |
| 65 | + # Reshape grad_output to 2D |
| 66 | + grad_output_orig_shape = grad_output.shape |
| 67 | + grad_output = grad_output.reshape(-1, grad_output_orig_shape[-1]).contiguous() |
| 68 | + assert grad_output.shape[1] % 128 == 0, "unsupported" |
| 69 | + |
| 70 | + # Cast grad_output to fp8 blockwise 1x128 since it is the grad of the output activation. |
| 71 | + grad_output_fp8, grad_output_scale = fp8_blockwise_act_quant_lhs( |
| 72 | + grad_output, |
| 73 | + block_size, |
| 74 | + ) |
| 75 | + |
| 76 | + # Cast weight to fp8 blockwise to 128x128 in column major format. |
| 77 | + weight_fp8, weight_scale = fp8_blockwise_weight_quant_rhs( |
| 78 | + weight, |
| 79 | + block_size=block_size, |
| 80 | + ) |
| 81 | + |
| 82 | + # grad_x = grad_output @ weight |
| 83 | + grad_x = blockwise_fp8_gemm_1x128_128x128( |
| 84 | + grad_output_fp8, |
| 85 | + 1.0 / grad_output_scale, |
| 86 | + weight_fp8, |
| 87 | + 1.0 / weight_scale, |
| 88 | + ) |
| 89 | + |
| 90 | + # Cast grad_output_t to fp8 blockwise with (1 x block_size) scaling groups, since it is |
| 91 | + # the grad of the output activation. |
| 92 | + # Write directly with transposed dims in row major format, as needed for dW calc. |
| 93 | + grad_output_t_fp8, grad_output_t_scale = fp8_blockwise_act_quant_transposed_lhs( |
| 94 | + grad_output, |
| 95 | + block_size, |
| 96 | + ) |
| 97 | + |
| 98 | + # Cast x to fp8 blockwise with (block_size x 1) scaling groups, in column major format. |
| 99 | + # RHS should have groupwise scales calculated colwise, so scaling groups do not cross the |
| 100 | + # contracting (K) dim. |
| 101 | + x_fp8, x_scale = fp8_blockwise_act_quant_rhs(x, block_size) |
| 102 | + |
| 103 | + # grad_weight = grad_output.T @ x |
| 104 | + grad_weight = blockwise_fp8_gemm_1x128_128x1( |
| 105 | + grad_output_t_fp8, |
| 106 | + 1.0 / grad_output_t_scale, |
| 107 | + x_fp8, |
| 108 | + 1.0 / x_scale, |
| 109 | + ) |
| 110 | + |
| 111 | + # Reshape grad_x to expected potentially 3D+ shape |
| 112 | + grad_x = grad_x.reshape(*grad_output_orig_shape[:-1], grad_x.shape[-1]) |
| 113 | + return grad_x, grad_weight, None, None |
| 114 | + |
| 115 | + |
| 116 | +class Float8BlockwiseLinear(nn.Linear): |
| 117 | + """ |
| 118 | + Custom linear layer with support for quantized weights and optional bias. |
| 119 | +
|
| 120 | + Args: |
| 121 | + in_features (int): Number of input features. |
| 122 | + out_features (int): Number of output features. |
| 123 | + bias (bool): Whether to include a bias term. Defaults to False. |
| 124 | + block_size (int): Block size for quantization. Defaults to 128. |
| 125 | + dtype (torch.dtype): Data type for the weights. Defaults to torch.float8_e4m3fn. |
| 126 | + """ |
| 127 | + |
| 128 | + supported_dtypes = [ |
| 129 | + torch.bfloat16, |
| 130 | + ] |
| 131 | + |
| 132 | + def __init__( |
| 133 | + self, |
| 134 | + *args, |
| 135 | + block_size: int = 128, |
| 136 | + dtype=torch.bfloat16, |
| 137 | + **kwargs, |
| 138 | + ): |
| 139 | + super().__init__(*args, **kwargs) |
| 140 | + |
| 141 | + assert dtype in self.supported_dtypes, ( |
| 142 | + f"Unsupported dtype: {dtype}. Supported dtypes: {self.supported_dtypes}" |
| 143 | + ) |
| 144 | + assert is_sm_at_least_90(), "Only support SM90" |
| 145 | + self.block_size = block_size |
| 146 | + self.dtype = dtype |
| 147 | + |
| 148 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 149 | + """ |
| 150 | + Forward pass for the custom linear layer. |
| 151 | +
|
| 152 | + Args: |
| 153 | + x (torch.Tensor): input tensor. |
| 154 | +
|
| 155 | + Returns: |
| 156 | + torch.Tensor: Transformed tensor after linear computation. |
| 157 | + """ |
| 158 | + return fp8_blockwise_mm.apply(x, self.weight, self.block_size) |
| 159 | + |
| 160 | + @classmethod |
| 161 | + def from_float( |
| 162 | + cls, |
| 163 | + mod, |
| 164 | + ): |
| 165 | + assert mod.bias is None, "unsupported" |
| 166 | + assert mod.in_features % 128 == 0, "unsupported" |
| 167 | + assert mod.out_features % 128 == 0, "unsupported" |
| 168 | + with torch.device("meta"): |
| 169 | + new_mod = cls( |
| 170 | + mod.in_features, |
| 171 | + mod.out_features, |
| 172 | + bias=False, |
| 173 | + ) |
| 174 | + new_mod.weight = mod.weight |
| 175 | + new_mod.bias = mod.bias |
| 176 | + return new_mod |
| 177 | + |
| 178 | + |
| 179 | +class Float8BlockwiseLinearConfig(AOBaseConfig): |
| 180 | + pass |
| 181 | + |
| 182 | + |
| 183 | +@register_quantize_module_handler(Float8BlockwiseLinearConfig) |
| 184 | +def _float8_blockwise_transform(module, config): |
| 185 | + return Float8BlockwiseLinear.from_float(module) |
0 commit comments