diff --git a/backends/arm/_passes/cast_int64_pass.py b/backends/arm/_passes/cast_int64_pass.py index 0cdd0422b61..8052c8fd2ce 100644 --- a/backends/arm/_passes/cast_int64_pass.py +++ b/backends/arm/_passes/cast_int64_pass.py @@ -47,7 +47,7 @@ def _to_int32(self, graph_module: torch.fx.GraphModule): buffer_name = self.exported_program.graph_signature.inputs_to_buffers[ node.name ] - buffer = self.exported_program.state_dict[node.name] + buffer = self.exported_program.state_dict[buffer_name] self._assert_within_int32(buffer, node) logger.warning( f"Casting buffer {node.name} from torch.int64 to torch.int32"