Skip to content

Simulated OpenVINO Backend for Testing Unmerged PR Features with Memory Profiling #21491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
aefa32f
[OpenVINO backend] support repeat
Mohamed-Ashraf273 Jul 1, 2025
4b66e60
[OpenVINO backend] support tri, triu, and tril
Mohamed-Ashraf273 Jul 14, 2025
8093d51
[OpenVINO backend] support slice_update
Mohamed-Ashraf273 Jul 14, 2025
75647be
[OpenVINO backend] add __array__ method
Mohamed-Ashraf273 Jul 14, 2025
950bdeb
[OpenVINO backend] support categorical
Mohamed-Ashraf273 Jul 15, 2025
850d003
[OpenVINO backend] suppor export model using openvino format
Mohamed-Ashraf273 Jul 16, 2025
a7e07f3
add export for openvino backend
Mohamed-Ashraf273 Jul 16, 2025
4e93a74
adding tests for openvino export format
Mohamed-Ashraf273 Jul 16, 2025
cb7812a
fix dynamic shape handling
Mohamed-Ashraf273 Jul 16, 2025
51ca2cd
[OpenVINO backend] add supporting for lists and tuples
Mohamed-Ashraf273 Jul 21, 2025
f3a2468
[OpenVINO backend] fix_transpose
Mohamed-Ashraf273 Jul 21, 2025
1d713ba
add more detailed comments
Mohamed-Ashraf273 Jul 21, 2025
f93d188
Merge branch 'support_slice_update' into gsoc2025
Mohamed-Ashraf273 Jul 21, 2025
87002b6
Merge branch 'support_triu' into gsoc2025
Mohamed-Ashraf273 Jul 21, 2025
94c27ad
Merge branch 'support_export' into gsoc2025
Mohamed-Ashraf273 Jul 21, 2025
b034524
Merge branch 'support_categorical' into gsoc2025
Mohamed-Ashraf273 Jul 21, 2025
cc087e7
Merge branch 'support_repeat' into gsoc2025
Mohamed-Ashraf273 Jul 21, 2025
b96aaef
Merge branch 'update_get_ov_output' into gsoc2025
Mohamed-Ashraf273 Jul 21, 2025
6139ec9
Merge branch 'add_array_method' into gsoc2025
Mohamed-Ashraf273 Jul 21, 2025
f0ca994
Merge branch 'fix_transpose' into gsoc2025
Mohamed-Ashraf273 Jul 21, 2025
ee5f734
[OpenVINO backend] fix numpy conversions
Mohamed-Ashraf273 Jul 21, 2025
d44de26
fix typo
Mohamed-Ashraf273 Jul 21, 2025
980c09d
add testing files
Mohamed-Ashraf273 Jul 22, 2025
1f8f2ea
add more detailed comments
Mohamed-Ashraf273 Jul 22, 2025
cc86259
Merge branch 'support_slice_update' into gsoc2025
Mohamed-Ashraf273 Jul 22, 2025
8bc66a1
add support for jax backend and avoid to load models on disc for open…
Mohamed-Ashraf273 Jul 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 173 additions & 4 deletions keras/src/backend/openvino/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ def get_ov_output(x, ov_type=None):
x = ov_opset.constant(x, OPENVINO_DTYPES["bfloat16"]).output(0)
else:
x = ov_opset.constant(x).output(0)
elif isinstance(x, (list, tuple)):
if isinstance(x, tuple):
x = list(x)
if ov_type is None:
x = ov_opset.constant(x).output(0)
else:
x = ov_opset.constant(x, ov_type).output(0)
elif np.isscalar(x):
x = ov_opset.constant(x).output(0)
elif isinstance(x, KerasVariable):
Expand Down Expand Up @@ -492,6 +499,21 @@ def __mod__(self, other):
)
return OpenVINOKerasTensor(ov_opset.mod(first, other).output(0))

def __array__(self, dtype=None):
try:
tensor = cast(self, dtype=dtype) if dtype is not None else self
return convert_to_numpy(tensor)
except Exception as e:
raise RuntimeError(
"An OpenVINOKerasTensor is symbolic: it's a placeholder "
"for a shape and a dtype.\n"
"It doesn't have any actual numerical value.\n"
"You cannot convert it to a NumPy array."
) from e

def numpy(self):
return self.__array__()


