Skip to content
185 changes: 175 additions & 10 deletions keras/src/backend/openvino/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def count_unsqueeze_before(dim):
if not (0 <= actual_dim < rank):
raise IndexError(
f"Index {index} is out of bounds for "
"axis {dim} with rank {rank}"
f"axis {dim} with rank {rank}"
)
length = ov_opset.gather(
partial_shape,
Expand Down Expand Up @@ -403,7 +403,7 @@ def count_unsqueeze_before(dim):
if index_type == Type.boolean or not index_type.is_integral():
raise ValueError(
"OpenVINO backend does not "
"support {index_type} indexing"
f"support {index_type} indexing"
)
axes.append(dim)
if len(index_shape) > 1:
Expand Down Expand Up @@ -654,13 +654,20 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
if dtype and dtype != x.dtype:
x = cast(x, dtype)
return x
if not is_tensor(x) and standardize_dtype(dtype) == "bfloat16":
return ov.Tensor(np.asarray(x).astype(dtype))
if dtype is None:
dtype = result_type(
*[getattr(item, "dtype", type(item)) for item in tree.flatten(x)]
original_type = type(x)
try:
if dtype is None:
dtype = getattr(x, "dtype", original_type)
ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)]
else:
ov_type = OPENVINO_DTYPES[dtype]
x = np.array(x)
return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0))
except Exception as e:
raise TypeError(
f"Cannot convert object of type {original_type} "
f"to OpenVINOKerasTensor: {e}"
)
return ov.Tensor(np.array(x, dtype=dtype))


def convert_to_numpy(x):
Expand Down Expand Up @@ -842,10 +849,168 @@ def prepare_slice_index(val):


def slice_update(inputs, start_indices, updates):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The slice_update function is lengthy and complex. Refactoring it into smaller, more focused helper functions could improve readability and maintainability. Consider helpers for processing start_indices, generating the multi-dimensional indices matrix, and calculating absolute indices.

