diff --git a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py index 36307cbf4e7..85e9889ca36 100644 --- a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py +++ b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py @@ -457,7 +457,11 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901 for node in original_nodes: if len(node.all_input_nodes) == 0: # This node has no inputs so we don't need to change anything, but still need to tag input nodes - if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor): + if ( + "val" in node.meta + and isinstance(node.meta["val"], torch.Tensor) + and len(node.meta["val"].shape) == 4 + ): if node.meta["val"].is_contiguous(): self.mark_as_nchw_node(node) else: