Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions torchao/prototype/smoothquant/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
import torch

from torchao.core.config import AOBaseConfig
from torchao.quantization.linear_activation_scale import (
to_weight_tensor_with_linear_activation_scale_metadata,
)
from torchao.quantization.quant_api import (
_QUANTIZE_CONFIG_HANDLER,
_linear_extra_repr,
)
from torchao.quantization.quantize_.common import SupportsActivationPreScaling
from torchao.quantization.transform_module import (
_QUANTIZE_CONFIG_HANDLER,
register_quantize_module_handler,
)
from torchao.utils import DummyModule
Expand Down Expand Up @@ -62,7 +60,7 @@ def _smooth_quant_transform(
config: SmoothQuantConfig,
) -> torch.nn.Module:
step = config.step
base_config = config.base_config
observed_linear = None

if step == SmoothQuantStep.PREPARE:
observer = SmoothQuantObserver(
Expand All @@ -71,7 +69,7 @@ def _smooth_quant_transform(
)
return SmoothQuantObservedLinear.from_float(module, observer)

if step == SmoothQuantStep.PREPARE_FOR_LOADING:
elif step == SmoothQuantStep.PREPARE_FOR_LOADING:
# loading from pre-quantized checkpoint
observer = SmoothQuantObserver(
weight=module.weight,
Expand All @@ -97,9 +95,19 @@ def _smooth_quant_transform(

# Compute smoothed weight parameters
smoothing_factor = observed_linear.obs.calculate_qparams()
weight = observed_linear.weight * smoothing_factor
smoothing_factor = torch.clamp(smoothing_factor, min=1e-6)

base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)]
dummy_mod = DummyModule(observed_linear.weight * smoothing_factor)
quant_mod = base_config_handler(dummy_mod, config.base_config)
qw = quant_mod.weight
assert isinstance(qw, SupportsActivationPreScaling), (
"weight must support activation scaling through implementing `SupportsActivationPreScaling`"
)
# since we want to do `act` / `smoothing_factor` during runtime for speed, we'll save the
# reciprocal of the `smoothing_factor`
qw.act_pre_scale = (1.0 / smoothing_factor).to(qw.dtype)

# Create new linear layer
with torch.device("meta"):
linear = torch.nn.Linear(
observed_linear.in_features,
Expand All @@ -110,16 +118,6 @@ def _smooth_quant_transform(
)
linear.bias = observed_linear.bias

# Quantize weights
base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(base_config)]
dummy_mod = DummyModule(weight)
quant_mod = base_config_handler(dummy_mod, base_config)
qw = quant_mod.weight

# Add smoothing factor metadata
qw = to_weight_tensor_with_linear_activation_scale_metadata(
qw, smoothing_factor.to(qw.dtype)
)
linear.weight = torch.nn.Parameter(qw, requires_grad=False)
linear.extra_repr = types.MethodType(_linear_extra_repr, linear)

Expand Down
35 changes: 30 additions & 5 deletions torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class LinearActivationQuantizedTensor(TorchAOBaseTensor):
a quantized tensor, this is used to quantize input
`quant_kwargs` (Dict[str, Any]): Additional keyword arguments for the quantization function.
Restriction: Must not contain tensor values.
`act_pre_scale` (Optional[torch.Tensor]): Pre-scaling factor for activation quantization
"""

quant_kwargs: Dict[str, Any]
Expand All @@ -42,6 +43,7 @@ def __new__(
original_weight_tensor: torch.Tensor,
input_quant_func: Callable,
quant_kwargs: Dict[str, Any],
act_pre_scale: Optional[torch.Tensor] = None,
):
kwargs = {}
dtype = original_weight_tensor.dtype
Expand All @@ -56,31 +58,47 @@ def __init__(
original_weight_tensor: torch.Tensor,
input_quant_func: Callable[[torch.Tensor], torch.Tensor],
quant_kwargs: Dict[str, Any],
act_pre_scale: Optional[torch.Tensor] = None,
):
self.original_weight_tensor = original_weight_tensor
self.input_quant_func = input_quant_func
self.quant_kwargs = quant_kwargs
self.act_pre_scale = act_pre_scale

def __repr__(self):
return f"{self.__class__.__name__}({self.original_weight_tensor}, {self.input_quant_func}, quant_kwargs={self.quant_kwargs}))"
return f"{self.__class__.__name__}({self.original_weight_tensor}, {self.input_quant_func}, quant_kwargs={self.quant_kwargs}, act_pre_scale={self.act_pre_scale})"

def __tensor_flatten__(self):
return ["original_weight_tensor"], [self.input_quant_func, self.quant_kwargs]
return ["original_weight_tensor"], [
self.input_quant_func,
self.quant_kwargs,
self.act_pre_scale,
]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
original_weight_tensor = tensor_data_dict["original_weight_tensor"]
input_quant_func, quant_kwargs = tensor_attributes
return cls(original_weight_tensor, input_quant_func, quant_kwargs)
input_quant_func, quant_kwargs, act_pre_scale = tensor_attributes
return cls(
original_weight_tensor, input_quant_func, quant_kwargs, act_pre_scale
)

@staticmethod
def _quantized_linear_op(
input_tensor: torch.Tensor, weight_tensor: torch.Tensor, bias: torch.Tensor
):
if input_tensor.numel() == 0:
return input_tensor

# Apply pre-scaling if present
if (
hasattr(weight_tensor, "act_pre_scale")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should always be true?

and weight_tensor.act_pre_scale is not None
):
input_tensor = input_tensor * weight_tensor.act_pre_scale

input_quant_func = weight_tensor.input_quant_func
original_weight_tensor = weight_tensor.original_weight_tensor
quant_kwargs = weight_tensor.quant_kwargs
Expand All @@ -95,16 +113,18 @@ def from_float(
input_float: torch.Tensor,
input_quant_func: Callable,
quant_kwargs: Optional[Dict[str, Any]] = None,
act_pre_scale: Optional[torch.Tensor] = None,
):
if quant_kwargs is None:
quant_kwargs = {}
return cls(input_float, input_quant_func, quant_kwargs)
return cls(input_float, input_quant_func, quant_kwargs, act_pre_scale)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.original_weight_tensor),
self.input_quant_func,
self.quant_kwargs,
self.act_pre_scale,
)

def to(self, *args, **kwargs):
Expand All @@ -113,6 +133,7 @@ def to(self, *args, **kwargs):
self.original_weight_tensor.to(**kwargs),
self.input_quant_func,
self.quant_kwargs,
self.act_pre_scale.to(**kwargs) if self.act_pre_scale is not None else None,
)


Expand Down Expand Up @@ -238,6 +259,7 @@ def _(func, types, args, kwargs):
func(args[0].original_weight_tensor, *args[1:]),
args[0].input_quant_func,
args[0].quant_kwargs,
args[0].act_pre_scale,
),
)

Expand All @@ -252,6 +274,7 @@ def _(func, types, args, kwargs):
func(args[0].original_weight_tensor, *args[1:]),
args[0].input_quant_func,
args[0].quant_kwargs,
args[0].act_pre_scale,
),
)

Expand All @@ -266,6 +289,7 @@ def _(func, types, args, kwargs):
func(args[0].original_weight_tensor, *args[1:]),
args[0].input_quant_func,
args[0].quant_kwargs,
args[0].act_pre_scale,
),
)

Expand All @@ -281,6 +305,7 @@ def _(func, types, args, kwargs):
func(args[0].original_weight_tensor, *args[1:]),
args[0].input_quant_func,
args[0].quant_kwargs,
args[0].act_pre_scale,
),
)

Expand Down