Skip to content

Commit 72c87aa

Browse files
authored
Use align_module_device util (#1298)
## Purpose ## * Standardization and clarity ## Changes ## * Replace all uses of `_hf_hook.pre_forward` with `align_module_device` ## Testing ## * `grep -r '_hf_hook.pre_forward' src/` Signed-off-by: Kyle Sayers <[email protected]>
1 parent 22b4877 commit 72c87aa

File tree

3 files changed

+39
-59
lines changed

3 files changed

+39
-59
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
)
99
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
1010
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
11-
from compressed_tensors.utils.offload import is_module_offloaded, update_parameter_data
11+
from compressed_tensors.utils import align_module_device, update_parameter_data
1212
from loguru import logger
1313
from torch.nn import Module
1414

@@ -72,27 +72,23 @@ def call_observer(module: Module, base_name: str, value: Optional[torch.Tensor]
7272
:param value: torch.Tensor to be passed to the observer for activations. If
7373
base_name is "weight", then the module's weight tensor will be used
7474
"""
75-
offloaded = is_module_offloaded(module)
76-
if offloaded:
77-
module._hf_hook.pre_forward(module)
78-
79-
if base_name == "weight":
80-
value = module.weight
81-
g_idx = getattr(module, "weight_g_idx", None)
82-
elif value is not None:
83-
g_idx = None
84-
else:
85-
raise ValueError("Must provide a value to observe if not using weight observer")
86-
87-
observer = getattr(module, f"{base_name}_observer")
88-
updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
75+
with align_module_device(module):
76+
if base_name == "weight":
77+
value = module.weight
78+
g_idx = getattr(module, "weight_g_idx", None)
79+
elif value is not None:
80+
g_idx = None
81+
else:
82+
raise ValueError(
83+
"Must provide a value to observe if not using weight observer"
84+
)
8985

90-
# update scale and zero point
91-
update_parameter_data(module, updated_scale, f"{base_name}_scale")
92-
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
86+
observer = getattr(module, f"{base_name}_observer")
87+
updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
9388

94-
if offloaded:
95-
module._hf_hook.post_forward(module, None)
89+
# update scale and zero point
90+
update_parameter_data(module, updated_scale, f"{base_name}_scale")
91+
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
9692

9793

9894
def update_weight_zp_scale(module: Module):

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Callable, Dict, List, Optional, Tuple, Union
33

44
import torch
5-
from compressed_tensors.utils.offload import is_module_offloaded
5+
from compressed_tensors.utils import align_module_device
66
from loguru import logger
77
from pydantic import Field
88
from torch.nn import Module
@@ -290,22 +290,16 @@ def _apply_smoothing(self, model: Module):
290290

291291
@torch.no_grad()
292292
def smooth(module):
293-
offloaded = is_module_offloaded(module)
294-
if offloaded:
295-
module._hf_hook.pre_forward(module)
296-
297-
if module in balance_layers:
298-
module.weight.mul_(scales.view(1, -1))
299-
elif module == smooth_layer:
300-
if module.weight.ndim == 1:
301-
module.weight.div_(scales)
302-
else:
303-
module.weight.div_(scales.view(-1, 1))
304-
if hasattr(module, "bias") and module.bias is not None:
305-
module.bias.div_(scales)
306-
307-
if offloaded:
308-
module._hf_hook.post_forward(module, None)
293+
with align_module_device(module):
294+
if module in balance_layers:
295+
module.weight.mul_(scales.view(1, -1))
296+
elif module == smooth_layer:
297+
if module.weight.ndim == 1:
298+
module.weight.div_(scales)
299+
else:
300+
module.weight.div_(scales.view(-1, 1))
301+
if hasattr(module, "bias") and module.bias is not None:
302+
module.bias.div_(scales)
309303

310304
parent = get_fsdp_parent(mapping.smooth_name, model)
311305
if parent is not None:
@@ -330,15 +324,9 @@ def _calculate_smoothing_scales(
330324
# get the channel-wise dynamic range for each layer to be balanced
331325
weight_scales = []
332326
for layer in balance_layers:
333-
offloaded = is_module_offloaded(layer)
334-
if offloaded:
335-
layer._hf_hook.pre_forward(layer)
336-
337-
scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
338-
weight_scales.append(scale)
339-
340-
if offloaded:
341-
layer._hf_hook.post_forward(layer, None)
327+
with align_module_device(layer):
328+
scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
329+
weight_scales.append(scale)
342330

343331
weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0]
344332

src/llmcompressor/transformers/compression/helpers.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import torch
66
from accelerate import infer_auto_device_map, init_empty_weights
77
from accelerate.accelerator import get_state_dict_offloaded_model
8-
from compressed_tensors import is_module_offloaded
98
from compressed_tensors.quantization.utils import iter_named_leaf_modules, module_type
9+
from compressed_tensors.utils import align_module_device
1010
from torch.nn.modules import Linear
1111
from tqdm import tqdm
1212
from transformers import AutoModelForCausalLM
@@ -298,18 +298,14 @@ def is_sparse_compression_target(
298298
:return: whether or not the module is a target for sparsity compression,
299299
i.e True if it is sparse and follows the sparsity structure, else False
300300
"""
301-
offloaded = is_module_offloaded(module)
302-
if offloaded:
303-
module._hf_hook.pre_forward(module)
304-
305-
result = (
306-
hasattr(module, "weight")
307-
and tensor_sparsity(module.weight) >= sparsity_threshold
308-
and tensor_follows_mask_structure(tensor=module.weight, mask=sparsity_structure)
309-
)
310-
311-
if offloaded:
312-
module._hf_hook.post_forward(module, None)
301+
with align_module_device(module):
302+
result = (
303+
hasattr(module, "weight")
304+
and tensor_sparsity(module.weight) >= sparsity_threshold
305+
and tensor_follows_mask_structure(
306+
tensor=module.weight, mask=sparsity_structure
307+
)
308+
)
313309

314310
return result
315311

0 commit comments

Comments
 (0)