-
Notifications
You must be signed in to change notification settings - Fork 310
Update coreml codebook #2648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update coreml codebook #2648
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2648
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 PendingAs of commit cceaab2 with merge base 22f9d31 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
weight_tensor = weight_tensor.dequantize() | ||
return func(indices, weight_tensor, **kwargs) | ||
|
||
|
||
@implements([aten.detach.default, aten.alias.default]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This I will save for a future PR
@@ -48,43 +49,45 @@ def choose_qparams_and_quantize_codebook_coreml( | |||
|
|||
Returns: | |||
Tuple[torch.Tensor, torch.Tensor] The codebook (lookup table) Tensor and the quantized Tensor (codes, torch.uint8) | |||
The LUT table has dimension input_tensor.dim() + 2, where: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe spell out the dimensions with variables/expressions?
dequant = torch.zeros_like(codes, dtype=output_dtype) | ||
# Compute shape of lookup group indices from codes shape and block size | ||
code_shape = codes.shape | ||
ndim = len(code_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can do codes.ndim
dequant[:, i, :] = codebook[i][codes[:, i, :]].squeeze() | ||
# Compute which codebook slice to use for each element | ||
group_indices = [] | ||
for dim, bsz in zip(code_shape, block_size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: dim might be confusing? this seems to be code_size_i
and block_size_i
mesh = torch.meshgrid(*group_indices, indexing="ij") | ||
group_index_tensor = torch.stack(mesh, dim=-1) # shape (..., N), where N = ndim | ||
|
||
# Flatten everything to index efficiently | ||
flat_codes = codes.reshape(-1) | ||
flat_groups = group_index_tensor.reshape(-1, ndim) # (..., ndim) | ||
|
||
# Compute dequantized values via indexing | ||
# index into codebook with (*group_index, code_index, :) | ||
gathered = codebook[(*flat_groups.T, flat_codes)] # shape (numel, vec_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you add comments of some examples for these to make it easier to understand
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, just some comments about making it easier to understand
This adds palletization support for embedding/linear layers in CoreML using TorchAO's quantize_ API. Note, this needs to wait for pytorch/ao#2648 to land in ao + a pin bump in ET before landing.
This PR updates the dequantize_codebook quant primitive to be more compatible with CoreML. More specifically:
code_dtype is changed to nbits because CoreML cannot process a function that has non-standard dtypes in the signature (e.g., torch.uint3)
This changes the codebook rank to codes.dim() + 2 to follow the convention here https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS18.compression.constexpr_lut_to_dense