Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/observers.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ from llmcompressor.observers import Observer
from compressed_tensors.quantization.quant_args import QuantizationArgs

args = QuantizationArgs(num_bits=4, strategy="group", group_size=128)
observer = Observer.load_from_registry("minmax", quantization_args=args)
observer = Observer.load_from_registry(
"minmax",
base_name="weight",
quantization_args=args,
)

x = torch.randn(64, 512)
scale, zero_point = observer(x)
Expand Down
10 changes: 6 additions & 4 deletions src/llmcompressor/modifiers/quantization/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,15 @@ def update(
"""

if len(self.k_observers) <= layer_idx:
k_observer_name = self.quantization_args.observer
k_observer = Observer.load_from_registry(
k_observer_name, quantization_args=self.quantization_args
self.quantization_args.observer,
base_name="k",
args=self.quantization_args,
)
v_observer_name = self.quantization_args.observer
v_observer = Observer.load_from_registry(
v_observer_name, quantization_args=self.quantization_args
self.quantization_args.observer,
base_name="v",
args=self.quantization_args,
)

# NOTE: User may ignore some layers in configuration,
Expand Down
75 changes: 19 additions & 56 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from compressed_tensors.quantization import (
DynamicType,
KVCacheScaleType,
QuantizationArgs,
QuantizationScheme,
QuantizationStatus,
QuantizationStrategy,
Expand All @@ -19,12 +20,6 @@
from llmcompressor.observers import Observer
from llmcompressor.utils.helpers import getattr_chain

DEFAULT_MAXSHRINK = 0.20
DEFAULT_PATIENCE = 5
DEFAULT_AVERAGING_CONSTANT = 0.01
DEFAULT_GRID = 100.0
DEFAULT_NORM = 2.4

__all__ = [
"initialize_observer",
"update_weight_zp_scale",
Expand Down Expand Up @@ -54,31 +49,19 @@ def initialize_observer(
:param base_name: str used to name the observer attribute

"""

arg_name = "weights" if base_name == "weight" else f"{base_name}_activations"
quantization_scheme = getattr(module, "quantization_scheme", None)
if not quantization_scheme:
# no quantization scheme nothing to do
return

quantization_args = getattr(quantization_scheme, arg_name, None)
# dont need observers for dynamic
if quantization_args is not None and quantization_args.dynamic in (
False,
DynamicType.LOCAL,
):
observer_kwargs = quantization_args.observer_kwargs or {}
if base_name == "weight":
arg_name = "weights"
elif base_name == "output":
arg_name = "output_activations"
else: # input, q, k, v
arg_name = "input_activations"

args: QuantizationArgs = getattr_chain(
module, f"quantization_scheme.{arg_name}", None
)
if args is not None and args.dynamic is not True:
observer = Observer.load_from_registry(
quantization_args.observer,
quantization_args=quantization_args,
averaging_constant=observer_kwargs.get(
"averaging_constant", DEFAULT_AVERAGING_CONSTANT
),
# used by mse observer only, will be ignored by minmax observer
maxshrink=observer_kwargs.get("maxshrink", DEFAULT_MAXSHRINK),
patience=observer_kwargs.get("patience", DEFAULT_PATIENCE),
grid=observer_kwargs.get("grid", DEFAULT_GRID),
norm=observer_kwargs.get("norm", DEFAULT_NORM),
args.observer, base_name=base_name, args=args, module=module
)
module.register_module(f"{base_name}_observer", observer)

Expand All @@ -100,36 +83,17 @@ def call_observer(
base_name is "weight", then the module's weight tensor will be used
"""
with align_module_device(module):
if base_name == "weight":
value = module.weight
g_idx = getattr(module, "weight_g_idx", None)
elif value is not None:
g_idx = None
else:
raise ValueError(
"Must provide a value to observe if not using weight observer"
)

observer = getattr(module, f"{base_name}_observer")
value = module.weight if base_name == "weight" else value
observer: Observer = getattr(module, f"{base_name}_observer")

if should_calculate_gparam:
global_scale = observer(
value,
should_calculate_gparam=True,
)
global_scale = observer.get_global_scale(value)
update_offload_parameter(module, f"{base_name}_global_scale", global_scale)
else:
global_scale = getattr(module, f"{base_name}_global_scale", None)

if should_calculate_qparams:
updated_scale, updated_zero_point = observer(
value, g_idx=g_idx, global_scale=global_scale
)
# register or update scale & zero_point parameters (supports block shapes)
scale_name = f"{base_name}_scale"
zp_name = f"{base_name}_zero_point"
update_offload_parameter(module, scale_name, updated_scale)
update_offload_parameter(module, zp_name, updated_zero_point)
scale, zero_point = observer(value)
update_offload_parameter(module, f"{base_name}_scale", scale)
update_offload_parameter(module, f"{base_name}_zero_point", zero_point)


def update_weight_global_scale(module: Module):
Expand All @@ -148,7 +112,6 @@ def update_weight_global_scale(module: Module):
should_calculate_gparam=True,
should_calculate_qparams=False,
)
module.weight_observer.reset()


def update_weight_zp_scale(module: Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ def quantize_weight(

# create observer for calculating quantization parameters
observer = Observer.load_from_registry(
quant_args.observer,
quantization_args=quant_args,
"minmax",
base_name="weight",
args=quant_args,
module=module,
averaging_constant=1.0, # ignore moving average
)

Expand Down
Loading
Loading