def ov_to_keras_type(ov_type):
for _keras_type, _ov_type in OPENVINO_DTYPES.items():
Expand Down Expand Up @@ -625,6 +647,8 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
dtype = standardize_dtype(type(x))
ov_type = OPENVINO_DTYPES[dtype]
return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0), x)
elif isinstance(x, ov.Output):
return OpenVINOKerasTensor(x)
if isinstance(x, Variable):
x = x.value
if dtype and dtype != x.dtype:
Expand Down Expand Up @@ -670,8 +694,10 @@ def convert_to_numpy(x):
ov_model = Model(results=[ov_result], parameters=[])
ov_compiled_model = compile_model(ov_model, get_device())
result = ov_compiled_model({})[0]
except:
raise "`convert_to_numpy` cannot convert to numpy"
except Exception as inner_exception:
raise RuntimeError(
"`convert_to_numpy` failed to convert the tensor."
) from inner_exception
return result


Expand All @@ -688,6 +714,7 @@ def shape(x):


def cast(x, dtype):
dtype = standardize_dtype(dtype)
ov_type = OPENVINO_DTYPES[dtype]
x = get_ov_output(x)
return OpenVINOKerasTensor(ov_opset.convert(x, ov_type).output(0))
Expand Down Expand Up @@ -815,10 +842,152 @@ def prepare_slice_index(val):


def slice_update(inputs, start_indices, updates):
raise NotImplementedError(
"`slice_update` is not supported with openvino backend"
inputs = get_ov_output(inputs)
updates_tensor = get_ov_output(updates)

if isinstance(start_indices, (list, np.ndarray)):
start_indices = tuple(start_indices)
assert isinstance(start_indices, tuple), (
"`slice_update` is not supported by openvino backend"
" for `start_indices` of type {}".format(type(start_indices))
)

zero_scalar = ov_opset.constant(0, Type.i32)
one_scalar = ov_opset.constant(1, Type.i32)
zero_tensor = ov_opset.constant([0], Type.i32)
one_tensor = ov_opset.constant([1], Type.i32)

processed_start_indices = []
for idx in start_indices:
val = get_ov_output(idx)
if not val.get_element_type().is_integral():
raise ValueError("`slice_update` requires integral start_indices")
if val.get_element_type() != Type.i32:
val = ov_opset.convert(val, Type.i32).output(0)
if val.get_partial_shape().rank.get_length() == 0:
val = ov_opset.unsqueeze(val, zero_scalar).output(0)
processed_start_indices.append(val)

updates_shape = ov_opset.shape_of(updates_tensor, Type.i32).output(0)
rank = updates_tensor.get_partial_shape().rank.get_length()

# Compute the total number of elements in the updates tensor.
# Example:
# if updates.shape = [2, 3], total_elements = 6.
total_elements = ov_opset.reduce_prod(
updates_shape, zero_tensor, keep_dims=False
).output(0)

# Generate a flat range [0, 1, ..., total_elements-1].
# This will be used to enumerate all positions in the updates tensor.
flat_indices = ov_opset.range(
zero_scalar, total_elements, one_scalar, output_type=Type.i32
).output(0)

dim_sizes = []
strides = []

# For each dimension, compute its size and the stride.
# (number of elements to skip to move to the next index in this dimension).
# Example:
# for shape [2, 3], strides = [3, 1].
for dim in range(rank):
dim_size = ov_opset.gather(
updates_shape, ov_opset.constant([dim], Type.i32), zero_scalar
).output(0)
dim_size_scalar = ov_opset.squeeze(dim_size, zero_tensor).output(0)
dim_sizes.append(dim_size_scalar)

# Strides to convert a flat index into a multi-dimensional index.
# This allows us to map each element in the flattened updates tensor
# to its correct N-dimensional position, so we can compute the absolute
# index in the input tensor for the scatter update.
# Stride for a dimension is the product of all dimensions after it.
# For the last dimension, stride is 1.
# Example:
# For a 3D tensor with shape [2, 3, 4]:
# - stride for dim=0 (first axis) is 3*4=12
# (to move to the next "block" along axis 0)
# - stride for dim=1 is 4 (to move to the next row along axis 1)
# - stride for dim=2 is 1 (to move to the next element along axis 2)
# This is equivalent to how numpy flattens multi-dimensional arrays.
if dim < rank - 1:
remaining_dims = ov_opset.slice(
updates_shape,
ov_opset.constant([dim + 1], Type.i32),
ov_opset.constant([rank], Type.i32),
one_tensor,
zero_tensor,
).output(0)
stride = ov_opset.reduce_prod(
remaining_dims, zero_tensor, keep_dims=False
).output(0)
else:
stride = one_scalar
strides.append(stride)

coord_tensors = []
# For each dimension, compute the coordinate for every flat index.
# Example:
# for shape [2, 3], flat index 4 -> coordinates [1, 1] (row 1, col 1).
for dim in range(rank):
coords = ov_opset.mod(
ov_opset.divide(flat_indices, strides[dim]).output(0),
dim_sizes[dim],
).output(0)
coord_tensors.append(coords)

coord_tensors_unsqueezed = []
for coord in coord_tensors:
# Unsqueeze to make each coordinate a column vector for concatenation.
coord_unsqueezed = ov_opset.unsqueeze(coord, one_tensor).output(0)
coord_tensors_unsqueezed.append(coord_unsqueezed)

# Concatenate all coordinate columns to form [total_elements, rank] matrix.
# Each row is a multi-dimensional index into the updates tensor.
# Example:
# for shape [2, 3], row 4 = [1, 1].
indices_matrix = ov_opset.concat(coord_tensors_unsqueezed, axis=1).output(0)

# Broadcast start indices to match the number of updates.
# Example:
# start_indices = (2, 3), indices_matrix = [[0,0],[0,1],...],
# start_broadcast = [[2,3],[2,3],...]
start_tensor = ov_opset.concat(processed_start_indices, axis=0).output(0)
start_reshaped = ov_opset.reshape(
start_tensor, ov_opset.constant([1, rank], Type.i32), special_zero=False
).output(0)

broadcast_shape = ov_opset.concat(
[
ov_opset.unsqueeze(total_elements, zero_tensor).output(0),
one_tensor,
],
axis=0,
).output(0)

start_broadcast = ov_opset.tile(start_reshaped, broadcast_shape).output(0)

# Add the broadcasted start indices to the relative indices
# to get absolute indices in the input tensor.
# Example:
# if start=(2,3), update index [1,1] -> absolute index [3,4].
absolute_indices = ov_opset.add(indices_matrix, start_broadcast).output(0)

# Flatten the updates tensor to match the flat indices.
updates_flat = ov_opset.reshape(
updates_tensor,
ov_opset.unsqueeze(total_elements, zero_tensor).output(0),
special_zero=False,
).output(0)

# Perform the scatter update: for each absolute index,
# set the corresponding value from updates_flat.
result = ov_opset.scatter_nd_update(
inputs, absolute_indices, updates_flat
).output(0)
return OpenVINOKerasTensor(result)


def while_loop(
cond,
Expand Down
65 changes: 55 additions & 10 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
NumPyTestRot90
NumpyArrayCreateOpsCorrectnessTest::test_eye
NumpyArrayCreateOpsCorrectnessTest::test_tri
NumpyDtypeTest::test_absolute_bool
NumpyDtypeTest::test_add_
NumpyDtypeTest::test_all
Expand Down Expand Up @@ -48,7 +47,6 @@ NumpyDtypeTest::test_multiply
NumpyDtypeTest::test_power
NumpyDtypeTest::test_prod
NumpyDtypeTest::test_quantile
NumpyDtypeTest::test_repeat
NumpyDtypeTest::test_roll
NumpyDtypeTest::test_round
NumpyDtypeTest::test_searchsorted
Expand All @@ -62,7 +60,6 @@ NumpyDtypeTest::test_swapaxes
NumpyDtypeTest::test_tensordot_
NumpyDtypeTest::test_tile
NumpyDtypeTest::test_trace
NumpyDtypeTest::test_tri
NumpyDtypeTest::test_trunc
NumpyDtypeTest::test_unravel
NumpyDtypeTest::test_var
Expand Down Expand Up @@ -110,7 +107,6 @@ NumpyOneInputOpsCorrectnessTest::test_pad_uint8_constant_2
NumpyOneInputOpsCorrectnessTest::test_pad_int32_constant_2
NumpyOneInputOpsCorrectnessTest::test_prod
NumpyOneInputOpsCorrectnessTest::test_real
NumpyOneInputOpsCorrectnessTest::test_repeat
NumpyOneInputOpsCorrectnessTest::test_reshape
NumpyOneInputOpsCorrectnessTest::test_roll
NumpyOneInputOpsCorrectnessTest::test_round
Expand All @@ -128,9 +124,6 @@ NumpyOneInputOpsCorrectnessTest::test_swapaxes
NumpyOneInputOpsCorrectnessTest::test_tile
NumpyOneInputOpsCorrectnessTest::test_trace
NumpyOneInputOpsCorrectnessTest::test_transpose
NumpyOneInputOpsCorrectnessTest::test_tril
NumpyOneInputOpsCorrectnessTest::test_tril_in_layer
NumpyOneInputOpsCorrectnessTest::test_triu
NumpyOneInputOpsCorrectnessTest::test_trunc
NumpyOneInputOpsCorrectnessTest::test_unravel_index
NumpyOneInputOpsCorrectnessTest::test_var
Expand Down Expand Up @@ -174,17 +167,14 @@ CoreOpsCallsTests::test_map_basic_call
CoreOpsCallsTests::test_scan_basic_call
CoreOpsCallsTests::test_scatter_basic_call
CoreOpsCallsTests::test_scatter_update_basic_call
CoreOpsCallsTests::test_slice_update_basic_call
CoreOpsCallsTests::test_switch_basic_call
CoreOpsCallsTests::test_unstack_basic_functionality
CoreOpsCorrectnessTest::test_associative_scan
CoreOpsCorrectnessTest::test_cond
CoreOpsCorrectnessTest::test_dynamic_slice
CoreOpsCorrectnessTest::test_fori_loop
CoreOpsCorrectnessTest::test_map
CoreOpsCorrectnessTest::test_scan
CoreOpsCorrectnessTest::test_scatter
CoreOpsCorrectnessTest::test_slice_update
CoreOpsCorrectnessTest::test_switch
CoreOpsCorrectnessTest::test_unstack
CoreOpsCorrectnessTest::test_vectorized_map
Expand Down Expand Up @@ -229,6 +219,61 @@ MathOpsCorrectnessTest::test_stft3
MathOpsCorrectnessTest::test_stft4
MathOpsCorrectnessTest::test_stft5
MathOpsCorrectnessTest::test_stft6
RandomCorrectnessTest::test_beta0
RandomCorrectnessTest::test_beta1
RandomCorrectnessTest::test_beta2
RandomCorrectnessTest::test_binomial0
RandomCorrectnessTest::test_binomial1
RandomCorrectnessTest::test_binomial2
RandomCorrectnessTest::test_dropout
RandomCorrectnessTest::test_dropout_noise_shape
RandomCorrectnessTest::test_gamma0
RandomCorrectnessTest::test_gamma1
RandomCorrectnessTest::test_gamma2
RandomCorrectnessTest::test_randint0
RandomCorrectnessTest::test_randint1
RandomCorrectnessTest::test_randint2
RandomCorrectnessTest::test_randint3
RandomCorrectnessTest::test_randint4
RandomCorrectnessTest::test_shuffle
RandomCorrectnessTest::test_truncated_normal0
RandomCorrectnessTest::test_truncated_normal1
RandomCorrectnessTest::test_truncated_normal2
RandomCorrectnessTest::test_truncated_normal3
RandomCorrectnessTest::test_truncated_normal4
RandomCorrectnessTest::test_truncated_normal5
RandomCorrectnessTest::test_uniform0
RandomCorrectnessTest::test_uniform1
RandomCorrectnessTest::test_uniform2
RandomCorrectnessTest::test_uniform3
RandomCorrectnessTest::test_uniform4
RandomBehaviorTest::test_beta_tf_data_compatibility
RandomDTypeTest::test_beta_bfloat16
RandomDTypeTest::test_beta_float16
RandomDTypeTest::test_beta_float32
RandomDTypeTest::test_beta_float64
RandomDTypeTest::test_binomial_bfloat16
RandomDTypeTest::test_binomial_float16
RandomDTypeTest::test_binomial_float32
RandomDTypeTest::test_binomial_float64
RandomDTypeTest::test_dropout_bfloat16
RandomDTypeTest::test_dropout_float16
RandomDTypeTest::test_dropout_float32
RandomDTypeTest::test_dropout_float64
RandomDTypeTest::test_gamma_bfloat16
RandomDTypeTest::test_gamma_float16
RandomDTypeTest::test_gamma_float32
RandomDTypeTest::test_gamma_float64
RandomDTypeTest::test_normal_bfloat16
RandomDTypeTest::test_randint_int16
RandomDTypeTest::test_randint_int32
RandomDTypeTest::test_randint_int64
RandomDTypeTest::test_randint_int8
RandomDTypeTest::test_randint_uint16
RandomDTypeTest::test_randint_uint32
RandomDTypeTest::test_randint_uint8
RandomDTypeTest::test_truncated_normal_bfloat16
RandomDTypeTest::test_uniform_bfloat16
SegmentSumTest::test_segment_sum_call
SegmentMaxTest::test_segment_max_call
TestMathErrors::test_invalid_fft_length
Expand Down
2 changes: 1 addition & 1 deletion keras/src/backend/openvino/excluded_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ keras/src/ops/linalg_test.py
keras/src/ops/nn_test.py
keras/src/optimizers
keras/src/quantizers
keras/src/random
keras/src/random/seed_generator_test.py
keras/src/regularizers
keras/src/saving
keras/src/trainers
Expand Down
Loading