Skip to content

Commit f1b8e5a

Browse files
kylesayrsdhuangnm
andcommitted
[Observers] Refactor for better FP4 support, static and memoryless observers (#1903)
* FP4 * Fix bug discovered [here](#1830 (comment)) where dynamic="local" nvfp4 calculations would increment the observer twice as fast as normal * Enable MSE observer to be used with FP4 ```psuedocode mse_quant_error := mean((x - fake_quant(x))**2) global_scale <- min[min_vals, max_vals, global_scale](mse_quant_error(x)) scale, zp <- min[min_vals, max_vals](mse_quant_error(x, global_scale)) ``` * Simplification * Make supporting attention calibration easier by separating out weight/activation/attention reshaping * Improve readability of observer codes by removing many levels of function indirection * Drop support for calibration with non-divisible group sizes. This is not really a loss, since [forward passes](https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/lifecycle/forward.py#L279) also make this assumption * New observers * `memoryless_minmax` computes min and max values on the fly in a dynamic-quantization style. This observer is useful for PTQ weight quantization * `static_minmax` computes absolute min and max values across all observations. This observer is useful for PTQ activation quantization * `memoryless_mse` computes best qparams w.r.t. MSE loss for each observation. This observer is useful for PTQ weight quantization * Memory improvements * All observers no longer store copies of scales and zero points, reducing the amount of required memory * Newly introduced "memoryless" observers do not store any quantization parameters, which greatly reduces the memory requirements for PTQ weight quantization of very large models | Diagrams | | - | | Before | | <img width="886" height="595" alt="before" src="https://github.com/user-attachments/assets/660d94c2-3ac8-4e05-9e9b-53d21145abac" /> | | After | <img width="1527" height="595" alt="after" src="https://github.com/user-attachments/assets/51a0107e-3fbd-413c-a7a6-03ddc3612169" /> | * Standardize reshaping using `flatten_for_calibration` * This function reshapes all observed values to `(num_observations, *qparams_shape, group_size)` * This function the complexity associated with passing "reduce dims" and trying to handle weights, activations, and attention states all in the same function * In the future, this function could be applied to the quantization forward pass, although there's probably no need to outside of standardization * Implement `get_global_scale` on `Observer` base * This function decouples minmax calculations from regular qparam calculations (avoiding the double increment bug) * This function enables the MSE observer to be used with FP4 global scales * Added additional minmax tests which check exact values of scales. This test passes both on main and this branch, demonstrating that minmax observer behavior remains unchanged * Added additional MSE tests which check exact values of mse losses. This test passes both on main and this branch, demonstrating that MSE observer behavior remains unchanged * Added FP4 MSE test ``` nvfp4-static-minmax | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr| |--------|------:|------|-----:|--------|---|-----:|---|------| |mmmu_val| 0|none | 0|mmmu_acc|↑ |0.6167|± | N/A| ``` ``` nvfp4-minmax | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr| |--------|------:|------|-----:|--------|---|-----:|---|------| |mmmu_val| 0|none | 0|mmmu_acc|↑ |0.6011|± | N/A| ``` --------- Signed-off-by: Kyle Sayers <[email protected]> Signed-off-by: Dan Huang <[email protected]> Co-authored-by: dhuangnm <[email protected]>
1 parent aba933c commit f1b8e5a

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import torch
44
from compressed_tensors.quantization import (
55
DynamicType,
6+
<<<<<<< HEAD
67
QuantizationArgs,
8+
=======
9+
KVCacheScaleType,
10+
QuantizationArgs,
11+
QuantizationScheme,
12+
>>>>>>> d7d1b45b ([Observers] Refactor for better FP4 support, static and memoryless observers (#1903))
713
QuantizationStatus,
814
QuantizationStrategy,
915
)

0 commit comments

Comments
 (0)