diff --git a/src/compressed_tensors/compressors/quantized_compressors/__init__.py b/src/compressed_tensors/compressors/quantized_compressors/__init__.py index 6189d59e0..1c8ec23b0 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/__init__.py +++ b/src/compressed_tensors/compressors/quantized_compressors/__init__.py @@ -14,6 +14,6 @@ # flake8: noqa from .base import * +from .fp4_quantized import * from .naive_quantized import * -from .nvfp4_quantized import * from .pack_quantized import * diff --git a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py similarity index 97% rename from src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py rename to src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py index 4fc28539e..71c7ae76c 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py @@ -210,3 +210,12 @@ def unpack_fp4_from_uint8( # Reshape to final form return values.reshape(m, n).to(dtype=dtype) + + +@BaseCompressor.register(name=CompressionFormat.mxfp4_pack_quantized.value) +class MXFP4PackedCompressor(NVFP4PackedCompressor): + """ + Alias for mxfp4 quantized models + """ + + pass diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py index f11d7b42b..17d1f3b4d 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py @@ -19,7 +19,7 @@ from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor from compressed_tensors.config import CompressionFormat, SparsityStructure -from compressed_tensors.quantization import FP8_DTYPE +from compressed_tensors.quantization import FP8_E4M3_DATA from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks from torch import Tensor @@ -189,11 +189,11 @@ def sparse24_bitmask_compress( bytemasks = get_24_bytemasks(tensor=tensor) - if tensor.dtype == FP8_DTYPE: + if tensor.dtype == FP8_E4M3_DATA.dtype: # acces raw bytes of the tensor tensor_view = tensor.view(torch.int8) values = tensor_view[bytemasks] - values = values.view(FP8_DTYPE) + values = values.view(FP8_E4M3_DATA.dtype) else: values = tensor[bytemasks] @@ -241,7 +241,7 @@ def get_24_bytemasks(tensor): multiple of 4. """ original_dtype = tensor.dtype - if tensor.dtype == FP8_DTYPE: + if tensor.dtype == FP8_E4M3_DATA.dtype: tensor = tensor.view(torch.int8) original_shape = tensor.shape num_elements = tensor.numel() diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py index 0e08be031..51a65a07e 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py @@ -18,7 +18,7 @@ from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor from compressed_tensors.config import CompressionFormat -from compressed_tensors.quantization import FP8_DTYPE +from compressed_tensors.quantization import FP8_E4M3_DATA from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks from torch import Tensor @@ -138,11 +138,11 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]: bytemasks = tensor != 0 row_counts = bytemasks.sum(dim=-1) row_offsets = torch.cumsum(row_counts, 0) - row_counts - if tensor.dtype == FP8_DTYPE: + if tensor.dtype == FP8_E4M3_DATA.dtype: # acces raw bytes of the tensor tensor_view = tensor.view(torch.int8) values = tensor_view[bytemasks] - values = values.view(FP8_DTYPE) + values = values.view(FP8_E4M3_DATA.dtype) else: values = tensor[bytemasks] bitmasks_packed = pack_bitmasks(bytemasks) diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index 5024b1d61..73d168f37 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -34,6 +34,7 @@ class CompressionFormat(Enum): marlin_24 = "marlin-24" mixed_precision = "mixed-precision" nvfp4_pack_quantized = "nvfp4-pack-quantized" + mxfp4_pack_quantized = "mxfp4-pack-quantized" @unique diff --git a/src/compressed_tensors/config/format.py b/src/compressed_tensors/config/format.py index 4f6610de3..5d0c11436 100644 --- a/src/compressed_tensors/config/format.py +++ b/src/compressed_tensors/config/format.py @@ -50,6 +50,8 @@ def _get_quant_compression_format( is_weight_only = weight_args is not None and input_args is None if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value: + if weight_args.group_size == 32: + return CompressionFormat.mxfp4_pack_quantized return CompressionFormat.nvfp4_pack_quantized if is_weight_only: # w4a16 and w8a16 diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 176b2f52d..f8348250c 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -468,6 +468,11 @@ def _quantize( if global_scale is not None: scale = scale.to(global_scale.dtype) / global_scale + # convert from exponent + scale_exp = scale.to(torch.int32) - 127 + scale = 2.0 ** (scale_exp.to(torch.float)) + scale = scale.to(dtype) + scaled = x / scale if zero_point is not None: @@ -501,6 +506,11 @@ def _dequantize( if global_scale is not None: scale = scale.to(global_scale.dtype) / global_scale + # convert from exponent + scale_exp = scale.to(torch.int32) - 127 + scale = 2.0 ** (scale_exp.to(torch.float)) + scale = scale.to(dtype) + dequant_value = x_q.to(scale.dtype) if zero_point is not None: @@ -510,5 +520,4 @@ def _dequantize( if dtype is not None: dequant_value = dequant_value.to(dtype) - return dequant_value diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 50757adc3..7a64a4c90 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -248,7 +248,11 @@ def initialize_qparams( scale_dtype = observed_dtype if is_fp4(quantization_args=quantization_args): - scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype + if quantization_args.group_size == 16: + scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype + else: + # group_size 32 + scale_dtype = zp_dtype = torch.uint8 else: # TODO: consider erroring out in the future as if the dtype if not one of these, # there is likely bug diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index d9a92d0d9..b2c146b8d 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -23,9 +23,9 @@ __all__ = [ - "FP8_DTYPE", "FP8_E4M3_DATA", "FP4_E2M1_DATA", + "BFLOAT16_DATA", "FloatArgs", "QuantizationType", "QuantizationStrategy", @@ -39,9 +39,9 @@ class FloatArgs: exponent: int mantissa: int - bits: int - max: float - min: float + bits: Optional[int] = None + max: Optional[float] = None + min: Optional[float] = None dtype: Optional[torch.dtype] = None @@ -76,9 +76,9 @@ class FP8_E4M3_DATA(FloatArgs): min = torch.finfo(torch.float8_e4m3fn).min dtype = torch.float8_e4m3fn - -# TODO: Remove soon in favour of a more descriptive FloatArgs -FP8_DTYPE = torch.float8_e4m3fn +class BFLOAT16_DATA(FloatArgs): + exponent = 8 + mantissa = 7 class QuantizationType(str, Enum): diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index e6cb79293..41ea936e8 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -193,6 +193,7 @@ def is_preset_scheme(name: str) -> bool: symmetric=True, dynamic=False, group_size=32, + observer="static_minmax", ) ) @@ -204,6 +205,7 @@ def is_preset_scheme(name: str) -> bool: symmetric=True, dynamic=False, group_size=32, + observer="static_minmax", ), input_activations=QuantizationArgs( num_bits=4, @@ -211,11 +213,11 @@ def is_preset_scheme(name: str) -> bool: strategy=QuantizationStrategy.GROUP, dynamic=True, symmetric=True, + observer=None, group_size=32, ), ) - # 8 bit integer weights and 8 bit activations quantization INT8_W8A8 = dict( weights=QuantizationArgs( diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index fccd677c0..a4e0d9c3b 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -20,6 +20,7 @@ from compressed_tensors.quantization.quant_args import ( FP4_E2M1_DATA, FP8_E4M3_DATA, + BFLOAT16_DATA, FloatArgs, QuantizationArgs, QuantizationStrategy, @@ -94,26 +95,49 @@ def calculate_qparams( bit_range = bit_max - bit_min if is_fp4(quantization_args=quantization_args): - zp_dtype = FP8_E4M3_DATA.dtype + if quantization_args.group_size == 16: + zp_dtype = FP8_E4M3_DATA.dtype + else: + # group_size 32 + zp_dtype = torch.uint8 else: zp_dtype = quantization_args.pytorch_dtype() if quantization_args.symmetric: max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) - if is_fp4(quantization_args=quantization_args) and global_scale is not None: - # Conditionally scale the generated local scale by a global_scale - scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max) - scales = torch.clamp(scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min) - scales = scales.to(FP8_E4M3_DATA.dtype) - + if is_fp4(quantization_args=quantization_args): + if global_scale is not None: + # Conditionally scale the generated local scale by a global_scale + scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max) + scales = torch.clamp( + scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min + ) + scales = scales.to(FP8_E4M3_DATA.dtype) + else: + assert max_val_pos.dtype == torch.bfloat16 + max_val_pos = max_val_pos.view(torch.uint16).to(torch.int32) + # Find closest power of 2 + BFLOAT16_VAL_TO_ADD = 1 << (BFLOAT16_DATA.mantissa - FP4_E2M1_DATA.mantissa - 1) + BFLOAT16_SIGN_EXPONENT_MASK = ((1 << (BFLOAT16_DATA.exponent + 1)) - 1) << BFLOAT16_DATA.mantissa + # mask to only keep mantissa + block_max_uint = torch.bitwise_and(max_val_pos + BFLOAT16_VAL_TO_ADD, BFLOAT16_SIGN_EXPONENT_MASK) + block_max = block_max_uint.to(torch.uint16).view(torch.bfloat16) + + # Convert to to exponent + scale_exp = ( + 127 + torch.floor(torch.log2(block_max)).to(torch.int32) - 2 + ) + scale_exp = torch.clamp( + scale_exp, + max=torch.iinfo(torch.uint8).max, + min=torch.iinfo(torch.uint8).min, + ) + scales = scale_exp.to(torch.uint8) else: scales = max_val_pos / (float(bit_range) / 2) - # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped if scales.dtype == FP8_E4M3_DATA.dtype: - # torch.clamp not supported for FP8 - # use the next largest fp8 value from 0 scales = torch.where( scales == 0, torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device),