Skip to content

Update coreml codebook #2648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions test/prototype/test_codebook_coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
)
from torchao.quantization import quantize_
from torchao.quantization.utils import compute_error
from torchao.testing.utils import skip_if_no_cuda
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, is_package_at_least


Expand All @@ -36,7 +35,7 @@ def test_choose_qparams_codebook(self):
self.block_size,
)
group_size = self.block_size[-1]
self.assertEqual(codebook.shape, (256 // group_size, 2**self.nbits, 1))
self.assertEqual(codebook.shape, (1, 256 // group_size, 2**self.nbits, 1))
self.assertEqual(wq.shape, (100, 256))

self.assertFalse(torch.isnan(codebook).any())
Expand Down Expand Up @@ -76,7 +75,6 @@ def test_quantize_api(self):
)
assert type(m[0].weight) == CodebookQuantizedTensor

@skip_if_no_cuda()
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "requires 2.6+.")
def test_export(self):
m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(torch.float32)
Expand Down
3 changes: 1 addition & 2 deletions torchao/prototype/quantization/codebook_coreml/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,12 @@ def _codebook_weight_only_transform(
raise ImportError("Requires coremltools >= 8.3.0")

dtype = config.dtype
block_size = config.block_size
weight = module.weight

quantized_weight = CodebookQuantizedTensor.from_float(
weight,
dtype,
block_size,
config.block_size,
)
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
return module
155 changes: 98 additions & 57 deletions torchao/prototype/quantization/codebook_coreml/codebook_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ def choose_qparams_and_quantize_codebook_coreml(
Args:
input_tensor (torch.Tensor): The input tensor to be quantized.
code_dtype (torch.dtype): The dtype for the codes. [torch.uint1, ..., torch.uint8]
block_size (List[int]): the size for how many elements of last dimension of input_tensor
belong to the same group and should share the same lookup table. let's say original
shape is (N, K), and block_size of (N, group_size) or (-1, group_size),
then the slice of (N, group_size) elements should use the same lookup
table, and there will be (K // group_size) lookup tables
block_size (List[int]): block sizes for how many elements in each dimension share
the same lookup table (len(block_size) == input_tensor.dim())
Each dimension of input_tensor must be divisible by the corresponding element of block_size
Look up tables are indexed by {(di // bi) for i in input_tensor.dim()}
For example, if the input tensor has shape (N, K), and block_size is (N, group_size), this means
there is a lookup table for group_size columns, i.e., (K // group_size) total look up tables
force_kmeans1d (bool): Use kmeans1d regardless of number of weights
cluster_dim (int): this means the size of the vector for vector lookup table quantization
e.g. when cluster_dim is 4, instead of quantizing each scalar value one by one, we quantize
Expand All @@ -48,43 +49,45 @@ def choose_qparams_and_quantize_codebook_coreml(

Returns:
Tuple[torch.Tensor, torch.Tensor] The codebook (lookup table) Tensor and the quantized Tensor (codes, torch.uint8)
The LUT table has dimension (g0, .., g(N-1), 2**nbits, vec_dim), where:
* The first N dimensions index over the different tables (gi = input_tensor.shape[i] // block_size[i] in each dimension)
* The N + 1 dimension indexes over the nbit indices (2 ** nbits)
* The N + 2 dimension indexes over the look up values (shape = 1 for scalar)
"""
assert code_dtype in list(_SUB_BYTE_UINT_BOUNDS.keys()) + [torch.uint8]
assert len(block_size) == input_tensor.ndim
nbits = _DTYPE_TO_BIT_WIDTH[code_dtype]
assert nbits >= 1 and nbits <= 8, f"nbits must be in [1, 8], got {nbits}"

assert len(block_size) == input_tensor.dim()
block_size = block_size.copy()
for i in range(input_tensor.ndim - 1):
assert block_size[i] == -1 or block_size[i] == input_tensor.shape[i], (
f"{block_size} not supported"
for i in range(len(block_size)):
if block_size[i] == -1:
block_size[i] = input_tensor.shape[i]
assert block_size[i] >= 1 and input_tensor.shape[i] % block_size[i] == 0, (
"block_size[i] must divide input_tensor.shape[i]"
)

group_size = block_size[-1]
if group_size == -1:
group_size = input_tensor.shape[-1]

assert input_tensor.shape[-1] % group_size == 0
assert input_tensor.ndim == 2
assert input_tensor.dim() == 2, "Currently only rank 2 tensors are supported"
assert block_size[0] == input_tensor.shape[0], (
"Currently only support per-grouped channel granularity"
)
assert cluster_dim == 1, (
f"only cluster_dim == 1 is supported right now, got {cluster_dim}"
)

num_lut = input_tensor.shape[1] // block_size[1]
group_size = block_size[1]

# for converting to numpy
input_tensor = input_tensor.detach()
# (N, K)
original_shape = input_tensor.shape
# (K // group_size)
num_lut = input_tensor.shape[1] // group_size

# reshape to (N, K // group_size, group_size)
input_tensor = input_tensor.reshape(input_tensor.shape[0], num_lut, group_size)
from coremltools.models.neural_network.quantization_utils import (
_get_kmeans_lookup_table_and_weight,
)

nbits = _DTYPE_TO_BIT_WIDTH[code_dtype]
if nbits > 8:
print(f"Requested nbits: {nbits}, rewriting to 8 bits to reduce the size")
nbits = 8

res_lut = []
# each res_w[:, i, :] will use the same lookup table
# res_w: (N, K // group_size, group_size)
Expand All @@ -102,6 +105,13 @@ def choose_qparams_and_quantize_codebook_coreml(
# res_lut: (K // group_size, 2 ** nbits)
res_lut = torch.stack(res_lut, dim=0)

# The final LUT should have dimension equal to input_tensor.dim() + 2
# The first input_tensor.dim() dimensions index over the tables,
# input_tensor.dim() + 1 indexes over the nbit indices
# input_tensor.dim() + 2 are the look up values (shape = 1 for scalar)
# res_lut: (N, K // group_size, 2 ** nbits, group_size)
res_lut = res_lut.reshape(1, num_lut, 2**nbits, 1)

# reshape back to (N, K)
res_w = res_w.reshape(*original_shape)

Expand All @@ -112,7 +122,7 @@ def choose_qparams_and_quantize_codebook_coreml(
def dequantize_codebook(
codes: torch.Tensor,
codebook: torch.Tensor,
code_dtype: torch.dtype,
nbits: int,
block_size: List[int],
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
Expand All @@ -121,13 +131,14 @@ def dequantize_codebook(

Args:
codes (torch.Tensor): Indices of codebook entries for each element
shape (N, K) for scalar quantization
codebook (torch.Tensor): Codebook tensor used for quantization,
shape (K // group_size, 2 ** nbits) where K is the dim 1 shape of input
code_dtype (torch.dtype): The logical dtype for the codes, [torch.uint1, ..., torch.uint8]
Note that codes is stored in torch.uint8, this is just addtional information for dequantize op
block_size (List[int]): a slice of elements with shape block_size will share the same lookup table
only support (-1, ..., group_size) right now (all preceding dimensions has to match input)
General shape: (d0, d1, d2, ..., dN)
Simple example shape: (N, K)
codebook (torch.Tensor): Codebook tensor used for quantization
General shape: (d0 // block_size[0], ..., dN // block_size[N], 2**nbits, vec_dim), where vec_dim = 1 for scalar look up values
Simple example shape: (1, group_size, 2 ** nbits, 1) for scalar look up values, with 1 table per group_size columns
nbits: int: number of bits for the quantization
block_size (List[int]): a slice of elements with shape block_size will share the same lookup table.
If block_size[i] == -1, then the entire dimension is used.
output_dtype (torch.dtype): dtype for the output tensor.

Returns:
Expand All @@ -140,37 +151,67 @@ def dequantize_codebook(
torch.bfloat16,
], f"Unsupported output dtype: {output_dtype}"

assert code_dtype in list(_SUB_BYTE_UINT_BOUNDS.keys()) + [torch.uint8]
assert nbits >= 1 and nbits <= 8, f"nbits must be in [1, 8], got {nbits}"

assert len(block_size) == codes.ndim
assert len(block_size) == codes.dim()
block_size = block_size.copy()
for i in range(codes.ndim - 1):
assert block_size[i] == -1 or block_size[i] == codes.shape[i], (
f"{block_size} not supported"
for i in range(len(block_size)):
if block_size[i] == -1:
block_size[i] = codes.shape[i]
assert block_size[i] >= 1 and codes.shape[i] % block_size[i] == 0, (
"block_size[i] must divide codes.shape[i]"
)

group_size = block_size[-1]
if group_size == -1:
group_size = codes.shape[-1]
assert codebook.dim() == codes.dim() + 2
codebook_shape = codebook.shape
vec_dim = codebook_shape[-1]
quant_levels = 2**nbits

assert codes.shape[-1] % group_size == 0
K = codes.shape[-1]
num_lut = K // group_size
# (N, K)
original_shape = codes.shape
# Check that last two dimensions of codebook are [quant_levels, vec_dim]
assert codebook_shape[-2] == quant_levels, "Codebook shape mismatch with nbits"

# reshape to (N, num_lut, group_size)
codes = codes.reshape(codes.shape[0], num_lut, group_size)
dequant = torch.zeros_like(codes, dtype=output_dtype)
# Compute shape of lookup group indices from codes shape and block size
code_shape = codes.shape
ndim = codes.ndim
assert len(block_size) == ndim, "block_size must match dimensionality of codes"

# do lookup for each lookup table
# dequant shape: (N, num_lut, group_size)
# codebook shape: (num_lut, 2 ** nbits)
# codes shape: (N, num_lut, group_size)
for i in range(num_lut):
# dequant[:, i, :]: (N, group_size)
# using squeeze to remove the training dim 1s after the lookup
dequant[:, i, :] = codebook[i][codes[:, i, :]].squeeze()
# Compute which codebook slice to use for each element
group_indices = []
for i in range(ndim):
assert block_size[i] >= 1 and code_shape[i] % block_size[i] == 0, (
f"dimension {code_shape[i]} not divisible by block size {block_size[i]}"
)

dequant = dequant.reshape(*original_shape)
return dequant.to(output_dtype)
# Index of block
idx = (
torch.arange(code_shape[i], device=codes.device) // block_size[i]
) # shape (di,)

# Reshape idx to broadcast along all other dims
shape = [1] * ndim
shape[i] = code_shape[i]
idx = idx.view(*shape) # shape (1, ..., 1, di, 1, ..., 1)
idx = idx.expand(code_shape) # shape (d0, ..., dN)
group_indices.append(idx)

# Stack the broadcasted group indices
# group_index_tensor at (i0, i1, ..., iN) is the gives the group indices (g0, ..., gN)
# for the element at (i0, i1, ..., iN) in the original code
# If code.shape = (d1, d2, d3), then group_index_tensor.shape = (d1, d2, d3, 3)
group_index_tensor = torch.stack(
group_indices, dim=-1
) # shape (d0, d1, ..., dN, ndim)

# Flatten everything to index efficiently
flat_codes = codes.reshape(-1) # shape (numel,)
flat_groups = group_index_tensor.reshape(-1, ndim) # (numel, ndim)

# Compute dequantized values via indexing
# index into codebook with (*group_index, code_index, :)
gathered = codebook[(*flat_groups.T, flat_codes)] # shape (numel, vec_dim)
dequant = gathered.reshape(*code_shape, vec_dim)

if vec_dim == 1:
dequant = dequant.squeeze(-1)

return dequant.to(dtype=output_dtype)
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
choose_qparams_and_quantize_codebook_coreml,
dequantize_codebook,
)
from torchao.quantization.quant_primitives import (
_DTYPE_TO_BIT_WIDTH,
)
from torchao.utils import TorchAOBaseTensor

aten = torch.ops.aten
Expand Down Expand Up @@ -95,7 +98,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
return dequantize_codebook(
codes,
self.codebook,
self.code_dtype,
_DTYPE_TO_BIT_WIDTH[self.code_dtype],
self.block_size,
output_dtype=output_dtype,
)
Expand Down Expand Up @@ -174,6 +177,17 @@ def _(func, types, args, kwargs):
return func(input_tensor, weight_tensor, bias)


@implements([torch.nn.functional.embedding, aten.embedding.default])
def _(func, types, args, kwargs):
assert len(args) == 2
indices, weight_tensor = (
args[0],
args[1],
)
weight_tensor = weight_tensor.dequantize()
return func(indices, weight_tensor, **kwargs)


@implements([aten.detach.default, aten.alias.default])
Copy link
Contributor

@jerryzh168 jerryzh168 Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can be a separate PR, you can try defining CodebookTensor.tensor_data_names and CodebookTensor.tensor_attribute_names removing these things now and see if it still works

#2597 and #2598 added more utils for TorchAOBaseTensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This I will save for a future PR

def _(func, types, args, kwargs):
return return_and_correct_aliasing(
Expand Down
Loading