Skip to content

Conversation

namgyu-youn
Copy link
Contributor

@namgyu-youn namgyu-youn commented Sep 16, 2025

Summary

Similar to AWQ (#2753), SmoothQuant now uses direct protocol implementation instead of an external wrapper (to_weight_tensor_with_linear_activation_scale_metadata).

This PR was inspired by @jerryzh168, at #2728 (comment).

  • add attribute to LinearActivationQuantizedTensor: act_pre_scale

Test Plan

python torchao/prototype/smoothquant/example.py
python test/prototype/test_smoothquant.py
python test/integration/test_integration.py

Update SmoothQuant to use SupportsActivationPreScaling protocol instead of wrapper

Similar to AWQ (pytorch#2753), SmoothQuant
now uses direct protocol implementation instead of
`to_weight_tensor_with_linear_activation_scale_metadata` wrapper.

Key changes:
- Add `act_pre_scale` attribute to `LinearActivationQuantizedTensor`
- Apply pre-scaling in all dispatch methods
- Remove external wrapper dependency

Test Plan:
```bash
python torchao/prototype/smoothquant/example.py
python test/prototype/test_smoothquant.py
python test/integration/test_integration.py
```
Copy link

pytorch-bot bot commented Sep 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3012

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 16, 2025
@namgyu-youn
Copy link
Contributor Author

namgyu-youn commented Sep 16, 2025

Result of test/integration/test_integration.py:

----------------------------------------------------------------------
Ran 358 tests in 477.933s

OK (skipped=205)

@namgyu-youn namgyu-youn changed the title Update SmoothQuant to use SupportsActivationPreScaling protocol instead of external wrapper Update SmoothQuant to use subtensor instead of external wrapper Sep 16, 2025
@jerryzh168
Copy link
Contributor

@namgyu-youn thanks, I think you probably want to migrate the Int8 tensor first, can you help with that?

#2752 the plain_layout for int8_act_int8_weight

def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
and _aqt_is_int8_reduced_range(input_tensor)
and isinstance(weight_tensor, AffineQuantizedTensor)
and _aqt_is_int8(weight_tensor)
and input_tensor.dtype == weight_tensor.dtype
and isinstance(input_tensor._layout, PlainLayout)
and isinstance(weight_tensor._layout, PlainLayout)
)
def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
#
# 1. do the matrix form of dot(X_i, W_j)
#
#
# 2. rescale the output
#
# in cases with large matrices, y_dot_int32 can grow sufficiently
# large that y_dot_int32 * a float16 scale is greater than the maximum
# value of a float 16, (which results in a value of inf even if multiplying
# by the other scale would bring it within the expected range)
x_vals_int8 = input_tensor.tensor_impl.int_data
x_scales = input_tensor.tensor_impl.scale
w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t()
w_scales = weight_tensor.tensor_impl.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
x_scales_dtype = x_scales.dtype
# Cast fp16 scale to float to avoid overflow in int_scaled_matmul
intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype
y_dot_scaled = int_scaled_matmul(
tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)
)
y_dot_scaled = y_dot_scaled.to(x_scales_dtype)
y = (y_dot_scaled * w_scales).reshape(
*x_vals_int8.shape[:-1], y_dot_scaled.shape[-1]
)
# can downcast only at the very end
output_dtype = input_tensor.dtype
y = y.to(output_dtype)
if bias is not None:
y += bias
return y

@namgyu-youn
Copy link
Contributor Author

@namgyu-youn thanks, I think you probably want to migrate the Int8 tensor first, can you help with that?

#2752 the plain_layout for int8_act_int8_weight

def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
and _aqt_is_int8_reduced_range(input_tensor)
and isinstance(weight_tensor, AffineQuantizedTensor)
and _aqt_is_int8(weight_tensor)
and input_tensor.dtype == weight_tensor.dtype
and isinstance(input_tensor._layout, PlainLayout)
and isinstance(weight_tensor._layout, PlainLayout)
)
def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
#
# 1. do the matrix form of dot(X_i, W_j)
#
#
# 2. rescale the output
#
# in cases with large matrices, y_dot_int32 can grow sufficiently
# large that y_dot_int32 * a float16 scale is greater than the maximum
# value of a float 16, (which results in a value of inf even if multiplying
# by the other scale would bring it within the expected range)
x_vals_int8 = input_tensor.tensor_impl.int_data
x_scales = input_tensor.tensor_impl.scale
w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t()
w_scales = weight_tensor.tensor_impl.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
x_scales_dtype = x_scales.dtype
# Cast fp16 scale to float to avoid overflow in int_scaled_matmul
intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype
y_dot_scaled = int_scaled_matmul(
tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)
)
y_dot_scaled = y_dot_scaled.to(x_scales_dtype)
y = (y_dot_scaled * w_scales).reshape(
*x_vals_int8.shape[:-1], y_dot_scaled.shape[-1]
)
# can downcast only at the very end
output_dtype = input_tensor.dtype
y = y.to(output_dtype)
if bias is not None:
y += bias
return y

I am definitely interested in it because affine transformation is still hard for me. I will look into it; thanks for your suggestion.


# Apply pre-scaling if present
if (
hasattr(weight_tensor, "act_pre_scale")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should always be true?

@jerryzh168
Copy link
Contributor

jerryzh168 commented Sep 18, 2025

oh actually this is not ready for review, we need to update int8 tensor instead, converting this to draft now

@jerryzh168 jerryzh168 marked this pull request as draft September 18, 2025 16:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants