Skip to content

Commit 46d84d8

Browse files
authored
[Quantization][Decompression] Fix QDQ for dynamic quant; Update NVFP4 Compression Params (#407)
* add compression param; update qdq for batch greater than 1 * make generic * fix tests * remove incorrect line change; make generic * update
1 parent 3d49764 commit 46d84d8

File tree

4 files changed

+36
-31
lines changed

4 files changed

+36
-31
lines changed

src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,27 @@ def compression_param_names(self) -> Tuple[str]:
6161
"weight_global_scale",
6262
)
6363

64+
def compression_param_info(
65+
self,
66+
weight_shape: torch.Size,
67+
quantization_args: Optional[QuantizationArgs] = None,
68+
) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
69+
"""
70+
Creates a dictionary of expected shapes and dtypes for each compression
71+
parameter used by the compressor
72+
73+
:param weight_shape: uncompressed weight shape
74+
:param quantization_args: quantization parameters for the weight
75+
:return: dictionary mapping compressed parameter names to shape and dtype
76+
"""
77+
output = {
78+
"weight_packed": (
79+
torch.Size((weight_shape[0], weight_shape[1] // 2)),
80+
torch.uint8,
81+
),
82+
}
83+
return output
84+
6485
def compress_weight(
6586
self,
6687
weight: Tensor,

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -257,13 +257,10 @@ def _process_quantization(
257257
QuantizationStrategy.GROUP,
258258
QuantizationStrategy.TENSOR_GROUP,
259259
):
260-
n_dims = x.shape
261-
if len(n_dims) > 2:
262-
x = x.squeeze(0)
263260

264261
output_dtype = dtype if dtype is not None else x.dtype
265262
output = torch.zeros_like(x).to(output_dtype)
266-
columns = output.shape[1]
263+
columns = output.shape[-1]
267264

268265
# TODO: make validation step for inputs
269266

@@ -293,14 +290,12 @@ def _process_quantization(
293290
perm = torch.argsort(g_idx)
294291
x = safe_permute(x, perm, dim=1)
295292

296-
x = torch.reshape(
297-
x,
298-
(
299-
x.shape[0],
300-
ceil(x.shape[1] / group_size),
301-
group_size,
302-
),
293+
# Maintain all dimensions apart from the last dim, which is divided by the group_size
294+
reshaped_dims = (
295+
ceil(x.shape[-1] / group_size),
296+
group_size,
303297
)
298+
x = x.unflatten(-1, reshaped_dims)
304299

305300
if do_quantize:
306301
output = _quantize(
@@ -323,19 +318,12 @@ def _process_quantization(
323318
global_scale=global_scale,
324319
)
325320

326-
output = torch.reshape(
327-
output,
328-
(output.shape[0], output.shape[1] * output.shape[2]),
329-
)
330-
321+
output = output.flatten(start_dim=-2)
331322
output = output.to(output_dtype)
332323

333324
if not is_column_order:
334325
output = safe_permute(output, torch.argsort(perm), dim=1)
335326

336-
if len(n_dims) > 2:
337-
output = output.unsqueeze(0)
338-
339327
else: # covers channel, token and tensor strategies
340328
if do_quantize:
341329
output = _quantize(

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -175,20 +175,16 @@ def compute_dynamic_scales_and_zp(
175175
QuantizationStrategy.TENSOR_GROUP,
176176
QuantizationStrategy.GROUP,
177177
):
178-
if len(value.shape) > 2:
179-
value = value.squeeze(0)
180178

181-
dim = {0, 1}
182-
reduce_dims = tuple(idx for idx in range(3) if idx not in dim)
179+
reduce_dims = -1
183180
keep_dims = False
184-
value = torch.reshape(
185-
value,
186-
(
187-
value.shape[0],
188-
math.ceil(value.shape[1] / args.group_size),
189-
args.group_size,
190-
),
181+
182+
reshaped_dims = (
183+
math.ceil(value.shape[-1] / args.group_size),
184+
args.group_size,
191185
)
186+
value = value.unflatten(-1, reshaped_dims)
187+
192188
else:
193189
supported_strategies = (
194190
QuantizationStrategy.TOKEN,

tests/test_quantization/test_utils/test_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def test_fused_global_scales():
8383
"shape,group_size,exp_shape",
8484
[
8585
# Only batch size =1 is supported for dynamic GROUP quantization
86-
((1, 4, 8), 4, torch.Size([4, 2])),
86+
((1, 4, 8), 4, torch.Size([1, 4, 2])),
8787
],
8888
)
8989
def test_compute_dynamic_scales_and_zp_group(shape, group_size, exp_shape):

0 commit comments

Comments
 (0)