Skip to content

Commit 882e806

Browse files
Merge branch 'support_slice_update' into gsoc2025
2 parents 4e4cead + 6e0aecd commit 882e806

File tree

2 files changed

+160
-5
lines changed

2 files changed

+160
-5
lines changed

keras/src/backend/openvino/core.py

Lines changed: 160 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -842,10 +842,168 @@ def prepare_slice_index(val):
842842

843843

844844
def slice_update(inputs, start_indices, updates):
845-
raise NotImplementedError(
846-
"`slice_update` is not supported with openvino backend"
845+
inputs = get_ov_output(inputs)
846+
updates_tensor = get_ov_output(updates)
847+
848+
if isinstance(start_indices, (list, np.ndarray)):
849+
start_indices = tuple(start_indices)
850+
assert isinstance(start_indices, tuple), (
851+
"`slice_update` is not supported by openvino backend"
852+
" for `start_indices` of type {}".format(type(start_indices))
847853
)
848854

855+
zero_scalar = ov_opset.constant(0, Type.i32)
856+
one_scalar = ov_opset.constant(1, Type.i32)
857+
zero_tensor = ov_opset.constant([0], Type.i32)
858+
one_tensor = ov_opset.constant([1], Type.i32)
859+
860+
processed_start_indices = []
861+
for idx in start_indices:
862+
val = get_ov_output(idx)
863+
if not val.get_element_type().is_integral():
864+
raise ValueError("`slice_update` requires integral start_indices")
865+
if val.get_element_type() != Type.i32:
866+
val = ov_opset.convert(val, Type.i32).output(0)
867+
if val.get_partial_shape().rank.get_length() == 0:
868+
val = ov_opset.unsqueeze(val, zero_scalar).output(0)
869+
processed_start_indices.append(val)
870+
871+
updates_shape = ov_opset.shape_of(updates_tensor, Type.i32).output(0)
872+
rank = updates_tensor.get_partial_shape().rank.get_length()
873+
if rank == 0:
874+
# Handle scalar update
875+
start_tensor = ov_opset.concat(processed_start_indices, axis=0).output(
876+
0
877+
)
878+
# For scatter_nd_update,
879+
# indices should be of shape [num_updates, rank_of_inputs]
880+
# and updates should be of shape [num_updates]. Here num_updates is 1.
881+
absolute_indices = ov_opset.unsqueeze(start_tensor, zero_scalar).output(
882+
0
883+
)
884+
updates_flat = ov_opset.unsqueeze(updates_tensor, zero_scalar).output(0)
885+
result = ov_opset.scatter_nd_update(
886+
inputs, absolute_indices, updates_flat
887+
).output(0)
888+
return OpenVINOKerasTensor(result)
889+
890+
# Compute the total number of elements in the updates tensor.
891+
# Example:
892+
# if updates.shape = [2, 3], total_elements = 6.
893+
total_elements = ov_opset.reduce_prod(
894+
updates_shape, zero_tensor, keep_dims=False
895+
).output(0)
896+
897+
# Generate a flat range [0, 1, ..., total_elements-1].
898+
# This will be used to enumerate all positions in the updates tensor.
899+
flat_indices = ov_opset.range(
900+
zero_scalar, total_elements, one_scalar, output_type=Type.i32
901+
).output(0)
902+
903+
dim_sizes = []
904+
strides = []
905+
906+
# For each dimension, compute its size and the stride.
907+
# (number of elements to skip to move to the next index in this dimension).
908+
# Example:
909+
# for shape [2, 3], strides = [3, 1].
910+
for dim in range(rank):
911+
dim_size = ov_opset.gather(
912+
updates_shape, ov_opset.constant([dim], Type.i32), zero_scalar
913+
).output(0)
914+
dim_size_scalar = ov_opset.squeeze(dim_size, zero_tensor).output(0)
915+
dim_sizes.append(dim_size_scalar)
916+
917+
# Strides to convert a flat index into a multi-dimensional index.
918+
# This allows us to map each element in the flattened updates tensor
919+
# to its correct N-dimensional position, so we can compute the absolute
920+
# index in the input tensor for the scatter update.
921+
# Stride for a dimension is the product of all dimensions after it.
922+
# For the last dimension, stride is 1.
923+
# Example:
924+
# For a 3D tensor with shape [2, 3, 4]:
925+
# - stride for dim=0 (first axis) is 3*4=12
926+
# (to move to the next "block" along axis 0)
927+
# - stride for dim=1 is 4 (to move to the next row along axis 1)
928+
# - stride for dim=2 is 1 (to move to the next element along axis 2)
929+
# This is equivalent to how numpy flattens multi-dimensional arrays.
930+
if dim < rank - 1:
931+
remaining_dims = ov_opset.slice(
932+
updates_shape,
933+
ov_opset.constant([dim + 1], Type.i32),
934+
ov_opset.constant([rank], Type.i32),
935+
one_tensor,
936+
zero_tensor,
937+
).output(0)
938+
stride = ov_opset.reduce_prod(
939+
remaining_dims, zero_tensor, keep_dims=False
940+
).output(0)
941+
else:
942+
stride = one_scalar
943+
strides.append(stride)
944+
945+
coord_tensors = []
946+
# For each dimension, compute the coordinate for every flat index.
947+
# Example:
948+
# for shape [2, 3], flat index 4 -> coordinates [1, 1] (row 1, col 1).
949+
for dim in range(rank):
950+
coords = ov_opset.mod(
951+
ov_opset.divide(flat_indices, strides[dim]).output(0),
952+
dim_sizes[dim],
953+
).output(0)
954+
coord_tensors.append(coords)
955+
956+
coord_tensors_unsqueezed = []
957+
for coord in coord_tensors:
958+
# Unsqueeze to make each coordinate a column vector for concatenation.
959+
coord_unsqueezed = ov_opset.unsqueeze(coord, one_tensor).output(0)
960+
coord_tensors_unsqueezed.append(coord_unsqueezed)
961+
962+
# Concatenate all coordinate columns to form [total_elements, rank] matrix.
963+
# Each row is a multi-dimensional index into the updates tensor.
964+
# Example:
965+
# for shape [2, 3], row 4 = [1, 1].
966+
indices_matrix = ov_opset.concat(coord_tensors_unsqueezed, axis=1).output(0)
967+
968+
# Broadcast start indices to match the number of updates.
969+
# Example:
970+
# start_indices = (2, 3), indices_matrix = [[0,0],[0,1],...],
971+
# start_broadcast = [[2,3],[2,3],...]
972+
start_tensor = ov_opset.concat(processed_start_indices, axis=0).output(0)
973+
start_reshaped = ov_opset.reshape(
974+
start_tensor, ov_opset.constant([1, rank], Type.i32), special_zero=False
975+
).output(0)
976+
977+
broadcast_shape = ov_opset.concat(
978+
[
979+
ov_opset.unsqueeze(total_elements, zero_tensor).output(0),
980+
one_tensor,
981+
],
982+
axis=0,
983+
).output(0)
984+
985+
start_broadcast = ov_opset.tile(start_reshaped, broadcast_shape).output(0)
986+
987+
# Add the broadcasted start indices to the relative indices
988+
# to get absolute indices in the input tensor.
989+
# Example:
990+
# if start=(2,3), update index [1,1] -> absolute index [3,4].
991+
absolute_indices = ov_opset.add(indices_matrix, start_broadcast).output(0)
992+
993+
# Flatten the updates tensor to match the flat indices.
994+
updates_flat = ov_opset.reshape(
995+
updates_tensor,
996+
ov_opset.unsqueeze(total_elements, zero_tensor).output(0),
997+
special_zero=False,
998+
).output(0)
999+
1000+
# Perform the scatter update: for each absolute index,
1001+
# set the corresponding value from updates_flat.
1002+
result = ov_opset.scatter_nd_update(
1003+
inputs, absolute_indices, updates_flat
1004+
).output(0)
1005+
return OpenVINOKerasTensor(result)
1006+
8491007

8501008
def while_loop(
8511009
cond,

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,17 +171,14 @@ CoreOpsCallsTests::test_map_basic_call
171171
CoreOpsCallsTests::test_scan_basic_call
172172
CoreOpsCallsTests::test_scatter_basic_call
173173
CoreOpsCallsTests::test_scatter_update_basic_call
174-
CoreOpsCallsTests::test_slice_update_basic_call
175174
CoreOpsCallsTests::test_switch_basic_call
176175
CoreOpsCallsTests::test_unstack_basic_functionality
177176
CoreOpsCorrectnessTest::test_associative_scan
178177
CoreOpsCorrectnessTest::test_cond
179-
CoreOpsCorrectnessTest::test_dynamic_slice
180178
CoreOpsCorrectnessTest::test_fori_loop
181179
CoreOpsCorrectnessTest::test_map
182180
CoreOpsCorrectnessTest::test_scan
183181
CoreOpsCorrectnessTest::test_scatter
184-
CoreOpsCorrectnessTest::test_slice_update
185182
CoreOpsCorrectnessTest::test_switch
186183
CoreOpsCorrectnessTest::test_unstack
187184
CoreOpsCorrectnessTest::test_vectorized_map

0 commit comments

Comments
 (0)