From d2a8f049fa539963f8a34a7a51ef65ae3ea41d84 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 30 Jul 2025 18:47:30 -0700 Subject: [PATCH 1/8] Add palletization support --- backends/apple/coreml/compiler/torch_ops.py | 188 +++++++++++++++++++ backends/apple/coreml/test/test_torch_ops.py | 37 +++- 2 files changed, 220 insertions(+), 5 deletions(-) diff --git a/backends/apple/coreml/compiler/torch_ops.py b/backends/apple/coreml/compiler/torch_ops.py index 11294a69a3d..26ce7be63fa 100644 --- a/backends/apple/coreml/compiler/torch_ops.py +++ b/backends/apple/coreml/compiler/torch_ops.py @@ -21,6 +21,150 @@ transpose, unbind, ) +import numpy as np +from coremltools.converters.mil.frontend.torch.torch_op_registry import ( + register_torch_op, +) +from coremltools.converters.mil.mil import types +from executorch.exir.dim_order_utils import get_memory_format + + +# https://github.com/apple/coremltools/pull/2556 +@register_torch_op(override=False) +def transpose_copy(context, node): + transpose(context, node) + + +# https://github.com/apple/coremltools/pull/2557 +@register_torch_op(override=False) +def unbind_copy(context, node): + unbind(context, node) + + +# https://github.com/apple/coremltools/pull/2563 +@register_torch_op(override=False) +def split_copy(context, node): + split(context, node) + + +@register_torch_op( + torch_alias=[ + "dim_order_ops::_to_dim_order_copy", + "dim_order_ops._to_dim_order_copy", + ], + override=False, +) +def _to_dim_order_copy(context, node): + dim_order = _get_kwinputs(context, node, "dim_order", default=[None])[0] + node.kwinputs.pop("dim_order") + + # In CoreML, dim_order.val will be an ndarray, so we convert it to a list + dim_order = [int(d) for d in dim_order.val] + memory_format = get_memory_format(dim_order) + assert ( + memory_format == _torch.contiguous_format + ), "Only contiguous memory format is supported in CoreML" + to(context, node) + + +# https://github.com/apple/coremltools/pull/2558 +@register_torch_op( + torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"], + override=False, +) +def dequantize_affine(context, node): + inputs = _get_inputs(context, node, expected=[7, 8]) + int_data = inputs[0].val + block_size = inputs[1].val + scale = inputs[2].val + zero_point = ( + inputs[3].val if inputs[3] is not None and inputs[3].val is not None else None + ) + # I do not think we need to worry about input_dtype b/c it gets cast to int4/int8 + # For now, we just check that it is int8 or int32 + input_dtype = inputs[4].val # noqa: F841 + assert NUM_TO_TORCH_DTYPE[input_dtype] in [ + _torch.int8, + _torch.int32, + ], "input_dtype should be int8 or int32" + + quant_min = inputs[5].val + quant_max = inputs[6].val + + assert len(int_data.shape) == 2, "dequantize_affine only supports rank 2 inputs" + + assert len(int_data.shape) == len( + block_size + ), "block_size must have the same length as int_data.shape" + assert block_size[0] == 1, "block_size[0] must be 1" + group_size = block_size[1] + k = int_data.shape[1] + assert k % group_size == 0, "k must be divisible by group_size" + scales_per_row = k // group_size + scale = scale.reshape(-1, scales_per_row) + if zero_point is not None: + zero_point = zero_point.reshape(-1, scales_per_row) + + # TODO: I don't know if CoreML can make use of this + # We could add a cast op to the output, but I'm pretty CoreML will remove this during a later pass + # For now, we just log a warning + out_np_dtype = None + if len(inputs) > 7: + out_np_dtype = NUM_TO_NUMPY_DTYPE[inputs[7].val] + _logger.warning( + f"Core ML ignores output_dtype {out_np_dtype} on torchao.dequantize_affine and instead uses the native precision." + ) + + if quant_min == -8 and quant_max == 7: + quantized_np_dtype = types.nptype_from_builtin(types.string_to_builtin("int4")) + elif quant_min == -128 and quant_max == 127: + quantized_np_dtype = types.nptype_from_builtin(types.string_to_builtin("int8")) + else: + raise ValueError( + f"Unsupported quantization range: {quant_min} to {quant_max}. CoreML only supports 4-bit and 8-bit quantization." + ) + + output = _utils._construct_constexpr_dequant_op( + int_data.astype(quantized_np_dtype), + zero_point, + scale, + axis=-1, + name=node.name, + ) + context.add(output, node.name) + + + +# codes: torch.Tensor, +# codebook: torch.Tensor, +# code_dtype: torch.dtype, +# block_size: List[int], +# output_dtype: torch.dtype = torch.float32, + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file registers torch ops that are not yet in coremltools, or are in a more recent version of +# coremltools than is used by ExecuTorch. Each op registered here should have a link to a PR in coremltools that adds +# the op to the coremltools library. + +import torch as _torch +from coremltools import _logger +from coremltools.converters.mil.frontend import _utils +from coremltools.converters.mil.frontend.torch.ops import ( + _get_inputs, + _get_kwinputs, + NUM_TO_NUMPY_DTYPE, + NUM_TO_TORCH_DTYPE, + split, + to, + transpose, + unbind, +) +import numpy as np from coremltools.converters.mil.frontend.torch.torch_op_registry import ( register_torch_op, @@ -132,3 +276,47 @@ def dequantize_affine(context, node): name=node.name, ) context.add(output, node.name) + + +@register_torch_op( + torch_alias=["quant::dequantize_codebook", "quant.dequantize_codebook"], + override=False, +) +def dequantize_codebook(context, node): + print("IN DEQUANTIZE CODEBOOK") + inputs = _get_inputs(context, node, expected=[4, 5]) + codes = inputs[0].val + codebook = inputs[1].val + code_dtype = inputs[2].val + block_size = inputs[3].val + + + assert len(codes.shape) == 2, "Only rank 2 inputs are supported" + + # In TorchAO, the codebook shape is (n_lut, nbit, 1). The LUTs are for the columns. + # In CoreML, the expected shape is (lut_block_size, nbit, 1). 1 here is for scalar + # lut_block_size has the same rank as codes/idxs and tells you how many LUTs there are per block, e.g., + # lut_block_size=(1, 8) means there is 1 LUT per 8 columns + assert len(codebook.shape) == 3, "Only rank 3 inputs are supported for codebook" + assert codebook.shape[-1] == 1, "we only support scalar palletization" + codebook = np.expand_dims(codebook, 0) + + + if len(inputs) > 4: + output_dtype = inputs[4].val + out_np_dtype = NUM_TO_NUMPY_DTYPE[output_dtype] + _logger.warning( + f"Core ML ignores output_dtype {out_np_dtype} on torchao.dequantize_affine and instead uses the native precision." + ) + + print("codes", codes.shape) + print("codebook", codebook.shape) + print("code_dtype", code_dtype) + print("block_size", block_size) + + output = _utils._construct_constexpr_lut_op( + codes.astype(np.int8), + codebook, + name=node.name, + ) + context.add(output, node.name) diff --git a/backends/apple/coreml/test/test_torch_ops.py b/backends/apple/coreml/test/test_torch_ops.py index 323f76afd1b..22e87d2790d 100644 --- a/backends/apple/coreml/test/test_torch_ops.py +++ b/backends/apple/coreml/test/test_torch_ops.py @@ -16,6 +16,10 @@ from executorch.backends.apple.coreml.partition import CoreMLPartitioner from torchao.quantization import IntxWeightOnlyConfig, PerAxis, PerGroup, quantize_ +from torchao.prototype.quantization.codebook_coreml import ( + CodebookWeightOnlyConfig, +) + def is_fbcode(): return not hasattr(torch.version, "git_version") @@ -163,12 +167,35 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self): ], f"Got unexpected node target after delegation: {node.target.__name__}" et_prog = delegated_program.to_executorch() self._compare_outputs(et_prog, model, example_inputs) + + def test_dequantize_codebook_linear(self): + model, example_inputs = self._get_test_model() + quantize_( + model, + CodebookWeightOnlyConfig(dtype=torch.uint8, block_size=[-1, 16]), + ) + ep = torch.export.export(model, example_inputs) + print("ORIGINAL MODEL", ep) + delegated_program = executorch.exir.to_edge_transform_and_lower( + ep, + partitioner=[self._coreml_partitioner()], + ) + for node in delegated_program.exported_program().graph.nodes: + if node.op == "call_function": + assert node.target.__name__ in [ + "executorch_call_delegate", + "getitem", + ], f"Got unexpected node target after delegation: {node.target.__name__}" + et_prog = delegated_program.to_executorch() + print(et_prog.exported_program()) + self._compare_outputs(et_prog, model, example_inputs) if __name__ == "__main__": test_runner = TestTorchOps() - test_runner.test_dequantize_affine_b4w_embedding() - test_runner.test_dequantize_affine_b4w_linear() - test_runner.test_dequantize_affine_c4w_embedding() - test_runner.test_dequantize_affine_c4w_linear() - test_runner.test_dequantize_affine_c8w_embedding_b4w_linear() + # test_runner.test_dequantize_affine_b4w_embedding() + # test_runner.test_dequantize_affine_b4w_linear() + # test_runner.test_dequantize_affine_c4w_embedding() + # test_runner.test_dequantize_affine_c4w_linear() + # test_runner.test_dequantize_affine_c8w_embedding_b4w_linear() + test_runner.test_dequantize_codebook_linear() From 44acb33904798487e7bc45e755fcbf528a410c03 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 31 Jul 2025 14:33:13 -0700 Subject: [PATCH 2/8] Add codebook support to CoreML using quantize_ --- backends/apple/coreml/compiler/torch_ops.py | 174 ++----------------- backends/apple/coreml/test/test_torch_ops.py | 43 +++-- 2 files changed, 45 insertions(+), 172 deletions(-) diff --git a/backends/apple/coreml/compiler/torch_ops.py b/backends/apple/coreml/compiler/torch_ops.py index 26ce7be63fa..81306c9a2fd 100644 --- a/backends/apple/coreml/compiler/torch_ops.py +++ b/backends/apple/coreml/compiler/torch_ops.py @@ -8,149 +8,7 @@ # coremltools than is used by ExecuTorch. Each op registered here should have a link to a PR in coremltools that adds # the op to the coremltools library. -import torch as _torch -from coremltools import _logger -from coremltools.converters.mil.frontend import _utils -from coremltools.converters.mil.frontend.torch.ops import ( - _get_inputs, - _get_kwinputs, - NUM_TO_NUMPY_DTYPE, - NUM_TO_TORCH_DTYPE, - split, - to, - transpose, - unbind, -) import numpy as np -from coremltools.converters.mil.frontend.torch.torch_op_registry import ( - register_torch_op, -) -from coremltools.converters.mil.mil import types -from executorch.exir.dim_order_utils import get_memory_format - - -# https://github.com/apple/coremltools/pull/2556 -@register_torch_op(override=False) -def transpose_copy(context, node): - transpose(context, node) - - -# https://github.com/apple/coremltools/pull/2557 -@register_torch_op(override=False) -def unbind_copy(context, node): - unbind(context, node) - - -# https://github.com/apple/coremltools/pull/2563 -@register_torch_op(override=False) -def split_copy(context, node): - split(context, node) - - -@register_torch_op( - torch_alias=[ - "dim_order_ops::_to_dim_order_copy", - "dim_order_ops._to_dim_order_copy", - ], - override=False, -) -def _to_dim_order_copy(context, node): - dim_order = _get_kwinputs(context, node, "dim_order", default=[None])[0] - node.kwinputs.pop("dim_order") - - # In CoreML, dim_order.val will be an ndarray, so we convert it to a list - dim_order = [int(d) for d in dim_order.val] - memory_format = get_memory_format(dim_order) - assert ( - memory_format == _torch.contiguous_format - ), "Only contiguous memory format is supported in CoreML" - to(context, node) - - -# https://github.com/apple/coremltools/pull/2558 -@register_torch_op( - torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"], - override=False, -) -def dequantize_affine(context, node): - inputs = _get_inputs(context, node, expected=[7, 8]) - int_data = inputs[0].val - block_size = inputs[1].val - scale = inputs[2].val - zero_point = ( - inputs[3].val if inputs[3] is not None and inputs[3].val is not None else None - ) - # I do not think we need to worry about input_dtype b/c it gets cast to int4/int8 - # For now, we just check that it is int8 or int32 - input_dtype = inputs[4].val # noqa: F841 - assert NUM_TO_TORCH_DTYPE[input_dtype] in [ - _torch.int8, - _torch.int32, - ], "input_dtype should be int8 or int32" - - quant_min = inputs[5].val - quant_max = inputs[6].val - - assert len(int_data.shape) == 2, "dequantize_affine only supports rank 2 inputs" - - assert len(int_data.shape) == len( - block_size - ), "block_size must have the same length as int_data.shape" - assert block_size[0] == 1, "block_size[0] must be 1" - group_size = block_size[1] - k = int_data.shape[1] - assert k % group_size == 0, "k must be divisible by group_size" - scales_per_row = k // group_size - scale = scale.reshape(-1, scales_per_row) - if zero_point is not None: - zero_point = zero_point.reshape(-1, scales_per_row) - - # TODO: I don't know if CoreML can make use of this - # We could add a cast op to the output, but I'm pretty CoreML will remove this during a later pass - # For now, we just log a warning - out_np_dtype = None - if len(inputs) > 7: - out_np_dtype = NUM_TO_NUMPY_DTYPE[inputs[7].val] - _logger.warning( - f"Core ML ignores output_dtype {out_np_dtype} on torchao.dequantize_affine and instead uses the native precision." - ) - - if quant_min == -8 and quant_max == 7: - quantized_np_dtype = types.nptype_from_builtin(types.string_to_builtin("int4")) - elif quant_min == -128 and quant_max == 127: - quantized_np_dtype = types.nptype_from_builtin(types.string_to_builtin("int8")) - else: - raise ValueError( - f"Unsupported quantization range: {quant_min} to {quant_max}. CoreML only supports 4-bit and 8-bit quantization." - ) - - output = _utils._construct_constexpr_dequant_op( - int_data.astype(quantized_np_dtype), - zero_point, - scale, - axis=-1, - name=node.name, - ) - context.add(output, node.name) - - - -# codes: torch.Tensor, -# codebook: torch.Tensor, -# code_dtype: torch.dtype, -# block_size: List[int], -# output_dtype: torch.dtype = torch.float32, - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# This file registers torch ops that are not yet in coremltools, or are in a more recent version of -# coremltools than is used by ExecuTorch. Each op registered here should have a link to a PR in coremltools that adds -# the op to the coremltools library. - import torch as _torch from coremltools import _logger from coremltools.converters.mil.frontend import _utils @@ -164,8 +22,6 @@ def dequantize_affine(context, node): transpose, unbind, ) -import numpy as np - from coremltools.converters.mil.frontend.torch.torch_op_registry import ( register_torch_op, ) @@ -283,24 +139,25 @@ def dequantize_affine(context, node): override=False, ) def dequantize_codebook(context, node): - print("IN DEQUANTIZE CODEBOOK") inputs = _get_inputs(context, node, expected=[4, 5]) codes = inputs[0].val codebook = inputs[1].val - code_dtype = inputs[2].val - block_size = inputs[3].val + nbits = inputs[2].val - - assert len(codes.shape) == 2, "Only rank 2 inputs are supported" + # information in block_size is redundant with codebook.shape + block_size = inputs[3].val # noqa: F841 - # In TorchAO, the codebook shape is (n_lut, nbit, 1). The LUTs are for the columns. - # In CoreML, the expected shape is (lut_block_size, nbit, 1). 1 here is for scalar - # lut_block_size has the same rank as codes/idxs and tells you how many LUTs there are per block, e.g., - # lut_block_size=(1, 8) means there is 1 LUT per 8 columns - assert len(codebook.shape) == 3, "Only rank 3 inputs are supported for codebook" - assert codebook.shape[-1] == 1, "we only support scalar palletization" - codebook = np.expand_dims(codebook, 0) + assert len(codes.shape) == 2, "Only rank 2 inputs are supported" + # Assert codebook is as expected. codebook.dim() = codes.dim() + 2 + assert len(codebook.shape) == 4, "Only rank 4 inputs are supported for codebook" + assert codebook.shape[0] == 1, "Only grouped_channel granularity is supported" + n_luts = codebook.shape[1] + assert ( + codes.shape[1] % n_luts == 0 + ), "codes.shape[1] must be divisible by codebook.shape[1]" + assert codebook.shape[2] == 2**nbits + assert codebook.shape[3] == 1, "Only scalar look up values are supported" if len(inputs) > 4: output_dtype = inputs[4].val @@ -309,11 +166,6 @@ def dequantize_codebook(context, node): f"Core ML ignores output_dtype {out_np_dtype} on torchao.dequantize_affine and instead uses the native precision." ) - print("codes", codes.shape) - print("codebook", codebook.shape) - print("code_dtype", code_dtype) - print("block_size", block_size) - output = _utils._construct_constexpr_lut_op( codes.astype(np.int8), codebook, diff --git a/backends/apple/coreml/test/test_torch_ops.py b/backends/apple/coreml/test/test_torch_ops.py index 22e87d2790d..ee196c68e19 100644 --- a/backends/apple/coreml/test/test_torch_ops.py +++ b/backends/apple/coreml/test/test_torch_ops.py @@ -14,11 +14,9 @@ from executorch.backends.apple.coreml.compiler import CoreMLBackend from executorch.backends.apple.coreml.partition import CoreMLPartitioner -from torchao.quantization import IntxWeightOnlyConfig, PerAxis, PerGroup, quantize_ -from torchao.prototype.quantization.codebook_coreml import ( - CodebookWeightOnlyConfig, -) +from torchao.prototype.quantization.codebook_coreml import CodebookWeightOnlyConfig +from torchao.quantization import IntxWeightOnlyConfig, PerAxis, PerGroup, quantize_ def is_fbcode(): @@ -167,12 +165,12 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self): ], f"Got unexpected node target after delegation: {node.target.__name__}" et_prog = delegated_program.to_executorch() self._compare_outputs(et_prog, model, example_inputs) - + def test_dequantize_codebook_linear(self): model, example_inputs = self._get_test_model() quantize_( model, - CodebookWeightOnlyConfig(dtype=torch.uint8, block_size=[-1, 16]), + CodebookWeightOnlyConfig(dtype=torch.uint2, block_size=[-1, 16]), ) ep = torch.export.export(model, example_inputs) print("ORIGINAL MODEL", ep) @@ -190,12 +188,35 @@ def test_dequantize_codebook_linear(self): print(et_prog.exported_program()) self._compare_outputs(et_prog, model, example_inputs) + def test_dequantize_codebook_embedding(self): + model, example_inputs = self._get_test_model() + quantize_( + model, + CodebookWeightOnlyConfig(dtype=torch.uint3, block_size=[-1, 16]), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + ep = torch.export.export(model, example_inputs) + delegated_program = executorch.exir.to_edge_transform_and_lower( + ep, + partitioner=[self._coreml_partitioner()], + ) + for node in delegated_program.exported_program().graph.nodes: + if node.op == "call_function": + assert node.target.__name__ in [ + "executorch_call_delegate", + "getitem", + ], f"Got unexpected node target after delegation: {node.target.__name__}" + et_prog = delegated_program.to_executorch() + print(et_prog.exported_program()) + self._compare_outputs(et_prog, model, example_inputs) + if __name__ == "__main__": test_runner = TestTorchOps() - # test_runner.test_dequantize_affine_b4w_embedding() - # test_runner.test_dequantize_affine_b4w_linear() - # test_runner.test_dequantize_affine_c4w_embedding() - # test_runner.test_dequantize_affine_c4w_linear() - # test_runner.test_dequantize_affine_c8w_embedding_b4w_linear() + test_runner.test_dequantize_affine_b4w_embedding() + test_runner.test_dequantize_affine_b4w_linear() + test_runner.test_dequantize_affine_c4w_embedding() + test_runner.test_dequantize_affine_c4w_linear() + test_runner.test_dequantize_affine_c8w_embedding_b4w_linear() test_runner.test_dequantize_codebook_linear() + test_runner.test_dequantize_codebook_embedding() From ae46ad934200236cc1cad1c2a30754102ea32566 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Mon, 4 Aug 2025 15:40:13 -0700 Subject: [PATCH 3/8] Update torchao pin --- third-party/ao | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third-party/ao b/third-party/ao index 2eb4f9762d5..ca5f7887c24 160000 --- a/third-party/ao +++ b/third-party/ao @@ -1 +1 @@ -Subproject commit 2eb4f9762d5f995ba44342c34039adc45d3577c2 +Subproject commit ca5f7887c24aaeab21bc4b3519ea9802f754d710 From e2a5fb5527686d8f4ed0c0838704b715fa1ffe8f Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Mon, 4 Aug 2025 16:47:53 -0700 Subject: [PATCH 4/8] up --- third-party/ao | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third-party/ao b/third-party/ao index ca5f7887c24..6bb2baf0512 160000 --- a/third-party/ao +++ b/third-party/ao @@ -1 +1 @@ -Subproject commit ca5f7887c24aaeab21bc4b3519ea9802f754d710 +Subproject commit 6bb2baf05122fe5b2a0f982a63140d5832e33cf5 From 88b7f77f9592e35907907562499b0ce66698f28b Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Mon, 4 Aug 2025 17:44:28 -0700 Subject: [PATCH 5/8] up --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 40ff4eb0465..98cf935c191 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,8 @@ dependencies=[ "typing-extensions>=4.10.0", # Keep this version in sync with: ./backends/apple/coreml/scripts/install_requirements.sh "coremltools==8.3; platform_system == 'Darwin' or platform_system == 'Linux'", + # scikit-learn is used to support palettization in the coreml backend + "scikit-learn==1.7.1", "hydra-core>=1.3.0", "omegaconf>=2.3.0", ] From 5be96ca7b578f321b947d18b957cd53d50e6c95e Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 6 Aug 2025 11:29:44 -0700 Subject: [PATCH 6/8] up --- backends/apple/coreml/test/test_torch_ops.py | 21 +++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/backends/apple/coreml/test/test_torch_ops.py b/backends/apple/coreml/test/test_torch_ops.py index ee196c68e19..c07a936936a 100644 --- a/backends/apple/coreml/test/test_torch_ops.py +++ b/backends/apple/coreml/test/test_torch_ops.py @@ -17,7 +17,7 @@ from torchao.prototype.quantization.codebook_coreml import CodebookWeightOnlyConfig from torchao.quantization import IntxWeightOnlyConfig, PerAxis, PerGroup, quantize_ - +from executorch.exir.backend.utils import format_delegated_graph def is_fbcode(): return not hasattr(torch.version, "git_version") @@ -173,7 +173,7 @@ def test_dequantize_codebook_linear(self): CodebookWeightOnlyConfig(dtype=torch.uint2, block_size=[-1, 16]), ) ep = torch.export.export(model, example_inputs) - print("ORIGINAL MODEL", ep) + assert "torch.ops.quant.dequantize_codebook.default" in ep.graph_module.code delegated_program = executorch.exir.to_edge_transform_and_lower( ep, partitioner=[self._coreml_partitioner()], @@ -184,8 +184,11 @@ def test_dequantize_codebook_linear(self): "executorch_call_delegate", "getitem", ], f"Got unexpected node target after delegation: {node.target.__name__}" + + print(format_delegated_graph(delegated_program.exported_program().graph_module)) + + et_prog = delegated_program.to_executorch() - print(et_prog.exported_program()) self._compare_outputs(et_prog, model, example_inputs) def test_dequantize_codebook_embedding(self): @@ -196,6 +199,7 @@ def test_dequantize_codebook_embedding(self): lambda m, fqn: isinstance(m, torch.nn.Embedding), ) ep = torch.export.export(model, example_inputs) + assert "torch.ops.quant.dequantize_codebook.default" in ep.graph_module.code delegated_program = executorch.exir.to_edge_transform_and_lower( ep, partitioner=[self._coreml_partitioner()], @@ -207,16 +211,15 @@ def test_dequantize_codebook_embedding(self): "getitem", ], f"Got unexpected node target after delegation: {node.target.__name__}" et_prog = delegated_program.to_executorch() - print(et_prog.exported_program()) self._compare_outputs(et_prog, model, example_inputs) if __name__ == "__main__": test_runner = TestTorchOps() - test_runner.test_dequantize_affine_b4w_embedding() - test_runner.test_dequantize_affine_b4w_linear() - test_runner.test_dequantize_affine_c4w_embedding() - test_runner.test_dequantize_affine_c4w_linear() - test_runner.test_dequantize_affine_c8w_embedding_b4w_linear() + # test_runner.test_dequantize_affine_b4w_embedding() + # test_runner.test_dequantize_affine_b4w_linear() + # test_runner.test_dequantize_affine_c4w_embedding() + # test_runner.test_dequantize_affine_c4w_linear() + # test_runner.test_dequantize_affine_c8w_embedding_b4w_linear() test_runner.test_dequantize_codebook_linear() test_runner.test_dequantize_codebook_embedding() From efbe98d664a79b8b3cf27c3180dfd5c375b8e3d1 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 6 Aug 2025 11:52:17 -0700 Subject: [PATCH 7/8] up --- backends/apple/coreml/test/test_torch_ops.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/backends/apple/coreml/test/test_torch_ops.py b/backends/apple/coreml/test/test_torch_ops.py index c07a936936a..a90d292d49a 100644 --- a/backends/apple/coreml/test/test_torch_ops.py +++ b/backends/apple/coreml/test/test_torch_ops.py @@ -14,10 +14,11 @@ from executorch.backends.apple.coreml.compiler import CoreMLBackend from executorch.backends.apple.coreml.partition import CoreMLPartitioner +from executorch.exir.backend.utils import format_delegated_graph from torchao.prototype.quantization.codebook_coreml import CodebookWeightOnlyConfig from torchao.quantization import IntxWeightOnlyConfig, PerAxis, PerGroup, quantize_ -from executorch.exir.backend.utils import format_delegated_graph + def is_fbcode(): return not hasattr(torch.version, "git_version") @@ -185,8 +186,10 @@ def test_dequantize_codebook_linear(self): "getitem", ], f"Got unexpected node target after delegation: {node.target.__name__}" - print(format_delegated_graph(delegated_program.exported_program().graph_module)) - + assert ( + "executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default" + in format_delegated_graph(delegated_program.exported_program().graph_module) + ) et_prog = delegated_program.to_executorch() self._compare_outputs(et_prog, model, example_inputs) @@ -210,6 +213,12 @@ def test_dequantize_codebook_embedding(self): "executorch_call_delegate", "getitem", ], f"Got unexpected node target after delegation: {node.target.__name__}" + + assert ( + "executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default" + in format_delegated_graph(delegated_program.exported_program().graph_module) + ) + et_prog = delegated_program.to_executorch() self._compare_outputs(et_prog, model, example_inputs) From 5ee46af69625de4c1c3699f86c3b9968a6b097cd Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 6 Aug 2025 11:53:29 -0700 Subject: [PATCH 8/8] up --- backends/apple/coreml/test/test_torch_ops.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/backends/apple/coreml/test/test_torch_ops.py b/backends/apple/coreml/test/test_torch_ops.py index a90d292d49a..89eab1a8b00 100644 --- a/backends/apple/coreml/test/test_torch_ops.py +++ b/backends/apple/coreml/test/test_torch_ops.py @@ -225,10 +225,10 @@ def test_dequantize_codebook_embedding(self): if __name__ == "__main__": test_runner = TestTorchOps() - # test_runner.test_dequantize_affine_b4w_embedding() - # test_runner.test_dequantize_affine_b4w_linear() - # test_runner.test_dequantize_affine_c4w_embedding() - # test_runner.test_dequantize_affine_c4w_linear() - # test_runner.test_dequantize_affine_c8w_embedding_b4w_linear() + test_runner.test_dequantize_affine_b4w_embedding() + test_runner.test_dequantize_affine_b4w_linear() + test_runner.test_dequantize_affine_c4w_embedding() + test_runner.test_dequantize_affine_c4w_linear() + test_runner.test_dequantize_affine_c8w_embedding_b4w_linear() test_runner.test_dequantize_codebook_linear() test_runner.test_dequantize_codebook_embedding()