From d1b315f60efc433fda0afb391562def58dcd39f1 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 28 Aug 2025 17:36:55 +0000 Subject: [PATCH 1/6] update --- src/compressed_tensors/quantization/quant_scheme.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index e6cb7929..faca2afa 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -215,7 +215,6 @@ def is_preset_scheme(name: str) -> bool: ), ) - # 8 bit integer weights and 8 bit activations quantization INT8_W8A8 = dict( weights=QuantizationArgs( From 868655f42ed47d421492fa3645d84b116c0f9d73 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 28 Aug 2025 21:42:25 +0000 Subject: [PATCH 2/6] add mxfp4 calibration support --- .../quantization/lifecycle/forward.py | 4 +- .../quantization/lifecycle/initialize.py | 6 +- .../quantization/utils/helpers.py | 61 ++++++++++++++----- 3 files changed, 53 insertions(+), 18 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 176b2f52..bf6214c6 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -468,6 +468,7 @@ def _quantize( if global_scale is not None: scale = scale.to(global_scale.dtype) / global_scale + scale = scale.to(x.dtype) / torch.iinfo(torch.uint8).max scaled = x / scale if zero_point is not None: @@ -501,6 +502,8 @@ def _dequantize( if global_scale is not None: scale = scale.to(global_scale.dtype) / global_scale + scale = scale.to(torch.float16) / torch.iinfo(torch.uint8).max + dequant_value = x_q.to(scale.dtype) if zero_point is not None: @@ -510,5 +513,4 @@ def _dequantize( if dtype is not None: dequant_value = dequant_value.to(dtype) - return dequant_value diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 50757adc..7a64a4c9 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -248,7 +248,11 @@ def initialize_qparams( scale_dtype = observed_dtype if is_fp4(quantization_args=quantization_args): - scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype + if quantization_args.group_size == 16: + scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype + else: + # group_size 32 + scale_dtype = zp_dtype = torch.uint8 else: # TODO: consider erroring out in the future as if the dtype if not one of these, # there is likely bug diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index fccd677c..f9c8eacb 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -64,6 +64,18 @@ def is_fp4(quantization_args: QuantizationArgs): and quantization_args.type == QuantizationType.FLOAT ) +def get_power_of_two(x): + powers = torch.tensor([0, 1, 2, 4, 8, 16, 32, 64, 128], dtype=torch.uint8).to(x.device) + + # Expand and compute distances + diff = (x.unsqueeze(-1).to(torch.int16) - powers.to(torch.int16)).abs() + + # Find nearest index + nearest_idx = diff.argmin(dim=-1) + + return powers[nearest_idx] + + def calculate_qparams( min_vals: Tensor, @@ -94,33 +106,50 @@ def calculate_qparams( bit_range = bit_max - bit_min if is_fp4(quantization_args=quantization_args): - zp_dtype = FP8_E4M3_DATA.dtype + if quantization_args.group_size == 16: + zp_dtype = FP8_E4M3_DATA.dtype + else: + # group_size 32 + zp_dtype = torch.uint8 else: zp_dtype = quantization_args.pytorch_dtype() if quantization_args.symmetric: max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) - if is_fp4(quantization_args=quantization_args) and global_scale is not None: - # Conditionally scale the generated local scale by a global_scale - scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max) - scales = torch.clamp(scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min) - scales = scales.to(FP8_E4M3_DATA.dtype) + if is_fp4(quantization_args=quantization_args): + if global_scale is not None: + # Conditionally scale the generated local scale by a global_scale + scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max) + scales = torch.clamp( + scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min + ) + scales = scales.to(FP8_E4M3_DATA.dtype) + else: + + scales = torch.iinfo(torch.uint8).max * (max_val_pos) # / FP4_E2M1_DATA.max) + scales = torch.clamp( + scales, + max=torch.iinfo(torch.uint8).max, + min=torch.iinfo(torch.uint8).min, + ) + scales = scales.to(torch.uint8) + scales = get_power_of_two(scales) else: scales = max_val_pos / (float(bit_range) / 2) # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped - if scales.dtype == FP8_E4M3_DATA.dtype: - # torch.clamp not supported for FP8 - # use the next largest fp8 value from 0 - scales = torch.where( - scales == 0, - torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device), - scales, - ) - else: - scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + # if scales.dtype == FP8_E4M3_DATA.dtype: + # torch.clamp not supported for FP8 + # use the next largest fp8 value from 0 + # scales = torch.where( + # scales == 0, + # torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device), + # scales, + # ) + # else: + # scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: From f72fdf545eae490328322243ef87ac6765e02c02 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 16 Oct 2025 15:58:24 -0400 Subject: [PATCH 3/6] update --- .../quantized_compressors/__init__.py | 2 +- .../{nvfp4_quantized.py => fp4_quantized.py} | 9 ++++ src/compressed_tensors/config/base.py | 1 + src/compressed_tensors/config/format.py | 2 + .../quantization/lifecycle/forward.py | 9 +++- .../quantization/quant_scheme.py | 3 ++ .../quantization/utils/helpers.py | 35 ++++++------- .../quantization/utils/mxfp4.py | 49 +++++++++++++++++++ 8 files changed, 88 insertions(+), 22 deletions(-) rename src/compressed_tensors/compressors/quantized_compressors/{nvfp4_quantized.py => fp4_quantized.py} (97%) create mode 100644 src/compressed_tensors/quantization/utils/mxfp4.py diff --git a/src/compressed_tensors/compressors/quantized_compressors/__init__.py b/src/compressed_tensors/compressors/quantized_compressors/__init__.py index 6189d59e..1c8ec23b 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/__init__.py +++ b/src/compressed_tensors/compressors/quantized_compressors/__init__.py @@ -14,6 +14,6 @@ # flake8: noqa from .base import * +from .fp4_quantized import * from .naive_quantized import * -from .nvfp4_quantized import * from .pack_quantized import * diff --git a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py similarity index 97% rename from src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py rename to src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py index 4fc28539..71c7ae76 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py @@ -210,3 +210,12 @@ def unpack_fp4_from_uint8( # Reshape to final form return values.reshape(m, n).to(dtype=dtype) + + +@BaseCompressor.register(name=CompressionFormat.mxfp4_pack_quantized.value) +class MXFP4PackedCompressor(NVFP4PackedCompressor): + """ + Alias for mxfp4 quantized models + """ + + pass diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index 5024b1d6..73d168f3 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -34,6 +34,7 @@ class CompressionFormat(Enum): marlin_24 = "marlin-24" mixed_precision = "mixed-precision" nvfp4_pack_quantized = "nvfp4-pack-quantized" + mxfp4_pack_quantized = "mxfp4-pack-quantized" @unique diff --git a/src/compressed_tensors/config/format.py b/src/compressed_tensors/config/format.py index 4f6610de..5d0c1143 100644 --- a/src/compressed_tensors/config/format.py +++ b/src/compressed_tensors/config/format.py @@ -50,6 +50,8 @@ def _get_quant_compression_format( is_weight_only = weight_args is not None and input_args is None if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value: + if weight_args.group_size == 32: + return CompressionFormat.mxfp4_pack_quantized return CompressionFormat.nvfp4_pack_quantized if is_weight_only: # w4a16 and w8a16 diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index bf6214c6..68a2e754 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -468,7 +468,10 @@ def _quantize( if global_scale is not None: scale = scale.to(global_scale.dtype) / global_scale - scale = scale.to(x.dtype) / torch.iinfo(torch.uint8).max + scale_exp = scale.to(torch.int32) - 127 + scale = 2.0 ** (scale_exp.to(torch.float)) + scale = scale.to(dtype) + scaled = x / scale if zero_point is not None: @@ -502,7 +505,9 @@ def _dequantize( if global_scale is not None: scale = scale.to(global_scale.dtype) / global_scale - scale = scale.to(torch.float16) / torch.iinfo(torch.uint8).max + scale_exp = scale.to(torch.int32) - 127 + scale = 2.0 ** (scale_exp.to(torch.float)) + scale = scale.to(dtype) dequant_value = x_q.to(scale.dtype) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index faca2afa..41ea936e 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -193,6 +193,7 @@ def is_preset_scheme(name: str) -> bool: symmetric=True, dynamic=False, group_size=32, + observer="static_minmax", ) ) @@ -204,6 +205,7 @@ def is_preset_scheme(name: str) -> bool: symmetric=True, dynamic=False, group_size=32, + observer="static_minmax", ), input_activations=QuantizationArgs( num_bits=4, @@ -211,6 +213,7 @@ def is_preset_scheme(name: str) -> bool: strategy=QuantizationStrategy.GROUP, dynamic=True, symmetric=True, + observer=None, group_size=32, ), ) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index f9c8eacb..ef048410 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -64,18 +64,6 @@ def is_fp4(quantization_args: QuantizationArgs): and quantization_args.type == QuantizationType.FLOAT ) -def get_power_of_two(x): - powers = torch.tensor([0, 1, 2, 4, 8, 16, 32, 64, 128], dtype=torch.uint8).to(x.device) - - # Expand and compute distances - diff = (x.unsqueeze(-1).to(torch.int16) - powers.to(torch.int16)).abs() - - # Find nearest index - nearest_idx = diff.argmin(dim=-1) - - return powers[nearest_idx] - - def calculate_qparams( min_vals: Tensor, @@ -126,16 +114,25 @@ def calculate_qparams( ) scales = scales.to(FP8_E4M3_DATA.dtype) else: - - scales = torch.iinfo(torch.uint8).max * (max_val_pos) # / FP4_E2M1_DATA.max) - scales = torch.clamp( - scales, + + """ + block_max = max_val_pos.view(torch.uint16).to(torch.int32) + BFLOAT16_VAL_TO_ADD = (1 <<(7 - 1 - 1)) + BFLOAT16_SIGN_EXPONENT_MASK = (((1 << (8 + 1)) - 1) << 7) + block_max_uint = torch.bitwise_and(block_max + BFLOAT16_VAL_TO_ADD, BFLOAT16_SIGN_EXPONENT_MASK) + block_max_uint = block_max_uint.to(torch.uint16) + block_max = block_max_uint.view(torch.bfloat16) + """ + scale_exp = ( + 127 + torch.floor(torch.log2(max_val_pos)).to(torch.int32) - 2 + ) + # clamp and convert to uint8 + scale_exp = torch.clamp( + scale_exp, max=torch.iinfo(torch.uint8).max, min=torch.iinfo(torch.uint8).min, ) - scales = scales.to(torch.uint8) - scales = get_power_of_two(scales) - + scales = scale_exp.to(torch.uint8) else: scales = max_val_pos / (float(bit_range) / 2) diff --git a/src/compressed_tensors/quantization/utils/mxfp4.py b/src/compressed_tensors/quantization/utils/mxfp4.py new file mode 100644 index 00000000..d13d2383 --- /dev/null +++ b/src/compressed_tensors/quantization/utils/mxfp4.py @@ -0,0 +1,49 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +hidden_size = 64 * 32 +FLOAT8_E8M0_MAX_EXP = 127 +BFLOAT16_EXP_BITS = 8 +BFLOAT16_MANTISSA_BITS = 7 +FLOAT4_MANTISSA_BITS = 1 + +BFLOAT16_VAL_TO_ADD = 1 << (7 - 1 - 1) +BFLOAT16_SIGN_EXPONENT_MASK = ((1 << (8 + 1)) - 1) << 7 + + +x = torch.rand(1, hidden_size, dtype=torch.bfloat16, device="cuda") +x = x.reshape(*x.shape[:-1], -1, 32) +block_max = torch.max(torch.abs(x), dim=-1).values +breakpoint() +# --- 3. Bit-level normalization (same as before) +block_max_bits = block_max.view(torch.uint16).to(torch.int32) +block_max_bits = torch.bitwise_and( + block_max_bits + BFLOAT16_VAL_TO_ADD, BFLOAT16_SIGN_EXPONENT_MASK +) +block_max_bits = block_max_bits.to(torch.uint16) +block_max = block_max_bits.view(torch.bfloat16) + +# --- 4. Compute exponent scale (power-of-two) +scale_exp = FLOAT8_E8M0_MAX_EXP + torch.floor(torch.log2(block_max)).to(torch.int32) - 2 +scale_exp = torch.clamp(scale_exp, 0, 255) # uint8 range + +# --- 5. Convert to uint8 and to actual float scale +scales_uint8 = scale_exp.to(torch.uint8) +print(x.shape) +print(block_max.shape) + +breakpoint() From 868501f0e14b50e5c71ad94e11abc8410fee0821 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 16 Oct 2025 16:00:41 -0400 Subject: [PATCH 4/6] remove --- .../quantization/utils/mxfp4.py | 49 ------------------- 1 file changed, 49 deletions(-) delete mode 100644 src/compressed_tensors/quantization/utils/mxfp4.py diff --git a/src/compressed_tensors/quantization/utils/mxfp4.py b/src/compressed_tensors/quantization/utils/mxfp4.py deleted file mode 100644 index d13d2383..00000000 --- a/src/compressed_tensors/quantization/utils/mxfp4.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - - -hidden_size = 64 * 32 -FLOAT8_E8M0_MAX_EXP = 127 -BFLOAT16_EXP_BITS = 8 -BFLOAT16_MANTISSA_BITS = 7 -FLOAT4_MANTISSA_BITS = 1 - -BFLOAT16_VAL_TO_ADD = 1 << (7 - 1 - 1) -BFLOAT16_SIGN_EXPONENT_MASK = ((1 << (8 + 1)) - 1) << 7 - - -x = torch.rand(1, hidden_size, dtype=torch.bfloat16, device="cuda") -x = x.reshape(*x.shape[:-1], -1, 32) -block_max = torch.max(torch.abs(x), dim=-1).values -breakpoint() -# --- 3. Bit-level normalization (same as before) -block_max_bits = block_max.view(torch.uint16).to(torch.int32) -block_max_bits = torch.bitwise_and( - block_max_bits + BFLOAT16_VAL_TO_ADD, BFLOAT16_SIGN_EXPONENT_MASK -) -block_max_bits = block_max_bits.to(torch.uint16) -block_max = block_max_bits.view(torch.bfloat16) - -# --- 4. Compute exponent scale (power-of-two) -scale_exp = FLOAT8_E8M0_MAX_EXP + torch.floor(torch.log2(block_max)).to(torch.int32) - 2 -scale_exp = torch.clamp(scale_exp, 0, 255) # uint8 range - -# --- 5. Convert to uint8 and to actual float scale -scales_uint8 = scale_exp.to(torch.uint8) -print(x.shape) -print(block_max.shape) - -breakpoint() From b8c04dc9d816c64c96f8b632c37da6e9dc603ae8 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 21 Oct 2025 17:13:40 -0400 Subject: [PATCH 5/6] update --- .../quantization/lifecycle/forward.py | 2 + .../quantization/quant_args.py | 14 +++---- .../quantization/utils/helpers.py | 41 +++++++++---------- 3 files changed, 28 insertions(+), 29 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 68a2e754..f8348250 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -468,6 +468,7 @@ def _quantize( if global_scale is not None: scale = scale.to(global_scale.dtype) / global_scale + # convert from exponent scale_exp = scale.to(torch.int32) - 127 scale = 2.0 ** (scale_exp.to(torch.float)) scale = scale.to(dtype) @@ -505,6 +506,7 @@ def _dequantize( if global_scale is not None: scale = scale.to(global_scale.dtype) / global_scale + # convert from exponent scale_exp = scale.to(torch.int32) - 127 scale = 2.0 ** (scale_exp.to(torch.float)) scale = scale.to(dtype) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index d9a92d0d..323f2b6d 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -23,9 +23,9 @@ __all__ = [ - "FP8_DTYPE", "FP8_E4M3_DATA", "FP4_E2M1_DATA", + "BFLOAT16_DATA" "FloatArgs", "QuantizationType", "QuantizationStrategy", @@ -39,9 +39,9 @@ class FloatArgs: exponent: int mantissa: int - bits: int - max: float - min: float + bits: Optional[int] = None + max: Optional[float] = None + min: Optional[float] = None dtype: Optional[torch.dtype] = None @@ -76,9 +76,9 @@ class FP8_E4M3_DATA(FloatArgs): min = torch.finfo(torch.float8_e4m3fn).min dtype = torch.float8_e4m3fn - -# TODO: Remove soon in favour of a more descriptive FloatArgs -FP8_DTYPE = torch.float8_e4m3fn +class BFLOAT16_DATA(FloatArgs): + exponent = 8 + mantissa = 7 class QuantizationType(str, Enum): diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index ef048410..3300c136 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -20,6 +20,7 @@ from compressed_tensors.quantization.quant_args import ( FP4_E2M1_DATA, FP8_E4M3_DATA, + BFLOAT16_DATA, FloatArgs, QuantizationArgs, QuantizationStrategy, @@ -114,19 +115,18 @@ def calculate_qparams( ) scales = scales.to(FP8_E4M3_DATA.dtype) else: - - """ - block_max = max_val_pos.view(torch.uint16).to(torch.int32) - BFLOAT16_VAL_TO_ADD = (1 <<(7 - 1 - 1)) - BFLOAT16_SIGN_EXPONENT_MASK = (((1 << (8 + 1)) - 1) << 7) - block_max_uint = torch.bitwise_and(block_max + BFLOAT16_VAL_TO_ADD, BFLOAT16_SIGN_EXPONENT_MASK) - block_max_uint = block_max_uint.to(torch.uint16) - block_max = block_max_uint.view(torch.bfloat16) - """ + max_val_pos = max_val_pos.view(torch.uint16).to(torch.int32) + # Find closest power of 2 + BFLOAT16_VAL_TO_ADD = 1 << (BFLOAT16_DATA.mantissa - FP4_E2M1_DATA.mantissa - 1) + BFLOAT16_SIGN_EXPONENT_MASK = ((1 << (BFLOAT16_DATA.exponent + 1)) - 1) << BFLOAT16_DATA.mantissa + # mask to only keep mantissa + block_max_uint = torch.bitwise_and(max_val_pos + BFLOAT16_VAL_TO_ADD, BFLOAT16_SIGN_EXPONENT_MASK) + block_max = block_max_uint.to(torch.uint16).view(torch.bfloat16) + + # Convert to to exponent scale_exp = ( - 127 + torch.floor(torch.log2(max_val_pos)).to(torch.int32) - 2 + 127 + torch.floor(torch.log2(block_max)).to(torch.int32) - 2 ) - # clamp and convert to uint8 scale_exp = torch.clamp( scale_exp, max=torch.iinfo(torch.uint8).max, @@ -136,17 +136,14 @@ def calculate_qparams( else: scales = max_val_pos / (float(bit_range) / 2) - # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped - # if scales.dtype == FP8_E4M3_DATA.dtype: - # torch.clamp not supported for FP8 - # use the next largest fp8 value from 0 - # scales = torch.where( - # scales == 0, - # torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device), - # scales, - # ) - # else: - # scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + if scales.dtype == FP8_E4M3_DATA.dtype: + scales = torch.where( + scales == 0, + torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device), + scales, + ) + else: + scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: From c7c973068088438c331f13f63df95834353f6adf Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 22 Oct 2025 10:20:23 -0400 Subject: [PATCH 6/6] update --- .../compressors/sparse_compressors/sparse_24_bitmask.py | 8 ++++---- .../compressors/sparse_compressors/sparse_bitmask.py | 6 +++--- src/compressed_tensors/quantization/quant_args.py | 2 +- src/compressed_tensors/quantization/utils/helpers.py | 1 + 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py index f11d7b42..17d1f3b4 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py @@ -19,7 +19,7 @@ from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor from compressed_tensors.config import CompressionFormat, SparsityStructure -from compressed_tensors.quantization import FP8_DTYPE +from compressed_tensors.quantization import FP8_E4M3_DATA from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks from torch import Tensor @@ -189,11 +189,11 @@ def sparse24_bitmask_compress( bytemasks = get_24_bytemasks(tensor=tensor) - if tensor.dtype == FP8_DTYPE: + if tensor.dtype == FP8_E4M3_DATA.dtype: # acces raw bytes of the tensor tensor_view = tensor.view(torch.int8) values = tensor_view[bytemasks] - values = values.view(FP8_DTYPE) + values = values.view(FP8_E4M3_DATA.dtype) else: values = tensor[bytemasks] @@ -241,7 +241,7 @@ def get_24_bytemasks(tensor): multiple of 4. """ original_dtype = tensor.dtype - if tensor.dtype == FP8_DTYPE: + if tensor.dtype == FP8_E4M3_DATA.dtype: tensor = tensor.view(torch.int8) original_shape = tensor.shape num_elements = tensor.numel() diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py index 0e08be03..51a65a07 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py @@ -18,7 +18,7 @@ from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor from compressed_tensors.config import CompressionFormat -from compressed_tensors.quantization import FP8_DTYPE +from compressed_tensors.quantization import FP8_E4M3_DATA from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks from torch import Tensor @@ -138,11 +138,11 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]: bytemasks = tensor != 0 row_counts = bytemasks.sum(dim=-1) row_offsets = torch.cumsum(row_counts, 0) - row_counts - if tensor.dtype == FP8_DTYPE: + if tensor.dtype == FP8_E4M3_DATA.dtype: # acces raw bytes of the tensor tensor_view = tensor.view(torch.int8) values = tensor_view[bytemasks] - values = values.view(FP8_DTYPE) + values = values.view(FP8_E4M3_DATA.dtype) else: values = tensor[bytemasks] bitmasks_packed = pack_bitmasks(bytemasks) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 323f2b6d..b2c146b8 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -25,7 +25,7 @@ __all__ = [ "FP8_E4M3_DATA", "FP4_E2M1_DATA", - "BFLOAT16_DATA" + "BFLOAT16_DATA", "FloatArgs", "QuantizationType", "QuantizationStrategy", diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 3300c136..a4e0d9c3 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -115,6 +115,7 @@ def calculate_qparams( ) scales = scales.to(FP8_E4M3_DATA.dtype) else: + assert max_val_pos.dtype == torch.bfloat16 max_val_pos = max_val_pos.view(torch.uint16).to(torch.int32) # Find closest power of 2 BFLOAT16_VAL_TO_ADD = 1 << (BFLOAT16_DATA.mantissa - FP4_E2M1_DATA.mantissa - 1)