Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 138 additions & 59 deletions test/prototype/test_parq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import copy
import tempfile
import unittest
from typing import Optional

Expand All @@ -27,9 +28,13 @@
UnifQuantizer,
UnifTorchaoQuantizer,
)
from torchao.prototype.parq.quant.config_torchao import TRANSFORMERS_AVAIL, _is_hf_model
from torchao.prototype.parq.quant.config_torchao import (
TRANSFORMERS_AVAIL,
_attach_hf_quantization_config,
_is_hf_model,
)
from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE
from torchao.quantization.granularity import PerGroup
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig
from torchao.quantization.quant_api import (
Int4WeightOnlyConfig,
Expand All @@ -50,6 +55,84 @@
_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class M(nn.Module):
_tied_weights_keys: list[str] = []

def __init__(
self, m=256, n=128, k=16, bias=False, embedding=True, tied_weights=False
):
nn.Module.__init__(self)
self.embed_tokens = nn.Embedding(k, m) if embedding else nn.Identity()
self.linear1 = nn.Linear(m, n, bias=bias)
self.linear2 = nn.Linear(n, k, bias=bias)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()

if embedding and tied_weights:
assert self.embed_tokens.weight.shape == self.linear2.weight.shape
self.tie_weights()
self._tied_weights_keys.append("linear2.weight")

def tie_weights(self):
self.linear2.weight = self.embed_tokens.weight

def example_inputs(self, device=None):
if isinstance(self.embed_tokens, nn.Identity):
inputs = torch.randn(1, self.linear1.in_features, device=device)
else:
k = self.embed_tokens.num_embeddings
inputs = torch.randint(1, k, (1, self.linear1.in_features), device=device)
return inputs

def forward(self, x):
x = self.embed_tokens(x)
x = self.relu(self.linear1(x))
x = self.sigmoid(self.linear2(x))
return x


if TRANSFORMERS_AVAIL:
from transformers import PretrainedConfig, PreTrainedModel, TorchAoConfig

class MConfig(PretrainedConfig):
def __init__(
self,
m=256,
n=128,
k=16,
bias=False,
embedding=True,
tied_weights=False,
**kwargs,
):
super().__init__(**kwargs)
self.m = m
self.n = n
self.k = k
self.bias = bias
self.embedding = embedding
self.tied_weights = tied_weights

class PreTrainedM(M, PreTrainedModel):
Copy link
Contributor

@jerryzh168 jerryzh168 Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I meant just using some specific model defined in transformers, and use the public APIs, just making sure, would the tests work for existing models in transformers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, existing transformers models also inherit from PreTrainedModel. AutoModelForCausalLM.from_pretrained(..., quantization_config=quantization_config) can be tested in the same way

base_model_prefix = "base"
config_class = MConfig

def __init__(self, config: MConfig):
PreTrainedModel.__init__(self, config)
M.__init__(
self,
m=config.m,
n=config.n,
k=config.k,
bias=config.bias,
embedding=config.embedding,
tied_weights=config.tied_weights,
)

def get_input_embeddings(self) -> nn.Module:
return self.embed_tokens


def split_param_groups(model) -> tuple[list, list, list]:
params_quant, params_embed, params_no_quant = [], [], []

Expand Down Expand Up @@ -191,49 +274,9 @@ def apply_activation_quantization(
pass


class M(nn.Module):
_tied_weights_keys: list[str] = []

def __init__(
self, m=256, n=128, k=16, bias=False, embedding=True, tied_weights=False
):
super().__init__()
self.embedding = nn.Embedding(k, m) if embedding else nn.Identity()
self.linear1 = nn.Linear(m, n, bias=bias)
self.linear2 = nn.Linear(n, k, bias=bias)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()

if embedding and tied_weights:
assert self.embedding.weight.shape == self.linear2.weight.shape
self.linear2.weight = self.embedding.weight
self._tied_weights_keys.append("linear2.weight")

def reset_parameters(self):
for module in (self.linear1, self.linear2):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)

def example_inputs(self, device=None):
if isinstance(self.embedding, nn.Identity):
inputs = torch.randn(1, self.linear1.in_features, device=device)
else:
k = self.embedding.num_embeddings
inputs = torch.randint(1, k, (1, self.linear1.in_features), device=device)
return inputs

def forward(self, x):
x = self.embedding(x)
x = self.relu(self.linear1(x))
x = self.sigmoid(self.linear2(x))
return x


class TestPARQuantization(common_utils.TestCase):
def setUp(self):
torch.manual_seed(123)
self.model = M(bias=True).to(_DEVICE)

@common_utils.parametrize("b", [0, 1, 2, 4])
@common_utils.parametrize("unif_quant", [True, False])
Expand All @@ -242,13 +285,13 @@ def setUp(self):
def test_parq_train_loop(
self, b: int = 2, unif_quant=True, hard_prox=True, per_group_quantizer=False
):
self.model.reset_parameters()
model = M(bias=True).to(_DEVICE)
if unif_quant:
quantizer = TernaryUnifQuantizer() if b == 0 else UnifQuantizer()
else:
quantizer = LSBQuantizer()
param_groups = build_param_groups(
self.model, b, quantizer=quantizer if per_group_quantizer else None
model, b, quantizer=quantizer if per_group_quantizer else None
)
base_optimizer = torch.optim.AdamW(param_groups)

Expand All @@ -257,12 +300,12 @@ def test_parq_train_loop(
)
optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map)
for _ in range(3):
x = self.model.example_inputs(device=_DEVICE)
out = self.model(x)
x = model.example_inputs(device=_DEVICE)
out = model(x)
out.sum().backward()
optimizer.step()

