From 4d7f6b1e0b2bf961a01f2d9aa35786111185e3ab Mon Sep 17 00:00:00 2001 From: Michael Adragna Date: Tue, 15 Jul 2025 15:02:45 -0700 Subject: [PATCH] Fix Transpose Optimization Bug With non-4D Tensor Input (#12520) Summary: Fixed a bug in the channels last tagged reshape pass, where non-4d inputs were being tagged as contiguous/channels last memory formats, which isn't expected as these formats only apply to 4d tensors. The repro is in N7569847. The fix was completed by checking tensor shape size before tagging input nodes. Reviewed By: mcr229 Differential Revision: D78357428 --- .../xnnpack/_passes/channels_last_tagged_reshape_pass.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py index 36307cbf4e7..85e9889ca36 100644 --- a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py +++ b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py @@ -457,7 +457,11 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901 for node in original_nodes: if len(node.all_input_nodes) == 0: # This node has no inputs so we don't need to change anything, but still need to tag input nodes - if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor): + if ( + "val" in node.meta + and isinstance(node.meta["val"], torch.Tensor) + and len(node.meta["val"].shape) == 4 + ): if node.meta["val"].is_contiguous(): self.mark_as_nchw_node(node) else: