Skip to content

Commit 07c8f0f

Browse files
Arm backend: Change node.name to buffer name when indexing state_dict (#12485)
Using node.name for indexing state_dict can result in the following: `KeyError: 'b__tensor_constant1'`. Using the following to fetch the buffer name is the correct way to do it: ``` buffer_name = self.exported_program.graph_signature.inputs_to_buffers[ node.name ] ``` Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 5128307 commit 07c8f0f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

backends/arm/_passes/cast_int64_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _to_int32(self, graph_module: torch.fx.GraphModule):
4747
buffer_name = self.exported_program.graph_signature.inputs_to_buffers[
4848
node.name
4949
]
50-
buffer = self.exported_program.state_dict[node.name]
50+
buffer = self.exported_program.state_dict[buffer_name]
5151
self._assert_within_int32(buffer, node)
5252
logger.warning(
5353
f"Casting buffer {node.name} from torch.int64 to torch.int32"

0 commit comments

Comments
 (0)