Skip to content

Commit dc83c02

Browse files
authored
Merge branch 'main' into issue-1927-type-hints
2 parents 6c00316 + 100cc26 commit dc83c02

File tree

3 files changed

+134
-136
lines changed

3 files changed

+134
-136
lines changed

src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from abc import abstractmethod
33
from collections import defaultdict
44
from functools import partial
5-
from typing import Any, Dict, List, Optional, Tuple, Union
5+
from typing import Any
66

77
import numpy
88
import torch
@@ -27,24 +27,24 @@ class SparsityModifierBase(Modifier):
2727
"""
2828

2929
# modifier arguments
30-
sparsity: Optional[Union[float, List[float]]]
31-
sparsity_profile: Optional[str] = None
30+
sparsity: float | list[float] | None
31+
sparsity_profile: str | None = None
3232
mask_structure: str = "0:0"
33-
owl_m: Optional[int] = None
34-
owl_lmbda: Optional[float] = None
33+
owl_m: int | None = None
34+
owl_lmbda: float | None = None
3535

3636
# data pipeline arguments
37-
sequential_update: Optional[bool] = False # deprecated
38-
sequential_targets: Union[str, List[str], None] = None
39-
targets: Union[str, List[str]] = ["Linear"]
40-
ignore: List[str] = Field(default_factory=list)
37+
sequential_update: bool | None = False # deprecated
38+
sequential_targets: str | list[str] | None = None
39+
targets: str | list[str] = ["Linear"]
40+
ignore: list[str] = Field(default_factory=list)
4141

4242
# private variables
43-
_prune_n: Optional[int] = PrivateAttr(default=None)
44-
_prune_m: Optional[int] = PrivateAttr(default=None)
45-
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
46-
_target_layers: Dict[str, torch.nn.Module] = PrivateAttr(default_factory=dict)
47-
_module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
43+
_prune_n: int | None = PrivateAttr(default=None)
44+
_prune_m: int | None = PrivateAttr(default=None)
45+
_module_names: dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
46+
_target_layers: dict[str, torch.nn.Module] = PrivateAttr(default_factory=dict)
47+
_module_sparsities: dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
4848

4949
@field_validator("sequential_update", mode="before")
5050
def validate_sequential_update(cls, value: bool) -> bool:
@@ -58,7 +58,7 @@ def validate_sequential_update(cls, value: bool) -> bool:
5858
return True
5959

6060
@field_validator("sparsity_profile", mode="before")
61-
def validate_sparsity_profile(cls, value: Optional[str]) -> bool:
61+
def validate_sparsity_profile(cls, value: str | None) -> bool:
6262
if value is None:
6363
return value
6464

@@ -94,7 +94,7 @@ def validate_model_after(model: "SparsityModifierBase") -> "SparsityModifierBase
9494
def calibrate_module(
9595
self,
9696
module: torch.nn.Module,
97-
args: Tuple[torch.Tensor, ...],
97+
args: tuple[torch.Tensor, ...],
9898
_output: torch.Tensor,
9999
):
100100
raise NotImplementedError()
@@ -143,12 +143,13 @@ def on_start(self, state: State, event: Event, **kwargs):
143143

144144
# register hooks
145145
for index, (layer_name, layer) in enumerate(self._target_layers.items()):
146-
if isinstance(self.sparsity, dict):
147-
layer_sparsity = self.sparsity[layer_name]
148-
elif isinstance(self.sparsity, list):
149-
layer_sparsity = self.sparsity[index]
150-
else:
151-
layer_sparsity = self.sparsity
146+
match self.sparsity:
147+
case dict():
148+
layer_sparsity = self.sparsity[layer_name]
149+
case list():
150+
layer_sparsity = self.sparsity[index]
151+
case _:
152+
layer_sparsity = self.sparsity
152153

153154
for name, module in get_prunable_layers(layer).items():
154155
name = f"{layer_name}.{name}"
@@ -191,21 +192,21 @@ def on_end(self, state: State, event: Event, **kwargs):
191192
self.ended_ = True
192193
self.remove_hooks()
193194

194-
def _infer_sequential_targets(
195-
self, model: torch.nn.Module
196-
) -> Union[str, List[str]]:
197-
if self.sequential_targets is None:
198-
return get_no_split_params(model)
199-
if isinstance(self.sequential_targets, str):
200-
return [self.sequential_targets]
201-
return self.sequential_targets
195+
def _infer_sequential_targets(self, model: torch.nn.Module) -> str | list[str]:
196+
match self.sequential_targets:
197+
case None:
198+
return get_no_split_params(model)
199+
case str():
200+
return [self.sequential_targets]
201+
case _:
202+
return self.sequential_targets
202203

203204
def _infer_owl_layer_sparsity(
204205
self,
205206
model: torch.nn.Module,
206-
layers: Dict[str, torch.nn.Module],
207+
layers: dict[str, torch.nn.Module],
207208
dataloader: torch.utils.data.DataLoader,
208-
) -> Dict[str, float]:
209+
) -> dict[str, float]:
209210
activations = self._get_activations(model, dataloader)
210211

211212
groups = {}
@@ -248,12 +249,12 @@ def _infer_owl_layer_sparsity(
248249
logger.info(f"Sparsity for {k}: {sparsities[k]}")
249250
return sparsities
250251

251-
def _get_activations(self, model, dataloader, nsamples=128) -> Dict[str, int]:
252+
def _get_activations(self, model, dataloader, nsamples=128) -> dict[str, int]:
252253
from llmcompressor.pipelines.basic import run_calibration
253254

254255
acts = defaultdict(int)
255256

256-
def save_acts(_module, input: Union[Tuple[Any, ...], torch.Tensor], name: str):
257+
def save_acts(_module, input: tuple[Any, ...] | torch.Tensor, name: str):
257258
nonlocal acts
258259
if isinstance(input, tuple):
259260
input = input[0]
@@ -270,6 +271,6 @@ def save_acts(_module, input: Union[Tuple[Any, ...], torch.Tensor], name: str):
270271

271272
return acts
272273

273-
def _split_mask_structure(self, mask_structure: str) -> Tuple[int, int]:
274+
def _split_mask_structure(self, mask_structure: str) -> tuple[int, int]:
274275
n, m = mask_structure.split(":")
275276
return int(n), int(m)

src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import math
22
from copy import copy
3-
from typing import Dict, Optional, Tuple, Union
43

54
import torch
65
import transformers
@@ -23,7 +22,7 @@
2322

2423

2524
def make_empty_hessian(
26-
module: torch.nn.Module, device: Optional[torch.device] = None
25+
module: torch.nn.Module, device: torch.device | None = None
2726
) -> torch.Tensor:
2827
weight = module.weight
2928
num_columns = weight.shape[1]
@@ -34,30 +33,30 @@ def make_empty_hessian(
3433
def accumulate_hessian(
3534
inp: torch.Tensor,
3635
module: torch.nn.Module,
37-
H: Optional[torch.Tensor],
36+
H: torch.Tensor | None,
3837
num_samples: int,
39-
) -> Tuple[torch.Tensor, int]:
38+
) -> tuple[torch.Tensor, int]:
4039
inp = inp.to(device=H.device)
4140
if len(inp.shape) == 2:
4241
inp = inp.unsqueeze(0)
4342

4443
num_added = inp.shape[0]
4544

46-
if isinstance(module, (torch.nn.Linear, transformers.Conv1D)):
47-
if len(inp.shape) == 3:
48-
inp = inp.reshape((-1, inp.shape[-1]))
49-
inp = inp.t()
50-
51-
if isinstance(module, torch.nn.Conv2d):
52-
unfold = torch.nn.Unfold(
53-
module.kernel_size,
54-
dilation=module.dilation,
55-
padding=module.padding,
56-
stride=module.stride,
57-
)
58-
inp = unfold(inp)
59-
inp = inp.permute([1, 0, 2])
60-
inp = inp.flatten(1)
45+
match module:
46+
case torch.nn.Linear() | transformers.Conv1D():
47+
if len(inp.shape) == 3:
48+
inp = inp.reshape((-1, inp.shape[-1]))
49+
inp = inp.t()
50+
case torch.nn.Conv2d():
51+
unfold = torch.nn.Unfold(
52+
module.kernel_size,
53+
dilation=module.dilation,
54+
padding=module.padding,
55+
stride=module.stride,
56+
)
57+
inp = unfold(inp)
58+
inp = inp.permute([1, 0, 2])
59+
inp = inp.flatten(1)
6160

6261
H *= num_samples / (num_samples + num_added)
6362
num_samples += num_added
@@ -72,10 +71,10 @@ def accumulate_hessian(
7271
def quantize_weight(
7372
module: torch.nn.Module,
7473
quant_args: QuantizationArgs,
75-
hessians_dict: Dict[torch.nn.Module, torch.Tensor],
74+
hessians_dict: dict[torch.nn.Module, torch.Tensor],
7675
blocksize: int = 128,
7776
percdamp: float = 0.01,
78-
) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]:
77+
) -> tuple[float, torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor]:
7978
"""
8079
Quantize a module weight according to the GPTQ algorithm
8180
@@ -103,10 +102,11 @@ def quantize_weight(
103102
)
104103

105104
# standardize shape and dtype
106-
if isinstance(module, torch.nn.Conv2d):
107-
W = W.flatten(1)
108-
elif isinstance(module, transformers.Conv1D):
109-
W.transpose_(0, 1)
105+
match module:
106+
case torch.nn.Conv2d():
107+
W = W.flatten(1)
108+
case transformers.Conv1D():
109+
W.transpose_(0, 1)
110110
W = W.to(dtype=GPTQ_PRECISION)
111111
num_rows = W.shape[0]
112112
num_columns = W.shape[1]
@@ -284,7 +284,7 @@ def quantize_weight(
284284

285285
def _apply_activation_ordering(
286286
W: torch.Tensor, H: torch.Tensor
287-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
287+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
288288
"""
289289
Permute weight and hessian in order of greatest outupt activations
290290

0 commit comments

Comments
 (0)