Skip to content

Update coreml codebook #2648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 4, 2025
Merged

Update coreml codebook #2648

merged 4 commits into from
Aug 4, 2025

Conversation

metascroy
Copy link
Contributor

This PR updates the dequantize_codebook quant primitive to be more compatible with CoreML. More specifically:

Copy link

pytorch-bot bot commented Jul 31, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2648

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Pending

As of commit cceaab2 with merge base 22f9d31 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@metascroy metascroy requested a review from jerryzh168 July 31, 2025 21:49
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 31, 2025
@metascroy metascroy added topic: not user facing Use this tag if you don't want this PR to show up in release notes topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) labels Aug 1, 2025
weight_tensor = weight_tensor.dequantize()
return func(indices, weight_tensor, **kwargs)


@implements([aten.detach.default, aten.alias.default])
Copy link
Contributor

@jerryzh168 jerryzh168 Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can be a separate PR, you can try defining CodebookTensor.tensor_data_names and CodebookTensor.tensor_attribute_names removing these things now and see if it still works

#2597 and #2598 added more utils for TorchAOBaseTensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This I will save for a future PR

@@ -48,43 +49,45 @@ def choose_qparams_and_quantize_codebook_coreml(

Returns:
Tuple[torch.Tensor, torch.Tensor] The codebook (lookup table) Tensor and the quantized Tensor (codes, torch.uint8)
The LUT table has dimension input_tensor.dim() + 2, where:
Copy link
Contributor

@jerryzh168 jerryzh168 Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe spell out the dimensions with variables/expressions?

dequant = torch.zeros_like(codes, dtype=output_dtype)
# Compute shape of lookup group indices from codes shape and block size
code_shape = codes.shape
ndim = len(code_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can do codes.ndim

dequant[:, i, :] = codebook[i][codes[:, i, :]].squeeze()
# Compute which codebook slice to use for each element
group_indices = []
for dim, bsz in zip(code_shape, block_size):
Copy link
Contributor

@jerryzh168 jerryzh168 Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: dim might be confusing? this seems to be code_size_i and block_size_i

Comment on lines 188 to 197
mesh = torch.meshgrid(*group_indices, indexing="ij")
group_index_tensor = torch.stack(mesh, dim=-1) # shape (..., N), where N = ndim

# Flatten everything to index efficiently
flat_codes = codes.reshape(-1)
flat_groups = group_index_tensor.reshape(-1, ndim) # (..., ndim)

# Compute dequantized values via indexing
# index into codebook with (*group_index, code_index, :)
gathered = codebook[(*flat_groups.T, flat_codes)] # shape (numel, vec_dim)
Copy link
Contributor

@jerryzh168 jerryzh168 Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you add comments of some examples for these to make it easier to understand

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, just some comments about making it easier to understand

@metascroy metascroy merged commit ca5f788 into main Aug 4, 2025
17 of 20 checks passed
metascroy added a commit to pytorch/executorch that referenced this pull request Aug 6, 2025
This adds palletization support for embedding/linear layers in CoreML
using TorchAO's quantize_ API.

Note, this needs to wait for pytorch/ao#2648 to
land in ao + a pin bump in ET before landing.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants