@@ -842,10 +842,168 @@ def prepare_slice_index(val):
842
842
843
843
844
844
def slice_update (inputs , start_indices , updates ):
845
- raise NotImplementedError (
846
- "`slice_update` is not supported with openvino backend"
845
+ inputs = get_ov_output (inputs )
846
+ updates_tensor = get_ov_output (updates )
847
+
848
+ if isinstance (start_indices , (list , np .ndarray )):
849
+ start_indices = tuple (start_indices )
850
+ assert isinstance (start_indices , tuple ), (
851
+ "`slice_update` is not supported by openvino backend"
852
+ " for `start_indices` of type {}" .format (type (start_indices ))
847
853
)
848
854
855
+ zero_scalar = ov_opset .constant (0 , Type .i32 )
856
+ one_scalar = ov_opset .constant (1 , Type .i32 )
857
+ zero_tensor = ov_opset .constant ([0 ], Type .i32 )
858
+ one_tensor = ov_opset .constant ([1 ], Type .i32 )
859
+
860
+ processed_start_indices = []
861
+ for idx in start_indices :
862
+ val = get_ov_output (idx )
863
+ if not val .get_element_type ().is_integral ():
864
+ raise ValueError ("`slice_update` requires integral start_indices" )
865
+ if val .get_element_type () != Type .i32 :
866
+ val = ov_opset .convert (val , Type .i32 ).output (0 )
867
+ if val .get_partial_shape ().rank .get_length () == 0 :
868
+ val = ov_opset .unsqueeze (val , zero_scalar ).output (0 )
869
+ processed_start_indices .append (val )
870
+
871
+ updates_shape = ov_opset .shape_of (updates_tensor , Type .i32 ).output (0 )
872
+ rank = updates_tensor .get_partial_shape ().rank .get_length ()
873
+ if rank == 0 :
874
+ # Handle scalar update
875
+ start_tensor = ov_opset .concat (processed_start_indices , axis = 0 ).output (
876
+ 0
877
+ )
878
+ # For scatter_nd_update,
879
+ # indices should be of shape [num_updates, rank_of_inputs]
880
+ # and updates should be of shape [num_updates]. Here num_updates is 1.
881
+ absolute_indices = ov_opset .unsqueeze (start_tensor , zero_scalar ).output (
882
+ 0
883
+ )
884
+ updates_flat = ov_opset .unsqueeze (updates_tensor , zero_scalar ).output (0 )
885
+ result = ov_opset .scatter_nd_update (
886
+ inputs , absolute_indices , updates_flat
887
+ ).output (0 )
888
+ return OpenVINOKerasTensor (result )
889
+
890
+ # Compute the total number of elements in the updates tensor.
891
+ # Example:
892
+ # if updates.shape = [2, 3], total_elements = 6.
893
+ total_elements = ov_opset .reduce_prod (
894
+ updates_shape , zero_tensor , keep_dims = False
895
+ ).output (0 )
896
+
897
+ # Generate a flat range [0, 1, ..., total_elements-1].
898
+ # This will be used to enumerate all positions in the updates tensor.
899
+ flat_indices = ov_opset .range (
900
+ zero_scalar , total_elements , one_scalar , output_type = Type .i32
901
+ ).output (0 )
902
+
903
+ dim_sizes = []
904
+ strides = []
905
+
906
+ # For each dimension, compute its size and the stride.
907
+ # (number of elements to skip to move to the next index in this dimension).
908
+ # Example:
909
+ # for shape [2, 3], strides = [3, 1].
910
+ for dim in range (rank ):
911
+ dim_size = ov_opset .gather (
912
+ updates_shape , ov_opset .constant ([dim ], Type .i32 ), zero_scalar
913
+ ).output (0 )
914
+ dim_size_scalar = ov_opset .squeeze (dim_size , zero_tensor ).output (0 )
915
+ dim_sizes .append (dim_size_scalar )
916
+
917
+ # Strides to convert a flat index into a multi-dimensional index.
918
+ # This allows us to map each element in the flattened updates tensor
919
+ # to its correct N-dimensional position, so we can compute the absolute
920
+ # index in the input tensor for the scatter update.
921
+ # Stride for a dimension is the product of all dimensions after it.
922
+ # For the last dimension, stride is 1.
923
+ # Example:
924
+ # For a 3D tensor with shape [2, 3, 4]:
925
+ # - stride for dim=0 (first axis) is 3*4=12
926
+ # (to move to the next "block" along axis 0)
927
+ # - stride for dim=1 is 4 (to move to the next row along axis 1)
928
+ # - stride for dim=2 is 1 (to move to the next element along axis 2)
929
+ # This is equivalent to how numpy flattens multi-dimensional arrays.
930
+ if dim < rank - 1 :
931
+ remaining_dims = ov_opset .slice (
932
+ updates_shape ,
933
+ ov_opset .constant ([dim + 1 ], Type .i32 ),
934
+ ov_opset .constant ([rank ], Type .i32 ),
935
+ one_tensor ,
936
+ zero_tensor ,
937
+ ).output (0 )
938
+ stride = ov_opset .reduce_prod (
939
+ remaining_dims , zero_tensor , keep_dims = False
940
+ ).output (0 )
941
+ else :
942
+ stride = one_scalar
943
+ strides .append (stride )
944
+
945
+ coord_tensors = []
946
+ # For each dimension, compute the coordinate for every flat index.
947
+ # Example:
948
+ # for shape [2, 3], flat index 4 -> coordinates [1, 1] (row 1, col 1).
949
+ for dim in range (rank ):
950
+ coords = ov_opset .mod (
951
+ ov_opset .divide (flat_indices , strides [dim ]).output (0 ),
952
+ dim_sizes [dim ],
953
+ ).output (0 )
954
+ coord_tensors .append (coords )
955
+
956
+ coord_tensors_unsqueezed = []
957
+ for coord in coord_tensors :
958
+ # Unsqueeze to make each coordinate a column vector for concatenation.
959
+ coord_unsqueezed = ov_opset .unsqueeze (coord , one_tensor ).output (0 )
960
+ coord_tensors_unsqueezed .append (coord_unsqueezed )
961
+
962
+ # Concatenate all coordinate columns to form [total_elements, rank] matrix.
963
+ # Each row is a multi-dimensional index into the updates tensor.
964
+ # Example:
965
+ # for shape [2, 3], row 4 = [1, 1].
966
+ indices_matrix = ov_opset .concat (coord_tensors_unsqueezed , axis = 1 ).output (0 )
967
+
968
+ # Broadcast start indices to match the number of updates.
969
+ # Example:
970
+ # start_indices = (2, 3), indices_matrix = [[0,0],[0,1],...],
971
+ # start_broadcast = [[2,3],[2,3],...]
972
+ start_tensor = ov_opset .concat (processed_start_indices , axis = 0 ).output (0 )
973
+ start_reshaped = ov_opset .reshape (
974
+ start_tensor , ov_opset .constant ([1 , rank ], Type .i32 ), special_zero = False
975
+ ).output (0 )
976
+
977
+ broadcast_shape = ov_opset .concat (
978
+ [
979
+ ov_opset .unsqueeze (total_elements , zero_tensor ).output (0 ),
980
+ one_tensor ,
981
+ ],
982
+ axis = 0 ,
983
+ ).output (0 )
984
+
985
+ start_broadcast = ov_opset .tile (start_reshaped , broadcast_shape ).output (0 )
986
+
987
+ # Add the broadcasted start indices to the relative indices
988
+ # to get absolute indices in the input tensor.
989
+ # Example:
990
+ # if start=(2,3), update index [1,1] -> absolute index [3,4].
991
+ absolute_indices = ov_opset .add (indices_matrix , start_broadcast ).output (0 )
992
+
993
+ # Flatten the updates tensor to match the flat indices.
994
+ updates_flat = ov_opset .reshape (
995
+ updates_tensor ,
996
+ ov_opset .unsqueeze (total_elements , zero_tensor ).output (0 ),
997
+ special_zero = False ,
998
+ ).output (0 )
999
+
1000
+ # Perform the scatter update: for each absolute index,
1001
+ # set the corresponding value from updates_flat.
1002
+ result = ov_opset .scatter_nd_update (
1003
+ inputs , absolute_indices , updates_flat
1004
+ ).output (0 )
1005
+ return OpenVINOKerasTensor (result )
1006
+
849
1007
850
1008
def while_loop (
851
1009
cond ,
0 commit comments