Skip to content

Commit 23b7cfc

Browse files
committed
call the right functions
Signed-off-by: shanjiaz <[email protected]>
1 parent 85d51b2 commit 23b7cfc

File tree

3 files changed

+42
-25
lines changed

3 files changed

+42
-25
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
)
1111
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
1212
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
13-
from compressed_tensors.utils import align_module_device, update_parameter_data
13+
from compressed_tensors.utils import (
14+
align_module_device,
15+
delete_offload_parameter,
16+
update_offload_parameter,
17+
)
1418
from loguru import logger
1519
from torch.nn import Module
1620

@@ -116,7 +120,7 @@ def call_observer(
116120
value,
117121
should_calculate_gparam=True,
118122
)
119-
update_parameter_data(module, global_scale, f"{base_name}_global_scale")
123+
update_offload_parameter(module, f"{base_name}_global_scale", global_scale)
120124
else:
121125
global_scale = getattr(module, f"{base_name}_global_scale", None)
122126

@@ -127,22 +131,21 @@ def call_observer(
127131
# register or update scale & zero_point parameters (supports block shapes)
128132
scale_name = f"{base_name}_scale"
129133
zp_name = f"{base_name}_zero_point"
130-
if not hasattr(module, scale_name) or getattr(module, scale_name).shape != updated_scale.shape:
131-
if hasattr(module, scale_name):
132-
delattr(module, scale_name)
133-
module.register_parameter(
134-
scale_name, torch.nn.Parameter(updated_scale.clone())
135-
)
136-
else:
137-
update_parameter_data(module, updated_scale, scale_name)
138-
if not hasattr(module, zp_name) or getattr(module, zp_name).shape != updated_zero_point.shape:
139-
if hasattr(module, zp_name):
140-
delattr(module, zp_name)
141-
module.register_parameter(
142-
zp_name, torch.nn.Parameter(updated_zero_point.clone())
143-
)
144-
else:
145-
update_parameter_data(module, updated_zero_point, zp_name)
134+
for name, value in [
135+
(scale_name, updated_scale),
136+
(zp_name, updated_zero_point),
137+
]:
138+
if (
139+
not hasattr(module, name)
140+
or getattr(module, name).shape != value.shape
141+
):
142+
if hasattr(module, name):
143+
delete_offload_parameter(module, name)
144+
module.register_offload_parameter(
145+
name, torch.nn.Parameter(value.clone(), requires_grad=False)
146+
)
147+
else:
148+
update_offload_parameter(module, name, value)
146149

147150

148151
def update_weight_global_scale(module: Module):
@@ -273,8 +276,8 @@ def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Te
273276
kv_cache = getattr(module, "kv_cache")
274277
k_scale = kv_cache.k_scales[module.layer_idx]
275278
v_scale = kv_cache.v_scales[module.layer_idx]
276-
update_parameter_data(module, k_scale, KVCacheScaleType.KEY.value)
277-
update_parameter_data(module, v_scale, KVCacheScaleType.VALUE.value)
279+
update_offload_parameter(module, KVCacheScaleType.KEY.value, k_scale)
280+
update_offload_parameter(module, KVCacheScaleType.VALUE.value, v_scale)
278281

279282

280283
def initialize_quantized_kv_cache(module: Module):

src/llmcompressor/observers/base.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,17 +193,29 @@ def get_qparams(
193193
)
194194

195195
elif self.quantization_args.strategy == QuantizationStrategy.BLOCK:
196-
# Block-wise quantization: one scale/zero_point per block of shape [block_rows, block_cols]
196+
# Block-wise quantization: one scale/zero_point per block of shape
197+
# [block_rows, block_cols]
197198
rows, cols = observed.shape[:2]
198199
bs = self.quantization_args.block_structure
199-
if not (isinstance(bs, (list, tuple)) and len(bs) == 2 and all(isinstance(x, int) for x in bs)):
200-
raise ValueError(f"Invalid block_structure '{bs}'. Must be a list of two ints [rows, cols].")
200+
if not (
201+
isinstance(bs, (list, tuple))
202+
and len(bs) == 2
203+
and all(isinstance(x, int) for x in bs)
204+
):
205+
raise ValueError(
206+
f"Invalid block_structure '{bs}'. "
207+
f"Must be a list of two ints [rows, cols]."
208+
)
201209
block_rows, block_cols = bs
202210
num_br = int(ceil(rows / block_rows))
203211
num_bc = int(ceil(cols / block_cols))
204212
# allocate per-block scale and zero_point
205-
self._scale = torch.empty((num_br, num_bc), dtype=observed.dtype, device=observed.device)
206-
self._zero_point = torch.empty((num_br, num_bc), dtype=observed.dtype, device=observed.device)
213+
self._scale = torch.empty(
214+
(num_br, num_bc), dtype=observed.dtype, device=observed.device
215+
)
216+
self._zero_point = torch.empty(
217+
(num_br, num_bc), dtype=observed.dtype, device=observed.device
218+
)
207219
# compute qparams for each block
208220
for i in range(num_br):
209221
r0 = i * block_rows

tests/llmcompressor/modifiers/quantization/test_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def q_config_kwargs(config_0, config_1):
3434
)
3535
)
3636

37+
3738
@pytest.fixture
3839
def block_q_config_kwargs():
3940
return dict(
@@ -53,6 +54,7 @@ def block_q_config_kwargs():
5354
)
5455
)
5556

57+
5658
def test_block_strategy_parsing(block_q_config_kwargs):
5759
modifier = GPTQModifier(**block_q_config_kwargs)
5860
resolved = modifier.resolve_quantization_config()

0 commit comments

Comments
 (0)