22from abc import abstractmethod
33from collections import defaultdict
44from functools import partial
5- from typing import Any , Dict , List , Optional , Tuple , Union
5+ from typing import Any
66
77import numpy
88import torch
@@ -27,24 +27,24 @@ class SparsityModifierBase(Modifier):
2727 """
2828
2929 # modifier arguments
30- sparsity : Optional [ Union [ float , List [float ]]]
31- sparsity_profile : Optional [ str ] = None
30+ sparsity : float | list [float ] | None
31+ sparsity_profile : str | None = None
3232 mask_structure : str = "0:0"
33- owl_m : Optional [ int ] = None
34- owl_lmbda : Optional [ float ] = None
33+ owl_m : int | None = None
34+ owl_lmbda : float | None = None
3535
3636 # data pipeline arguments
37- sequential_update : Optional [ bool ] = False # deprecated
38- sequential_targets : Union [ str , List [str ], None ] = None
39- targets : Union [ str , List [str ] ] = ["Linear" ]
40- ignore : List [str ] = Field (default_factory = list )
37+ sequential_update : bool | None = False # deprecated
38+ sequential_targets : str | list [str ] | None = None
39+ targets : str | list [str ] = ["Linear" ]
40+ ignore : list [str ] = Field (default_factory = list )
4141
4242 # private variables
43- _prune_n : Optional [ int ] = PrivateAttr (default = None )
44- _prune_m : Optional [ int ] = PrivateAttr (default = None )
45- _module_names : Dict [torch .nn .Module , str ] = PrivateAttr (default_factory = dict )
46- _target_layers : Dict [str , torch .nn .Module ] = PrivateAttr (default_factory = dict )
47- _module_sparsities : Dict [torch .nn .Module , str ] = PrivateAttr (default_factory = dict )
43+ _prune_n : int | None = PrivateAttr (default = None )
44+ _prune_m : int | None = PrivateAttr (default = None )
45+ _module_names : dict [torch .nn .Module , str ] = PrivateAttr (default_factory = dict )
46+ _target_layers : dict [str , torch .nn .Module ] = PrivateAttr (default_factory = dict )
47+ _module_sparsities : dict [torch .nn .Module , str ] = PrivateAttr (default_factory = dict )
4848
4949 @field_validator ("sequential_update" , mode = "before" )
5050 def validate_sequential_update (cls , value : bool ) -> bool :
@@ -58,7 +58,7 @@ def validate_sequential_update(cls, value: bool) -> bool:
5858 return True
5959
6060 @field_validator ("sparsity_profile" , mode = "before" )
61- def validate_sparsity_profile (cls , value : Optional [ str ] ) -> bool :
61+ def validate_sparsity_profile (cls , value : str | None ) -> bool :
6262 if value is None :
6363 return value
6464
@@ -94,7 +94,7 @@ def validate_model_after(model: "SparsityModifierBase") -> "SparsityModifierBase
9494 def calibrate_module (
9595 self ,
9696 module : torch .nn .Module ,
97- args : Tuple [torch .Tensor , ...],
97+ args : tuple [torch .Tensor , ...],
9898 _output : torch .Tensor ,
9999 ):
100100 raise NotImplementedError ()
@@ -143,12 +143,13 @@ def on_start(self, state: State, event: Event, **kwargs):
143143
144144 # register hooks
145145 for index , (layer_name , layer ) in enumerate (self ._target_layers .items ()):
146- if isinstance (self .sparsity , dict ):
147- layer_sparsity = self .sparsity [layer_name ]
148- elif isinstance (self .sparsity , list ):
149- layer_sparsity = self .sparsity [index ]
150- else :
151- layer_sparsity = self .sparsity
146+ match self .sparsity :
147+ case dict ():
148+ layer_sparsity = self .sparsity [layer_name ]
149+ case list ():
150+ layer_sparsity = self .sparsity [index ]
151+ case _:
152+ layer_sparsity = self .sparsity
152153
153154 for name , module in get_prunable_layers (layer ).items ():
154155 name = f"{ layer_name } .{ name } "
@@ -191,21 +192,21 @@ def on_end(self, state: State, event: Event, **kwargs):
191192 self .ended_ = True
192193 self .remove_hooks ()
193194
194- def _infer_sequential_targets (
195- self , model : torch . nn . Module
196- ) -> Union [ str , List [ str ]] :
197- if self . sequential_targets is None :
198- return get_no_split_params ( model )
199- if isinstance ( self .sequential_targets , str ):
200- return [ self . sequential_targets ]
201- return self .sequential_targets
195+ def _infer_sequential_targets (self , model : torch . nn . Module ) -> str | list [ str ]:
196+ match self . sequential_targets :
197+ case None :
198+ return get_no_split_params ( model )
199+ case str ():
200+ return [ self .sequential_targets ]
201+ case _:
202+ return self .sequential_targets
202203
203204 def _infer_owl_layer_sparsity (
204205 self ,
205206 model : torch .nn .Module ,
206- layers : Dict [str , torch .nn .Module ],
207+ layers : dict [str , torch .nn .Module ],
207208 dataloader : torch .utils .data .DataLoader ,
208- ) -> Dict [str , float ]:
209+ ) -> dict [str , float ]:
209210 activations = self ._get_activations (model , dataloader )
210211
211212 groups = {}
@@ -248,12 +249,12 @@ def _infer_owl_layer_sparsity(
248249 logger .info (f"Sparsity for { k } : { sparsities [k ]} " )
249250 return sparsities
250251
251- def _get_activations (self , model , dataloader , nsamples = 128 ) -> Dict [str , int ]:
252+ def _get_activations (self , model , dataloader , nsamples = 128 ) -> dict [str , int ]:
252253 from llmcompressor .pipelines .basic import run_calibration
253254
254255 acts = defaultdict (int )
255256
256- def save_acts (_module , input : Union [ Tuple [ Any , ...], torch .Tensor ] , name : str ):
257+ def save_acts (_module , input : tuple [ Any , ...] | torch .Tensor , name : str ):
257258 nonlocal acts
258259 if isinstance (input , tuple ):
259260 input = input [0 ]
@@ -270,6 +271,6 @@ def save_acts(_module, input: Union[Tuple[Any, ...], torch.Tensor], name: str):
270271
271272 return acts
272273
273- def _split_mask_structure (self , mask_structure : str ) -> Tuple [int , int ]:
274+ def _split_mask_structure (self , mask_structure : str ) -> tuple [int , int ]:
274275 n , m = mask_structure .split (":" )
275276 return int (n ), int (m )
0 commit comments