Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/llmcompressor/core/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class EventType(Enum):
BATCH_START = "batch_start"
LOSS_CALCULATED = "loss_calculated"
BATCH_END = "batch_end"
SEQUENTIAL_EPOCH_END = "sequential_epoch_end"

# step lifecycle
OPTIM_PRE_STEP = "optim_pre_step"
Expand Down
11 changes: 11 additions & 0 deletions src/llmcompressor/core/session_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,5 +212,16 @@ def batch_end(cls, **kwargs) -> ModifiedState:
active_session()._log_model_info()
return cls.event(EventType.BATCH_END, **kwargs)

@classmethod
def sequential_epoch_end(cls, **kwargs) -> ModifiedState:
"""
Invoke a sequential epoch end event for the active session. This event should be
called after one sequential layer has been calibrated/trained for one epoch

This is called after a sequential layer has been calibrated with one batch, see
`src/llmcompressor/pipelines/sequential/pipeline.py` for usage example
"""
return cls.event(EventType.SEQUENTIAL_EPOCH_END, **kwargs)


callbacks = LifecycleCallbacks
21 changes: 17 additions & 4 deletions src/llmcompressor/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from loguru import logger
from pydantic import PrivateAttr

from llmcompressor.core import State
from llmcompressor.core import Event, EventType, State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.obcq.sgpt_mixin import SparsityModifierMixin
from llmcompressor.modifiers.obcq.sgpt_sparsify import (
Expand Down Expand Up @@ -90,6 +90,14 @@ def calibrate_module(
args: Tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
"""
Calibration hook used to accumulate the hessian of the input to the module

:param module: module being calibrated
:param args: inputs to the module, the first element of which is the
cannonical input
:param _output: uncompressed module output, unused
"""
# Assume that the first argument is the input
inp = args[0]

Expand All @@ -108,10 +116,13 @@ def calibrate_module(
self._num_samples[module],
)

def on_sequential_batch_end(self):
def on_event(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
self.compress_modules()

def compress_modules(self):
"""
Sparsify modules
TODO: implement with event callback
Sparsify modules which have been calibrated
"""
for module in list(self._num_samples.keys()):
name = self._module_names[module]
Expand Down Expand Up @@ -154,6 +165,8 @@ def _maybe_onload_hessian(self, module: torch.nn.Module):
self._hessians[module] = self._hessians[module].to(device="cpu")

def on_finalize(self, state: State, **kwargs) -> bool:
self.compress_modules() # compress any remaining modules

self.remove_hooks()
self._hessians = dict()
self._num_samples = dict()
Expand Down
2 changes: 0 additions & 2 deletions src/llmcompressor/modifiers/obcq/sgpt_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
state.data.calib,
self.sequential_targets,
self.ignore,
self,
)
return True

Expand All @@ -186,7 +185,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
state.model,
state.data.calib,
self.sequential_targets,
self,
)
return True

Expand Down
22 changes: 17 additions & 5 deletions src/llmcompressor/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from loguru import logger
from pydantic import PrivateAttr

from llmcompressor.core import State
from llmcompressor.core import Event, EventType, State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.obcq.sgpt_mixin import SparsityModifierMixin
from llmcompressor.modifiers.pruning.wanda.wanda_sparsify import (
Expand Down Expand Up @@ -74,6 +74,14 @@ def calibrate_module(
args: Tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
"""
Calibration hook used to accumulate the row scalars of the input to the module

:param module: module being calibrated
:param args: inputs to the module, the first element of which is the
cannonical input
:param _output: uncompressed module output, unused
"""
# Assume that the first argument is the input
inp = args[0]

Expand All @@ -91,12 +99,14 @@ def calibrate_module(
self._num_samples[module],
)

def on_sequential_batch_end(self):
def on_event(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
self.compress_modules()

def compress_modules(self):
"""
Sparsify modules
TODO: implement with event callback
Sparsify modules which have been calibrated
"""

for module in list(self._num_samples.keys()):
name = self._module_names[module]
sparsity = self._module_sparsities[module]
Expand All @@ -122,6 +132,8 @@ def on_sequential_batch_end(self):
del self._num_samples[module]

def on_finalize(self, state: State, **kwargs) -> bool:
self.compress_modules() # compress any remaining modules

self.remove_hooks()
self._row_scalars = dict()
self._num_samples = dict()
Expand Down
26 changes: 14 additions & 12 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from loguru import logger
from pydantic import Field, PrivateAttr, field_validator

from llmcompressor.core import State
from llmcompressor.core import Event, EventType, State
from llmcompressor.modifiers import Modifier, ModifierFactory
from llmcompressor.modifiers.quantization.calibration import freeze_module_quantization
from llmcompressor.modifiers.quantization.gptq.gptq_quantize import (
Expand Down Expand Up @@ -236,7 +236,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:
state.data.calib,
self.sequential_targets,
self.ignore,
self,
)
return True

Expand All @@ -257,7 +256,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:
state.model,
state.data.calib,
self.sequential_targets,
self,
)
return True

Expand All @@ -281,6 +279,8 @@ def on_finalize(self, state: State, **kwargs) -> bool:

:param state: session state storing input model and calibration data
"""
self.compress_modules() # compress any remaining modules

if self._quantization_modifier:
self._quantization_modifier.finalize(state, **kwargs)

Expand All @@ -298,13 +298,12 @@ def calibrate_module(
_output: torch.Tensor,
):
"""
Quantize a module's weight according to the GPTQ algorithm

:param name: name of module being quantized
:param module: module being quantized
:param args: input arguments for module forward pass
Calibration hook used to accumulate the hessian of the input to the module

:return: total loss from applying weight quantization to this module
:param module: module being calibrated
:param args: inputs to the module, the first element of which is the
cannonical input
:param _output: uncompressed module output, unused
"""
# Assume that first argument is the input
inp = args[0]
Expand All @@ -326,10 +325,13 @@ def calibrate_module(
self._num_samples[module],
)

def on_sequential_batch_end(self):
def on_event(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
self.compress_modules()

def compress_modules(self):
"""
Quantize modules.
TODO: implement with event callback
Quantize modules which have been calibrated
"""
for module in list(self._num_samples.keys()):
name = self._module_names[module]
Expand Down
70 changes: 37 additions & 33 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from compressed_tensors.utils.offload import is_module_offloaded
from compressed_tensors.utils.offload import align_module_device
from loguru import logger
from pydantic import ConfigDict
from torch.nn import Module

from llmcompressor.core import State
from llmcompressor.core import Event, EventType, State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.smoothquant.utils import (
get_layer_mappings_from_architecture,
Expand Down Expand Up @@ -105,7 +106,7 @@ class SmoothQuantModifier(Modifier):
num_calibration_steps: Optional[int] = None
calibration_function: Optional[Callable] = None

resolved_mappings_: Optional[List] = None
resolved_mappings_: Optional[List[SmoothQuantMapping]] = None
scales_: Optional[Dict] = None

def on_initialize(self, state: State, **kwargs) -> bool:
Expand Down Expand Up @@ -139,13 +140,23 @@ def on_initialize(self, state: State, **kwargs) -> bool:

return True

def on_event(self, state: State, event: Event, **kwargs):
"""
Sparsify modules which have been calibrated with samples
"""
if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
self._apply_smoothing(state.model)

def on_finalize(self, state: State, **kwargs) -> bool:
"""
Clean up by clearing the scale and mapping data

:param state: unused
:return: True
"""
self.remove_hooks()
self._apply_smoothing(state.model)

if self.scales_ is not None:
self.scales_.clear()
if self.resolved_mappings_ is not None:
Expand All @@ -166,7 +177,7 @@ def _infer_mappings_from_model(
)

@handle_mapping_resolution_errors
def _resolve_mappings(self, model: Module) -> List:
def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
"""
Transforms the list of activations to smooth and their corresponding weights
into SmoothQuantMapping objects, resolving regular expressions.
Expand Down Expand Up @@ -259,9 +270,6 @@ def _calibrate(self, model: Module, calibration_dataloader: List):
self.calibration_function,
)

# remove the hooks now that we are done calibrating
self.remove_hooks()

@torch.no_grad()
def _apply_smoothing(self, model: Module):
"""
Expand All @@ -273,8 +281,11 @@ def _apply_smoothing(self, model: Module):

This modifies the weights of the model in-place.
"""
logger.info("Smoothing activation scales...")
for mapping in self.resolved_mappings_:
if mapping.smooth_name not in self.scales_:
continue
logger.info(f"Smoothing with {mapping.smooth_name}")

activation_scales = ( # get dynamic range for each activation channel
self.scales_[mapping.smooth_name].max_channel_vals
- self.scales_[mapping.smooth_name].min_channel_vals
Expand All @@ -289,22 +300,16 @@ def _apply_smoothing(self, model: Module):

@torch.no_grad()
def smooth(module):
offloaded = is_module_offloaded(module)
if offloaded:
module._hf_hook.pre_forward(module)

if module in balance_layers:
module.weight.mul_(scales.view(1, -1))
elif module == smooth_layer:
if module.weight.ndim == 1:
module.weight.div_(scales)
else:
module.weight.div_(scales.view(-1, 1))
if hasattr(module, "bias") and module.bias is not None:
module.bias.div_(scales)

if offloaded:
module._hf_hook.post_forward(module, None)
with align_module_device(module):
if module in balance_layers:
module.weight.mul_(scales.view(1, -1))
elif module == smooth_layer:
if module.weight.ndim == 1:
module.weight.div_(scales)
else:
module.weight.div_(scales.view(-1, 1))
if hasattr(module, "bias") and module.bias is not None:
module.bias.div_(scales)

parent = get_fsdp_parent(mapping.smooth_name, model)
if parent is not None:
Expand All @@ -315,6 +320,9 @@ def smooth(module):
smooth(layer)
smooth(smooth_layer)

# clear calibration data
del self.scales_[mapping.smooth_name]

def _calculate_smoothing_scales(
self, balance_layers: List[Module], activation_scales: torch.Tensor
) -> List[float]:
Expand All @@ -329,15 +337,9 @@ def _calculate_smoothing_scales(
# get the channel-wise dynamic range for each layer to be balanced
weight_scales = []
for layer in balance_layers:
offloaded = is_module_offloaded(layer)
if offloaded:
layer._hf_hook.pre_forward(layer)

scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
weight_scales.append(scale)

if offloaded:
layer._hf_hook.post_forward(layer, None)
with align_module_device(layer):
scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
weight_scales.append(scale)

weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0]

Expand All @@ -350,3 +352,5 @@ def _calculate_smoothing_scales(
scales = torch.where(weight_scales > 0.0, scales, activation_scales)

return scales

model_config = ConfigDict(arbitrary_types_allowed=True)
4 changes: 0 additions & 4 deletions src/llmcompressor/pipelines/basic/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,3 @@ def run_pipeline(
batch = apply_pad_mask_to_batch(batch)
batch = tensors_to_device(batch, model_device)
model(**batch)

# TODO: replace with a lifecycle event
if callback_modifier:
callback_modifier.on_sequential_batch_end()
Loading