Skip to content

Commit 75fc571

Browse files
authored
[BE] Convert quantization internal methods private (#2568)
1 parent f5b5567 commit 75fc571

File tree

10 files changed

+56
-61
lines changed

10 files changed

+56
-61
lines changed

test/integration/test_integration.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@
6868
LoggingTensorMode,
6969
_apply_logging_hook,
7070
_fqn_to_op_to_shape_to_count,
71+
_quant_int8_dynamic_per_token_linear,
72+
_quantize_activation_per_token_absmax,
7173
compute_error,
7274
dequantize_per_channel,
7375
dynamically_quantize_per_channel,
74-
quant_int8_dynamic_per_token_linear,
75-
quantize_activation_per_token_absmax,
7676
)
7777
from torchao.quantization.utils import (
7878
compute_error as SQNR,
@@ -557,7 +557,7 @@ def test_dynamic_quant_per_channel_numerics_cuda(self):
557557

558558
def _test_quantize_per_token_impl(self, device, dtype):
559559
x = torch.randn(3, 3, 3, device=device, dtype=dtype)
560-
xq, scales = quantize_activation_per_token_absmax(x)
560+
xq, scales = _quantize_activation_per_token_absmax(x)
561561
block_size = (1, 1, 3)
562562
x_dq = dequantize_affine(
563563
xq, block_size, scales, None, torch.int8, output_dtype=x.dtype
@@ -581,7 +581,7 @@ def _test_per_token_linear_impl(self, device, dtype):
581581
# Note: need to make the weight contiguous because we are
582582
# testing in eager mode and cuBlas will not give correct results
583583
# for a transposed weight
584-
y = quant_int8_dynamic_per_token_linear(
584+
y = _quant_int8_dynamic_per_token_linear(
585585
x, wq.t().contiguous(), w_scales, None, dtype
586586
)
587587
y_ref = torch.matmul(x, w.t())
@@ -1679,9 +1679,9 @@ def forward(self, x):
16791679
assert not isinstance(mod.mha.out_proj.weight, AutoQuantizableLinearWeight)
16801680
assert isinstance(mod.lin.weight, AutoQuantizableLinearWeight)
16811681
mod(*input)
1682-
from torchao.quantization.autoquant import AUTOQUANT_CACHE
1682+
from torchao.quantization.autoquant import _AUTOQUANT_CACHE
16831683

1684-
assert len(AUTOQUANT_CACHE) > 0
1684+
assert len(_AUTOQUANT_CACHE) > 0
16851685

16861686
@parameterized.expand(COMMON_DEVICE_DTYPE)
16871687
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.")

test/quantization/test_quant_primitives.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323

2424
# TODO: remove test for utils?
2525
from torchao.quantization.utils import (
26+
_quantize_activation_per_token_absmax,
2627
get_group_qparams_symmetric,
2728
groupwise_affine_dequantize_tensor_from_qparams,
2829
groupwise_affine_quantize_tensor_from_qparams,
29-
quantize_activation_per_token_absmax,
3030
)
3131
from torchao.utils import (
3232
TORCH_VERSION_AT_LEAST_2_3,
@@ -352,7 +352,7 @@ def test_choose_qparams_tensor_sym(self):
352352
)
353353
def test_quantize_activation_per_token_abs_max(self):
354354
input = torch.randn(10, 10)
355-
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)
355+
quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input)
356356

357357
mapping_type = MappingType.SYMMETRIC
358358
block_size = list(input.shape)
@@ -386,22 +386,22 @@ def test_quantize_activation_per_token_abs_max(self):
386386
def test_quantize_activation_per_token_abs_max_zero_input(self):
387387
input = torch.zeros(10, 10)
388388
# make sure it still works
389-
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)
389+
quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input)
390390

391391
@unittest.skipIf(
392392
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
393393
)
394394
def test_quantize_activation_per_token_abs_max_dtype(self):
395395
input = torch.zeros(10, 10, dtype=torch.bfloat16)
396-
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)
396+
quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input)
397397
self.assertTrue(scale_ref.dtype, torch.bfloat16)
398398

399399
input = torch.zeros(10, 10, dtype=torch.float32)
400-
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)
400+
quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input)
401401
self.assertTrue(scale_ref.dtype, torch.float32)
402402

403403
input = torch.zeros(10, 10, dtype=torch.float16)
404-
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)
404+
quantized_ref, scale_ref = _quantize_activation_per_token_absmax(input)
405405
self.assertTrue(scale_ref.dtype, torch.float32)
406406

