2
2
from typing import Any , Dict , List , Optional , Union
3
3
4
4
import torch
5
- from compressed_tensors .utils import align_module_device , update_offload_parameter
5
+ from compressed_tensors .utils import (
6
+ align_module_device ,
7
+ get_execution_device ,
8
+ update_offload_parameter ,
9
+ )
6
10
from loguru import logger
7
11
from pydantic import ConfigDict
8
12
from torch .nn import Module
11
15
from llmcompressor .core import State
12
16
from llmcompressor .modifiers import Modifier
13
17
from llmcompressor .modifiers .utils .pytorch_helpers import run_calibration_forward
14
- from llmcompressor .pytorch .utils import tensor_forward_with_input_args
15
18
from llmcompressor .utils .fsdp .helpers import get_fsdp_parent
16
19
from llmcompressor .utils .helpers import calibration_forward_context
17
20
from llmcompressor .utils .pytorch .module import (
@@ -217,7 +220,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
217
220
self ._resolved_mappings = resolved_mappings
218
221
return
219
222
220
- def _setup_scale_hooks (self ):
223
+ def _setup_scale_hooks (self ) -> None :
221
224
"""
222
225
Attach a forward hook to each activation we want to smooth. This allows us to
223
226
calculate the dynamic range during calibration
@@ -243,7 +246,7 @@ def hook_fn(module, inp, out):
243
246
self .register_hook (layer , create_hook_fn (name ), "forward" )
244
247
245
248
@torch .no_grad ()
246
- def _calibrate (self , model : Module , calibration_dataloader : List ):
249
+ def _calibrate (self , model : Module , calibration_dataloader : List ) -> None :
247
250
"""
248
251
Catch the output dynamic ranges of each layer that will be smoothed by running
249
252
forward passes with calibration_dataloader
@@ -264,7 +267,7 @@ def _calibrate(self, model: Module, calibration_dataloader: List):
264
267
calibration_dataloader ,
265
268
)
266
269
267
- def _concat_collected_activations (self ):
270
+ def _concat_collected_activations (self ) -> None :
268
271
"""
269
272
Concatenate the collected activation values from each forward pass into a single
270
273
tensor for each layer
@@ -277,7 +280,7 @@ def _concat_collected_activations(self):
277
280
self ._scales [name ] = torch .cat (self ._scales [name ], dim = 0 )
278
281
279
282
@torch .no_grad ()
280
- def _apply_smoothing (self , model : Module ):
283
+ def _apply_smoothing (self , model : Module ) -> None :
281
284
"""
282
285
Calculate the best scaling factors for each layer to smooth activations and
283
286
apply the scaling factors to the weights of the next layer to offset the
@@ -484,7 +487,7 @@ def _compute_loss(
484
487
fp16_output : torch .Tensor ,
485
488
int_w_output : torch .Tensor ,
486
489
device : torch .device ,
487
- ):
490
+ ) -> torch . Tensor :
488
491
loss = 0.0
489
492
fp16_output_flat = fp16_output .view (- 1 )
490
493
int_w_output_flat = int_w_output .view (- 1 )
@@ -579,7 +582,7 @@ def _forward_input_with_kwargs(
579
582
module : Module ,
580
583
inputs : torch .Tensor ,
581
584
input_kwargs : Optional [Dict [str , Any ]] = None ,
582
- ):
585
+ ) -> torch . Tensor :
583
586
"""
584
587
Forward pass with input arguments
585
588
@@ -590,43 +593,44 @@ def _forward_input_with_kwargs(
590
593
"""
591
594
kwargs = input_kwargs or self ._module_kwargs
592
595
kwargs = _sanitize_kwargs (kwargs , module )
593
- return tensor_forward_with_input_args (
594
- module = module ,
595
- inputs = inputs ,
596
- input_kwargs = kwargs ,
597
- )[0 ]
596
+
597
+ inputs = inputs .to (get_execution_device (module ))
598
+
599
+ return module (inputs , ** kwargs )[0 ]
598
600
599
601
600
- def _sanitize_kwargs (inputs_kwargs , module ) :
602
+ def _sanitize_kwargs (input_kwargs : Dict [ str , Any ], module : Module ) -> Dict [ str , Any ] :
601
603
"""
602
- Remove the arguments that are not supported in the module's
603
- forward pass to avoid breaking behaviour between different versions
604
- of transformers.
604
+ Sanitize input keyword arguments to match the module's forward method signature,
605
+ excluding `use_cache` which is not desired to be passed into module.
605
606
606
607
Args:
607
608
inputs_kwargs (`dict`):
608
609
The input dictionary to pass to the model layer
609
610
module (`torch.nn.Module`):
610
611
Target module to quantize.
611
612
"""
613
+
612
614
params = inspect .signature (module .forward ).parameters
613
- sanitized_kwargs = {}
614
- for k , v in inputs_kwargs . items ():
615
- if k in params and k != "use_cache" :
616
- sanitized_kwargs [ k ] = v
617
- # In case forward pass has optional dependencies that don't default to None.
615
+
616
+ # Filter out any kwargs not in module.forward signature
617
+ sanitized_kwargs = { k : v for k , v in input_kwargs . items () if k in params }
618
+
619
+ # Edge Case: forward pass has optional dependencies that don't default to None.
618
620
# This is the case for `LlamaAttention.forward` which has input
619
621
# `attention_mask: Optional[torch.Tensor],` (with no `= None` default)
620
622
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L246
621
623
for k , v in params .items ():
622
624
if (
623
625
k not in sanitized_kwargs
624
- and k != "use_cache"
625
626
and v .default is inspect .Parameter .empty
626
627
and str (v .annotation ).startswith ("typing.Optional" )
627
628
):
628
629
sanitized_kwargs [k ] = None
629
630
631
+ # Exclude `use_cache` entirely
632
+ sanitized_kwargs .pop ("use_cache" , None )
633
+
630
634
return sanitized_kwargs
631
635
632
636
0 commit comments