Skip to content

Commit cb7812a

Browse files
fix dynamic shape handling
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 4e93a74 commit cb7812a

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

keras/src/export/openvino.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,14 @@ def parameterize_inputs(inputs, prefix=""):
8080
parameters = [p.output.get_node() for p in tree.flatten(params)]
8181
results = [ov_opset.result(r.output) for r in tree.flatten(outputs)]
8282
ov_model = ov.Model(results=results, parameters=parameters)
83-
for param in ov_model.inputs:
84-
rank = len(param.get_partial_shape())
85-
dynamic_shape = ov.PartialShape([-1] * rank)
86-
param.get_node().set_partial_shape(dynamic_shape)
83+
flat_specs = tree.flatten(input_signature)
84+
for ov_input, spec in zip(ov_model.inputs, flat_specs):
85+
# Respect the dynamic axes from the original input signature.
86+
dynamic_shape_dims = [
87+
-1 if dim is None else dim for dim in spec.shape
88+
]
89+
dynamic_shape = ov.PartialShape(dynamic_shape_dims)
90+
ov_input.get_node().set_partial_shape(dynamic_shape)
8791

8892
elif backend.backend() == "tensorflow":
8993
import tempfile

0 commit comments

Comments
 (0)