407407
@unittest.skipIf(

torchao/_models/llama/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def update(self, input_pos, k_val, v_val):
192192
return k_out, v_out
193193

194194

195-
from torchao.quantization.utils import quantize_activation_per_token_absmax
195+
from torchao.quantization.utils import _quantize_activation_per_token_absmax
196196

197197

198198
class AffineQuantizedKVCache(nn.Module):
@@ -218,13 +218,13 @@ def __init__(
218218

219219
def update(self, input_pos, k_val, v_val):
220220
# quantize current k_val and store it in the cache
221-
q_k_val, k_scale = quantize_activation_per_token_absmax(k_val)
221+
q_k_val, k_scale = _quantize_activation_per_token_absmax(k_val)
222222
self.k_cache[:, :, input_pos] = q_k_val
223223
self.k_cache_scale[:, :, input_pos] = k_scale.unsqueeze(-1)
224224
k_out = self.k_cache * self.k_cache_scale
225225
k_out[:, :, input_pos] = k_val
226226

227-
q_v_val, v_scale = quantize_activation_per_token_absmax(v_val)
227+
q_v_val, v_scale = _quantize_activation_per_token_absmax(v_val)
228228
self.v_cache[:, :, input_pos] = q_v_val
229229
self.v_cache_scale[:, :, input_pos] = v_scale.unsqueeze(-1)
230230
v_out = self.v_cache * self.v_cache_scale

torchao/prototype/quantization/autoquant_v2.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
Int8WeightOnlyQuantizedLinearWeight,
4646
QuantizedLinearWeightBase,
4747
)
48-
from torchao.quantization.utils import quantize_activation_per_token_absmax
48+
from torchao.quantization.utils import _quantize_activation_per_token_absmax
4949
from torchao.utils import (
5050
TORCH_VERSION_AT_LEAST_2_3,
5151
TORCH_VERSION_AT_LEAST_2_5,
@@ -110,7 +110,7 @@ def _graph_equals(g1, g2):
110110

111111
aten = torch.ops.aten
112112

113-
AUTOQUANT_CACHE = {}
113+
_AUTOQUANT_CACHE = {}
114114

115115
# This is a flag to control whether we do some rewrite for graph
116116
# to account for different batch sizes, it's a temporary solution for llama model
@@ -119,15 +119,15 @@ def _graph_equals(g1, g2):
119119

120120

121121
def check_cache(gm, cls, shapes_and_dtype):
122-
for gm_, cls_, shapes_and_dtype_ in AUTOQUANT_CACHE.keys():
122+
for gm_, cls_, shapes_and_dtype_ in _AUTOQUANT_CACHE.keys():
123123
graph_equals = _graph_equals(gm_.graph, gm.graph)
124124
if graph_equals and cls_ is cls and shapes_and_dtype_ == shapes_and_dtype:
125-
return AUTOQUANT_CACHE[(gm_, cls_, shapes_and_dtype_)]
125+
return _AUTOQUANT_CACHE[(gm_, cls_, shapes_and_dtype_)]
126126
return None
127127

128128

129129
def update_cache(gm, cls, shapes_and_dtype, res):
130-
AUTOQUANT_CACHE[(gm, cls, shapes_and_dtype)] = res
130+
_AUTOQUANT_CACHE[(gm, cls, shapes_and_dtype)] = res
131131

132132

133133
# adjust each input's bsz to target_bsz
@@ -638,7 +638,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
638638
# SAM best is between .8 and 1, SDXL also performs best in this range
639639
INTERPOLATION_CONSTANT = mode[1]
640640
w_qtensor = cls.from_float(weight)
641-
x_vals_int8, x_scales = quantize_activation_per_token_absmax(
641+
x_vals_int8, x_scales = _quantize_activation_per_token_absmax(
642642
act_mat.reshape(-1, act_mat.shape[-1])
643643
)
644644
quantized_matmul = (

torchao/quantization/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,21 +101,21 @@ When used as in the example above, when the `autoquant` api is called alongside
101101

102102
When `model(input)` is called, (under the hood) the tool does a preliminary run with the input where each linear layer keeps track of the different shapes and types of activations that it sees. Once the preliminary run is complete, the next step is to check each linear layer and benchmark the tracked shapes for different types of quantization techniques in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, the next step is to apply the necessary quantization technique to each layer, before finally allowing the normal `torch.compile` process to occur on the now quantized model. By default the api only uses int8 techniques, i.e. it chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer, though there is also an option add int4 quantization which can be used for maximum performance or to avoid perf regressions from `Int4WeightOnlyConfig()` since for certain (compute bound) regimes, int4 weight only quantization can be very slow.
103103

104-
Sometimes it is desirable to reuse a quantization plan that `autoquant` came up with. `torchao.quantization.AUTOQUANT_CACHE` is a dictionary holding autoquant's benchmark results. We can save it and restore it later, which will cause `autoquant` to choose the same quantization methods.
104+
Sometimes it is desirable to reuse a quantization plan that `autoquant` came up with. `torchao.quantization._AUTOQUANT_CACHE` is a dictionary holding autoquant's benchmark results. We can save it and restore it later, which will cause `autoquant` to choose the same quantization methods.
105105

106106
```python
107107
import pickle
108108
import torchao.quantization
109109

110110
# After the first forward pass (when quantization was done)
111-
from torchao.quantization.autoquant import AUTOQUANT_CACHE
111+
from torchao.quantization.autoquant import _AUTOQUANT_CACHE
112112
with open("quantization-cache.pkl", "wb") as f:
113-
pickle.dump(AUTOQUANT_CACHE, f)
113+
pickle.dump(_AUTOQUANT_CACHE, f)
114114

115115
# On load
116-
from torchao.quantization.autoquant import AUTOQUANT_CACHE
116+
from torchao.quantization.autoquant import _AUTOQUANT_CACHE
117117
with open("quantization-cache.pkl", "rb") as f:
118-
AUTOQUANT_CACHE.update(pickle.load(f))
118+
_AUTOQUANT_CACHE.update(pickle.load(f))
119119
```
120120

121121
## Quantization Techniques

torchao/quantization/autoquant.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
ZeroPointDomain,
2828
)
2929
from torchao.quantization.utils import (
30+
_quantize_activation_per_token_absmax,
3031
compute_error,
31-
quantize_activation_per_token_absmax,
3232
)
3333
from torchao.utils import (
3434
TORCH_VERSION_AT_LEAST_2_3,
@@ -63,15 +63,15 @@
6363

6464
aten = torch.ops.aten
6565

66-
AUTOQUANT_CACHE = {}
66+
_AUTOQUANT_CACHE = {}
6767

6868

69-
def check_cache(cls, shapes_and_dtype):
70-
return AUTOQUANT_CACHE.get((cls,) + shapes_and_dtype, None)
69+
def _check_cache(cls, shapes_and_dtype):
70+
return _AUTOQUANT_CACHE.get((cls,) + shapes_and_dtype, None)
7171

7272

73-
def update_cache(cls, shapes_and_dtype, res):
74-
AUTOQUANT_CACHE[(cls,) + shapes_and_dtype] = res
73+
def _update_cache(cls, shapes_and_dtype, res):
74+
_AUTOQUANT_CACHE[(cls,) + shapes_and_dtype] = res
7575

7676

7777
# TODO: Document the methods
@@ -145,12 +145,12 @@ def log_shape(act_mat, w_autoquant, bias):
145145
shapes_and_dtype, 0
146146
)
147147
for q_cls in w_autoquant.qtensor_class_list:
148-
if check_cache(q_cls, shapes_and_dtype) is None:
149-
update_cache(q_cls, shapes_and_dtype, None)
148+
if _check_cache(q_cls, shapes_and_dtype) is None:
149+
_update_cache(q_cls, shapes_and_dtype, None)
150150

151151
def tune_autoquant(self, q_cls, shapes_and_dtype, best_time):
152152
act_shape, w_shape, bias_shape, act_dtype = shapes_and_dtype
153-
if check_cache(q_cls, shapes_and_dtype) is None:
153+
if _check_cache(q_cls, shapes_and_dtype) is None:
154154
with torch.no_grad():
155155
act_mat = torch.randn(act_shape, dtype=act_dtype, device=self.device)
156156
bias = (
@@ -183,7 +183,7 @@ def tune_autoquant(self, q_cls, shapes_and_dtype, best_time):
183183
f"warning: failed to autoquant {q_cls.__name__} for shape: {shapes_and_dtype} due to {e}"
184184
)
185185
res = torch.inf
186-
update_cache(q_cls, shapes_and_dtype, res)
186+
_update_cache(q_cls, shapes_and_dtype, res)
187187

188188
@torch.no_grad()
189189
def to_quantized(self, error_on_unseen, **kwargs):
@@ -223,13 +223,13 @@ def count_shapes(self, do_print=True):
223223
total_seen = 0
224224
shape_count = count_shapes(self, do_print=False)
225225
for shapes_and_dtype, times_seen in self.logged_data.items():
226-
if check_cache(q_cls, shapes_and_dtype) is None:
226+
if _check_cache(q_cls, shapes_and_dtype) is None:
227227
# only print shapes once
228228
if print_shape_once:
229229
print_shape_once = False
230230
count_shapes(self, do_print=True)
231231

232-
time_for_best_shape = check_cache(best_cls, shapes_and_dtype)
232+
time_for_best_shape = _check_cache(best_cls, shapes_and_dtype)
233233
time_for_best_shape = (
234234
torch.inf
235235
if time_for_best_shape is None
@@ -238,7 +238,7 @@ def count_shapes(self, do_print=True):
238238
self.tune_autoquant(q_cls, shapes_and_dtype, time_for_best_shape)
239239
ran_new_benchmarks = True
240240
torch._dynamo.reset()
241-
cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen
241+
cur_time += _check_cache(q_cls, shapes_and_dtype) * times_seen
242242
total_seen += times_seen
243243
cur_time = cur_time / total_seen
244244
# print aggregated time if there were multiple shapes to aggregate and some new benchmarking was done
@@ -498,7 +498,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
498498
# SAM best is between .8 and 1, SDXL also performs best in this range
499499
INTERPOLATION_CONSTANT = mode[1]
500500
w_qtensor = cls.from_float(weight)
501-
x_vals_int8, x_scales = quantize_activation_per_token_absmax(
501+
x_vals_int8, x_scales = _quantize_activation_per_token_absmax(
502502
act_mat.reshape(-1, act_mat.shape[-1])
503503
)
504504
quantized_matmul = (

torchao/quantization/dynamic_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import torch.nn as nn
99

1010
from .utils import (
11+
_quant_int8_dynamic_per_token_linear,
1112
dynamically_quantize_per_channel,
12-
quant_int8_dynamic_per_token_linear,
1313
)
1414

1515
__all__ = ["DynamicallyPerAxisQuantizedLinear"]
@@ -44,7 +44,7 @@ def forward(self, X: torch.Tensor, *args, **kwargs) -> torch.Tensor:
4444
4545
"""
4646

47-
Y = quant_int8_dynamic_per_token_linear(
47+
Y = _quant_int8_dynamic_per_token_linear(
4848
X, self.W_int_repr_t, self.W_scales, self.bias, X.dtype
4949
)
5050
return Y

torchao/quantization/smoothquant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import torch.nn.functional as F
1717

1818
from .utils import (
19+
_quant_int8_dynamic_per_token_linear,
1920
dynamically_quantize_per_channel,
20-
quant_int8_dynamic_per_token_linear,
2121
)
2222

2323
__all__ = [
@@ -152,7 +152,7 @@ def forward(self, X, *args, **kwargs):
152152
W_int_repr_t = (
153153
self.W_int_repr if self.store_w_int_repr_t else self.W_int_repr.t()
154154
)
155-
Y = quant_int8_dynamic_per_token_linear(
155+
Y = _quant_int8_dynamic_per_token_linear(
156156
X, W_int_repr_t, self.W_scales, self.bias, X.dtype
157157
)
158158
return Y

torchao/quantization/subclass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from torch.utils._python_dispatch import return_and_correct_aliasing
1010

1111
from torchao.quantization.utils import (
12+
_quant_int8_dynamic_per_token_linear,
1213
dequantize_per_channel,
1314
dynamically_quantize_per_channel,
1415
groupwise_affine_quantize_tensor,
15-
quant_int8_dynamic_per_token_linear,
1616
unpack_tinygemm_scales_and_zeros,
1717
)
1818
from torchao.utils import (
@@ -244,7 +244,7 @@ def __init__(self, int_data, q_scales, transposed, shape, dtype=None, **kwargs):
244244

245245
@staticmethod
246246
def _quantized_op(act_mat, w_qtensor, bias):
247-
return quant_int8_dynamic_per_token_linear(
247+
return _quant_int8_dynamic_per_token_linear(
248248
act_mat, w_qtensor.int_data, w_qtensor.q_scales, bias, act_mat.dtype
249249
)
250250

0 commit comments

Comments
 (0)