@@ -844,26 +844,46 @@ def slice_update(inputs, start_indices, updates):
844
844
updates_shape = ov_opset .shape_of (updates_tensor , Type .i32 ).output (0 )
845
845
rank = updates_tensor .get_partial_shape ().rank .get_length ()
846
846
847
+ # Compute the total number of elements in the updates tensor.
848
+ # Example:
849
+ # if updates.shape = [2, 3], total_elements = 6.
847
850
total_elements = ov_opset .reduce_prod (
848
851
updates_shape , zero_tensor , keep_dims = False
849
852
).output (0 )
850
853
851
- # Create a single range for all indices
854
+ # Generate a flat range [0, 1, ..., total_elements-1].
855
+ # This will be used to enumerate all positions in the updates tensor.
852
856
flat_indices = ov_opset .range (
853
857
zero_scalar , total_elements , one_scalar , output_type = Type .i32
854
858
).output (0 )
855
859
856
860
dim_sizes = []
857
861
strides = []
858
862
863
+ # For each dimension, compute its size and the stride.
864
+ # (number of elements to skip to move to the next index in this dimension).
865
+ # Example:
866
+ # for shape [2, 3], strides = [3, 1].
859
867
for dim in range (rank ):
860
868
dim_size = ov_opset .gather (
861
869
updates_shape , ov_opset .constant ([dim ], Type .i32 ), zero_scalar
862
870
).output (0 )
863
871
dim_size_scalar = ov_opset .squeeze (dim_size , zero_tensor ).output (0 )
864
872
dim_sizes .append (dim_size_scalar )
865
873
866
- # Compute stride (product of dimensions after current)
874
+ # Strides to convert a flat index into a multi-dimensional index.
875
+ # This allows us to map each element in the flattened updates tensor
876
+ # to its correct N-dimensional position, so we can compute the absolute
877
+ # index in the input tensor for the scatter update.
878
+ # Stride for a dimension is the product of all dimensions after it.
879
+ # For the last dimension, stride is 1.
880
+ # Example:
881
+ # For a 3D tensor with shape [2, 3, 4]:
882
+ # - stride for dim=0 (first axis) is 3*4=12
883
+ # (to move to the next "block" along axis 0)
884
+ # - stride for dim=1 is 4 (to move to the next row along axis 1)
885
+ # - stride for dim=2 is 1 (to move to the next element along axis 2)
886
+ # This is equivalent to how numpy flattens multi-dimensional arrays.
867
887
if dim < rank - 1 :
868
888
remaining_dims = ov_opset .slice (
869
889
updates_shape ,
@@ -880,8 +900,10 @@ def slice_update(inputs, start_indices, updates):
880
900
strides .append (stride )
881
901
882
902
coord_tensors = []
903
+ # For each dimension, compute the coordinate for every flat index.
904
+ # Example:
905
+ # for shape [2, 3], flat index 4 -> coordinates [1, 1] (row 1, col 1).
883
906
for dim in range (rank ):
884
- # Calculate coordinates for this dimension
885
907
coords = ov_opset .mod (
886
908
ov_opset .divide (flat_indices , strides [dim ]).output (0 ),
887
909
dim_sizes [dim ],
@@ -890,19 +912,25 @@ def slice_update(inputs, start_indices, updates):
890
912
891
913
coord_tensors_unsqueezed = []
892
914
for coord in coord_tensors :
915
+ # Unsqueeze to make each coordinate a column vector for concatenation.
893
916
coord_unsqueezed = ov_opset .unsqueeze (coord , one_tensor ).output (0 )
894
917
coord_tensors_unsqueezed .append (coord_unsqueezed )
895
918
896
- # Create index matrix [total_elements, rank] in one operation
919
+ # Concatenate all coordinate columns to form [total_elements, rank] matrix.
920
+ # Each row is a multi-dimensional index into the updates tensor.
921
+ # Example:
922
+ # for shape [2, 3], row 4 = [1, 1].
897
923
indices_matrix = ov_opset .concat (coord_tensors_unsqueezed , axis = 1 ).output (0 )
898
924
899
- # Broadcast start indices
925
+ # Broadcast start indices to match the number of updates.
926
+ # Example:
927
+ # start_indices = (2, 3), indices_matrix = [[0,0],[0,1],...],
928
+ # start_broadcast = [[2,3],[2,3],...]
900
929
start_tensor = ov_opset .concat (processed_start_indices , axis = 0 ).output (0 )
901
930
start_reshaped = ov_opset .reshape (
902
931
start_tensor , ov_opset .constant ([1 , rank ], Type .i32 ), special_zero = False
903
932
).output (0 )
904
933
905
- # Broadcast to match indices matrix shape
906
934
broadcast_shape = ov_opset .concat (
907
935
[
908
936
ov_opset .unsqueeze (total_elements , zero_tensor ).output (0 ),
@@ -913,14 +941,21 @@ def slice_update(inputs, start_indices, updates):
913
941
914
942
start_broadcast = ov_opset .tile (start_reshaped , broadcast_shape ).output (0 )
915
943
916
- # Add offset to get absolute indices
944
+ # Add the broadcasted start indices to the relative indices
945
+ # to get absolute indices in the input tensor.
946
+ # Example:
947
+ # if start=(2,3), update index [1,1] -> absolute index [3,4].
917
948
absolute_indices = ov_opset .add (indices_matrix , start_broadcast ).output (0 )
918
949
950
+ # Flatten the updates tensor to match the flat indices.
919
951
updates_flat = ov_opset .reshape (
920
952
updates_tensor ,
921
953
ov_opset .unsqueeze (total_elements , zero_tensor ).output (0 ),
922
954
special_zero = False ,
923
955
).output (0 )
956
+
957
+ # Perform the scatter update: for each absolute index,
958
+ # set the corresponding value from updates_flat.
924
959
result = ov_opset .scatter_nd_update (
925
960
inputs , absolute_indices , updates_flat
926
961
).output (0 )
0 commit comments