Skip to content

Fix delegate node metadata #12504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ def _get_dtype_distribution(
if node.op == "placeholder":
placeholder_dtypes.append(str(node.meta["val"].dtype))
if node.op == "call_function":
if "val" in node.meta:
if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor):
dtype, _, _ = extract_tensor_meta(node.meta, tosa_spec)
call_function_dtypes.append(ts.DTypeNames[dtype])
return Counter(placeholder_dtypes), Counter(call_function_dtypes)
Expand Down
6 changes: 3 additions & 3 deletions exir/backend/backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ def generate_debug_handle(ep: ExportedProgram) -> int:
call_submodule_node.kwargs,
)
call_delegate_node.meta["debug_handle"] = generate_debug_handle(owning_program)
call_delegate_node.meta["val"] = submodule_output_node.meta["val"]
call_delegate_node.meta["val"] = [
out_arg.meta["val"] for out_arg in submodule_output_node.args[0]
]
call_submodule_node.replace_all_uses_with(call_delegate_node)
owning_graph_module.graph.erase_node(call_submodule_node)
if is_submodule:
Expand Down Expand Up @@ -472,11 +474,9 @@ def _create_partitions_in_graph_module(
tagged_graph_module, node_list, tag
)

tagged_graph_module_output_node = tagged_graph_module.graph.output_node()
submodule_output_node = submodule.graph.output_node()
# Copy the output node meta from the original output node, because
# create_submodule_from_nodes doesn't cover the meta field
submodule_output_node.meta = tagged_graph_module_output_node.meta
logging.debug(f"Partitioned graph module: {tagged_graph_module}")
(
submodule_program,
Expand Down
58 changes: 57 additions & 1 deletion exir/backend/test/demos/test_xnnpack_qnnpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import tempfile
import unittest

from typing import Tuple

import executorch.exir as exir

import torch
Expand All @@ -20,7 +23,13 @@
# import the xnnpack backend implementation
from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend

from executorch.exir import CaptureConfig
from executorch.exir import (
CaptureConfig,
EdgeCompileConfig,
EdgeProgramManager,
to_edge_transform_and_lower,
)

from executorch.exir.backend.backend_api import to_backend, validation_disabled
from executorch.exir.passes.spec_prop_pass import SpecPropPass

Expand Down Expand Up @@ -132,3 +141,50 @@ def forward(self, x, y):
self.assertTrue(
torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03)
)

def test_serde(self):
# The module with blank_logprobs() function
class BlankLogProbsModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(768, 1)
self.log_sigmoid = torch.nn.LogSigmoid()

def forward(self, joint_encodings: torch.Tensor) -> torch.Tensor:
tanh_out = torch.tanh(joint_encodings)
linear_out = self.linear(tanh_out)
blank_output = self.log_sigmoid(linear_out)
return blank_output

def get_blank_logprobs_inputs_fn() -> Tuple[torch.Tensor, ...]:
"""
Get the input to the blank_logprobs() and nonblank_logprobs() functions.
"""
return (torch.randn(1, 1, 1, 768),)

model = BlankLogProbsModule()
# Get the inputs for the logprobs function
logprobs_fake_inputs = get_blank_logprobs_inputs_fn()

# Export and partition
aten_prog = torch.export.export(model, logprobs_fake_inputs, strict=True)
partitioned_prog: EdgeProgramManager = to_edge_transform_and_lower(
aten_prog,
partitioner=[XnnpackFloatingPointPartitioner()],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_use_edge_ops=True,
),
)

with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
exir.save(partitioned_prog.exported_program(), f.name)
f.seek(0)
loaded_model = exir.load(f.name)

self.assertTrue(
torch.allclose(
model(*logprobs_fake_inputs),
loaded_model.module()(*logprobs_fake_inputs),
)
)
Loading