for child in self.model.children():
for child in model.children():
if isinstance(child, nn.Linear):
self.assertEqual(
child.weight.unique().numel(), quantizer.get_quant_size(b)
Expand All @@ -281,7 +324,6 @@ def setUp(self):
@common_utils.parametrize("group_size", [32, 256])
def test_int4_weight_only(self, group_size: int = 32):
model = M(m=512, n=512).to(_DEVICE, dtype=torch.bfloat16)
model.reset_parameters()

m_ref = copy.deepcopy(model).eval().to(_DEVICE)
config = Int4WeightOnlyConfig(group_size=group_size)
Expand All @@ -299,7 +341,6 @@ def test_int4_weight_only(self, group_size: int = 32):
@common_utils.parametrize("group_size", [32, 512])
def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
model = M(m=512, n=512).to(_DEVICE)
model.reset_parameters()

m_ref = copy.deepcopy(model).eval().to(_DEVICE)
quantize_(
Expand All @@ -319,7 +360,6 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
)
def test_int4_weight_only_e2e(self, group_size: int = 32):
model = M(m=512, n=512, embedding=False).to(torch.bfloat16).to(_DEVICE)
model.reset_parameters()

m_ref = copy.deepcopy(model).eval().to(_DEVICE)
config = Int4WeightOnlyConfig(group_size=group_size)
Expand All @@ -339,7 +379,6 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
@common_utils.parametrize("b", [2, 3, 4, 8])
def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
model = M(m=512, n=512, embedding=False).to(_DEVICE)
model.reset_parameters()

m_ref = copy.deepcopy(model).eval().to(_DEVICE)
config = IntxWeightOnlyConfig(
Expand All @@ -366,7 +405,6 @@ def setUp(self):
@common_utils.parametrize("group_size", [32, 256])
def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32):
model = M(m=512, n=512).to(_DEVICE)
model.reset_parameters()

quantizer_ref = UnifQuantizer()
quantizer = StretchedUnifTorchaoQuantizer(b)
Expand All @@ -389,7 +427,6 @@ def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32
@common_utils.parametrize("group_size", [32, 512])
def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
model = M(m=512, n=512).to(_DEVICE)
model.reset_parameters()

quantizer = StretchedUnifTorchaoQuantizer(b)

Expand All @@ -411,7 +448,6 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
@common_utils.parametrize("b", [2, 3])
def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
model = M(m=512, n=512, embedding=False).to(_DEVICE)
model.reset_parameters()

quantizer = StretchedUnifTorchaoQuantizer(b)

Expand Down Expand Up @@ -456,14 +492,16 @@ def test_intx_weight_only_tied_embed_linear(
optimizer.torchao_convert(model)
check_torchao_tensor_subclass(self, model)
self.assertTrue(
torch.equal(model.embedding.weight.qdata, model.linear2.weight.qdata)
torch.equal(model.embed_tokens.weight.qdata, model.linear2.weight.qdata)
)


class TestInt8DynamicActivationTorchaoQuantizer(common_utils.TestCase):
def setUp(self):
torch.manual_seed(123)

@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
@unittest.skipIf(not TRANSFORMERS_AVAIL, "Need transformers")
@common_utils.parametrize("b", [2, 3, 4, 8])
@common_utils.parametrize(
"model_dtype", [torch.float16, torch.float32, torch.bfloat16]
Expand All @@ -475,7 +513,8 @@ def test_int8_dynamic_activation_intx_e2e(
model_dtype: torch.dtype = torch.float32,
group_size: int = 32,
):
model = M(embedding=False, bias=True).to(_DEVICE, dtype=model_dtype)
config = MConfig(embedding=False, bias=True)
model = PreTrainedM(config).to(_DEVICE, dtype=model_dtype)
x = model.example_inputs(device=_DEVICE).to(model_dtype)

# reference model using native quantization
Expand Down Expand Up @@ -506,9 +545,6 @@ def test_int8_dynamic_activation_intx_e2e(

attach_hf_config = False
if TRANSFORMERS_AVAIL:
from transformers import PretrainedConfig

model.config = PretrainedConfig() # pretend this is a HF model
attach_hf_config = _is_hf_model(model)
self.assertTrue(attach_hf_config)

Expand All @@ -530,6 +566,49 @@ def test_int8_dynamic_activation_intx_e2e(
self.assertTrue(isinstance(torchao_config, config.__class__))


class TestTorchAoConfigIntegration(common_utils.TestCase):
@unittest.skipIf(torch.backends.mps.is_available(), "MPS not supported")
@unittest.skipIf(not TRANSFORMERS_AVAIL, "Need transformers")
def test_tied_weights_quantization(self, b: int = 4):
config = MConfig(m=128, n=128, tied_weights=True)
model = PreTrainedM(config).to(_DEVICE)

quantizer = StretchedUnifTorchaoQuantizer(b)
linear_config = StretchedIntxWeightConfig(
b=b,
quant_min=quantizer.quant_min,
quant_max=quantizer.quant_max,
granularity=PerAxis(0),
)
embed_config = IntxWeightOnlyConfig(
weight_dtype=_BIT_WIDTH_TO_DTYPE[b], granularity=PerGroup(32)
)
module_to_config = {"_default": linear_config}
configs = [embed_config]
filter_fns = [lambda m: isinstance(m, nn.Embedding)]
_attach_hf_quantization_config(model, filter_fns, configs, module_to_config)

quantization_config = getattr(model.config, "quantization_config", None)
self.assertTrue(isinstance(quantization_config, TorchAoConfig))
self.assertTrue(quantization_config.modules_to_not_convert == ["linear2"])

# Let HF apply quantize_ given quantization_config
del model.config.quantization_config
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=False)
model = PreTrainedM.from_pretrained(
tmp_dir, quantization_config=quantization_config
)

check_torchao_tensor_subclass(self, model.linear1)
check_torchao_tensor_subclass(self, model.linear2, weight_only=True)
check_torchao_tensor_subclass(self, model.embed_tokens, weight_only=True)

self.assertTrue(
model.linear2.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
)


common_utils.instantiate_parametrized_tests(TestPARQuantization)
common_utils.instantiate_parametrized_tests(TestUnifTorchaoQuantizer)
common_utils.instantiate_parametrized_tests(TestInt8DynamicActivationTorchaoQuantizer)
Expand Down
7 changes: 3 additions & 4 deletions torchao/prototype/parq/quant/config_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,17 +187,16 @@ def _attach_hf_quantization_config(
if module_to_config is None:
module_to_config = {}

seen_data_ptrs = set()
tied_weights_keys = set(getattr(model, "_tied_weights_keys", []))
modules_to_not_convert = []
for name, module in model.named_modules():
if not hasattr(module, "weight"):
continue

data_ptr = module.weight.data_ptr()
if data_ptr in seen_data_ptrs: # do not re-quantize tied weight
# Do not quantize pointers to tied weights or normalization layers
if f"{name}.weight" in tied_weights_keys or "norm" in name:
modules_to_not_convert.append(name)
continue
seen_data_ptrs.add(data_ptr)

for i, filter_fn in enumerate(filter_fns):
if filter_fn(module):
Expand Down
Loading