From 6f227ce42df265bad64ee09680d3c3924ce3368f Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 31 Jul 2025 14:37:53 -0700 Subject: [PATCH 1/4] Update CoreML codebook APIs --- test/prototype/test_codebook_coreml.py | 4 +- .../quantization/codebook_coreml/api.py | 3 +- .../codebook_coreml/codebook_ops.py | 139 +++++++++++------- .../codebook_quantized_tensor.py | 16 +- 4 files changed, 101 insertions(+), 61 deletions(-) diff --git a/test/prototype/test_codebook_coreml.py b/test/prototype/test_codebook_coreml.py index 0c16de8969..e72edbecbc 100644 --- a/test/prototype/test_codebook_coreml.py +++ b/test/prototype/test_codebook_coreml.py @@ -33,10 +33,10 @@ def test_choose_qparams_codebook(self): codebook, wq = choose_qparams_and_quantize_codebook_coreml( self.input, self.code_dtype, - self.block_size, + [self.input.shape[0], 4], ) group_size = self.block_size[-1] - self.assertEqual(codebook.shape, (256 // group_size, 2**self.nbits, 1)) + self.assertEqual(codebook.shape, (1, 256 // group_size, 2**self.nbits, 1)) self.assertEqual(wq.shape, (100, 256)) self.assertFalse(torch.isnan(codebook).any()) diff --git a/torchao/prototype/quantization/codebook_coreml/api.py b/torchao/prototype/quantization/codebook_coreml/api.py index f2e1c78210..36fa0d299f 100644 --- a/torchao/prototype/quantization/codebook_coreml/api.py +++ b/torchao/prototype/quantization/codebook_coreml/api.py @@ -42,13 +42,12 @@ def _codebook_weight_only_transform( raise ImportError("Requires coremltools >= 8.3.0") dtype = config.dtype - block_size = config.block_size weight = module.weight quantized_weight = CodebookQuantizedTensor.from_float( weight, dtype, - block_size, + config.block_size, ) module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) return module diff --git a/torchao/prototype/quantization/codebook_coreml/codebook_ops.py b/torchao/prototype/quantization/codebook_coreml/codebook_ops.py index 3ecb4852aa..a12ac1998c 100644 --- a/torchao/prototype/quantization/codebook_coreml/codebook_ops.py +++ b/torchao/prototype/quantization/codebook_coreml/codebook_ops.py @@ -34,11 +34,12 @@ def choose_qparams_and_quantize_codebook_coreml( Args: input_tensor (torch.Tensor): The input tensor to be quantized. code_dtype (torch.dtype): The dtype for the codes. [torch.uint1, ..., torch.uint8] - block_size (List[int]): the size for how many elements of last dimension of input_tensor - belong to the same group and should share the same lookup table. let's say original - shape is (N, K), and block_size of (N, group_size) or (-1, group_size), - then the slice of (N, group_size) elements should use the same lookup - table, and there will be (K // group_size) lookup tables + block_size (List[int]): block sizes for how many elements in each dimension share + the same lookup table (len(block_size) == input_tensor.dim()) + Each dimension of input_tensor must be divisible by the corresponding element of block_size + Look up tables are indexed by {(di // bi) for i in input_tensor.dim()} + For example, if the input tensor has shape (N, K), and block_size is (N, group_size), this means + there is a lookup table for group_size columns, i.e., (K // group_size) total look up tables force_kmeans1d (bool): Use kmeans1d regardless of number of weights cluster_dim (int): this means the size of the vector for vector lookup table quantization e.g. when cluster_dim is 4, instead of quantizing each scalar value one by one, we quantize @@ -48,31 +49,38 @@ def choose_qparams_and_quantize_codebook_coreml( Returns: Tuple[torch.Tensor, torch.Tensor] The codebook (lookup table) Tensor and the quantized Tensor (codes, torch.uint8) + The LUT table has dimension input_tensor.dim() + 2, where: + * The first input_tensor.dim() dimensions index over the different tables (input_tensor.shape[i] // block_size[i] in each dimension) + * The input_tensor.dim() + 1 dimension indexes over the nbit indices (2 ** nbits) + * The input_tensor.dim() + 2 dimension indexes over the look up values (shape = 1 for scalar) """ assert code_dtype in list(_SUB_BYTE_UINT_BOUNDS.keys()) + [torch.uint8] - assert len(block_size) == input_tensor.ndim + nbits = _DTYPE_TO_BIT_WIDTH[code_dtype] + assert nbits >= 1 and nbits <= 8, f"nbits must be in [1, 8], got {nbits}" + + assert len(block_size) == input_tensor.dim() block_size = block_size.copy() - for i in range(input_tensor.ndim - 1): - assert block_size[i] == -1 or block_size[i] == input_tensor.shape[i], ( - f"{block_size} not supported" + for i in range(len(block_size)): + if block_size[i] == -1: + block_size[i] = input_tensor.shape[i] + assert block_size[i] >= 1 and input_tensor.shape[i] % block_size[i] == 0, ( + "block_size[i] must divide input_tensor.shape[i]" ) - group_size = block_size[-1] - if group_size == -1: - group_size = input_tensor.shape[-1] - - assert input_tensor.shape[-1] % group_size == 0 - assert input_tensor.ndim == 2 + assert input_tensor.dim() == 2, "Currently only rank 2 tensors are supported" + assert block_size[0] == input_tensor.shape[0], ( + "Currently only support per-grouped channel granularity" + ) assert cluster_dim == 1, ( f"only cluster_dim == 1 is supported right now, got {cluster_dim}" ) + num_lut = input_tensor.shape[1] // block_size[1] + group_size = block_size[1] + # for converting to numpy input_tensor = input_tensor.detach() - # (N, K) original_shape = input_tensor.shape - # (K // group_size) - num_lut = input_tensor.shape[1] // group_size # reshape to (N, K // group_size, group_size) input_tensor = input_tensor.reshape(input_tensor.shape[0], num_lut, group_size) @@ -80,11 +88,6 @@ def choose_qparams_and_quantize_codebook_coreml( _get_kmeans_lookup_table_and_weight, ) - nbits = _DTYPE_TO_BIT_WIDTH[code_dtype] - if nbits > 8: - print(f"Requested nbits: {nbits}, rewriting to 8 bits to reduce the size") - nbits = 8 - res_lut = [] # each res_w[:, i, :] will use the same lookup table # res_w: (N, K // group_size, group_size) @@ -102,6 +105,13 @@ def choose_qparams_and_quantize_codebook_coreml( # res_lut: (K // group_size, 2 ** nbits) res_lut = torch.stack(res_lut, dim=0) + # The final LUT should have dimension equal to input_tensor.dim() + 2 + # The first input_tensor.dim() dimensions index over the tables, + # input_tensor.dim() + 1 indexes over the nbit indices + # input_tensor.dim() + 2 are the look up values (shape = 1 for scalar) + # res_lut: (N, K // group_size, 2 ** nbits, group_size) + res_lut = res_lut.reshape(1, num_lut, 2**nbits, 1) + # reshape back to (N, K) res_w = res_w.reshape(*original_shape) @@ -112,7 +122,7 @@ def choose_qparams_and_quantize_codebook_coreml( def dequantize_codebook( codes: torch.Tensor, codebook: torch.Tensor, - code_dtype: torch.dtype, + nbits: int, block_size: List[int], output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: @@ -121,13 +131,13 @@ def dequantize_codebook( Args: codes (torch.Tensor): Indices of codebook entries for each element - shape (N, K) for scalar quantization - codebook (torch.Tensor): Codebook tensor used for quantization, - shape (K // group_size, 2 ** nbits) where K is the dim 1 shape of input - code_dtype (torch.dtype): The logical dtype for the codes, [torch.uint1, ..., torch.uint8] - Note that codes is stored in torch.uint8, this is just addtional information for dequantize op + General shape: (d0, d1, d2, ..., dN) + Simple example shape: (N, K) + codebook (torch.Tensor): Codebook tensor used for quantization + General shape: (d0 // block_size[0], ..., dN // block_size[N], 2**nbits, vec_dim), where vec_dim = 1 for scalar look up values + Simple example shape: (1, group_size, 2 ** nbits, 1) for scalar look up values, with 1 table per group_size columns + nbits: int: number of bits for the quantization block_size (List[int]): a slice of elements with shape block_size will share the same lookup table - only support (-1, ..., group_size) right now (all preceding dimensions has to match input) output_dtype (torch.dtype): dtype for the output tensor. Returns: @@ -140,37 +150,54 @@ def dequantize_codebook( torch.bfloat16, ], f"Unsupported output dtype: {output_dtype}" - assert code_dtype in list(_SUB_BYTE_UINT_BOUNDS.keys()) + [torch.uint8] + assert nbits >= 1 and nbits <= 8, f"nbits must be in [1, 8], got {nbits}" - assert len(block_size) == codes.ndim + assert len(block_size) == codes.dim() block_size = block_size.copy() - for i in range(codes.ndim - 1): - assert block_size[i] == -1 or block_size[i] == codes.shape[i], ( - f"{block_size} not supported" + for i in range(len(block_size)): + if block_size[i] == -1: + block_size[i] = codes.shape[i] + assert block_size[i] >= 1 and codes.shape[i] % block_size[i] == 0, ( + "block_size[i] must divide codes.shape[i]" ) - group_size = block_size[-1] - if group_size == -1: - group_size = codes.shape[-1] + assert codebook.dim() == codes.dim() + 2 + codebook_shape = codebook.shape + vec_dim = codebook_shape[-1] + quant_levels = 2**nbits - assert codes.shape[-1] % group_size == 0 - K = codes.shape[-1] - num_lut = K // group_size - # (N, K) - original_shape = codes.shape + # Check that last two dimensions of codebook are [quant_levels, vec_dim] + assert codebook_shape[-2] == quant_levels, "Codebook shape mismatch with nbits" - # reshape to (N, num_lut, group_size) - codes = codes.reshape(codes.shape[0], num_lut, group_size) - dequant = torch.zeros_like(codes, dtype=output_dtype) + # Compute shape of lookup group indices from codes shape and block size + code_shape = codes.shape + ndim = len(code_shape) + assert len(block_size) == ndim, "block_size must match dimensionality of codes" - # do lookup for each lookup table - # dequant shape: (N, num_lut, group_size) - # codebook shape: (num_lut, 2 ** nbits) - # codes shape: (N, num_lut, group_size) - for i in range(num_lut): - # dequant[:, i, :]: (N, group_size) - # using squeeze to remove the training dim 1s after the lookup - dequant[:, i, :] = codebook[i][codes[:, i, :]].squeeze() + # Compute which codebook slice to use for each element + group_indices = [] + for dim, bsz in zip(code_shape, block_size): + assert bsz >= 1 and dim % bsz == 0, ( + f"dimension {dim} not divisible by block size {bsz}" + ) + for i, bsz in enumerate(block_size): + indices = torch.arange(code_shape[i], device=codes.device) // bsz + group_indices.append(indices) + + # Broadcast group_indices to shape of codes + mesh = torch.meshgrid(*group_indices, indexing="ij") + group_index_tensor = torch.stack(mesh, dim=-1) # shape (..., N), where N = ndim + + # Flatten everything to index efficiently + flat_codes = codes.reshape(-1) + flat_groups = group_index_tensor.reshape(-1, ndim) # (..., ndim) + + # Compute dequantized values via indexing + # index into codebook with (*group_index, code_index, :) + gathered = codebook[(*flat_groups.T, flat_codes)] # shape (numel, vec_dim) + dequant = gathered.reshape(*code_shape, vec_dim) + + if vec_dim == 1: + dequant = dequant.squeeze(-1) - dequant = dequant.reshape(*original_shape) - return dequant.to(output_dtype) + return dequant.to(dtype=output_dtype) diff --git a/torchao/prototype/quantization/codebook_coreml/codebook_quantized_tensor.py b/torchao/prototype/quantization/codebook_coreml/codebook_quantized_tensor.py index 4c8be29f20..7283a23918 100644 --- a/torchao/prototype/quantization/codebook_coreml/codebook_quantized_tensor.py +++ b/torchao/prototype/quantization/codebook_coreml/codebook_quantized_tensor.py @@ -12,6 +12,9 @@ choose_qparams_and_quantize_codebook_coreml, dequantize_codebook, ) +from torchao.quantization.quant_primitives import ( + _DTYPE_TO_BIT_WIDTH, +) from torchao.utils import TorchAOBaseTensor aten = torch.ops.aten @@ -95,7 +98,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor return dequantize_codebook( codes, self.codebook, - self.code_dtype, + _DTYPE_TO_BIT_WIDTH[self.code_dtype], self.block_size, output_dtype=output_dtype, ) @@ -174,6 +177,17 @@ def _(func, types, args, kwargs): return func(input_tensor, weight_tensor, bias) +@implements([torch.nn.functional.embedding, aten.embedding.default]) +def _(func, types, args, kwargs): + assert len(args) == 2 + indices, weight_tensor = ( + args[0], + args[1], + ) + weight_tensor = weight_tensor.dequantize() + return func(indices, weight_tensor, **kwargs) + + @implements([aten.detach.default, aten.alias.default]) def _(func, types, args, kwargs): return return_and_correct_aliasing( From 5e1189eb73ceaf2750f7888dc18e2ec4a4743a09 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 31 Jul 2025 14:43:54 -0700 Subject: [PATCH 2/4] up --- test/prototype/test_codebook_coreml.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype/test_codebook_coreml.py b/test/prototype/test_codebook_coreml.py index e72edbecbc..33f7648e70 100644 --- a/test/prototype/test_codebook_coreml.py +++ b/test/prototype/test_codebook_coreml.py @@ -33,7 +33,7 @@ def test_choose_qparams_codebook(self): codebook, wq = choose_qparams_and_quantize_codebook_coreml( self.input, self.code_dtype, - [self.input.shape[0], 4], + self.block_size, ) group_size = self.block_size[-1] self.assertEqual(codebook.shape, (1, 256 // group_size, 2**self.nbits, 1)) From 126fb3a2100bc5feae30c7b313bf493e40f2ccff Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Sun, 3 Aug 2025 23:12:21 -0700 Subject: [PATCH 3/4] up --- .../codebook_coreml/codebook_ops.py | 48 ++++++++++++------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/torchao/prototype/quantization/codebook_coreml/codebook_ops.py b/torchao/prototype/quantization/codebook_coreml/codebook_ops.py index a12ac1998c..b9297084fe 100644 --- a/torchao/prototype/quantization/codebook_coreml/codebook_ops.py +++ b/torchao/prototype/quantization/codebook_coreml/codebook_ops.py @@ -49,10 +49,10 @@ def choose_qparams_and_quantize_codebook_coreml( Returns: Tuple[torch.Tensor, torch.Tensor] The codebook (lookup table) Tensor and the quantized Tensor (codes, torch.uint8) - The LUT table has dimension input_tensor.dim() + 2, where: - * The first input_tensor.dim() dimensions index over the different tables (input_tensor.shape[i] // block_size[i] in each dimension) - * The input_tensor.dim() + 1 dimension indexes over the nbit indices (2 ** nbits) - * The input_tensor.dim() + 2 dimension indexes over the look up values (shape = 1 for scalar) + The LUT table has dimension (g0, .., g(N-1), 2**nbits, vec_dim), where: + * The first N dimensions index over the different tables (gi = input_tensor.shape[i] // block_size[i] in each dimension) + * The N + 1 dimension indexes over the nbit indices (2 ** nbits) + * The N + 2 dimension indexes over the look up values (shape = 1 for scalar) """ assert code_dtype in list(_SUB_BYTE_UINT_BOUNDS.keys()) + [torch.uint8] nbits = _DTYPE_TO_BIT_WIDTH[code_dtype] @@ -137,7 +137,8 @@ def dequantize_codebook( General shape: (d0 // block_size[0], ..., dN // block_size[N], 2**nbits, vec_dim), where vec_dim = 1 for scalar look up values Simple example shape: (1, group_size, 2 ** nbits, 1) for scalar look up values, with 1 table per group_size columns nbits: int: number of bits for the quantization - block_size (List[int]): a slice of elements with shape block_size will share the same lookup table + block_size (List[int]): a slice of elements with shape block_size will share the same lookup table. + If block_size[i] == -1, then the entire dimension is used. output_dtype (torch.dtype): dtype for the output tensor. Returns: @@ -171,26 +172,39 @@ def dequantize_codebook( # Compute shape of lookup group indices from codes shape and block size code_shape = codes.shape - ndim = len(code_shape) + ndim = code_shape.ndim assert len(block_size) == ndim, "block_size must match dimensionality of codes" # Compute which codebook slice to use for each element group_indices = [] - for dim, bsz in zip(code_shape, block_size): - assert bsz >= 1 and dim % bsz == 0, ( - f"dimension {dim} not divisible by block size {bsz}" + for i in range(ndim): + assert block_size[i] >= 1 and code_shape[i] % block_size[i] == 0, ( + f"dimension {code_shape[i]} not divisible by block size {block_size[i]}" ) - for i, bsz in enumerate(block_size): - indices = torch.arange(code_shape[i], device=codes.device) // bsz - group_indices.append(indices) - # Broadcast group_indices to shape of codes - mesh = torch.meshgrid(*group_indices, indexing="ij") - group_index_tensor = torch.stack(mesh, dim=-1) # shape (..., N), where N = ndim + # Index of block + idx = ( + torch.arange(code_shape[i], device=codes.device) // block_size[i] + ) # shape (di,) + + # Reshape idx to broadcast along all other dims + shape = [1] * ndim + shape[i] = code_shape[i] + idx = idx.view(*shape) # shape (1, ..., 1, di, 1, ..., 1) + idx = idx.expand(code_shape) # shape (d0, ..., dN) + group_indices.append(idx) + + # Stack the broadcasted group indices + # group_index_tensor at (i0, i1, ..., iN) is the gives the group indices (g0, ..., gN) + # for the element at (i0, i1, ..., iN) in the original code + # If code.shape = (d1, d2, d3), then group_index_tensor.shape = (d1, d2, d3, 3) + group_index_tensor = torch.stack( + group_indices, dim=-1 + ) # shape (d0, d1, ..., dN, ndim) # Flatten everything to index efficiently - flat_codes = codes.reshape(-1) - flat_groups = group_index_tensor.reshape(-1, ndim) # (..., ndim) + flat_codes = codes.reshape(-1) # shape (numel,) + flat_groups = group_index_tensor.reshape(-1, ndim) # (numel, ndim) # Compute dequantized values via indexing # index into codebook with (*group_index, code_index, :) From cceaab2efca1fe92f598893213798aff06e281d2 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Mon, 4 Aug 2025 11:41:40 -0700 Subject: [PATCH 4/4] up --- test/prototype/test_codebook_coreml.py | 2 -- torchao/prototype/quantization/codebook_coreml/codebook_ops.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/test/prototype/test_codebook_coreml.py b/test/prototype/test_codebook_coreml.py index 33f7648e70..69956c7729 100644 --- a/test/prototype/test_codebook_coreml.py +++ b/test/prototype/test_codebook_coreml.py @@ -14,7 +14,6 @@ ) from torchao.quantization import quantize_ from torchao.quantization.utils import compute_error -from torchao.testing.utils import skip_if_no_cuda from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, is_package_at_least @@ -76,7 +75,6 @@ def test_quantize_api(self): ) assert type(m[0].weight) == CodebookQuantizedTensor - @skip_if_no_cuda() @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "requires 2.6+.") def test_export(self): m = torch.nn.Sequential(torch.nn.Linear(128, 64)).to(torch.float32) diff --git a/torchao/prototype/quantization/codebook_coreml/codebook_ops.py b/torchao/prototype/quantization/codebook_coreml/codebook_ops.py index b9297084fe..c945b07edf 100644 --- a/torchao/prototype/quantization/codebook_coreml/codebook_ops.py +++ b/torchao/prototype/quantization/codebook_coreml/codebook_ops.py @@ -172,7 +172,7 @@ def dequantize_codebook( # Compute shape of lookup group indices from codes shape and block size code_shape = codes.shape - ndim = code_shape.ndim + ndim = codes.ndim assert len(block_size) == ndim, "block_size must match dimensionality of codes" # Compute which codebook slice to use for each element