Skip to content

Commit 1f8f2ea

Browse files
add more detailed comments
1 parent 8093d51 commit 1f8f2ea

File tree

1 file changed

+42
-7
lines changed

1 file changed

+42
-7
lines changed

keras/src/backend/openvino/core.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -844,26 +844,46 @@ def slice_update(inputs, start_indices, updates):
844844
updates_shape = ov_opset.shape_of(updates_tensor, Type.i32).output(0)
845845
rank = updates_tensor.get_partial_shape().rank.get_length()
846846

847+
# Compute the total number of elements in the updates tensor.
848+
# Example:
849+
# if updates.shape = [2, 3], total_elements = 6.
847850
total_elements = ov_opset.reduce_prod(
848851
updates_shape, zero_tensor, keep_dims=False
849852
).output(0)
850853

851-
# Create a single range for all indices
854+
# Generate a flat range [0, 1, ..., total_elements-1].
855+
# This will be used to enumerate all positions in the updates tensor.
852856
flat_indices = ov_opset.range(
853857
zero_scalar, total_elements, one_scalar, output_type=Type.i32
854858
).output(0)
855859

856860
dim_sizes = []
857861
strides = []
858862

863+
# For each dimension, compute its size and the stride.
864+
# (number of elements to skip to move to the next index in this dimension).
865+
# Example:
866+
# for shape [2, 3], strides = [3, 1].
859867
for dim in range(rank):
860868
dim_size = ov_opset.gather(
861869
updates_shape, ov_opset.constant([dim], Type.i32), zero_scalar
862870
).output(0)
863871
dim_size_scalar = ov_opset.squeeze(dim_size, zero_tensor).output(0)
864872
dim_sizes.append(dim_size_scalar)
865873

866-
# Compute stride (product of dimensions after current)
874+
# Strides to convert a flat index into a multi-dimensional index.
875+
# This allows us to map each element in the flattened updates tensor
876+
# to its correct N-dimensional position, so we can compute the absolute
877+
# index in the input tensor for the scatter update.
878+
# Stride for a dimension is the product of all dimensions after it.
879+
# For the last dimension, stride is 1.
880+
# Example:
881+
# For a 3D tensor with shape [2, 3, 4]:
882+
# - stride for dim=0 (first axis) is 3*4=12
883+
# (to move to the next "block" along axis 0)
884+
# - stride for dim=1 is 4 (to move to the next row along axis 1)
885+
# - stride for dim=2 is 1 (to move to the next element along axis 2)
886+
# This is equivalent to how numpy flattens multi-dimensional arrays.
867887
if dim < rank - 1:
868888
remaining_dims = ov_opset.slice(
869889
updates_shape,
@@ -880,8 +900,10 @@ def slice_update(inputs, start_indices, updates):
880900
strides.append(stride)
881901

882902
coord_tensors = []
903+
# For each dimension, compute the coordinate for every flat index.
904+
# Example:
905+
# for shape [2, 3], flat index 4 -> coordinates [1, 1] (row 1, col 1).
883906
for dim in range(rank):
884-
# Calculate coordinates for this dimension
885907
coords = ov_opset.mod(
886908
ov_opset.divide(flat_indices, strides[dim]).output(0),
887909
dim_sizes[dim],
@@ -890,19 +912,25 @@ def slice_update(inputs, start_indices, updates):
890912

891913
coord_tensors_unsqueezed = []
892914
for coord in coord_tensors:
915+
# Unsqueeze to make each coordinate a column vector for concatenation.
893916
coord_unsqueezed = ov_opset.unsqueeze(coord, one_tensor).output(0)
894917
coord_tensors_unsqueezed.append(coord_unsqueezed)
895918

896-
# Create index matrix [total_elements, rank] in one operation
919+
# Concatenate all coordinate columns to form [total_elements, rank] matrix.
920+
# Each row is a multi-dimensional index into the updates tensor.
921+
# Example:
922+
# for shape [2, 3], row 4 = [1, 1].
897923
indices_matrix = ov_opset.concat(coord_tensors_unsqueezed, axis=1).output(0)
898924

899-
# Broadcast start indices
925+
# Broadcast start indices to match the number of updates.
926+
# Example:
927+
# start_indices = (2, 3), indices_matrix = [[0,0],[0,1],...],
928+
# start_broadcast = [[2,3],[2,3],...]
900929
start_tensor = ov_opset.concat(processed_start_indices, axis=0).output(0)
901930
start_reshaped = ov_opset.reshape(
902931
start_tensor, ov_opset.constant([1, rank], Type.i32), special_zero=False
903932
).output(0)
904933

905-
# Broadcast to match indices matrix shape
906934
broadcast_shape = ov_opset.concat(
907935
[
908936
ov_opset.unsqueeze(total_elements, zero_tensor).output(0),
@@ -913,14 +941,21 @@ def slice_update(inputs, start_indices, updates):
913941

914942
start_broadcast = ov_opset.tile(start_reshaped, broadcast_shape).output(0)
915943

916-
# Add offset to get absolute indices
944+
# Add the broadcasted start indices to the relative indices
945+
# to get absolute indices in the input tensor.
946+
# Example:
947+
# if start=(2,3), update index [1,1] -> absolute index [3,4].
917948
absolute_indices = ov_opset.add(indices_matrix, start_broadcast).output(0)
918949

950+
# Flatten the updates tensor to match the flat indices.
919951
updates_flat = ov_opset.reshape(
920952
updates_tensor,
921953
ov_opset.unsqueeze(total_elements, zero_tensor).output(0),
922954
special_zero=False,
923955
).output(0)
956+
957+
# Perform the scatter update: for each absolute index,
958+
# set the corresponding value from updates_flat.
924959
result = ov_opset.scatter_nd_update(
925960
inputs, absolute_indices, updates_flat
926961
).output(0)

0 commit comments

Comments
 (0)