raise NotImplementedError(
"`slice_update` is not supported with openvino backend"
inputs = get_ov_output(inputs)
updates_tensor = get_ov_output(updates)

if isinstance(start_indices, (list, np.ndarray)):
start_indices = tuple(start_indices)
assert isinstance(start_indices, tuple), (
"`slice_update` is not supported by openvino backend"
" for `start_indices` of type {}".format(type(start_indices))
)

zero_scalar = ov_opset.constant(0, Type.i32)
one_scalar = ov_opset.constant(1, Type.i32)
zero_tensor = ov_opset.constant([0], Type.i32)
one_tensor = ov_opset.constant([1], Type.i32)

processed_start_indices = []
for idx in start_indices:
val = get_ov_output(idx)
if not val.get_element_type().is_integral():
raise ValueError("`slice_update` requires integral start_indices")
if val.get_element_type() != Type.i32:
val = ov_opset.convert(val, Type.i32).output(0)
if val.get_partial_shape().rank.get_length() == 0:
val = ov_opset.unsqueeze(val, zero_scalar).output(0)
processed_start_indices.append(val)

updates_shape = ov_opset.shape_of(updates_tensor, Type.i32).output(0)
rank = updates_tensor.get_partial_shape().rank.get_length()
if rank == 0:
# Handle scalar update
start_tensor = ov_opset.concat(processed_start_indices, axis=0).output(
0
)
# For scatter_nd_update,
# indices should be of shape [num_updates, rank_of_inputs]
# and updates should be of shape [num_updates]. Here num_updates is 1.
absolute_indices = ov_opset.unsqueeze(start_tensor, zero_scalar).output(
0
)
updates_flat = ov_opset.unsqueeze(updates_tensor, zero_scalar).output(0)
result = ov_opset.scatter_nd_update(
inputs, absolute_indices, updates_flat
).output(0)
return OpenVINOKerasTensor(result)

# Compute the total number of elements in the updates tensor.
# Example:
# if updates.shape = [2, 3], total_elements = 6.
total_elements = ov_opset.reduce_prod(
updates_shape, zero_tensor, keep_dims=False
).output(0)

# Generate a flat range [0, 1, ..., total_elements-1].
# This will be used to enumerate all positions in the updates tensor.
flat_indices = ov_opset.range(
zero_scalar, total_elements, one_scalar, output_type=Type.i32
).output(0)

dim_sizes = []
strides = []

# For each dimension, compute its size and the stride.
# (number of elements to skip to move to the next index in this dimension).
# Example:
# for shape [2, 3], strides = [3, 1].
for dim in range(rank):
dim_size = ov_opset.gather(
updates_shape, ov_opset.constant([dim], Type.i32), zero_scalar
).output(0)
dim_size_scalar = ov_opset.squeeze(dim_size, zero_tensor).output(0)
dim_sizes.append(dim_size_scalar)

# Strides to convert a flat index into a multi-dimensional index.
# This allows us to map each element in the flattened updates tensor
# to its correct N-dimensional position, so we can compute the absolute
# index in the input tensor for the scatter update.
# Stride for a dimension is the product of all dimensions after it.
# For the last dimension, stride is 1.
# Example:
# For a 3D tensor with shape [2, 3, 4]:
# - stride for dim=0 (first axis) is 3*4=12
# (to move to the next "block" along axis 0)
# - stride for dim=1 is 4 (to move to the next row along axis 1)
# - stride for dim=2 is 1 (to move to the next element along axis 2)
# This is equivalent to how numpy flattens multi-dimensional arrays.
if dim < rank - 1:
remaining_dims = ov_opset.slice(
updates_shape,
ov_opset.constant([dim + 1], Type.i32),
ov_opset.constant([rank], Type.i32),
one_tensor,
zero_tensor,
).output(0)
stride = ov_opset.reduce_prod(
remaining_dims, zero_tensor, keep_dims=False
).output(0)
else:
stride = one_scalar
strides.append(stride)

coord_tensors = []
# For each dimension, compute the coordinate for every flat index.
# Example:
# for shape [2, 3], flat index 4 -> coordinates [1, 1] (row 1, col 1).
for dim in range(rank):
coords = ov_opset.mod(
ov_opset.divide(flat_indices, strides[dim]).output(0),
dim_sizes[dim],
).output(0)
coord_tensors.append(coords)

coord_tensors_unsqueezed = []
for coord in coord_tensors:
# Unsqueeze to make each coordinate a column vector for concatenation.
coord_unsqueezed = ov_opset.unsqueeze(coord, one_tensor).output(0)
coord_tensors_unsqueezed.append(coord_unsqueezed)

# Concatenate all coordinate columns to form [total_elements, rank] matrix.
# Each row is a multi-dimensional index into the updates tensor.
# Example:
# for shape [2, 3], row 4 = [1, 1].
indices_matrix = ov_opset.concat(coord_tensors_unsqueezed, axis=1).output(0)

# Broadcast start indices to match the number of updates.
# Example:
# start_indices = (2, 3), indices_matrix = [[0,0],[0,1],...],
# start_broadcast = [[2,3],[2,3],...]
start_tensor = ov_opset.concat(processed_start_indices, axis=0).output(0)
start_reshaped = ov_opset.reshape(
start_tensor, ov_opset.constant([1, rank], Type.i32), special_zero=False
).output(0)

broadcast_shape = ov_opset.concat(
[
ov_opset.unsqueeze(total_elements, zero_tensor).output(0),
one_tensor,
],
axis=0,
).output(0)

start_broadcast = ov_opset.tile(start_reshaped, broadcast_shape).output(0)

# Add the broadcasted start indices to the relative indices
# to get absolute indices in the input tensor.
# Example:
# if start=(2,3), update index [1,1] -> absolute index [3,4].
absolute_indices = ov_opset.add(indices_matrix, start_broadcast).output(0)

# Flatten the updates tensor to match the flat indices.
updates_flat = ov_opset.reshape(
updates_tensor,
ov_opset.unsqueeze(total_elements, zero_tensor).output(0),
special_zero=False,
).output(0)

# Perform the scatter update: for each absolute index,
# set the corresponding value from updates_flat.
result = ov_opset.scatter_nd_update(
inputs, absolute_indices, updates_flat
).output(0)
return OpenVINOKerasTensor(result)


def while_loop(
cond,
Expand Down
3 changes: 0 additions & 3 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,14 @@ CoreOpsCallsTests::test_map_basic_call
CoreOpsCallsTests::test_scan_basic_call
CoreOpsCallsTests::test_scatter_basic_call
CoreOpsCallsTests::test_scatter_update_basic_call
CoreOpsCallsTests::test_slice_update_basic_call
CoreOpsCallsTests::test_switch_basic_call
CoreOpsCallsTests::test_unstack_basic_functionality
CoreOpsCorrectnessTest::test_associative_scan
CoreOpsCorrectnessTest::test_cond
CoreOpsCorrectnessTest::test_dynamic_slice
CoreOpsCorrectnessTest::test_fori_loop
CoreOpsCorrectnessTest::test_map
CoreOpsCorrectnessTest::test_scan
CoreOpsCorrectnessTest::test_scatter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The CoreOpsCallsTests::test_slice_update_basic_call test is excluded. This exclusion indicates that slice_update may not be fully supported or may have issues. Verify the implementation of slice_update.

CoreOpsCorrectnessTest::test_slice_update
CoreOpsCorrectnessTest::test_switch
CoreOpsCorrectnessTest::test_unstack
CoreOpsCorrectnessTest::test_vectorized_map
Expand Down
45 changes: 45 additions & 0 deletions testing_files/gemma_test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
============================= test session starts ==============================
platform linux -- Python 3.12.3, pytest-8.4.0, pluggy-1.6.0 -- /home/mohamed-ashraf/Desktop/GSoC2025/env/bin/python
cachedir: .pytest_cache
rootdir: /home/mohamed-ashraf/Desktop/GSoC2025/keras-hub
configfile: pytest.ini
plugins: cov-6.1.1
collecting ... collected 15 items

keras_hub/src/models/gemma/gemma_causal_lm_test.py::TestCase::test_session SKIPPED [ 6%]
keras_hub/src/models/gemma/gemma_causal_lm_test.py::GemmaCausalLMTest::test_all_presets SKIPPED [ 13%]
keras_hub/src/models/gemma/gemma_causal_lm_test.py::GemmaCausalLMTest::test_cache_correctness PASSED [ 20%]
keras_hub/src/models/gemma/gemma_causal_lm_test.py::GemmaCausalLMTest::test_causal_lm_basics SKIPPED [ 26%]
keras_hub/src/models/gemma/gemma_causal_lm_test.py::GemmaCausalLMTest::test_early_stopping PASSED [ 33%]
keras_hub/src/models/gemma/gemma_causal_lm_test.py::GemmaCausalLMTest::test_flash_attention_call SKIPPED [ 40%]
keras_hub/src/models/gemma/gemma_causal_lm_test.py::GemmaCausalLMTest::test_generate PASSED [ 46%]
keras_hub/src/models/gemma/gemma_causal_lm_test.py::GemmaCausalLMTest::test_generate_compilation PASSED [ 53%]
keras_hub/src/models/gemma/gemma_causal_lm_test.py::GemmaCausalLMTest::test_generate_with_bfloat16 PASSED [ 60%]
keras_hub/src/models/gemma/gemma_causal_lm_test.py::GemmaCausalLMTest::test_multitoken_stopping PASSED [ 66%]
keras_hub/src/models/gemma/gemma_causal_lm_test.py::GemmaCausalLMTest::test_saved_model SKIPPED [ 73%]
keras_hub/src/models/gemma/gemma_causal_lm_test.py::GemmaCausalLMTest::test_score_layer_intercept_fn_exfiltration PASSED [ 80%]
keras_hub/src/models/gemma/gemma_causal_lm_test.py::GemmaCausalLMTest::test_score_logits PASSED [ 86%]
keras_hub/src/models/gemma/gemma_causal_lm_test.py::GemmaCausalLMTest::test_score_loss SKIPPED [ 93%]
keras_hub/src/models/gemma/gemma_causal_lm_test.py::GemmaCausalLMTest::test_session PASSED [100%]

=============================== warnings summary ===============================
../../../../../usr/lib/python3.12/multiprocessing/popen_fork.py:66
../../../../../usr/lib/python3.12/multiprocessing/popen_fork.py:66
/usr/lib/python3.12/multiprocessing/popen_fork.py:66: DeprecationWarning: This process (pid=8846) is multi-threaded, use of fork() may lead to deadlocks in the child.
self.pid = os.fork()

../openvino/build_python3.12/site-packages/python/openvino/runtime/__init__.py:10
/home/mohamed-ashraf/Desktop/GSoC2025/openvino/build_python3.12/site-packages/python/openvino/runtime/__init__.py:10: DeprecationWarning: The `openvino.runtime` module is deprecated and will be removed in the 2026.0 release. Please replace `openvino.runtime` with `openvino`.
warnings.warn(

../env/lib/python3.12/site-packages/_pytest/config/__init__.py:1474
/home/mohamed-ashraf/Desktop/GSoC2025/env/lib/python3.12/site-packages/_pytest/config/__init__.py:1474: PytestConfigWarning: Unknown config option: env

self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

keras_hub/src/models/gemma/gemma_causal_lm_test.py::GemmaCausalLMTest::test_session
/usr/lib/python3.12/unittest/case.py:690: DeprecationWarning: It is deprecated to return a value that is not None from a test case (<bound method TensorFlowTestCase.test_session of <keras_hub.src.models.gemma.gemma_causal_lm_test.GemmaCausalLMTest testMethod=test_session>>)
return self.run(*args, **kwds)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================== 9 passed, 6 skipped, 5 warnings in 18.03s ===================
41 changes: 41 additions & 0 deletions testing_files/gpt2_test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
============================= test session starts ==============================
platform linux -- Python 3.12.3, pytest-8.4.0, pluggy-1.6.0 -- /home/mohamed-ashraf/Desktop/GSoC2025/env/bin/python
cachedir: .pytest_cache
rootdir: /home/mohamed-ashraf/Desktop/GSoC2025/keras-hub
configfile: pytest.ini
plugins: cov-6.1.1
collecting ... collected 11 items

keras_hub/src/models/gpt2/gpt2_causal_lm_test.py::TestCase::test_session SKIPPED [ 9%]
keras_hub/src/models/gpt2/gpt2_causal_lm_test.py::GPT2CausalLMTest::test_all_presets SKIPPED [ 18%]
keras_hub/src/models/gpt2/gpt2_causal_lm_test.py::GPT2CausalLMTest::test_causal_lm_basics SKIPPED [ 27%]
keras_hub/src/models/gpt2/gpt2_causal_lm_test.py::GPT2CausalLMTest::test_early_stopping PASSED [ 36%]
keras_hub/src/models/gpt2/gpt2_causal_lm_test.py::GPT2CausalLMTest::test_generate PASSED [ 45%]
keras_hub/src/models/gpt2/gpt2_causal_lm_test.py::GPT2CausalLMTest::test_generate_compilation PASSED [ 54%]
keras_hub/src/models/gpt2/gpt2_causal_lm_test.py::GPT2CausalLMTest::test_saved_model SKIPPED [ 63%]
keras_hub/src/models/gpt2/gpt2_causal_lm_test.py::GPT2CausalLMTest::test_score_layer_intercept_fn_exfiltration PASSED [ 72%]
keras_hub/src/models/gpt2/gpt2_causal_lm_test.py::GPT2CausalLMTest::test_score_logits PASSED [ 81%]
keras_hub/src/models/gpt2/gpt2_causal_lm_test.py::GPT2CausalLMTest::test_score_loss SKIPPED [ 90%]
keras_hub/src/models/gpt2/gpt2_causal_lm_test.py::GPT2CausalLMTest::test_session PASSED [100%]

=============================== warnings summary ===============================
../../../../../usr/lib/python3.12/multiprocessing/popen_fork.py:66
../../../../../usr/lib/python3.12/multiprocessing/popen_fork.py:66
/usr/lib/python3.12/multiprocessing/popen_fork.py:66: DeprecationWarning: This process (pid=10106) is multi-threaded, use of fork() may lead to deadlocks in the child.
self.pid = os.fork()

../openvino/build_python3.12/site-packages/python/openvino/runtime/__init__.py:10
/home/mohamed-ashraf/Desktop/GSoC2025/openvino/build_python3.12/site-packages/python/openvino/runtime/__init__.py:10: DeprecationWarning: The `openvino.runtime` module is deprecated and will be removed in the 2026.0 release. Please replace `openvino.runtime` with `openvino`.
warnings.warn(

../env/lib/python3.12/site-packages/_pytest/config/__init__.py:1474
/home/mohamed-ashraf/Desktop/GSoC2025/env/lib/python3.12/site-packages/_pytest/config/__init__.py:1474: PytestConfigWarning: Unknown config option: env

self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

keras_hub/src/models/gpt2/gpt2_causal_lm_test.py::GPT2CausalLMTest::test_session
/usr/lib/python3.12/unittest/case.py:690: DeprecationWarning: It is deprecated to return a value that is not None from a test case (<bound method TensorFlowTestCase.test_session of <keras_hub.src.models.gpt2.gpt2_causal_lm_test.GPT2CausalLMTest testMethod=test_session>>)
return self.run(*args, **kwds)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================== 6 passed, 5 skipped, 5 warnings in 13.25s ===================
Loading