From e2b6750fc747f81857fc708e9fac330ce2d7b477 Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Tue, 15 Jul 2025 13:57:59 -0700 Subject: [PATCH] Fix delegate node metadata (#12504) Summary: The delegate node's metadata was set incorrectly, causing deserialization to fail Reviewed By: mcr229 Differential Revision: D78350040 --- backends/arm/test/tester/arm_tester.py | 2 +- exir/backend/backend_api.py | 6 +- .../test/demos/test_xnnpack_qnnpack.py | 58 ++++++++++++++++++- 3 files changed, 61 insertions(+), 5 deletions(-) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 99d95a2bc8e..7c160a7e0db 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -726,7 +726,7 @@ def _get_dtype_distribution( if node.op == "placeholder": placeholder_dtypes.append(str(node.meta["val"].dtype)) if node.op == "call_function": - if "val" in node.meta: + if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor): dtype, _, _ = extract_tensor_meta(node.meta, tosa_spec) call_function_dtypes.append(ts.DTypeNames[dtype]) return Counter(placeholder_dtypes), Counter(call_function_dtypes) diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index 724bbf3fcf6..6bb2df3dfdb 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -235,7 +235,9 @@ def generate_debug_handle(ep: ExportedProgram) -> int: call_submodule_node.kwargs, ) call_delegate_node.meta["debug_handle"] = generate_debug_handle(owning_program) - call_delegate_node.meta["val"] = submodule_output_node.meta["val"] + call_delegate_node.meta["val"] = [ + out_arg.meta["val"] for out_arg in submodule_output_node.args[0] + ] call_submodule_node.replace_all_uses_with(call_delegate_node) owning_graph_module.graph.erase_node(call_submodule_node) if is_submodule: @@ -472,11 +474,9 @@ def _create_partitions_in_graph_module( tagged_graph_module, node_list, tag ) - tagged_graph_module_output_node = tagged_graph_module.graph.output_node() submodule_output_node = submodule.graph.output_node() # Copy the output node meta from the original output node, because # create_submodule_from_nodes doesn't cover the meta field - submodule_output_node.meta = tagged_graph_module_output_node.meta logging.debug(f"Partitioned graph module: {tagged_graph_module}") ( submodule_program, diff --git a/exir/backend/test/demos/test_xnnpack_qnnpack.py b/exir/backend/test/demos/test_xnnpack_qnnpack.py index 5cbd7f7f659..7600988e19d 100644 --- a/exir/backend/test/demos/test_xnnpack_qnnpack.py +++ b/exir/backend/test/demos/test_xnnpack_qnnpack.py @@ -4,8 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import tempfile import unittest +from typing import Tuple + import executorch.exir as exir import torch @@ -20,7 +23,13 @@ # import the xnnpack backend implementation from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend -from executorch.exir import CaptureConfig +from executorch.exir import ( + CaptureConfig, + EdgeCompileConfig, + EdgeProgramManager, + to_edge_transform_and_lower, +) + from executorch.exir.backend.backend_api import to_backend, validation_disabled from executorch.exir.passes.spec_prop_pass import SpecPropPass @@ -132,3 +141,50 @@ def forward(self, x, y): self.assertTrue( torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03) ) + + def test_serde(self): + # The module with blank_logprobs() function + class BlankLogProbsModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(768, 1) + self.log_sigmoid = torch.nn.LogSigmoid() + + def forward(self, joint_encodings: torch.Tensor) -> torch.Tensor: + tanh_out = torch.tanh(joint_encodings) + linear_out = self.linear(tanh_out) + blank_output = self.log_sigmoid(linear_out) + return blank_output + + def get_blank_logprobs_inputs_fn() -> Tuple[torch.Tensor, ...]: + """ + Get the input to the blank_logprobs() and nonblank_logprobs() functions. + """ + return (torch.randn(1, 1, 1, 768),) + + model = BlankLogProbsModule() + # Get the inputs for the logprobs function + logprobs_fake_inputs = get_blank_logprobs_inputs_fn() + + # Export and partition + aten_prog = torch.export.export(model, logprobs_fake_inputs, strict=True) + partitioned_prog: EdgeProgramManager = to_edge_transform_and_lower( + aten_prog, + partitioner=[XnnpackFloatingPointPartitioner()], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _use_edge_ops=True, + ), + ) + + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: + exir.save(partitioned_prog.exported_program(), f.name) + f.seek(0) + loaded_model = exir.load(f.name) + + self.assertTrue( + torch.allclose( + model(*logprobs_fake_inputs), + loaded_model.module()(*logprobs_fake_inputs), + ) + )