Skip to content

Commit 01e1805

Browse files
angelayifacebook-github-bot
authored andcommitted
Fix delegate node metadata (#12504)
Summary: The delegate node's metadata was set incorrectly, causing deserialization to fail Reviewed By: mcr229 Differential Revision: D78350040
1 parent a8070ec commit 01e1805

File tree

2 files changed

+59
-5
lines changed

2 files changed

+59
-5
lines changed

exir/backend/backend_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,9 @@ def generate_debug_handle(ep: ExportedProgram) -> int:
235235
call_submodule_node.kwargs,
236236
)
237237
call_delegate_node.meta["debug_handle"] = generate_debug_handle(owning_program)
238-
call_delegate_node.meta["val"] = submodule_output_node.meta["val"]
238+
call_delegate_node.meta["val"] = [
239+
out_arg.meta["val"] for out_arg in submodule_output_node.args[0]
240+
]
239241
call_submodule_node.replace_all_uses_with(call_delegate_node)
240242
owning_graph_module.graph.erase_node(call_submodule_node)
241243
if is_submodule:
@@ -472,11 +474,9 @@ def _create_partitions_in_graph_module(
472474
tagged_graph_module, node_list, tag
473475
)
474476

475-
tagged_graph_module_output_node = tagged_graph_module.graph.output_node()
476477
submodule_output_node = submodule.graph.output_node()
477478
# Copy the output node meta from the original output node, because
478479
# create_submodule_from_nodes doesn't cover the meta field
479-
submodule_output_node.meta = tagged_graph_module_output_node.meta
480480
logging.debug(f"Partitioned graph module: {tagged_graph_module}")
481481
(
482482
submodule_program,

exir/backend/test/demos/test_xnnpack_qnnpack.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import tempfile
78
import unittest
89

10+
from typing import Tuple
11+
912
import executorch.exir as exir
1013

1114
import torch
@@ -20,7 +23,13 @@
2023
# import the xnnpack backend implementation
2124
from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend
2225

23-
from executorch.exir import CaptureConfig
26+
from executorch.exir import (
27+
CaptureConfig,
28+
EdgeCompileConfig,
29+
EdgeProgramManager,
30+
to_edge_transform_and_lower,
31+
)
32+
2433
from executorch.exir.backend.backend_api import to_backend, validation_disabled
2534
from executorch.exir.passes.spec_prop_pass import SpecPropPass
2635

@@ -41,7 +50,6 @@
4150
prepare_fx,
4251
)
4352

44-
4553
class TestXnnQnnBackends(unittest.TestCase):
4654
def test_add_xnnpack_and_dqlinear_qnn(self):
4755
qconfig_mapping = QConfigMapping().set_object_type(
@@ -132,3 +140,49 @@ def forward(self, x, y):
132140
self.assertTrue(
133141
torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03)
134142
)
143+
144+
def test_serde(self):
145+
# The module with blank_logprobs() function
146+
class BlankLogProbsModule(torch.nn.Module):
147+
def __init__(self) -> None:
148+
super().__init__()
149+
self.linear = torch.nn.Linear(768, 1)
150+
self.log_sigmoid = torch.nn.LogSigmoid()
151+
152+
def forward(self, joint_encodings: torch.Tensor) -> torch.Tensor:
153+
tanh_out = torch.tanh(joint_encodings)
154+
linear_out = self.linear(tanh_out)
155+
blank_output = self.log_sigmoid(linear_out)
156+
return blank_output
157+
158+
def get_blank_logprobs_inputs_fn() -> Tuple[torch.Tensor, ...]:
159+
"""
160+
Get the input to the blank_logprobs() and nonblank_logprobs() functions.
161+
"""
162+
return (torch.randn(1, 1, 1, 768),)
163+
164+
model = BlankLogProbsModule()
165+
# Get the inputs for the logprobs function
166+
logprobs_fake_inputs = get_blank_logprobs_inputs_fn()
167+
168+
# Export and partition
169+
aten_prog = torch.export.export(model, logprobs_fake_inputs, strict=True)
170+
partitioned_prog: EdgeProgramManager = to_edge_transform_and_lower(
171+
aten_prog,
172+
partitioner=[XnnpackFloatingPointPartitioner()],
173+
compile_config=EdgeCompileConfig(
174+
_check_ir_validity=False, _use_edge_ops=True,
175+
),
176+
)
177+
178+
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
179+
exir.save(partitioned_prog.exported_program(), f.name)
180+
f.seek(0)
181+
loaded_model = exir.load(f.name)
182+
183+
self.assertTrue(
184+
torch.allclose(
185+
model(*logprobs_fake_inputs),
186+
loaded_model.module()(*logprobs_fake_inputs),
187+
)
188+
)

0 commit comments

Comments
 (0)