From be8969019afafa6e91cdac426ee724d4918d9227 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 28 Jul 2025 21:52:26 +0000 Subject: [PATCH 1/5] add compression param; update qdq for batch greater than 1 --- .../quantized_compressors/nvfp4_quantized.py | 19 +++++++ .../quantization/lifecycle/forward.py | 50 +++++++++++++------ .../quantization/utils/helpers.py | 39 ++++++++++----- 3 files changed, 80 insertions(+), 28 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py index 5f348e91..fe27146a 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py @@ -60,6 +60,25 @@ def compression_param_names(self) -> Tuple[str]: "weight_zero_point", "weight_global_scale", ) + + def compression_param_info( + self, + weight_shape: torch.Size, + quantization_args: Optional[QuantizationArgs] = None, + ) -> Dict[str, Tuple[torch.Size, torch.dtype]]: + """ + Creates a dictionary of expected shapes and dtypes for each compression + parameter used by the compressor + + :param weight_shape: uncompressed weight shape + :param quantization_args: quantization parameters for the weight + :return: dictionary mapping compressed parameter names to shape and dtype + """ + output = { + "weight_packed": (torch.Size((weight_shape[0], weight_shape[1] // 2)), torch.uint8), + } + return output + def compress_weight( self, diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index b82a4195..55278f2d 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -257,13 +257,14 @@ def _process_quantization( QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP, ): + """ n_dims = x.shape if len(n_dims) > 2: x = x.squeeze(0) - + """ output_dtype = dtype if dtype is not None else x.dtype output = torch.zeros_like(x).to(output_dtype) - columns = output.shape[1] + columns = output.shape[-1] # TODO: make validation step for inputs @@ -293,14 +294,25 @@ def _process_quantization( perm = torch.argsort(g_idx) x = safe_permute(x, perm, dim=1) - x = torch.reshape( - x, - ( - x.shape[0], - ceil(x.shape[1] / group_size), - group_size, - ), - ) + if len(x.shape) > 2: + x = torch.reshape( + x, + ( + x.shape[0], + x.shape[1], + ceil(x.shape[-1] / group_size), + group_size, + ), + ) + else: + x = torch.reshape( + x, + ( + x.shape[0], + ceil(x.shape[-1] / group_size), + group_size, + ), + ) if do_quantize: output = _quantize( @@ -323,18 +335,24 @@ def _process_quantization( global_scale=global_scale, ) - output = torch.reshape( - output, - (output.shape[0], output.shape[1] * output.shape[2]), - ) + if len(x.shape) > 3: + output = torch.reshape( + output, + (output.shape[0], output.shape[1], output.shape[-1] * output.shape[-2]), + ) + else: + output = torch.reshape( + output, + (output.shape[0], output.shape[-1] * output.shape[-2]), + ) output = output.to(output_dtype) if not is_column_order: output = safe_permute(output, torch.argsort(perm), dim=1) - if len(n_dims) > 2: - output = output.unsqueeze(0) + #if len(n_dims) > 2: + # output = output.unsqueeze(0) else: # covers channel, token and tensor strategies if do_quantize: diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 42a6e19e..b2897415 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -167,7 +167,7 @@ def compute_dynamic_scales_and_zp( keep_dims = True if args.strategy == QuantizationStrategy.TOKEN: - dim = {1, 2} + dim = {0, 1, 2} reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim) elif args.strategy == QuantizationStrategy.TENSOR: reduce_dims = None @@ -175,20 +175,35 @@ def compute_dynamic_scales_and_zp( QuantizationStrategy.TENSOR_GROUP, QuantizationStrategy.GROUP, ): + #if len(value.shape) > 2: + # value = value.squeeze(0) if len(value.shape) > 2: - value = value.squeeze(0) + dim = {0, 1, 2} + else: + dim = {0, 1} - dim = {0, 1} - reduce_dims = tuple(idx for idx in range(3) if idx not in dim) + reduce_dims = tuple(idx for idx in range(len(value.shape) + 1) if idx not in dim) keep_dims = False - value = torch.reshape( - value, - ( - value.shape[0], - math.ceil(value.shape[1] / args.group_size), - args.group_size, - ), - ) + + if len(value.shape) > 2: + value = torch.reshape( + value, + ( + value.shape[0], + value.shape[1], + math.ceil(value.shape[-1] / args.group_size), + args.group_size, + ), + ) + else: + value = torch.reshape( + value, + ( + value.shape[0], + math.ceil(value.shape[-1] / args.group_size), + args.group_size, + ), + ) else: supported_strategies = ( QuantizationStrategy.TOKEN, From b29792fd7210205c9b5dc27fd61e775321323e90 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 30 Jul 2025 14:06:29 +0000 Subject: [PATCH 2/5] make generic --- .../quantized_compressors/nvfp4_quantized.py | 8 ++-- .../quantization/lifecycle/forward.py | 48 +++++-------------- .../quantization/utils/helpers.py | 32 ++++--------- 3 files changed, 26 insertions(+), 62 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py index fe27146a..90033649 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py @@ -60,7 +60,7 @@ def compression_param_names(self) -> Tuple[str]: "weight_zero_point", "weight_global_scale", ) - + def compression_param_info( self, weight_shape: torch.Size, @@ -75,11 +75,13 @@ def compression_param_info( :return: dictionary mapping compressed parameter names to shape and dtype """ output = { - "weight_packed": (torch.Size((weight_shape[0], weight_shape[1] // 2)), torch.uint8), + "weight_packed": ( + torch.Size((weight_shape[0], weight_shape[1] // 2)), + torch.uint8, + ), } return output - def compress_weight( self, weight: Tensor, diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 55278f2d..65d01f53 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -257,11 +257,7 @@ def _process_quantization( QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP, ): - """ - n_dims = x.shape - if len(n_dims) > 2: - x = x.squeeze(0) - """ + output_dtype = dtype if dtype is not None else x.dtype output = torch.zeros_like(x).to(output_dtype) columns = output.shape[-1] @@ -294,25 +290,12 @@ def _process_quantization( perm = torch.argsort(g_idx) x = safe_permute(x, perm, dim=1) - if len(x.shape) > 2: - x = torch.reshape( - x, - ( - x.shape[0], - x.shape[1], - ceil(x.shape[-1] / group_size), - group_size, - ), - ) - else: - x = torch.reshape( - x, - ( - x.shape[0], - ceil(x.shape[-1] / group_size), - group_size, - ), - ) + # Maintain all dimensions apart from the last dim, which is divided by the group_size + reshaped_dims = tuple(x.shape[:-1]) + ( + ceil(x.shape[-1] / group_size), + group_size, + ) + x = torch.reshape(x, reshaped_dims) if do_quantize: output = _quantize( @@ -335,25 +318,16 @@ def _process_quantization( global_scale=global_scale, ) - if len(x.shape) > 3: - output = torch.reshape( - output, - (output.shape[0], output.shape[1], output.shape[-1] * output.shape[-2]), - ) - else: - output = torch.reshape( - output, - (output.shape[0], output.shape[-1] * output.shape[-2]), - ) + original_shaped_dims = tuple(output.shape[:-2]) + ( + output.shape[-1] * output.shape[-2], + ) + output = torch.reshape(output, original_shaped_dims) output = output.to(output_dtype) if not is_column_order: output = safe_permute(output, torch.argsort(perm), dim=1) - #if len(n_dims) > 2: - # output = output.unsqueeze(0) - else: # covers channel, token and tensor strategies if do_quantize: output = _quantize( diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index b2897415..59b2dbe3 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -175,35 +175,23 @@ def compute_dynamic_scales_and_zp( QuantizationStrategy.TENSOR_GROUP, QuantizationStrategy.GROUP, ): - #if len(value.shape) > 2: - # value = value.squeeze(0) + if len(value.shape) > 2: dim = {0, 1, 2} else: dim = {0, 1} - reduce_dims = tuple(idx for idx in range(len(value.shape) + 1) if idx not in dim) + reduce_dims = tuple( + idx for idx in range(len(value.shape) + 1) if idx not in dim + ) keep_dims = False - if len(value.shape) > 2: - value = torch.reshape( - value, - ( - value.shape[0], - value.shape[1], - math.ceil(value.shape[-1] / args.group_size), - args.group_size, - ), - ) - else: - value = torch.reshape( - value, - ( - value.shape[0], - math.ceil(value.shape[-1] / args.group_size), - args.group_size, - ), - ) + reshaped_dims = tuple(value.shape[:-1]) + ( + math.ceil(value.shape[-1] / args.group_size), + args.group_size, + ) + value = torch.reshape(value, reshaped_dims) + else: supported_strategies = ( QuantizationStrategy.TOKEN, From 30ad3059932664ad7c47e0476d46da6e698162dc Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 30 Jul 2025 14:13:15 +0000 Subject: [PATCH 3/5] fix tests --- tests/test_quantization/test_utils/test_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_quantization/test_utils/test_helpers.py b/tests/test_quantization/test_utils/test_helpers.py index 2c6b1224..b9f9754c 100644 --- a/tests/test_quantization/test_utils/test_helpers.py +++ b/tests/test_quantization/test_utils/test_helpers.py @@ -83,7 +83,7 @@ def test_fused_global_scales(): "shape,group_size,exp_shape", [ # Only batch size =1 is supported for dynamic GROUP quantization - ((1, 4, 8), 4, torch.Size([4, 2])), + ((1, 4, 8), 4, torch.Size([1, 4, 2])), ], ) def test_compute_dynamic_scales_and_zp_group(shape, group_size, exp_shape): From 3548dc59dd088794a4cbfff30d5081b0daba6f3d Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 30 Jul 2025 16:01:56 +0000 Subject: [PATCH 4/5] remove incorrect line change; make generic --- src/compressed_tensors/quantization/utils/helpers.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 59b2dbe3..d6ee3486 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -167,7 +167,7 @@ def compute_dynamic_scales_and_zp( keep_dims = True if args.strategy == QuantizationStrategy.TOKEN: - dim = {0, 1, 2} + dim = {1, 2} reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim) elif args.strategy == QuantizationStrategy.TENSOR: reduce_dims = None @@ -176,13 +176,8 @@ def compute_dynamic_scales_and_zp( QuantizationStrategy.GROUP, ): - if len(value.shape) > 2: - dim = {0, 1, 2} - else: - dim = {0, 1} - reduce_dims = tuple( - idx for idx in range(len(value.shape) + 1) if idx not in dim + idx for idx in range(len(value.shape) + 1) if idx not in range(value.dim()) ) keep_dims = False From 1cfd8bb98fcc1f4b13e01f1223e87219d33b9137 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 31 Jul 2025 17:00:56 +0000 Subject: [PATCH 5/5] update --- .../quantization/lifecycle/forward.py | 10 +++------- src/compressed_tensors/quantization/utils/helpers.py | 8 +++----- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 65d01f53..d3c9da40 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -291,11 +291,11 @@ def _process_quantization( x = safe_permute(x, perm, dim=1) # Maintain all dimensions apart from the last dim, which is divided by the group_size - reshaped_dims = tuple(x.shape[:-1]) + ( + reshaped_dims = ( ceil(x.shape[-1] / group_size), group_size, ) - x = torch.reshape(x, reshaped_dims) + x = x.unflatten(-1, reshaped_dims) if do_quantize: output = _quantize( @@ -318,11 +318,7 @@ def _process_quantization( global_scale=global_scale, ) - original_shaped_dims = tuple(output.shape[:-2]) + ( - output.shape[-1] * output.shape[-2], - ) - output = torch.reshape(output, original_shaped_dims) - + output = output.flatten(start_dim=-2) output = output.to(output_dtype) if not is_column_order: diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index d6ee3486..5d28cac2 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -176,16 +176,14 @@ def compute_dynamic_scales_and_zp( QuantizationStrategy.GROUP, ): - reduce_dims = tuple( - idx for idx in range(len(value.shape) + 1) if idx not in range(value.dim()) - ) + reduce_dims = -1 keep_dims = False - reshaped_dims = tuple(value.shape[:-1]) + ( + reshaped_dims = ( math.ceil(value.shape[-1] / args.group_size), args.group_size, ) - value = torch.reshape(value, reshaped_dims) + value = value.unflatten(-1, reshaped_dims) else: supported_strategies = (