Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-check-transformers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
steps:
- uses: actions/setup-python@v5
with:
python-version: '3.9'
python-version: '3.10'
- uses: actions/checkout@v4
with:
fetch-depth: 0
Expand Down
6 changes: 5 additions & 1 deletion docs/observers.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ from llmcompressor.observers import Observer
from compressed_tensors.quantization.quant_args import QuantizationArgs

args = QuantizationArgs(num_bits=4, strategy="group", group_size=128)
observer = Observer.load_from_registry("minmax", quantization_args=args)
observer = Observer.load_from_registry(
"minmax",
base_name="weight",
quantization_args=args,
)

x = torch.randn(64, 512)
scale, zero_point = observer(x)
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ def localversion_func(version: ScmVersion) -> str:
),
("pillow>=10.4.0,<=11.3.0" if BUILD_TYPE == "release" else "pillow>=10.4.0"),
(
"compressed-tensors==0.12.1"
"compressed-tensors==0.12.2"
if BUILD_TYPE == "release"
else "compressed-tensors>=0.12.2a2"
else "compressed-tensors>=0.12.3a2"
),
],
extras_require={
Expand Down
10 changes: 6 additions & 4 deletions src/llmcompressor/modifiers/quantization/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,15 @@ def update(
"""

if len(self.k_observers) <= layer_idx:
k_observer_name = self.quantization_args.observer
k_observer = Observer.load_from_registry(
k_observer_name, quantization_args=self.quantization_args
self.quantization_args.observer,
base_name="k",
args=self.quantization_args,
)
v_observer_name = self.quantization_args.observer
v_observer = Observer.load_from_registry(
v_observer_name, quantization_args=self.quantization_args
self.quantization_args.observer,
base_name="v",
args=self.quantization_args,
)

# NOTE: User may ignore some layers in configuration,
Expand Down
75 changes: 19 additions & 56 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from compressed_tensors.quantization import (
DynamicType,
KVCacheScaleType,
QuantizationArgs,
QuantizationScheme,
QuantizationStatus,
QuantizationStrategy,
Expand All @@ -19,12 +20,6 @@
from llmcompressor.observers import Observer
from llmcompressor.utils.helpers import getattr_chain

DEFAULT_MAXSHRINK = 0.20
DEFAULT_PATIENCE = 5
DEFAULT_AVERAGING_CONSTANT = 0.01
DEFAULT_GRID = 100.0
DEFAULT_NORM = 2.4

__all__ = [
"initialize_observer",
"update_weight_zp_scale",
Expand Down Expand Up @@ -54,31 +49,19 @@ def initialize_observer(
:param base_name: str used to name the observer attribute

"""

arg_name = "weights" if base_name == "weight" else f"{base_name}_activations"
quantization_scheme = getattr(module, "quantization_scheme", None)
if not quantization_scheme:
# no quantization scheme nothing to do
return

quantization_args = getattr(quantization_scheme, arg_name, None)
# dont need observers for dynamic
if quantization_args is not None and quantization_args.dynamic in (
False,
DynamicType.LOCAL,
):
observer_kwargs = quantization_args.observer_kwargs or {}
if base_name == "weight":
arg_name = "weights"
elif base_name == "output":
arg_name = "output_activations"
else: # input, q, k, v
arg_name = "input_activations"

args: QuantizationArgs = getattr_chain(
module, f"quantization_scheme.{arg_name}", None
)
if args is not None and args.dynamic is not True:
observer = Observer.load_from_registry(
quantization_args.observer,
quantization_args=quantization_args,
averaging_constant=observer_kwargs.get(
"averaging_constant", DEFAULT_AVERAGING_CONSTANT
),
# used by mse observer only, will be ignored by minmax observer
maxshrink=observer_kwargs.get("maxshrink", DEFAULT_MAXSHRINK),
patience=observer_kwargs.get("patience", DEFAULT_PATIENCE),
grid=observer_kwargs.get("grid", DEFAULT_GRID),
norm=observer_kwargs.get("norm", DEFAULT_NORM),
args.observer, base_name=base_name, args=args, module=module
)
module.register_module(f"{base_name}_observer", observer)

Expand All @@ -100,36 +83,17 @@ def call_observer(
base_name is "weight", then the module's weight tensor will be used
"""
with align_module_device(module):
if base_name == "weight":
value = module.weight
g_idx = getattr(module, "weight_g_idx", None)
elif value is not None:
g_idx = None
else:
raise ValueError(
"Must provide a value to observe if not using weight observer"
)

observer = getattr(module, f"{base_name}_observer")
value = module.weight if base_name == "weight" else value
observer: Observer = getattr(module, f"{base_name}_observer")

if should_calculate_gparam:
global_scale = observer(
value,
should_calculate_gparam=True,
)
global_scale = observer.get_global_scale(value)
update_offload_parameter(module, f"{base_name}_global_scale", global_scale)
else:
global_scale = getattr(module, f"{base_name}_global_scale", None)

if should_calculate_qparams:
updated_scale, updated_zero_point = observer(
value, g_idx=g_idx, global_scale=global_scale
)
# register or update scale & zero_point parameters (supports block shapes)
scale_name = f"{base_name}_scale"
zp_name = f"{base_name}_zero_point"
update_offload_parameter(module, scale_name, updated_scale)
update_offload_parameter(module, zp_name, updated_zero_point)
scale, zero_point = observer(value)
update_offload_parameter(module, f"{base_name}_scale", scale)
update_offload_parameter(module, f"{base_name}_zero_point", zero_point)


def update_weight_global_scale(module: Module):
Expand All @@ -148,7 +112,6 @@ def update_weight_global_scale(module: Module):
should_calculate_gparam=True,
should_calculate_qparams=False,
)
module.weight_observer.reset()


def update_weight_zp_scale(module: Module):
Expand Down
16 changes: 10 additions & 6 deletions src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
QuantizationStrategy,
fake_quantize,
)
from compressed_tensors.utils import update_offload_parameter
from loguru import logger

from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD
Expand Down Expand Up @@ -95,8 +96,10 @@ def quantize_weight(

# create observer for calculating quantization parameters
observer = Observer.load_from_registry(
quant_args.observer,
quantization_args=quant_args,
"minmax",
base_name="weight",
args=quant_args,
module=module,
averaging_constant=1.0, # ignore moving average
)

Expand All @@ -119,22 +122,23 @@ def quantize_weight(
if actorder == ActivationOrdering.GROUP:
# permute by activation order first, then update groups
W, H, perm = _apply_activation_ordering(W, H)
scale, zero_point = observer(W, g_idx=None)
update_offload_parameter(module, "weight_g_idx", g_idx)
scale, zero_point = observer(W)

# use identity g_idx (invert permutation later)

elif actorder == ActivationOrdering.WEIGHT:
# update groups first, then permute by activation order
scale, zero_point = observer(W, g_idx=None)
scale, zero_point = observer(W)
W, H, perm = _apply_activation_ordering(W, H)

# permute g_idx to maintain identity mapping after unpermutation
g_idx = g_idx[perm]

else:
scale, zero_point = observer(W, g_idx=None)
scale, zero_point = observer(W)
else:
scale, zero_point = observer(W, g_idx=None)
scale, zero_point = observer(W)

# sparsity mask
sparsity = tensor_sparsity(W)
Expand Down
2 changes: 2 additions & 0 deletions src/llmcompressor/observers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,7 @@

from .helpers import *
from .base import *
from .moving_base import *
from .static_base import *
from .min_max import *
from .mse import *
Loading
Loading