Skip to content

Commit 6e0aecd

Browse files
[OpenVINO backend] handle scalar updates for slice_update
1 parent 1f8f2ea commit 6e0aecd

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

keras/src/backend/openvino/core.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,22 @@ def slice_update(inputs, start_indices, updates):
843843

844844
updates_shape = ov_opset.shape_of(updates_tensor, Type.i32).output(0)
845845
rank = updates_tensor.get_partial_shape().rank.get_length()
846+
if rank == 0:
847+
# Handle scalar update
848+
start_tensor = ov_opset.concat(processed_start_indices, axis=0).output(
849+
0
850+
)
851+
# For scatter_nd_update,
852+
# indices should be of shape [num_updates, rank_of_inputs]
853+
# and updates should be of shape [num_updates]. Here num_updates is 1.
854+
absolute_indices = ov_opset.unsqueeze(start_tensor, zero_scalar).output(
855+
0
856+
)
857+
updates_flat = ov_opset.unsqueeze(updates_tensor, zero_scalar).output(0)
858+
result = ov_opset.scatter_nd_update(
859+
inputs, absolute_indices, updates_flat
860+
).output(0)
861+
return OpenVINOKerasTensor(result)
846862

847863
# Compute the total number of elements in the updates tensor.
848864
# Example:

0 commit comments

Comments
 (0)