Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
# flake8: noqa

from .base import *
from .fp4_quantized import *
from .naive_quantized import *
from .nvfp4_quantized import *
from .pack_quantized import *
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,12 @@ def unpack_fp4_from_uint8(

# Reshape to final form
return values.reshape(m, n).to(dtype=dtype)


@BaseCompressor.register(name=CompressionFormat.mxfp4_pack_quantized.value)
class MXFP4PackedCompressor(NVFP4PackedCompressor):
"""
Alias for mxfp4 quantized models
"""

pass
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
from compressed_tensors.config import CompressionFormat, SparsityStructure
from compressed_tensors.quantization import FP8_DTYPE
from compressed_tensors.quantization import FP8_E4M3_DATA
from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
from torch import Tensor

Expand Down Expand Up @@ -189,11 +189,11 @@ def sparse24_bitmask_compress(

bytemasks = get_24_bytemasks(tensor=tensor)

if tensor.dtype == FP8_DTYPE:
if tensor.dtype == FP8_E4M3_DATA.dtype:
# acces raw bytes of the tensor
tensor_view = tensor.view(torch.int8)
values = tensor_view[bytemasks]
values = values.view(FP8_DTYPE)
values = values.view(FP8_E4M3_DATA.dtype)
else:
values = tensor[bytemasks]

Expand Down Expand Up @@ -241,7 +241,7 @@ def get_24_bytemasks(tensor):
multiple of 4.
"""
original_dtype = tensor.dtype
if tensor.dtype == FP8_DTYPE:
if tensor.dtype == FP8_E4M3_DATA.dtype:
tensor = tensor.view(torch.int8)
original_shape = tensor.shape
num_elements = tensor.numel()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import FP8_DTYPE
from compressed_tensors.quantization import FP8_E4M3_DATA
from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
from torch import Tensor

Expand Down Expand Up @@ -138,11 +138,11 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
bytemasks = tensor != 0
row_counts = bytemasks.sum(dim=-1)
row_offsets = torch.cumsum(row_counts, 0) - row_counts
if tensor.dtype == FP8_DTYPE:
if tensor.dtype == FP8_E4M3_DATA.dtype:
# acces raw bytes of the tensor
tensor_view = tensor.view(torch.int8)
values = tensor_view[bytemasks]
values = values.view(FP8_DTYPE)
values = values.view(FP8_E4M3_DATA.dtype)
else:
values = tensor[bytemasks]
bitmasks_packed = pack_bitmasks(bytemasks)
Expand Down
1 change: 1 addition & 0 deletions src/compressed_tensors/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class CompressionFormat(Enum):
marlin_24 = "marlin-24"
mixed_precision = "mixed-precision"
nvfp4_pack_quantized = "nvfp4-pack-quantized"
mxfp4_pack_quantized = "mxfp4-pack-quantized"


@unique
Expand Down
2 changes: 2 additions & 0 deletions src/compressed_tensors/config/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def _get_quant_compression_format(
is_weight_only = weight_args is not None and input_args is None

if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value:
if weight_args.group_size == 32:
return CompressionFormat.mxfp4_pack_quantized
return CompressionFormat.nvfp4_pack_quantized

if is_weight_only: # w4a16 and w8a16
Expand Down
11 changes: 10 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,11 @@ def _quantize(
if global_scale is not None:
scale = scale.to(global_scale.dtype) / global_scale

# convert from exponent
scale_exp = scale.to(torch.int32) - 127
scale = 2.0 ** (scale_exp.to(torch.float))
scale = scale.to(dtype)

scaled = x / scale

if zero_point is not None:
Expand Down Expand Up @@ -501,6 +506,11 @@ def _dequantize(
if global_scale is not None:
scale = scale.to(global_scale.dtype) / global_scale

# convert from exponent
scale_exp = scale.to(torch.int32) - 127
scale = 2.0 ** (scale_exp.to(torch.float))
scale = scale.to(dtype)

dequant_value = x_q.to(scale.dtype)

if zero_point is not None:
Expand All @@ -510,5 +520,4 @@ def _dequantize(

if dtype is not None:
dequant_value = dequant_value.to(dtype)

return dequant_value
6 changes: 5 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,11 @@ def initialize_qparams(
scale_dtype = observed_dtype

if is_fp4(quantization_args=quantization_args):
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
if quantization_args.group_size == 16:
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
else:
# group_size 32
scale_dtype = zp_dtype = torch.uint8
else:
# TODO: consider erroring out in the future as if the dtype if not one of these,
# there is likely bug
Expand Down
14 changes: 7 additions & 7 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@


__all__ = [
"FP8_DTYPE",
"FP8_E4M3_DATA",
"FP4_E2M1_DATA",
"BFLOAT16_DATA",
"FloatArgs",
"QuantizationType",
"QuantizationStrategy",
Expand All @@ -39,9 +39,9 @@
class FloatArgs:
exponent: int
mantissa: int
bits: int
max: float
min: float
bits: Optional[int] = None
max: Optional[float] = None
min: Optional[float] = None
dtype: Optional[torch.dtype] = None


Expand Down Expand Up @@ -76,9 +76,9 @@ class FP8_E4M3_DATA(FloatArgs):
min = torch.finfo(torch.float8_e4m3fn).min
dtype = torch.float8_e4m3fn


# TODO: Remove soon in favour of a more descriptive FloatArgs
FP8_DTYPE = torch.float8_e4m3fn
class BFLOAT16_DATA(FloatArgs):
exponent = 8
mantissa = 7


class QuantizationType(str, Enum):
Expand Down
4 changes: 3 additions & 1 deletion src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def is_preset_scheme(name: str) -> bool:
symmetric=True,
dynamic=False,
group_size=32,
observer="static_minmax",
)
)

Expand All @@ -204,18 +205,19 @@ def is_preset_scheme(name: str) -> bool:
symmetric=True,
dynamic=False,
group_size=32,
observer="static_minmax",
),
input_activations=QuantizationArgs(
num_bits=4,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.GROUP,
dynamic=True,
symmetric=True,
observer=None,
group_size=32,
),
)


# 8 bit integer weights and 8 bit activations quantization
INT8_W8A8 = dict(
weights=QuantizationArgs(
Expand Down
44 changes: 34 additions & 10 deletions src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from compressed_tensors.quantization.quant_args import (
FP4_E2M1_DATA,
FP8_E4M3_DATA,
BFLOAT16_DATA,
FloatArgs,
QuantizationArgs,
QuantizationStrategy,
Expand Down Expand Up @@ -94,26 +95,49 @@ def calculate_qparams(
bit_range = bit_max - bit_min

if is_fp4(quantization_args=quantization_args):
zp_dtype = FP8_E4M3_DATA.dtype
if quantization_args.group_size == 16:
zp_dtype = FP8_E4M3_DATA.dtype
else:
# group_size 32
zp_dtype = torch.uint8
else:
zp_dtype = quantization_args.pytorch_dtype()

if quantization_args.symmetric:
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))

if is_fp4(quantization_args=quantization_args) and global_scale is not None:
# Conditionally scale the generated local scale by a global_scale
scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max)
scales = torch.clamp(scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min)
scales = scales.to(FP8_E4M3_DATA.dtype)

if is_fp4(quantization_args=quantization_args):
if global_scale is not None:
# Conditionally scale the generated local scale by a global_scale
scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max)
scales = torch.clamp(
scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min
)
scales = scales.to(FP8_E4M3_DATA.dtype)
else:
assert max_val_pos.dtype == torch.bfloat16
max_val_pos = max_val_pos.view(torch.uint16).to(torch.int32)
# Find closest power of 2
BFLOAT16_VAL_TO_ADD = 1 << (BFLOAT16_DATA.mantissa - FP4_E2M1_DATA.mantissa - 1)
BFLOAT16_SIGN_EXPONENT_MASK = ((1 << (BFLOAT16_DATA.exponent + 1)) - 1) << BFLOAT16_DATA.mantissa
# mask to only keep mantissa
block_max_uint = torch.bitwise_and(max_val_pos + BFLOAT16_VAL_TO_ADD, BFLOAT16_SIGN_EXPONENT_MASK)
block_max = block_max_uint.to(torch.uint16).view(torch.bfloat16)

# Convert to to exponent
scale_exp = (
127 + torch.floor(torch.log2(block_max)).to(torch.int32) - 2
)
scale_exp = torch.clamp(
scale_exp,
max=torch.iinfo(torch.uint8).max,
min=torch.iinfo(torch.uint8).min,
)
scales = scale_exp.to(torch.uint8)
else:
scales = max_val_pos / (float(bit_range) / 2)

# TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped
if scales.dtype == FP8_E4M3_DATA.dtype:
# torch.clamp not supported for FP8
# use the next largest fp8 value from 0
scales = torch.where(
scales == 0,
torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device),
Expand Down