|  | 
| 3 | 3 | from typing import Dict, List, Optional, Tuple, Union | 
| 4 | 4 | 
 | 
| 5 | 5 | import torch | 
| 6 |  | -from compressed_tensors.quantization import disable_quantization | 
|  | 6 | +from compressed_tensors.quantization import QuantizationType, disable_quantization | 
| 7 | 7 | from compressed_tensors.utils import ( | 
| 8 | 8 |     align_modules, | 
| 9 | 9 |     get_execution_device, | 
| @@ -126,6 +126,7 @@ class AWQModifier(Modifier, QuantizationMixin): | 
| 126 | 126 | 
 | 
| 127 | 127 |     # Private vars set during validation | 
| 128 | 128 |     _num_bits: Optional[int] = PrivateAttr(default=None) | 
|  | 129 | +    _activation_bits: int = PrivateAttr(default=16) | 
| 129 | 130 |     _symmetric: Optional[bool] = PrivateAttr(default=None) | 
| 130 | 131 |     _group_size: Optional[int] = PrivateAttr(default=None) | 
| 131 | 132 | 
 | 
| @@ -189,6 +190,18 @@ def validate_model_after(model: "AWQModifier") -> "AWQModifier": | 
| 189 | 190 |             if act is not None | 
| 190 | 191 |         } | 
| 191 | 192 |         if not (len(num_bits_set) == 0 or num_bits_set == {16}): | 
|  | 193 | +            num_bits_type = { | 
|  | 194 | +                act.type | 
|  | 195 | +                for group in config.config_groups.values() | 
|  | 196 | +                for act in (group.input_activations, group.output_activations) | 
|  | 197 | +                if act is not None | 
|  | 198 | +            } | 
|  | 199 | +            assert ( | 
|  | 200 | +                next(iter(num_bits_type)) == QuantizationType.FLOAT | 
|  | 201 | +            ), "In AWQ, lower-precision activation quantization must be float" | 
|  | 202 | + | 
|  | 203 | +            model._activation_bits = next(iter(num_bits_set)) | 
|  | 204 | + | 
| 192 | 205 |             warnings.warn( | 
| 193 | 206 |                 "A strategy including activation quantization was detected. " | 
| 194 | 207 |                 "AWQ was originally intended for weight-only quantization. " | 
| @@ -612,16 +625,26 @@ def _compute_best_scale( | 
| 612 | 625 |             # Q(W * s) | 
| 613 | 626 |             for linear in linears2scale: | 
| 614 | 627 |                 linear.weight.mul_(_scalesview) | 
| 615 |  | -                update_offload_parameter( | 
| 616 |  | -                    linear, | 
| 617 |  | -                    "weight", | 
|  | 628 | +                scaled_weight = ( | 
| 618 | 629 |                     _pseudo_quantize_tensor( | 
| 619 | 630 |                         w=linear.weight.data, | 
| 620 | 631 |                         symmetric=self._symmetric, | 
| 621 | 632 |                         bit_width=self._num_bits, | 
| 622 | 633 |                         group_size=self._group_size, | 
| 623 | 634 |                     )[0] | 
| 624 |  | -                    / _scalesview, | 
|  | 635 | +                    / _scalesview | 
|  | 636 | +                ) | 
|  | 637 | + | 
|  | 638 | +                # fp8 activation simulation | 
|  | 639 | +                if self._activation_bits == 8: | 
|  | 640 | +                    scaled_weight = scaled_weight.to(torch.float8_e4m3fn).to( | 
|  | 641 | +                        torch.float16 | 
|  | 642 | +                    ) | 
|  | 643 | + | 
|  | 644 | +                update_offload_parameter( | 
|  | 645 | +                    linear, | 
|  | 646 | +                    "weight", | 
|  | 647 | +                    scaled_weight, | 
| 625 | 648 |                 ) | 
| 626 | 649 | 
 | 
| 627 | 650 |             # W * X | 
|  | 
0 commit comments